33 #ifndef STATE_SIM_BASE_HPP 34 #define STATE_SIM_BASE_HPP 38 #include <eigen3/Eigen/Eigen> 46 template <
class TDerived>
47 struct StateMapBaseTraits;
48 template <
class TDerived>
49 struct StateMapBaseTraits
52 using StateType =
typename StateMapBaseTraits<typename TDerived::StateMapBaseType>::StateType;
55 using StateForSimType =
typename StateMapBaseTraits<typename TDerived::StateMapBaseType>::StateForSimType;
58 using NumType =
typename StateMapBaseTraits<typename TDerived::StateMapBaseType>::NumType;
61 using StateCfType =
typename StateMapBaseTraits<typename TDerived::StateMapBaseType>::StateCfType;
64 using StateNmType =
typename StateMapBaseTraits<typename TDerived::StateMapBaseType>::StateNmType;
67 using StateVirtType =
typename StateMapBaseTraits<typename TDerived::StateMapBaseType>::StateVirtType;
70 using StateNmNumType =
typename StateMapBaseTraits<typename TDerived::StateMapBaseType>::StateNmNumType;
73 using StateWithGradNmType =
typename StateMapBaseTraits<typename TDerived::StateMapBaseType>::StateWithGradNmType;
76 using StateWithGradCfType =
typename StateMapBaseTraits<typename TDerived::StateMapBaseType>::StateWithGradCfType;
79 using StateWithGradNmNumType =
80 typename StateMapBaseTraits<typename TDerived::StateMapBaseType>::StateWithGradNmNumType;
83 template <
class TDerived>
84 struct StateSimBaseCRTPTraits;
92 template <
class TDerived>
96 using StateType =
typename StateSimBaseCRTPTraits<TDerived>::StateType;
99 using StateForSimType =
typename StateSimBaseCRTPTraits<TDerived>::StateForSimType;
102 using NumType =
typename StateSimBaseCRTPTraits<TDerived>::NumType;
105 using StateNmType =
typename StateSimBaseCRTPTraits<TDerived>::StateNmType;
108 static constexpr
const bool hasStateGrad =
109 !std::is_same<EmptyGradType, typename StateSimBaseCRTPTraits<TDerived>::StateWithGradNmType>
::value;
130 template <
bool stateGradientRepresentation = hasStateGrad,
131 typename std::enable_if<(stateGradientRepresentation)>::type* =
nullptr>
134 thisDerived().advanceWithGradImplCRTP(_arc);
140 thisDerived().advanceImplCRTP(_arc);
146 thisDerived().simToTImplCRTP(_arcEnd, _dt);
152 thisDerived().toState0ImplCRTP();
160 return thisDerived().stateImplCRTP();
166 return thisDerived().stateImplCRTP();
172 thisDerived().advanceSet0ImplCRTP(_state0, _tEnd, _dt);
178 return static_cast<TDerived&
>(*this);
184 return static_cast<const TDerived&
>(*this);
188 template <
class TNumType,
class StateVirtType>
213 advanceImplVirt(_arc);
219 advanceWithGradImplVirt(_arc);
236 virtual void advanceImplVirt(
const TNumType& _arc) = 0;
239 virtual void advanceWithGradImplVirt(
const TNumType& _arc) = 0;
242 virtual void toState0ImplVirt() = 0;
245 template <
class NumType,
class StateWithGradNmNumType,
template <
class>
class TDiscretizationType>
258 odeint::vector_space_algebra>;
263 template <
class TDerived2,
class TParamType2,
class TStateType2,
template <
class>
class TDiscretizationType2,
264 class... TFuncsType2>
272 template <
class TDerived2,
class TParamType2,
class TStateType2,
template <
class>
class TDiscretizationType2,
273 class... TFuncsType2>
277 template <
class TDerived,
class TParamType,
class TStateType,
template <
class>
class TDiscretizationType,
280 :
public StateSimBaseCRTP<StateSimBase<TDerived, TParamType, TStateType, TDiscretizationType, TFuncsType...>>,
281 public StateSimBaseVirt<typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::NumType,
282 typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateVirtType>,
283 public std::conditional<
284 !std::is_same<EmptyGradType,
285 typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateWithGradNmType>::value,
286 OdeStateSolverRealAlias<
287 typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::NumType,
288 typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateWithGradNmNumType,
289 TDiscretizationType>,
290 OdeStateSolverDummyAlias>::type
296 using StateForSimType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateForSimType;
305 using NumType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::NumType;
308 using StateVirtType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateVirtType;
311 using StateNmType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateNmType;
314 using StateCfType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateCfType;
317 using StateNmNumType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateNmNumType;
320 using StateWithGradNmType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateWithGradNmType;
323 using StateWithGradCfType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateWithGradCfType;
327 typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateWithGradNmNumType;
332 odeint::vector_space_algebra>;
338 using StateNumSimType =
typename std::conditional<hasStateGrad, StateWithGradNmType, StateNmType>::type;
377 advanceImplCRTP(_arc);
383 advanceWithGradImplVirtDispatch(_arc);
393 template <
bool stateGradientRepresentation = hasStateGrad,
394 typename std::enable_if<(stateGradientRepresentation)>::type* =
nullptr>
397 advanceWithGradImplCRTP(_arc);
401 template <
bool stateGradientRepresentation = hasStateGrad,
402 typename std::enable_if<(!stateGradientRepresentation)>::type* =
nullptr>
405 throw std::runtime_error(
"Cannot advance with gradient info (State class not suited)");
424 [
this](
const StateNmNumType& _x, StateNmNumType& _dxdt,
const NumType _t)
426 static NumType* memStartRefNm;
427 memStartRefNm = state_.stateNm().memStartRef();
428 state_.stateNm().bindToMemory((NumType*)_x.data());
431 _dxdt = stateWithGradNmDotCache_.state().data();
432 state_.stateNm().bindToMemory(memStartRefNm);
434 state_.stateNm().data(), arcOld_, _arc - arcOld_);
439 template <
bool stateGradientRepresentation = hasStateGrad,
440 typename std::enable_if<(stateGradientRepresentation)>::type* =
nullptr>
443 this->rkGrad_.do_step(
446 static NumType* memStartRefNm;
447 memStartRefNm = state_.stateWithGradNm().memStartRef();
448 state_.stateWithGradNm().bindToMemory((NumType*)_x.data());
452 _dxdt = stateWithGradNmDotCache_.data();
453 state_.stateWithGradNm().bindToMemory(memStartRefNm);
455 state_.stateWithGradNm().data(), arcOld_, _arc - arcOld_);
460 template <
bool stateGradientRepresentation = hasStateGrad,
461 typename std::enable_if<(!stateGradientRepresentation)>::type* =
nullptr>
468 funcI.precompute(*
this);
472 stateWithGradNmDotCache_ = state_.stateNm();
478 rk_.adjust_size(state_.stateNm().data().size());
483 template <
bool stateGradientRepresentation = hasStateGrad,
484 typename std::enable_if<(stateGradientRepresentation)>::type* =
nullptr>
492 funcI.precompute(*
this);
497 stateWithGradNmDotCache_ = state_.stateWithGradNm();
505 rk_.adjust_size(state_.stateNm().data().size());
506 this->rkGrad_.adjust_size(state_.stateWithGradNm().data().size());
527 simToTImplCRTP(_tEnd, _dt);
528 _state0.stateNm().data() = state_.stateNm().data();
529 _state0.stateCf().data() = state_.stateCf().data();
534 void setXCf(
const NumType& _arc,
const PfEaG& _eAG)
536 thisSimDerived().setXCfImpl(state_.stateCf(), _arc, _eAG);
542 thisSimDerived().setXCfDotImpl(stateCfDotCache_, state_.stateCf(), _arc, _eAG);
548 thisSimDerived().setXNmDotImpl(stateWithGradNmDotCache_.state(), state_.stateCf(), state_.stateNm(), _arc, _eAG);
551 funcI.computeFuncDot(stateWithGradNmDotCache_.state(), state_.stateNm(), state_.stateCf(), *
this,
557 template <
bool stateGradientRepresentation = hasStateGrad,
558 typename std::enable_if<(stateGradientRepresentation)>::type* =
nullptr>
561 thisSimDerived().setGradXCfImpl(state_.stateGradCf(), state_.stateCf(), _arc, _eAG);
565 template <
bool stateGradientRepresentation = hasStateGrad,
566 typename std::enable_if<(stateGradientRepresentation)>::type* =
nullptr>
569 thisSimDerived().setGradXNmDotImpl(stateWithGradNmDotCache_.stateGrad(), state_.stateWithGradCf(),
570 state_.stateWithGradNm(), _arc, _eAG);
573 funcI.computeGradFuncDot(stateWithGradNmDotCache_.stateGrad(), state_.stateWithGradNm(),
579 void init(std::shared_ptr<TParamType> _paramStructPtr)
581 paramStruct = _paramStructPtr;
587 return static_cast<TDerived&
>(*this);
593 return static_cast<const TDerived&
>(*this);
599 return stateWithGradNmDotCache_.state();
605 return stateCfDotCache_;
617 thisSimDerived().setXNm0Impl(state_.stateNm());
623 thisSimDerived().setGradXNm0Impl(state_.stateGradNm(), state_.stateNm());
629 thisSimDerived().adjustXSizeImpl(state_.stateNm(), state_.stateCf());
635 thisSimDerived().adjustGradXSizeImpl(state_.stateGradNm(), state_.stateGradCf());
659 template <
class TNumType2,
class StateVirtType2>
661 template <
class TDerived2>
667 template <
class TDerived,
class TParamType,
class TStateType,
template <
class>
class TDiscretizationType,
669 struct StateSimBaseCRTPTraits<
StateSimBase<TDerived, TParamType, TStateType, TDiscretizationType, TFuncsType...>>
675 using StateForSimType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateForSimType;
678 using NumType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::NumType;
681 using StateCfType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateCfType;
684 using StateNmType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateNmType;
687 using StateVirtType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateVirtType;
690 using StateNmNumType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateNmNumType;
693 using StateWithGradNmType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateWithGradNmType;
696 using StateWithGradCfType =
typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateWithGradCfType;
700 typename StateMapBaseTraits<typename TStateType::StateMapBaseType>::StateWithGradNmNumType;
706 #endif // STATE_SIM_BASE_HPP
ParamType< TNumType, TMapDataType > ParamStructType
typename StateMapBaseTraits< typename StateWithLWithGradE8< TNumType >::StateMapBaseType >::StateWithGradNmNumType StateWithGradNmNumType
StateNumSimType stateWithGradNmDotCache_
std::tuple< TFuncsType... > funcs_
void advanceWithGrad(const TNumType &_arc)
typename StateMapBaseTraits< typename StateWithLWithGradE8< TNumType >::StateMapBaseType >::StateWithGradNmType StateWithGradNmType
const TDerived & thisSimDerived() const
typename StateSimBaseCRTPTraits< StateSimBase< StateSimE8Base< TNumType, MapDataType, TStateType, TDiscretizationType, TFuncsType... >, ParamType< TNumType, MapDataType >, TStateType, TDiscretizationType, TFuncsType... > >::NumType NumType
void simToT(const NumType &_arcEnd, const NumType &_dt)
void setXNmDot(const NumType &_arc, const PfEaG &_eAG)
StateForSimType & stateImplCRTP()
void setXCfDot(const NumType &_arc, const PfEaG &_eAG)
std::shared_ptr< TParamType > paramStruct
void init(std::shared_ptr< TParamType > _paramStructPtr)
EmptyGradType OdeStateGradSolverType
void advanceWithGrad(const NumType &_arc)
typename StateSimBaseCRTPTraits< StateSimBase< StateSimE8Base< TNumType, MapDataType, TStateType, TDiscretizationType, TFuncsType... >, ParamType< TNumType, MapDataType >, TStateType, TDiscretizationType, TFuncsType... > >::StateType StateType
const StateNmType & stateNmDotCache() const
typename StateMapBaseTraits< typename StateWithLWithGradE8< TNumType >::StateMapBaseType >::StateNmNumType StateNmNumType
constexpr std::enable_if< II==sizeof...(Tp), void >::type for_each_tuple(std::tuple< Tp... > &, FuncT)
void toState0ImplVirt() overridefinal
void advanceImplCRTP(const NumType &_arc)
void setGradXNmDot(const NumType &_arc, const PfEaG &_eAG)
StateCfType stateCfDotCache_
typename std::conditional< hasStateGrad, StateWithGradNmType, StateNmType >::type StateNumSimType
const StateForSimType & stateImplCRTP() const
constexpr std::enable_if< II< sizeof...(Tp), void >::type for_each_tuple_class(std::tuple< Tp... > &t, FuncT &f){f[II]=[&t](const auto &_a, const auto &_b, const auto &_c, auto &_d){std::get< II >t).computeDArcIdPImpl(_a, _b, _c, _d);};for_each_tuple_class< II+1, FuncT, Tp... >t, f);}template< std::size_t IIMax, std::size_t II=0 >constexpr inline typename std::enable_if<(II==IIMax), void >::type get_costs_sizes(auto &partLatI, const size_t &_i, auto &_sizeCostsPerPartLattice, auto &_sizeCostsPerType){}template< std::size_t IIMax, std::size_t II=0 >constexpr inline typename std::enable_if<(II< IIMax), void >::type get_costs_sizes(auto &partLatI, const size_t &_i, auto &_sizeCostsPerPartLattice, auto &_sizeCostsPerType){size_t partLatICostsTypeJ=partLatI.lattice.size()*partLatI.template costFuncsNr< II >);_sizeCostsPerPartLattice[_i]+=partLatICostsTypeJ;_sizeCostsPerType[II]+=partLatICostsTypeJ;get_costs_sizes< IIMax, II+1 >partLatI, _i, _sizeCostsPerPartLattice, _sizeCostsPerType);}template< typename TNumType, typename TSimType, bool TUseStateNm, template< typename, typename > class...TLatticeTypes >class TrajectorySimulator{public:using StateSimSPtr=std::shared_ptr< TSimType >;public:using StateType=typename TSimType::StateType;public:using StateSPtr=std::shared_ptr< StateType >;public:using StateForSimType=typename TSimType::StateForSimType;public:using LatticePointType=LatticePoint< TNumType, StateType >;public:static constexpr const bool CanComputeStateGrad=!std::is_same< EmptyGradType, typename StateMapBaseTraits< StateType >::StateWithGradNmType >::value;public:static constexpr const size_t CostFuncsTypesNr=std::tuple_element< 0, std::tuple< TLatticeTypes< TNumType, TSimType >... > >::type::costFuncsTypesNr();public:TrajectorySimulator():stateSim_(std::shared_ptr< TSimType >new TSimType)){for_each_tuple_class(partialLattices_, correctStateGradFunc);for(size_t i=0;i< gradCostsMap_.size();++i){gradCostsMap_[i]=std::shared_ptr< Eigen::Map< Eigen::Matrix< TNumType,-1,-1, Eigen::RowMajor > > >new Eigen::Map< Eigen::Matrix< TNumType,-1,-1, Eigen::RowMajor > >nullptr, 0, 0));}}public:TrajectorySimulator(StateSimSPtr &_stateSim):stateSim_(_stateSim){for_each_tuple_class(partialLattices_, correctStateGradFunc);for(size_t i=0;i< gradCostsMap_.size();++i){gradCostsMap_[i]=std::shared_ptr< Eigen::Map< Eigen::Matrix< TNumType,-1,-1, Eigen::RowMajor > > >new Eigen::Map< Eigen::Matrix< TNumType,-1,-1, Eigen::RowMajor > >nullptr, 0, 0));}}public:StateSimSPtr &stateSim(){return stateSim_;}public:const StateSimSPtr &stateSim() const {return stateSim_;}public:LatticePointType &simLatticeI(const size_t &_i){return simulationLattice_[_i];}public:const LatticePointType &simLatticeI(const size_t &_i) const {return simulationLattice_[_i];}public:size_t simLatticeSize() const {return simulationLatticeActiveSize_;}public:template< template< typename, typename > class TLatticeType > TLatticeType< TNumType, TSimType > &partialLattice(){return std::get< TLatticeType< TNumType, TSimType > >partialLattices_);}public:template< template< typename, typename > class TLatticeType > const TLatticeType< TNumType, TSimType > &partialLattice() const {return std::get< TLatticeType< TNumType, TSimType > >partialLattices_);}public:template< size_t TLatticeIdx > auto &partialLattice(){return std::get< TLatticeIdx >partialLattices_);}public:template< size_t TLatticeIdx > const auto &partialLattice() const {return std::get< TLatticeIdx >partialLattices_);}private:using AdvanceFunction=void(TrajectorySimulator< TNumType, TSimType, TUseStateNm, TLatticeTypes... >::*)(const TNumType &);private:AdvanceFunction advanceFunc;private:void advanceFuncSimEmpty(const TNumType &_arcNow){}private:void advanceFuncSim(const TNumType &_arcNow){stateSim_-> advance(_arcNow)
typename StateMapBaseTraits< typename StateWithLWithGradE8< TNumType >::StateMapBaseType >::StateVirtType StateVirtType
typename StateMapBaseTraits< typename StateWithLWithGradE8< TNumType >::StateMapBaseType >::StateCfType StateCfType
TDerived & thisSimDerived()
typename StateMapBaseTraits< typename StateWithLWithGradE8< TNumType >::StateMapBaseType >::StateWithGradCfType StateWithGradCfType
void setXCf(const NumType &_arc, const PfEaG &_eAG)
static constexpr auto value
typename StateSimBaseCRTPTraits< StateSimBase< StateSimE8Base< TNumType, MapDataType, TStateType, TDiscretizationType, TFuncsType... >, ParamType< TNumType, MapDataType >, TStateType, TDiscretizationType, TFuncsType... > >::StateForSimType StateForSimType
const TDerived & thisDerived() const
const StateCfType & stateCfDotCache() const
EvalArcGuarantee
Flags if any guarantees about evaluation arc relative to last evaluation arc are present.
void advanceSet0ImplCRTP(auto &_state0, const NumType &_tEnd, const NumType &_dt)
void advanceWithGradImplVirt(const NumType &_arc) overridefinal
OdeStateGradSolverType rkGrad_
void advanceImplVirt(const NumType &_arc) overridefinal
void setXNmDot(const TNumType &_arcNow)
close to previous evaluation arc
StateForSimType & state()
void advanceWithGradImplVirtDispatch(const NumType &_arc)
void advance(const TNumType &_arc)
void advanceWithGradImplCRTP(const NumType &_arc)
void simToTImplCRTP(const NumType &_tEnd, const NumType &_dt)
void advance(const NumType &_arc)
this evaluation arc is at the arc parametrization begin
typename StateSimBaseCRTPTraits< StateSimBase< StateSimE8Base< TNumType, MapDataType, TStateType, TDiscretizationType, TFuncsType... >, ParamType< TNumType, MapDataType >, TStateType, TDiscretizationType, TFuncsType... > >::StateNmType StateNmType
void advanceSet0(auto &_state0, const NumType &_tEnd, const NumType &_dt)
void setGradXCf(const NumType &_arc, const PfEaG &_eAG)
const StateForSimType & state() const