Commit 635e475b authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

Properly propagate ndarray strides into multi_array_ref in numpy_bind.hh

This fixes problems with transposed arrays, or arrays with arbitrary strides.
parent 7d0eda91
......@@ -138,6 +138,22 @@ boost::python::object wrap_multi_array_not_owned(boost::multi_array<ValueType,Di
// get multi_array_ref from numpy ndarrays
template <class ValueType, size_t dim>
class numpy_multi_array: public boost::multi_array_ref<ValueType,dim>
{
typedef boost::multi_array_ref<ValueType,dim> base_t;
public:
template <class ExtentList, class StrideList>
explicit numpy_multi_array(typename base_t::element* data,
const ExtentList& sizes,
const StrideList& strides)
:base_t(data, sizes)
{
for (int i = 0; i < dim; ++i)
base_t::stride_list_[i] = strides[i];
}
};
struct invalid_numpy_conversion:
public std::exception
{
......@@ -169,14 +185,23 @@ boost::multi_array_ref<ValueType,dim> get_array(boost::python::object points)
throw invalid_numpy_conversion(error);
}
vector<size_t> shape(PyArray_NDIM(pa));
for (int i = 0; i < PyArray_NDIM(pa); ++i)
vector<size_t> shape(dim);
for (int i = 0; i < dim; ++i)
shape[i] = PyArray_DIMS(pa)[i];
if ((PyArray_FLAGS(pa) ^ NPY_ARRAY_C_CONTIGUOUS) != 0)
return boost::multi_array_ref<ValueType,dim>((ValueType *) PyArray_DATA(pa), shape);
else
return boost::multi_array_ref<ValueType,dim>((ValueType *) PyArray_DATA(pa), shape,
boost::fortran_storage_order());
vector<size_t> stride(dim);
for (size_t i = 0; i < dim; ++i)
stride[i] = PyArray_STRIDE(pa, i) / sizeof(ValueType);
return numpy_multi_array<ValueType,dim>((ValueType *) PyArray_DATA(pa),
shape, stride);
// boost::storage_order_type store;
// if ((PyArray_FLAGS(pa) ^ NPY_ARRAY_C_CONTIGUOUS) != 0)
// store = boost::c_storage_order();
// if (PyArray_FLAGS(pa) ^ NPY_ARRAY_F_CONTIGUOUS) != 0)
// store = boost::fortran_storage_order();
}
#endif
......
......@@ -124,6 +124,22 @@ boost::python::object wrap_multi_array_not_owned(boost::multi_array<ValueType,Di
// get multi_array_ref from numpy ndarrays
template <class ValueType, size_t dim>
class numpy_multi_array: public boost::multi_array_ref<ValueType,dim>
{
typedef boost::multi_array_ref<ValueType,dim> base_t;
public:
template <class ExtentList, class StrideList>
explicit numpy_multi_array(typename base_t::element* data,
const ExtentList& sizes,
const StrideList& strides)
:base_t(data, sizes)
{
for (int i = 0; i < dim; ++i)
base_t::stride_list_[i] = strides[i];
}
};
struct invalid_numpy_conversion:
public std::exception
{
......@@ -158,11 +174,13 @@ boost::multi_array_ref<ValueType,dim> get_array(boost::python::object points)
vector<size_t> shape(pa->nd);
for (int i = 0; i < pa->nd; ++i)
shape[i] = pa->dimensions[i];
if ((pa->flags ^ NPY_C_CONTIGUOUS) != 0)
return boost::multi_array_ref<ValueType,dim>((ValueType *) pa->data, shape);
else
return boost::multi_array_ref<ValueType,dim>((ValueType *) pa->data, shape,
boost::fortran_storage_order());
vector<size_t> stride(dim);
for (size_t i = 0; i < dim; ++i)
stride[i] = pa->strides[i] / sizeof(ValueType);
return numpy_multi_array<ValueType,dim>((ValueType *) pa->data,
shape, stride);
}
#endif // NUMPY_BIND_OLD_HH
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment