24 #if defined(__MINGW32__)
29 #include <Eigen/SparseCore>
37 "Eigen matrix support in pybind11 requires Eigen >= 3.2.7");
52 #if EIGEN_VERSION_AT_LEAST(3, 3, 0)
54 template <
typename Scalar,
int Flags,
typename StorageIndex>
58 template <
typename Scalar,
int Flags,
typename StorageIndex>
65 std::is_base_of<Eigen::MapBase<T, Eigen::ReadOnlyAccessors>,
T>>;
67 using is_eigen_mutable_map = std::is_base_of<Eigen::MapBase<T, Eigen::WriteAccessors>,
T>;
69 using is_eigen_dense_plain
83 template <
bool EigenRowMajor>
84 struct EigenConformable {
85 bool conformable =
false;
87 EigenDStride stride{0, 0};
88 bool negativestrides =
false;
91 EigenConformable(
bool fits =
false) : conformable{fits} {}
93 EigenConformable(EigenIndex r, EigenIndex
c, EigenIndex rstride, EigenIndex cstride)
97 stride{EigenRowMajor ? (rstride > 0 ? rstride : 0)
98 : (cstride > 0 ? cstride : 0) ,
99 EigenRowMajor ? (cstride > 0 ? cstride : 0)
100 : (rstride > 0 ? rstride : 0) },
101 negativestrides{rstride < 0 || cstride < 0} {}
103 EigenConformable(EigenIndex r, EigenIndex
c, EigenIndex stride)
104 : EigenConformable(r,
c, r == 1 ?
c * stride : stride,
c == 1 ? r : r * stride) {}
106 template <
typename props>
107 bool stride_compatible()
const {
112 if (negativestrides) {
118 return (props::inner_stride ==
Eigen::Dynamic || props::inner_stride == stride.inner()
119 || (EigenRowMajor ?
cols :
rows) == 1)
120 && (props::outer_stride ==
Eigen::Dynamic || props::outer_stride == stride.outer()
121 || (EigenRowMajor ?
rows :
cols) == 1);
124 operator bool()
const {
return conformable; }
127 template <
typename Type>
128 struct eigen_extract_stride {
131 template <
typename PlainObjectType,
int MapOptions,
typename Str
ideType>
132 struct eigen_extract_stride<
Eigen::Map<PlainObjectType, MapOptions, StrideType>> {
133 using type = StrideType;
135 template <
typename PlainObjectType,
int Options,
typename Str
ideType>
136 struct eigen_extract_stride<
Eigen::Ref<PlainObjectType, Options, StrideType>> {
137 using type = StrideType;
141 template <
typename Type_>
146 static constexpr EigenIndex
rows = Type::RowsAtCompileTime,
cols = Type::ColsAtCompileTime,
147 size = Type::SizeAtCompileTime;
148 static constexpr
bool row_major = Type::IsRowMajor,
150 = Type::IsVectorAtCompileTime,
153 dynamic = !fixed_rows && !fixed_cols;
155 template <EigenIndex i, EigenIndex ifzero>
156 using if_zero = std::integral_constant<EigenIndex, i == 0 ? ifzero : i>;
157 static constexpr EigenIndex inner_stride
159 outer_stride = if_zero < StrideType::OuterStrideAtCompileTime,
163 static constexpr
bool dynamic_stride
165 static constexpr
bool requires_row_major
166 = !dynamic_stride && !vector && (row_major ? inner_stride : outer_stride) == 1;
167 static constexpr
bool requires_col_major
168 = !dynamic_stride && !vector && (row_major ? outer_stride : inner_stride) == 1;
173 static EigenConformable<row_major> conformable(
const array &
a) {
174 const auto dims =
a.ndim();
175 if (dims < 1 || dims > 2) {
181 EigenIndex np_rows =
a.shape(0), np_cols =
a.shape(1),
184 if ((fixed_rows && np_rows !=
rows) || (fixed_cols && np_cols !=
cols)) {
188 return {np_rows, np_cols, np_rstride, np_cstride};
193 const EigenIndex
n =
a.shape(0),
197 if (fixed &&
size !=
n) {
200 return {
rows == 1 ? 1 :
n,
cols == 1 ? 1 :
n, stride};
212 return {1,
n, stride};
214 if (fixed_rows &&
rows !=
n) {
217 return {
n, 1, stride};
220 static constexpr
bool show_writeable
223 static constexpr
bool show_c_contiguous = show_order && requires_row_major;
224 static constexpr
bool show_f_contiguous
225 = !show_c_contiguous && show_order && requires_col_major;
227 static constexpr
auto descriptor
238 const_name<show_writeable>(
", flags.writeable",
"")
239 + const_name<show_c_contiguous>(
", flags.c_contiguous",
"")
240 + const_name<show_f_contiguous>(
", flags.f_contiguous",
"") +
const_name(
"]");
245 template <
typename props>
251 a =
array({src.size()}, {elem_size * src.innerStride()}, src.data(),
base);
253 a =
array({src.rows(), src.cols()},
254 {elem_size * src.rowStride(), elem_size * src.colStride()},
270 template <
typename props,
typename Type>
284 return eigen_ref_array<props>(*src,
base);
289 template <
typename Type>
294 using props = EigenProps<Type>;
309 auto dims = buf.ndim();
310 if (dims < 1 || dims > 2) {
314 auto fits = props::conformable(buf);
321 auto ref = reinterpret_steal<array>(eigen_ref_array<props>(
value));
324 }
else if (
ref.ndim() == 1) {
340 template <
typename CType>
345 return eigen_encapsulate<props>(src);
347 return eigen_encapsulate<props>(
new CType(std::move(*src)));
349 return eigen_array_cast<props>(*src);
352 return eigen_ref_array<props>(*src);
354 return eigen_ref_array<props>(*src, parent);
356 throw cast_error(
"unhandled return_value_policy: should not happen!");
375 return cast_impl(&src, policy, parent);
383 return cast(&src, policy, parent);
387 return cast_impl(src, policy, parent);
391 return cast_impl(src, policy, parent);
394 static constexpr
auto name = props::descriptor;
401 operator Type &&() && {
return std::move(
value); }
402 template <
typename T>
410 template <
typename MapType>
411 struct eigen_map_caster {
416 using props = EigenProps<MapType>;
428 return eigen_array_cast<props>(src);
437 pybind11_fail(
"Invalid return_value_policy for Eigen Map/Ref/Block type");
441 static constexpr
auto name = props::descriptor;
446 bool load(
handle,
bool) =
delete;
453 template <
typename Type>
458 template <
typename PlainObjectType,
typename Str
ideType>
460 Eigen::Ref<PlainObjectType, 0, StrideType>,
461 enable_if_t<is_eigen_dense_map<Eigen::Ref<PlainObjectType, 0, StrideType>>::value>>
462 :
public eigen_map_caster<Eigen::Ref<PlainObjectType, 0, StrideType>> {
465 using props = EigenProps<Type>;
473 | ((props::row_major ? props::inner_stride : props::outer_stride) == 1
475 : (props::row_major ? props::outer_stride : props::inner_stride) == 1
480 std::unique_ptr<MapType> map;
481 std::unique_ptr<Type>
ref;
494 bool need_copy = !isinstance<Array>(src);
496 EigenConformable<props::row_major> fits;
500 auto aref = reinterpret_borrow<Array>(src);
502 if (aref && (!need_writeable || aref.writeable())) {
503 fits = props::conformable(aref);
507 if (!fits.template stride_compatible<props>()) {
510 copy_or_ref = std::move(aref);
521 if (!
convert || need_writeable) {
525 Array
copy = Array::ensure(src);
529 fits = props::conformable(
copy);
530 if (!fits || !fits.template stride_compatible<props>()) {
533 copy_or_ref = std::move(
copy);
541 make_stride(fits.stride.outer(), fits.stride.inner())));
548 operator Type *() {
return ref.get(); }
550 operator Type &() {
return *
ref; }
551 template <
typename _T>
552 using cast_op_type = pybind11::detail::cast_op_type<_T>;
557 return a.mutable_data();
567 template <
typename S>
573 template <
typename S>
574 using stride_ctor_dual
579 template <
typename S>
580 using stride_ctor_outer
585 template <
typename S>
586 using stride_ctor_inner
593 static S make_stride(EigenIndex, EigenIndex) {
597 static S make_stride(EigenIndex outer, EigenIndex inner) {
598 return S(outer, inner);
601 static S make_stride(EigenIndex outer, EigenIndex) {
605 static S make_stride(EigenIndex, EigenIndex inner) {
614 template <
typename Type>
622 using props = EigenProps<Matrix>;
630 return cast(*src, policy, parent);
633 static constexpr
auto name = props::descriptor;
639 operator Type() =
delete;
644 template <
typename Type>
651 static constexpr
bool rowMajor = Type::IsRowMajor;
658 auto obj = reinterpret_borrow<object>(src);
660 object matrix_type = sparse_module.attr(rowMajor ?
"csr_matrix" :
"csc_matrix");
664 obj = matrix_type(obj);
673 auto shape = pybind11::tuple((pybind11::object) obj.attr(
"shape"));
674 auto nnz = obj.attr(
"nnz").cast<
Index>();
676 if (!
values || !innerIndices || !outerIndices) {
682 StorageIndex>(shape[0].cast<Index>(),
683 shape[1].cast<
Index>(),
685 outerIndices.mutable_data(),
686 innerIndices.mutable_data(),
696 =
module_::import(
"scipy.sparse").attr(rowMajor ?
"csr_matrix" :
"csc_matrix");
698 array data(src.nonZeros(), src.valuePtr());
699 array outerIndices((rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
700 array innerIndices(src.nonZeros(), src.innerIndexPtr());
703 std::move(
data), std::move(innerIndices), std::move(outerIndices)),
709 const_name<(Type::IsRowMajor) != 0>(
"scipy.sparse.csr_matrix[",
710 "scipy.sparse.csc_matrix[")