Skip to content

Commit 04c9a8c

Browse files
committed
introduce explicit transpose function for dense matrix
1 parent 85f456a commit 04c9a8c

3 files changed

Lines changed: 17 additions & 16 deletions

File tree

include/utils/dense_matrix.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,6 @@ Matrix *new_dense_matrix(int m, int n, const double *data);
1717
/* Transpose helper */
1818
Matrix *dense_matrix_trans(const Dense_Matrix *self);
1919

20+
void A_transpose(double *AT, const double *A, int m, int n);
21+
2022
#endif /* DENSE_MATRIX_H */

src/atoms/affine/left_matmul.c

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -193,18 +193,12 @@ static void refresh_dense_left(left_matmul_expr *lnode)
193193
Dense_Matrix *dm_AT = (Dense_Matrix *) lnode->AT;
194194
int m = dm_A->base.m;
195195
int n = dm_A->base.n;
196-
const double *vals = lnode->param_source->value;
197196

198-
/* Column-major A is row-major AT; copy directly into AT */
199-
memcpy(dm_AT->x, vals, m * n * sizeof(double));
200-
/* Transpose AT to get row-major A */
201-
for (int i = 0; i < m; i++)
202-
{
203-
for (int j = 0; j < n; j++)
204-
{
205-
dm_A->x[i * n + j] = dm_AT->x[j * m + i];
206-
}
207-
}
197+
/* The parameter represents the A in left_matmul_dense(A, x) in column-major.
198+
In this diffengine, we store A in row-major order. Hence, param->vals
199+
actually corresponds to the transpose of A, and we transpose AT to get A. */
200+
memcpy(dm_AT->x, lnode->param_source->value, m * n * sizeof(double));
201+
A_transpose(dm_A->x, dm_AT->x, m, n);
208202
}
209203

210204
expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A)

src/utils/dense_matrix.c

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,20 @@ Matrix *dense_matrix_trans(const Dense_Matrix *A)
7878
int n = A->base.n;
7979
double *AT_x = (double *) SP_MALLOC(m * n * sizeof(double));
8080

81+
A_transpose(AT_x, A->x, m, n);
82+
83+
Matrix *result = new_dense_matrix(n, m, AT_x);
84+
free(AT_x);
85+
return result;
86+
}
87+
88+
void A_transpose(double *AT, const double *A, int m, int n)
89+
{
8190
for (int i = 0; i < m; i++)
8291
{
8392
for (int j = 0; j < n; j++)
8493
{
85-
AT_x[j * m + i] = A->x[i * n + j];
94+
AT[j * m + i] = A[i * n + j];
8695
}
8796
}
88-
89-
Matrix *result = new_dense_matrix(n, m, AT_x);
90-
free(AT_x);
91-
return result;
9297
}

0 commit comments

Comments
 (0)