Python – How do I make the cython function accept float or double array input?

How do I make the cython function accept float or double array input?… here is a solution to the problem.

How do I make the cython function accept float or double array input?

Suppose I have the following (MCVE…) cython function

cimport cython

from scipy.linalg.cython_blas cimport dnrm2

cpdef double func(int n, double[:] x):
   cdef int inc = 1
   return dnrm2(&n, &x[0], &inc)

Then, I can’t call it on the np.float32 array x.

How do I get func to accept double[:] or float[:] and call dnrm2 or snrm2 or? My only solution at the moment is to have two functions, which creates a lot of duplicate code.

Solution

You can use the fusion type. Note that the following code doesn’t compile on my system because ddot and sdot obviously require 5 parameters:

# cython: infer_types=True
cimport cython

from scipy.linalg.cython_blas cimport ddot, sdot

ctypedef fused anyfloat:
   double
   float

cpdef anyfloat func(int n, anyfloat[:] x):
   cdef int inc = 1
   if anyfloat is double:
      return ddot(&n, &x[0], &inc)
   else:
      return sdot(&n, &x[0], &inc)

Related Problems and Solutions