$search
00001 #ifndef PARALLEL_SOLVER_H 00002 #define PARALLEL_SOLVER_H 00003 00004 #include "parallel_array.h" 00005 #include "parallel_batch.h" 00006 #include "parallel_math.h" 00007 #include "parallel_reduce.h" 00008 #include "parallel_timer.h" 00009 00010 #include "util.h" 00011 00012 namespace parallel_ode 00013 { 00014 00015 using ::parallel_utils::ParallelHDArray; 00016 using ::parallel_utils::MemManager; 00017 using ::parallel_utils::CopyType; 00018 00019 00020 namespace ParallelFlags 00021 { 00023 enum ParallelFlag { PARALLEL_NONE = 0x000, 00024 PARALLEL_PREPROCESS = 0x001, 00025 PARALLEL_ASYNC = 0x002, 00026 PARALLEL_RANDOMIZE = 0x004, 00027 PARALLEL_ATOMICS = 0x008, 00028 PARALLEL_ALIGN = 0x010, 00029 PARALLEL_PINNED = 0x020, 00030 PARALLEL_WC = 0x040, 00031 PARALLEL_REDUCE = 0x080, 00032 PARALLEL_COMPACT = 0x100}; 00033 } 00034 typedef ParallelFlags::ParallelFlag ParallelFlag; 00035 00037 const int DEFAULT_FLAGS = 00038 ParallelFlags::PARALLEL_ALIGN 00039 | ParallelFlags::PARALLEL_PINNED 00040 | ParallelFlags::PARALLEL_REDUCE 00041 | ParallelFlags::PARALLEL_ASYNC 00042 | ParallelFlags::PARALLEL_RANDOMIZE 00043 ;//| PARALLEL_ATOMICS; 00044 00046 /* ParallelPGSSolver is the abstract base class encapsulating functionality common to each of the parallel 00047 * Quickstep solvers. It takes the data from ODE, batches it appropriately, and relies on derived classes 00048 * to actually drive the solver. From there the data is transferred back to ODE. 00049 */ 00050 template<typename CudaT, typename ParamsT, ParallelType PType> 00051 class ParallelPGSSolver 00052 { 00053 public: 00054 struct SolverParams; 00055 typedef typename vec4<CudaT>::Type Vec4T; 00056 typedef const CudaT* CudaTPtr; 00057 typedef CudaT* CudaTMutablePtr; 00058 typedef const ParamsT* ParamsTPtr; 00059 typedef ParamsT* ParamsTMutablePtr; 00060 typedef MemManager<CudaT,PType> PMemManager; 00061 typedef typename PMemManager::mem_flags MemFlags; 00062 00071 ParallelPGSSolver( int parallelFlags = DEFAULT_FLAGS, 00072 BatchType batchType = BatchTypes::DEFAULT_BATCH_TYPE, 00073 ReduceType reduceType = ReduceTypes::DEFAULT_REDUCE_TYPE, 00074 uint numBatches = ParallelOptions::MAXBATCHES ) 00075 : parallelFlags_( parallelFlags ), 00076 parallelParams_( NULL ), 00077 batchStrategyType_( batchType ), 00078 batchStrategy_( NULL ), 00079 reduceStrategyType_( reduceType), 00080 reduceStrategy_( NULL ), 00081 batchRepetitionCount_( numBatches, 0 ), 00082 batchIndices_( numBatches, -1 ), 00083 batchSizes_( numBatches, 0 ), 00084 numBatches_( numBatches ), 00085 m_(0), 00086 constraintStride_(0), 00087 n_(0), 00088 bodyStride_(0), 00089 reduceStride_(0), 00090 bInit_(false) { 00091 } 00092 00096 virtual ~ParallelPGSSolver( ) { 00097 delete batchStrategy_; 00098 delete reduceStrategy_; 00099 } 00100 00106 virtual void worldSolve( SolverParams* params ); 00107 00111 virtual void initialize( ); 00112 00113 protected: 00114 00116 00117 void syncODEToDevice( ); 00118 void syncDeviceToODE( ); 00119 void loadBatches( int* jb ); 00120 00121 void setConstraintSize( int m, bool bResizeArrays = true ); 00122 void setBodySize( int n, bool bResizeArrays = true ); 00123 00125 00126 virtual void preProcessHost( ); 00127 virtual void preProcessDevice( const CudaT sorParam, const CudaT stepsize ); 00128 virtual void solveAndReduce( const int offset, const int batchSize ) = 0; 00129 00130 virtual void loadBodies( ); 00131 virtual void loadConstraints( ); 00132 virtual void loadConstants( ); 00133 virtual void loadKernels( ); 00134 virtual void loadSolution( ); 00135 00136 virtual void setMemFlags( MemFlags flags ); 00137 00139 00140 inline void checkInit( ) { if(!bInit_) this->initialize(); } 00141 inline void setBodyStride( int bodyStride ) { bodyStride_ = bodyStride; } 00142 inline void setConstraintStride( int constraintStride ) { constraintStride_ = constraintStride; } 00143 inline void setReduceStride( int reduceStride ) { reduceStride_ = reduceStride; } 00144 00145 inline int getNumBatches( ) const { return (int)batchSizes_.size(); } 00146 inline int getNumBodies( ) const { return n_; } 00147 inline int getBodyStride( ) const { return bodyStride_; } 00148 inline int getNumConstraints( ) const { return m_; } 00149 inline int getConstraintStride( ) const { return constraintStride_; } 00150 inline int getReduceStride( ) const { return reduceStride_; } 00151 inline int getDefaultAlign( ) const { return ParallelOptions::DEFAULTALIGN; } 00152 inline CopyType getCopyType( ) const { return asyncEnabled( ) ? parallel_utils::CopyTypes::COPY_ASYNC : parallel_utils::CopyTypes::COPY_SYNC; } 00153 00154 inline int bID(const int index, const int offset) const { return index + offset * getBodyStride(); } 00155 inline int cID(const int index, const int offset) const { return index + offset * getConstraintStride(); } 00156 inline int rID(const int index, const int offset) const { return index + offset * getReduceStride(); } 00157 00158 inline bool preprocessEnabled( ) const { return isEnabled( ParallelFlags::PARALLEL_PREPROCESS ); } 00159 inline bool asyncEnabled( ) const { return isEnabled( ParallelFlags::PARALLEL_ASYNC ); } 00160 inline bool randomizeEnabled( ) const { return isEnabled( ParallelFlags::PARALLEL_RANDOMIZE ); } 00161 inline bool atomicsEnabled( ) const { return isEnabled( ParallelFlags::PARALLEL_ATOMICS ); } 00162 inline bool alignEnabled( ) const { return isEnabled( ParallelFlags::PARALLEL_ALIGN ); } 00163 inline bool pinnedMemEnabled( ) const { return isEnabled( ParallelFlags::PARALLEL_PINNED ); } 00164 inline bool wcMemEnabled( ) const { return isEnabled( ParallelFlags::PARALLEL_WC ); } 00165 inline bool reduceEnabled( ) const { return isEnabled( ParallelFlags::PARALLEL_REDUCE ); } 00166 inline bool compactScalarsEnabled( ) const { return isEnabled( ParallelFlags::PARALLEL_COMPACT ); } 00167 00168 inline bool isEnabled( const ParallelFlag flag ) const { return parallelFlags_ & flag; } 00169 00170 void printConfig( ); 00171 00173 00174 int parallelFlags_; 00175 SolverParams* parallelParams_; 00177 BatchType batchStrategyType_; 00178 BatchStrategy *batchStrategy_; 00179 ReduceType reduceStrategyType_; 00180 ReduceStrategy *reduceStrategy_; 00182 IntVector constraintIndices_; 00183 IntVector reductionOffsets0_; 00184 IntVector reductionOffsets1_; 00185 IntVector batchRepetitionCount_; 00186 IntVector batchIndices_; 00187 IntVector batchSizes_; 00189 00190 00191 int numBatches_; 00193 int m_; 00194 int constraintStride_; 00195 int n_; 00196 int bodyStride_; 00198 int reduceStride_; 00200 bool bInit_; 00202 ParallelHDArray<Vec4T,PType> j0; 00203 ParallelHDArray<Vec4T,PType> ij0; 00204 ParallelHDArray<CudaT,PType> lambda0; 00205 ParallelHDArray<CudaT,PType> adcfm; 00206 ParallelHDArray<CudaT,PType> rhs; 00207 ParallelHDArray<CudaT,PType> lohiD; 00208 ParallelHDArray<int4,PType> bodyIDs; 00209 ParallelHDArray<int,PType> fIDs; 00211 ParallelHDArray<CudaT,PType> iMass; 00212 ParallelHDArray<Vec4T,PType> i0; 00214 ParallelHDArray<Vec4T,PType> bodyFAcc; 00215 ParallelHDArray<Vec4T,PType> bodyTAcc; 00216 ParallelHDArray<Vec4T,PType> bodyFAccReduction; 00217 ParallelHDArray<Vec4T,PType> bodyTAccReduction; 00219 public: 00220 00222 struct SolverParams { 00223 SolverParams( dxWorldProcessContext *contextIn, const dxQuickStepParameters *qsIn, 00224 const int mIn, const int nbIn, ParamsTMutablePtr JIn, int *jbIn, dxBody * const *bodyIn, 00225 ParamsTPtr invIIn, ParamsTMutablePtr lambdaIn, ParamsTMutablePtr fcIn, ParamsTMutablePtr bIn, 00226 ParamsTPtr loIn, ParamsTPtr hiIn, ParamsTPtr cfmIn, ParamsTMutablePtr iMJIn, ParamsTMutablePtr AdcfmIn, 00227 const int *findexIn, const ParamsT stepsizeIn ) 00228 : context(contextIn), qs(qsIn), m(mIn), nb(nbIn), J(JIn), jb(jbIn), body(bodyIn), 00229 invI(invIIn), lambda(lambdaIn), fc(fcIn), b(bIn), lo(loIn), hi(hiIn), cfm(cfmIn), 00230 iMJ(iMJIn), Adcfm(AdcfmIn), findex(findexIn), stepsize(stepsizeIn) { } 00231 00232 dxWorldProcessContext *context; 00233 const dxQuickStepParameters *qs; 00234 const int m; 00235 const int nb; 00236 ParamsTMutablePtr J; 00237 int *jb; 00238 dxBody * const *body; 00239 ParamsTPtr invI; 00240 ParamsTMutablePtr lambda; 00241 ParamsTMutablePtr fc; 00242 ParamsTMutablePtr b; 00243 ParamsTPtr lo; 00244 ParamsTPtr hi; 00245 ParamsTPtr cfm; 00246 ParamsTMutablePtr iMJ; 00247 ParamsTMutablePtr Adcfm; 00248 const int *findex; 00249 const ParamsT stepsize; 00250 }; 00251 00252 }; 00253 00254 } 00255 00256 #endif