$search
00001 #ifndef _GP_HET_REGRESSION_HPP_ 00002 #define _GP_HET_REGRESSION_HPP_ 00003 00004 #include <gpReg/types.hpp> 00005 00006 #ifdef SPARSE_REGRESSION 00007 #include "gpSparseRegression.hpp" 00008 #else 00009 #include "gpRegression.hpp" 00010 #endif 00011 00012 #include <assert.h> 00013 #include "covarianceFunction.hpp" 00014 #ifdef DO_PROFILING 00015 #include "profiler.hpp" 00016 #endif 00017 00018 00019 00020 /* 00021 * m_primGP: GP to predict mean of data process 00022 * m_secGP: GP to predict input-dependent noise variances (in log space) 00023 * m_combGP: GP for data process (using log noise variances predicted by m_secGP) 00024 * 00025 * defines: 00026 * 00027 * WRITE_DEBUG write debug data files 00028 * SPARSE_REGRESSION use efficient, sparse matrix libraries (currently disabled!) 00029 * DO_PROFILING enable the benchmarking code (global functions for profiling the algorithms) 00030 */ 00031 00032 // ----------------------------------------------------------------------- 00033 template <class TInput> 00034 class GPHetReg 00035 { 00036 00037 public: 00038 00039 00040 00041 // ----------------------------------------------------------------------- 00042 GPHetReg( 00043 CovFunc<TInput> &covPrimGP, 00044 CovFunc<TInput> &covSecGP, 00045 CovFunc<TInput> &covCombGP, 00046 double noisePrimGP, 00047 double noiseSecGP, 00048 double noiseCombGP ) : 00049 m_numDataPoints(0), m_dataPointsPrimGP(0), m_dataPointsSecGP(0), m_useEmpiricalVariances(false), 00050 m_primTargets(0), m_secTargets(0), m_covPrimGP(covPrimGP), m_covSecGP(covSecGP), m_covCombGP(covCombGP), 00051 m_noisePrimGP(noisePrimGP),m_noiseSecGP(noiseSecGP), m_noiseCombGP(noiseCombGP), 00052 m_primGP(m_covPrimGP,m_noisePrimGP), m_secGP(m_covSecGP,m_noiseSecGP), 00053 m_combGP(m_covCombGP,m_noiseCombGP, &m_secGP) 00054 {} 00055 00056 00057 00058 // ----------------------------------------------------------------------- 00059 ~GPHetReg() 00060 {} 00061 00062 00063 00064 // ----------------------------------------------------------------------- 00065 void setDataPoints( TVector<TInput> &dataPoints, TVector<double> &dataTargets ) 00066 { 00067 m_dataPointsPrimGP = dataPoints; 00068 m_numDataPoints = dataPoints.size(); 00069 m_primTargets = dataTargets; 00070 } 00071 00072 00073 00074 // ----------------------------------------------------------------------- 00075 void setEmpiricalVariances( TVector<TInput> &varPoints, TVector<double> &variances ) 00076 { 00077 m_dataPointsSecGP = varPoints; 00078 m_secTargets = variances; 00079 00080 #ifdef WRITE_DEBUG 00081 FILE *trainDatFile = fopen( "trainSecGP.dat", "w" ); 00082 #endif 00083 00084 for (unsigned int i=0; i<m_secTargets.size(); i++) { 00085 m_secTargets[i] = log( m_secTargets[i] ); 00086 if (m_secTargets[i] < -5.0) { 00087 m_secTargets[i] = -5.0; 00088 } 00089 00090 #ifdef WRITE_DEBUG 00091 fprintf( trainDatFile, "%i %f %f\n", i, m_dataPointsSecGP[i][0], m_secTargets[i] ); 00092 #endif 00093 00094 } 00095 #ifdef WRITE_DEBUG 00096 fclose( trainDatFile ); 00097 #endif 00098 00099 m_useEmpiricalVariances = true; 00100 } 00101 00102 00103 00104 // ----------------------------------------------------------------------- 00105 void buildGP() 00106 { 00107 //------------------------------- 00108 // build primGP 00109 //------------------------------- 00110 00111 m_primGP.setDataPoints( m_dataPointsPrimGP, m_primTargets ); 00112 m_primGP.buildGP(); 00113 00114 //------------------------------- 00115 // build secGP 00116 //------------------------------- 00117 00118 if (!m_useEmpiricalVariances) { 00119 double mean; 00120 double target; 00121 00122 #ifdef WRITE_DEBUG 00123 FILE *trainDatFile = fopen( "trainSecGP.dat", "w" ); 00124 #endif 00125 00126 m_dataPointsSecGP.resize( m_dataPointsPrimGP.size() ); 00127 m_secTargets.resize( m_dataPointsPrimGP.size() ); 00128 for (unsigned int i=0; i<m_primTargets.size(); i++) { 00129 m_dataPointsSecGP[i] = m_dataPointsPrimGP[i]; 00130 00131 m_primGP.evalGP( m_dataPointsSecGP[i], mean ); 00132 target = log( pow( mean-m_primTargets[i], 2.0 ) / 0.33 ); 00133 //target = log( pow( mean-m_primTargets[i], 2.0 ) / 1.0 ); 00134 00135 #ifdef WRITE_DEBUG 00136 fprintf( trainDatFile, "%i %f %f\n", i, m_dataPointsSecGP[i][0], target ); 00137 #endif 00138 m_secTargets[i] = target; 00139 } 00140 #ifdef WRITE_DEBUG 00141 fclose( trainDatFile ); 00142 #endif 00143 } 00144 m_secGP.setDataPoints( m_dataPointsSecGP, m_secTargets ); 00145 m_secGP.buildGP(); 00146 00147 //------------------------------- 00148 // build combGP 00149 //------------------------------- 00150 00151 m_combGP.setDataPoints( m_dataPointsPrimGP, m_primTargets ); 00152 m_combGP.buildGP(); 00153 } 00154 00155 00156 00157 // ----------------------------------------------------------------------- 00158 void evalGP( const TInput &x, double &mean, double &var ) 00159 { 00160 m_combGP.evalGP( x, mean, var ); 00161 } 00162 00163 00164 00165 // ----------------------------------------------------------------------- 00166 void evalGP( const TInput &x, double &mean ) 00167 { 00168 m_combGP.evalGP( x, mean ); 00169 } 00170 00171 00172 00173 // ----------------------------------------------------------------------- 00174 /* 00175 double getObservationLikelihood( TVector<Vector> &dataPoints, TVector<double> &observations ) 00176 { 00177 return m_combGP.getObservationLikelihood( dataPoints, observations ); 00178 } 00179 */ 00180 00181 00182 00183 // ----------------------------------------------------------------------- 00184 public: 00185 int m_numDataPoints; 00186 TVector<TInput> m_dataPointsPrimGP; 00187 TVector<TInput> m_dataPointsSecGP; 00188 bool m_useEmpiricalVariances; 00189 TVector<double> m_primTargets; 00190 TVector<double> m_secTargets; 00191 00192 CovFunc<TInput> &m_covPrimGP; 00193 CovFunc<TInput> &m_covSecGP; 00194 CovFunc<TInput> &m_covCombGP; 00195 00196 double m_noisePrimGP; 00197 double m_noiseSecGP; 00198 double m_noiseCombGP; 00199 00200 #ifdef SPARSE_REGRESSION 00201 GPSparseRegression m_primGP; 00202 GPSparseRegression m_secGP; 00203 GPSparseRegression m_combGP; 00204 #else 00205 GPReg<TInput> m_primGP; 00206 GPReg<TInput> m_secGP; 00207 GPReg<TInput> m_combGP; 00208 #endif 00209 }; 00210 00211 00212 00213 #endif //_GP_HET_REGRESSION_HPP_