Go to the documentation of this file.00001
00002
00003
00004
00005
00006 #include <boost/numpy.hpp>
00007 #include <boost/mpl/vector.hpp>
00008 #include <boost/mpl/vector_c.hpp>
00009
00010 namespace p = boost::python;
00011 namespace np = boost::numpy;
00012
00013 struct ArrayFiller
00014 {
00015
00016 typedef boost::mpl::vector< short, int, float, std::complex<double> > TypeSequence;
00017 typedef boost::mpl::vector_c< int, 1, 2 > DimSequence;
00018
00019 explicit ArrayFiller(np::ndarray const & arg) : argument(arg) {}
00020
00021 template <typename T, int N>
00022 void apply() const
00023 {
00024 if (N == 1)
00025 {
00026 char * p = argument.get_data();
00027 int stride = argument.strides(0);
00028 int size = argument.shape(0);
00029 for (int n = 0; n != size; ++n, p += stride)
00030 *reinterpret_cast<T*>(p) = static_cast<T>(n);
00031 }
00032 else
00033 {
00034 char * row_p = argument.get_data();
00035 int row_stride = argument.strides(0);
00036 int col_stride = argument.strides(1);
00037 int rows = argument.shape(0);
00038 int cols = argument.shape(1);
00039 int i = 0;
00040 for (int n = 0; n != rows; ++n, row_p += row_stride)
00041 {
00042 char * col_p = row_p;
00043 for (int m = 0; m != cols; ++i, ++m, col_p += col_stride)
00044 *reinterpret_cast<T*>(col_p) = static_cast<T>(i);
00045 }
00046 }
00047 }
00048
00049 np::ndarray argument;
00050 };
00051
00052 void fill(np::ndarray const & arg)
00053 {
00054 ArrayFiller filler(arg);
00055 np::invoke_matching_array<ArrayFiller::TypeSequence, ArrayFiller::DimSequence >(arg, filler);
00056 }
00057
00058 BOOST_PYTHON_MODULE(templates_mod)
00059 {
00060 np::initialize();
00061 p::def("fill", fill);
00062 }