How do I make the cython function accept float or double array input?… here is a solution to the problem.
How do I make the cython function accept float or double array input?
Suppose I have the following (MCVE…) cython function
cimport cython
from scipy.linalg.cython_blas cimport dnrm2
cpdef double func(int n, double[:] x):
cdef int inc = 1
return dnrm2(&n, &x[0], &inc)
Then, I can’t call it on the np.float32
array x
.
How do I get func
to accept double[:] or float[:]
and call dnrm2 or snrm2
or?
My only solution at the moment is to have two functions, which creates a lot of duplicate code.
Solution
You can use the fusion type. Note that the following code doesn’t compile on my system because ddot
and sdot
obviously require 5 parameters:
# cython: infer_types=True
cimport cython
from scipy.linalg.cython_blas cimport ddot, sdot
ctypedef fused anyfloat:
double
float
cpdef anyfloat func(int n, anyfloat[:] x):
cdef int inc = 1
if anyfloat is double:
return ddot(&n, &x[0], &inc)
else:
return sdot(&n, &x[0], &inc)