22 #ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP 23 #define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP 26 namespace TensorSycl {
47 template <
template <
class,
class>
class UnaryCategory,
typename OP,
typename RHSExpr,
typename Dev>
52 : rhsExpr(expr.impl()), func(expr.functor()) {}
56 template <
template <
class,
class>
class UnaryCategory,
typename OP,
typename RHSExpr,
typename Dev>
58 :
FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> >{};
62 template <
template<
class,
class,
class>
class BinaryCategory,
typename OP,
typename LHSExpr,
typename RHSExpr,
typename Dev>
68 : lhsExpr(expr.left_impl()),rhsExpr(expr.right_impl()),func(expr.functor()) {}
73 template <
template <
class,
class,
class>
class BinaryCategory,
typename OP,
typename LHSExpr,
typename RHSExpr,
typename Dev>
75 :
FunctorExtractor<TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> >{};
79 template <
template <
class,
class,
class,
class>
class TernaryCategory,
typename OP,
typename Arg1Expr,
typename Arg2Expr,
typename Arg3Expr,
typename Dev>
86 : arg1Expr(expr.arg1Impl()), arg2Expr(expr.arg2Impl()), arg3Expr(expr.arg3Impl()), func(expr.functor()) {}
91 template <
template <
class,
class,
class,
class>
class TernaryCategory,
typename OP,
typename Arg1Expr,
typename Arg2Expr,
typename Arg3Expr,
typename Dev>
93 :
FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> >{};
97 template <
typename IfExpr,
typename ThenExpr,
typename ElseExpr,
typename Dev>
103 : ifExpr(expr.cond_impl()), thenExpr(expr.then_impl()), elseExpr(expr.else_impl()) {}
108 template <
typename IfExpr,
typename ThenExpr,
typename ElseExpr,
typename Dev>
110 :
FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > {};
114 template <
typename LHSExpr,
typename RHSExpr,
typename Dev>
119 : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {}
124 template <
typename LHSExpr,
typename RHSExpr,
typename Dev>
126 :
FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> >{};
131 template <
typename RHSExpr,
typename Dev>
135 : rhsExpr(expr.impl()) {}
140 template <
typename RHSExpr,
typename Dev>
142 :
FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > {};
144 template<
typename Dim,
size_t NumOutputDim>
struct DimConstr {
145 template<
typename InDim>
146 static inline Dim
getDim(InDim dims ) {
return dims;}
150 template<
typename InDim>
151 static inline Dim
getDim(InDim dims ) {
return Dim(dims.TotalSize());}
154 template<
typename Op,
typename Dims,
typename ArgType,
template <
class>
class MakePointer_,
typename Device>
161 : m_dimensions(
DimConstr<Dimensions, Evaluator::NumOutputDims>::getDim(expr.
dimensions())) {}
165 template<
typename Op,
typename Dims,
typename ArgType,
template <
class>
class MakePointer_,
typename Device>
167 :
FunctorExtractor<TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>{};
169 template <
typename Evaluator>
177 #endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP static Dim getDim(InDim dims)
A cost model used to limit the number of threads used for evaluating tensor expression.
auto extractFunctors(const Evaluator &evaluator) -> FunctorExtractor< Evaluator >
template deduction function for FunctorExtractor
static Dim getDim(InDim dims)