-
-
Notifications
You must be signed in to change notification settings - Fork 23
Expand file tree
/
Copy pathgemm_cython.pyx
More file actions
32 lines (27 loc) · 677 Bytes
/
gemm_cython.pyx
File metadata and controls
32 lines (27 loc) · 677 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Authors: Jake Vanderplas, Olivier Grisel
# License: MIT
import numpy as np
cimport cython
from libc.math cimport sqrt
@cython.boundscheck(False)
@cython.wraparound(False)
def gemm_cython_for_loops(
double alpha,
double[:, ::1] A,
double[:, ::1] B,
double beta,
double[:, ::1] C,
):
cdef int M = C.shape[0]
cdef int N = C.shape[1]
cdef int K = A.shape[1]
cdef double tmp, d
for i in range(M):
for j in range(N):
d = 0.0
for k in range(K):
d += A[i, k] * B[k, j]
C[i, j] = alpha * d + beta * C[i, j]
benchmarks = (
gemm_cython_for_loops,
)