Go to the documentation of this file.00001 #ifndef CUDA_SOLVER_H
00002 #define CUDA_SOLVER_H
00003
00004 #include "parallel_solver.h"
00005
00006 namespace parallel_ode
00007 {
00008
00009 template<typename CudaT,typename ParamsT>
00010 class CudaPGSSolver : public ParallelPGSSolver<CudaT,ParamsT,ParallelTypes::CUDA>
00011 {
00012 public:
00013 typedef typename vec4<CudaT>::Type Vec4T;
00014 typedef const CudaT* CudaTPtr;
00015 typedef CudaT* CudaTMutablePtr;
00016 typedef MemManager<CudaT,ParallelTypes::CUDA> PMemManager;
00017 typedef typename PMemManager::mem_flags MemFlags;
00018
00027 CudaPGSSolver( int parallelFlags = DEFAULT_FLAGS,
00028 BatchType batchType = BatchTypes::DEFAULT_BATCH_TYPE,
00029 ReduceType reduceType = ReduceTypes::DEFAULT_REDUCE_TYPE,
00030 uint numBatches = ParallelOptions::MAXBATCHES )
00031
00032 : ParallelPGSSolver<CudaT,ParamsT,ParallelTypes::CUDA>( parallelFlags, batchType, reduceType, numBatches ) {
00033
00034 }
00035
00036 virtual ~CudaPGSSolver( ) {
00037 }
00038
00042 virtual void initialize( );
00043
00044 protected:
00045
00046 virtual void preProcessDevice( const CudaT sorParam, const CudaT stepsize );
00047
00048 virtual void solveAndReduce( const int offset, const int batchSize );
00049
00050 virtual void loadConstraints( );
00051 };
00052
00053 }
00054 #endif