23 #if defined(__MINGW32__)
28 #include <Eigen/SparseCore>
36 "Eigen matrix support in pybind11 requires Eigen >= 3.2.7");
51 #if EIGEN_VERSION_AT_LEAST(3, 3, 0)
53 template <
typename Scalar,
int Flags,
typename StorageIndex>
57 template <
typename Scalar,
int Flags,
typename StorageIndex>
64 std::is_base_of<Eigen::MapBase<T, Eigen::ReadOnlyAccessors>,
T>>;
66 using is_eigen_mutable_map = std::is_base_of<Eigen::MapBase<T, Eigen::WriteAccessors>,
T>;
68 using is_eigen_dense_plain
82 template <
bool EigenRowMajor>
83 struct EigenConformable {
84 bool conformable =
false;
86 EigenDStride stride{0, 0};
87 bool negativestrides =
false;
90 EigenConformable(
bool fits =
false) : conformable{fits} {}
92 EigenConformable(EigenIndex r, EigenIndex
c, EigenIndex rstride, EigenIndex cstride)
96 stride{EigenRowMajor ? (rstride > 0 ? rstride : 0)
97 : (cstride > 0 ? cstride : 0) ,
98 EigenRowMajor ? (cstride > 0 ? cstride : 0)
99 : (rstride > 0 ? rstride : 0) },
100 negativestrides{rstride < 0 || cstride < 0} {}
102 EigenConformable(EigenIndex r, EigenIndex
c, EigenIndex stride)
103 : EigenConformable(r,
c, r == 1 ?
c * stride : stride,
c == 1 ? r : r * stride) {}
105 template <
typename props>
106 bool stride_compatible()
const {
111 if (negativestrides) {
117 return (props::inner_stride ==
Eigen::Dynamic || props::inner_stride == stride.inner()
118 || (EigenRowMajor ?
cols :
rows) == 1)
119 && (props::outer_stride ==
Eigen::Dynamic || props::outer_stride == stride.outer()
120 || (EigenRowMajor ?
rows :
cols) == 1);
123 operator bool()
const {
return conformable; }
126 template <
typename Type>
127 struct eigen_extract_stride {
130 template <
typename PlainObjectType,
int MapOptions,
typename Str
ideType>
131 struct eigen_extract_stride<
Eigen::Map<PlainObjectType, MapOptions, StrideType>> {
132 using type = StrideType;
134 template <
typename PlainObjectType,
int Options,
typename Str
ideType>
135 struct eigen_extract_stride<
Eigen::Ref<PlainObjectType, Options, StrideType>> {
136 using type = StrideType;
140 template <
typename Type_>
145 static constexpr EigenIndex
rows = Type::RowsAtCompileTime,
cols = Type::ColsAtCompileTime,
146 size = Type::SizeAtCompileTime;
147 static constexpr
bool row_major = Type::IsRowMajor,
149 = Type::IsVectorAtCompileTime,
152 dynamic = !fixed_rows && !fixed_cols;
154 template <EigenIndex i, EigenIndex ifzero>
155 using if_zero = std::integral_constant<EigenIndex, i == 0 ? ifzero : i>;
156 static constexpr EigenIndex inner_stride
158 outer_stride = if_zero < StrideType::OuterStrideAtCompileTime,
162 static constexpr
bool dynamic_stride
164 static constexpr
bool requires_row_major
165 = !dynamic_stride && !vector && (row_major ? inner_stride : outer_stride) == 1;
166 static constexpr
bool requires_col_major
167 = !dynamic_stride && !vector && (row_major ? outer_stride : inner_stride) == 1;
172 static EigenConformable<row_major> conformable(
const array &
a) {
173 const auto dims =
a.ndim();
174 if (dims < 1 || dims > 2) {
180 EigenIndex np_rows =
a.shape(0), np_cols =
a.shape(1),
183 if ((fixed_rows && np_rows !=
rows) || (fixed_cols && np_cols !=
cols)) {
187 return {np_rows, np_cols, np_rstride, np_cstride};
192 const EigenIndex
n =
a.shape(0),
196 if (fixed &&
size !=
n) {
199 return {
rows == 1 ? 1 :
n,
cols == 1 ? 1 :
n, stride};
211 return {1,
n, stride};
213 if (fixed_rows &&
rows !=
n) {
216 return {
n, 1, stride};
219 static constexpr
bool show_writeable
222 static constexpr
bool show_c_contiguous = show_order && requires_row_major;
223 static constexpr
bool show_f_contiguous
224 = !show_c_contiguous && show_order && requires_col_major;
226 static constexpr
auto descriptor
237 const_name<show_writeable>(
", flags.writeable",
"")
238 + const_name<show_c_contiguous>(
", flags.c_contiguous",
"")
239 + const_name<show_f_contiguous>(
", flags.f_contiguous",
"") +
const_name(
"]");
244 template <
typename props>
250 a =
array({src.size()}, {elem_size * src.innerStride()}, src.data(),
base);
252 a =
array({src.rows(), src.cols()},
253 {elem_size * src.rowStride(), elem_size * src.colStride()},
269 template <
typename props,
typename Type>
283 return eigen_ref_array<props>(*src,
base);
288 template <
typename Type>
293 using props = EigenProps<Type>;
308 auto dims = buf.ndim();
309 if (dims < 1 || dims > 2) {
313 auto fits = props::conformable(buf);
320 auto ref = reinterpret_steal<array>(eigen_ref_array<props>(
value));
323 }
else if (
ref.ndim() == 1) {
339 template <
typename CType>
344 return eigen_encapsulate<props>(src);
346 return eigen_encapsulate<props>(
new CType(std::move(*src)));
348 return eigen_array_cast<props>(*src);
351 return eigen_ref_array<props>(*src);
353 return eigen_ref_array<props>(*src, parent);
355 throw cast_error(
"unhandled return_value_policy: should not happen!");
374 return cast_impl(&src, policy, parent);
382 return cast(&src, policy, parent);
386 return cast_impl(src, policy, parent);
390 return cast_impl(src, policy, parent);
393 static constexpr
auto name = props::descriptor;
400 operator Type &&() && {
return std::move(
value); }
401 template <
typename T>
409 template <
typename MapType>
410 struct eigen_map_caster {
415 using props = EigenProps<MapType>;
427 return eigen_array_cast<props>(src);
436 pybind11_fail(
"Invalid return_value_policy for Eigen Map/Ref/Block type");
440 static constexpr
auto name = props::descriptor;
445 bool load(
handle,
bool) =
delete;
452 template <
typename Type>
457 template <
typename PlainObjectType,
typename Str
ideType>
459 Eigen::Ref<PlainObjectType, 0, StrideType>,
460 enable_if_t<is_eigen_dense_map<Eigen::Ref<PlainObjectType, 0, StrideType>>::value>>
461 :
public eigen_map_caster<Eigen::Ref<PlainObjectType, 0, StrideType>> {
464 using props = EigenProps<Type>;
472 | ((props::row_major ? props::inner_stride : props::outer_stride) == 1
474 : (props::row_major ? props::outer_stride : props::inner_stride) == 1
479 std::unique_ptr<MapType> map;
480 std::unique_ptr<Type>
ref;
493 bool need_copy = !isinstance<Array>(src);
495 EigenConformable<props::row_major> fits;
499 auto aref = reinterpret_borrow<Array>(src);
501 if (aref && (!need_writeable || aref.writeable())) {
502 fits = props::conformable(aref);
506 if (!fits.template stride_compatible<props>()) {
509 copy_or_ref = std::move(aref);
520 if (!
convert || need_writeable) {
524 Array
copy = Array::ensure(src);
528 fits = props::conformable(
copy);
529 if (!fits || !fits.template stride_compatible<props>()) {
532 copy_or_ref = std::move(
copy);
540 make_stride(fits.stride.outer(), fits.stride.inner())));
547 operator Type *() {
return ref.get(); }
549 operator Type &() {
return *
ref; }
550 template <
typename _T>
551 using cast_op_type = pybind11::detail::cast_op_type<_T>;
556 return a.mutable_data();
566 template <
typename S>
572 template <
typename S>
573 using stride_ctor_dual
578 template <
typename S>
579 using stride_ctor_outer
584 template <
typename S>
585 using stride_ctor_inner
592 static S make_stride(EigenIndex, EigenIndex) {
596 static S make_stride(EigenIndex outer, EigenIndex inner) {
597 return S(outer, inner);
600 static S make_stride(EigenIndex outer, EigenIndex) {
604 static S make_stride(EigenIndex, EigenIndex inner) {
613 template <
typename Type>
621 using props = EigenProps<Matrix>;
629 return cast(*src, policy, parent);
632 static constexpr
auto name = props::descriptor;
638 operator Type() =
delete;
643 template <
typename Type>
650 static constexpr
bool rowMajor = Type::IsRowMajor;
657 auto obj = reinterpret_borrow<object>(src);
659 object matrix_type = sparse_module.attr(rowMajor ?
"csr_matrix" :
"csc_matrix");
663 obj = matrix_type(obj);
672 auto shape = pybind11::tuple((pybind11::object) obj.attr(
"shape"));
673 auto nnz = obj.attr(
"nnz").cast<
Index>();
675 if (!
values || !innerIndices || !outerIndices) {
681 StorageIndex>(shape[0].cast<Index>(),
682 shape[1].cast<
Index>(),
684 outerIndices.mutable_data(),
685 innerIndices.mutable_data(),
695 =
module_::import(
"scipy.sparse").attr(rowMajor ?
"csr_matrix" :
"csc_matrix");
697 array data(src.nonZeros(), src.valuePtr());
698 array outerIndices((rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
699 array innerIndices(src.nonZeros(), src.innerIndexPtr());
702 std::move(
data), std::move(innerIndices), std::move(outerIndices)),
708 const_name<(Type::IsRowMajor) != 0>(
"scipy.sparse.csr_matrix[",
709 "scipy.sparse.csc_matrix[")