00001 #ifndef CUDA_SOLVER_H 00002 #define CUDA_SOLVER_H 00003 00004 #include "parallel_solver.h" 00005 00006 namespace parallel_ode 00007 { 00008 00009 template<typename CudaT,typename ParamsT> 00010 class CudaPGSSolver : public ParallelPGSSolver<CudaT,ParamsT,ParallelTypes::CUDA> 00011 { 00012 public: 00013 typedef typename vec4<CudaT>::Type Vec4T; 00014 typedef const CudaT* CudaTPtr; 00015 typedef CudaT* CudaTMutablePtr; 00016 typedef MemManager<CudaT,ParallelTypes::CUDA> PMemManager; 00017 typedef typename PMemManager::mem_flags MemFlags; 00018 00027 CudaPGSSolver( int parallelFlags = DEFAULT_FLAGS, 00028 BatchType batchType = BatchTypes::DEFAULT_BATCH_TYPE, 00029 ReduceType reduceType = ReduceTypes::DEFAULT_REDUCE_TYPE, 00030 uint numBatches = ParallelOptions::MAXBATCHES ) 00031 00032 : ParallelPGSSolver<CudaT,ParamsT,ParallelTypes::CUDA>( parallelFlags, batchType, reduceType, numBatches ) { 00033 00034 } 00035 00036 virtual ~CudaPGSSolver( ) { 00037 } 00038 00042 virtual void initialize( ); 00043 00044 protected: 00045 00046 virtual void preProcessDevice( const CudaT sorParam, const CudaT stepsize ); 00047 00048 virtual void solveAndReduce( const int offset, const int batchSize ); 00049 00050 virtual void loadConstraints( ); 00051 }; 00052 00053 } 00054 #endif