I'm new to python-C++ mixed programing. I'm trying to build a python module which calls C++ function where Fast Fourier transforms are invoked. So I link the module to FFTW C++ library. The module is successfully built and the test function (in Python) can run without errors. But the testing result indicates the FFT transform doesn't be invoked (construct fftw_plan takes 0 s). Same code if I compile them through g++, the test function (in C++) prints the exactly correct result.
The followling are some snippets in my setup.py file:
sources = ['pyfot/fot.cpp']
INCLUDE_DIR = 'pyfot'
include_dirs = [INCLUDE_DIR]
library_dirs = [INCLUDE_DIR]
link_arg = ['fftw3', 'm']
setup(
name='pyfot',
version=__version__,
author='',
author_email='',
url='',
description='Fast optimal transport'
'with C++ backend.',
ext_modules= Extension(
'fot',
sources=sources,
include_dirs=include_dirs,
library_dirs=library_dirs,
libraries=link_arg,
language='c++',
),
install_requires=['pybind11>=2.3'],
setup_requires=['pybind11>=2.3'],
#cmdclass=cmdclass,
)
The fot.cpp
is as following. The FFT transform is invoked in function compute_l2_ot2d
which is implemented in fot.hpp
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include "fot.hpp"
namespace py = pybind11;
py::array_t<double> l2_ot(py::array_t<double> &mu, py::array_t<double> &nu, double sigma,
int maxIters, bool verbose)
{
py::buffer_info buf1 = mu.request();
py::buffer_info buf2 = nu.request();
if (buf1.ndim != 2 || buf2.ndim != 2)
throw std::runtime_error("Number of dimensions must be two.");
if (buf1.size != buf2.size)
throw std::runtime_error("Density shapes must match.");
int n1 = buf1.shape[0];
int n2 = buf1.shape[1];
printf("Shape n1:%d, n2:%d\n", n1, n2);
double *ptr1 = (double *) buf1.ptr,
*ptr2 = (double *) buf2.ptr;
auto phi = py::array_t<double>({n1, n2});
auto dual = py::array_t<double>({n1, n2});
auto values = py::array_t<double>({maxIters*2, 1});
py::buffer_info buf3 = phi.request();
py::buffer_info buf4 = dual.request();
py::buffer_info buf5 = values.request();
double *ptr3 = (double *) buf3.ptr,
*ptr4 = (double *) buf4.ptr;
double *ptrValues = (double *)buf5.ptr;
compute_l2_ot2d(ptr1, ptr2, ptr3, ptr4, ptrValues, sigma, maxIters, n2, n1, verbose);
return dual;
}
PYBIND11_MODULE(pyfot, m){
m.doc() = "fast optimal transport";
m.def("l2ot", &l2_ot, py::return_value_policy::reference);
}
Can any one help me out with this? By the way, if I wish function l2_ot
returns multiple arrays, e.g. phi, dual, values
, what should I do?
Thank you so much.
Aucun commentaire:
Enregistrer un commentaire