CombinatorialTrainer.cpp
Go to the documentation of this file.
1 
18 #include "CombinatorialTrainer.hpp"
19 
20 #include <boost/filesystem/path.hpp>
21 #include <boost/lexical_cast.hpp>
22 #include "boost/date_time/posix_time/posix_time.hpp"
23 #include "boost/date_time/posix_time/posix_time_io.hpp"
24 #include "boost/algorithm/string.hpp"
25 
26 #include "utility/LogHelper.hpp"
27 #include "utility/SVGHelper.hpp"
29 
30 #include <iostream>
31 #include <limits>
32 #include <ctime>
33 
35 
36 namespace ISM
37 {
38  using boost::filesystem::path;
39 
41  {
42  mParams = params;
43  mDBPath = params.general.dbfilename;
44 
45  if (!boost::filesystem::exists(mDBPath))
46  {
47  std::stringstream ss;
48  ss << mDBPath << " does not exist!";
49  throw std::runtime_error(ss.str());
50  }
51 
53  if(!boost::filesystem::is_directory(mOutputDataPath))
54  {
55  std::stringstream ss;
56  ss << mOutputDataPath << " does not exist or is not a directory!";
57  throw std::runtime_error(ss.str());
58  }
59 
60  const std::string runName = mDBPath.stem().string() + "_" + genTimeString();
61  mOutputDataPath /= runName;
62 
63 
65  tableHelper->dropModelTables();
66 
69  patternNames = tableHelper->getRecordedPatternNames();
71 
72  //Set up LogHelper
73  std::string logFileName = "Log_" + runName + ".txt";
74  path logFilePath = mOutputDataPath / "Logfile" / logFileName;
75  LogHelper::init(logFilePath, LOG_INFO);
76 
78 
81 
82  bool storeValidTestSets = !params.general.storeValidTestSetsTo.empty();
83  bool storeInvalidTestSets = !params.general.storeInvalidTestSetsTo.empty();
84 
85  if (storeValidTestSets || storeInvalidTestSets)
86  {
87  LogHelper::logMessage("Storing test sets: \n");
88  if (storeValidTestSets)
89  {
90  path dbFilePath(params.general.storeValidTestSetsTo);
91  storeTestSetsToDB(mTestSets.first, dbFilePath, "valid");
92  }
93  if (storeInvalidTestSets)
94  {
95  path dbFilePath(params.general.storeInvalidTestSetsTo);
96  storeTestSetsToDB(mTestSets.second, dbFilePath, "invalid");
97  }
99  }
100  }
101 
102  std::map<std::string, std::pair<double, TreePtr> >CombinatorialTrainer::learn()
103  {
106 
113 
114  LogHelper::logMessage("Training has started!");
115 
116  std::map<std::string, std::pair<double, TreePtr> > bestPerPattern;
117  std::map<std::string, IsmPtr> bestISMPerPattern;
118 
119  for (std::string& patternName : patternNames)
120  {
121  //Learn best Tree
122  std::pair<double, TreePtr> currentBest;
124  {
125  double bestEvaluationResult = std::numeric_limits<double>::max();
126  std::vector<TopologyPtr> startTopologies = mStartTopologiesPerPattern[patternName];
127  for (unsigned int i = 0; i < startTopologies.size(); ++i)
128  {
129  std::pair<double, TreePtr> result =
130  optimizationRunner->runOptimization(patternName, startTopologies[i]);
131 
132  if (result.first >= 0 && result.first < bestEvaluationResult)
133  {
134  currentBest = result;
135  bestEvaluationResult = result.first;
136  }
137  }
138  }
139  else
140  {
141  currentBest = optimizationRunner->runOptimization(patternName);
142  }
143 
144  if (currentBest.second)
145  {
146  IsmPtr bestISM = currentBest.second->getISM();
147 
148  bestPerPattern.insert(std::make_pair(patternName, currentBest));
149  bestISMPerPattern.insert(std::make_pair(patternName, bestISM));
150 
151  mDocumentationHelper->storeIsm("optimized", bestISM);
152  mDocumentationHelper->storeIsm(bestISM);
153  }
154  }
155 
156  mDocumentationHelper->writeResult();
157 
158  LogHelper::logMessage("Training is done!");
160 
161  return bestPerPattern;
162  }
163 
164  const std::map<std::string, ISM::TracksPtr> CombinatorialTrainer::getRecordedObjectsTracks()
165  {
166  std::vector<std::string> patternNames = tableHelper->getRecordedPatternNames();
167  std::vector<ISM::ObjectSetPtr> objectsInCurrentPattern;
168  ISM::TracksPtr tracksInCurrentPattern;
169  std::map<std::string, ISM::TracksPtr> objectTracksPerPattern;
170  for (std::string& patternNameIt : patternNames)
171  {
172  objectsInCurrentPattern = tableHelper->getRecordedPattern(patternNameIt)->objectSets;
173  tracksInCurrentPattern = ISM::TracksPtr(new ISM::Tracks(objectsInCurrentPattern));
174  objectTracksPerPattern.insert(std::make_pair(patternNameIt, tracksInCurrentPattern));
175  for (size_t it = 0; it < tracksInCurrentPattern->tracks.size(); ++it)
176  {
177  if (tracksInCurrentPattern->tracks[it]->objects.size() > 0)
178  {
179  bool foundObjectModel = false;
180  for(ObjectPtr& o : tracksInCurrentPattern->tracks[it]->objects)
181  {
182  if(o)
183  {
184  objectModelsPerPattern[patternNameIt].insert(
185  std::make_pair(o->type, o->ressourcePath));
186  foundObjectModel = true;
187  break;
188  }
189  }
190  if(!foundObjectModel)
191  {
192  std::cerr << "CombinatorialTrainer::getRecordedObjectsTracks: "
193  << " Not one Object in track." << std::endl;
194  }
195  }
196  if (it < tracksInCurrentPattern->tracks.size() - 1 &&
197  tracksInCurrentPattern->tracks[it]->objects.size() !=
198  tracksInCurrentPattern->tracks[it + 1]->objects.size())
199  {
200  std::cerr<<"Corrupt database\n";
201  exit(-6);
202  }
203  }
204  }
205 
206  return objectTracksPerPattern;
207  }
208 
209  const std::map<std::string, ISM::ObjectRelations> CombinatorialTrainer::calculateAllObjectRelations()
210  {
211  std::map<std::string, ISM::ObjectRelations> allObjectRelationsPerPattern;
212  for (const std::pair<std::string, ISM::TracksPtr>& patternIt : objectTracksPerPattern)
213  {
214  ISM::ObjectRelations inCurrentPattern;
215  std::vector<ISM::ObjectRelationPtr> selfiesInCurrentPattern;
216  unsigned relationId = 0;
217  for (ISM::TrackPtr& tracksIt : patternIt.second->tracks)
218  {
219  for (ISM::TrackPtr& otherTracksIt : patternIt.second->tracks)
220  {
221  if (otherTracksIt->observedId != tracksIt->observedId || otherTracksIt->type != tracksIt->type)
222  {
223  bool alreadyThere = false;
224  for (const std::pair<unsigned int, ISM::ObjectRelationPtr>& inCurrentPatternIt : inCurrentPattern)
225  {
226  if ((inCurrentPatternIt.second->getObjectIdA() == tracksIt->observedId &&
227  inCurrentPatternIt.second->getObjectIdB() == otherTracksIt->observedId &&
228  inCurrentPatternIt.second->getObjectTypeA() == tracksIt->type &&
229  inCurrentPatternIt.second->getObjectTypeB() == otherTracksIt->type)
230  || (inCurrentPatternIt.second->getObjectIdA() == otherTracksIt->observedId &&
231  inCurrentPatternIt.second->getObjectIdB() == tracksIt->observedId &&
232  inCurrentPatternIt.second->getObjectTypeA() == otherTracksIt->type &&
233  inCurrentPatternIt.second->getObjectTypeB() == tracksIt->type))
234  {
235  alreadyThere = true;
236  }
237  }
238  if (!alreadyThere)
239  {
240  //ObjectRelation constructs relations between objects at certain snapshots
241  ObjectRelationPtr objectRelation = ObjectRelationPtr(
242  new ObjectRelation(tracksIt, otherTracksIt, patternIt.first));
243  inCurrentPattern.insert(std::make_pair(relationId++, objectRelation));
244  }
245  }
246  else
247  {
248  ObjectRelationPtr objectRelation = ISM::ObjectRelationPtr(
249  new ISM::ObjectRelation(tracksIt, patternIt.first));
250  selfiesInCurrentPattern.push_back(objectRelation);
251  }
252  }
253  }
254 
255  allObjectRelationsPerPattern.insert(std::make_pair(patternIt.first, inCurrentPattern));
256  this->allSelfRelationsPerPattern.insert(std::make_pair(patternIt.first, selfiesInCurrentPattern));
257  }
258  //We have to assert that object appearance are clustered by objects. This is needed for the correct order of ISM construction
259  assert(checkCorrectOrder(allObjectRelationsPerPattern));
261  }
262 
263  bool CombinatorialTrainer::checkCorrectOrder(std::map<std::string, ISM::ObjectRelations> allObjectRelationsPerPattern)
264  {
265  for (const std::pair<std::string, ISM::ObjectRelations>& pattern : allObjectRelationsPerPattern)
266  {
267  for (const std::pair<unsigned int, ISM::ObjectRelationPtr>& relation : pattern.second)
268  {
269  bool appeared = false;
270  bool disappeared = false;
271  bool reappeared = false;
272  for (const std::pair<unsigned int, ISM::ObjectRelationPtr>& otherRelation : pattern.second)
273  {
274  if (otherRelation.second->getObjectTypeA() == relation.second->getObjectTypeA() &&
275  otherRelation.second->getObjectIdA() == relation.second->getObjectIdA())
276  {
277  appeared = true;
278  if (disappeared)
279  {
280  reappeared = true;
281  }
282  }
283  else if (appeared)
284  {
285  disappeared = true;
286  }
287  if (reappeared)
288  {
289  for (const std::pair<unsigned int, ISM::ObjectRelationPtr>& it : pattern.second)
290  {
291  std::cout<<it.second<<std::endl;
292  }
293  }
294  if (reappeared)
295  {
296  LogHelper::logMessage("Wrong order of relations. We have to ASSERT the correct order\n",
298  return false;
299  }
300  }
301  }
302  }
303  return true;
304  }
305 
307  const std::string& patternName) const
308  {
309  const ISM::TracksPtr allObjects = objectTracksPerPattern.at(patternName);
310  for (ISM::TrackPtr& object : allObjects->tracks)
311  {
312  bool hasAppearance = false;
313  for (const std::pair<unsigned int, ISM::ObjectRelationPtr>& relation : topology)
314  {
315  if (relation.second->containsObject(object->type, object->observedId))
316  {
317  hasAppearance = true;
318  }
319  }
320  if (hasAppearance == false)
321  {
322  return false;
323  }
324  }
325  return true;
326  }
327 
328  std::map<std::string, IsmPtr> CombinatorialTrainer::learnFullyMeshedTopologyPerPattern(bool naive)
329  {
331  std::map<std::string, IsmPtr> fullyMeshedTopologyPerPattern;
332  for (std::string& patternNameIt : patternNames)
333  {
334  LogHelper::logMessage("Learning fully meshed topology for pattern " + patternNameIt);
335  TreePtr fullyMeshedTopology = TreePtr(new Tree(patternNameIt,
336  allObjectRelationsPerPattern.at(patternNameIt), naive));
337  fullyMeshedTopologyPerPattern[patternNameIt] = fullyMeshedTopology->getISM();;
338  }
339 
340  LogHelper::logMessage("Learning is done");
342 
343  return fullyMeshedTopologyPerPattern;
344  }
345 
347  const path dbFilePath,
348  const std::string & type)
349  {
350  if(!boost::filesystem::exists(dbFilePath.parent_path()))
351  boost::filesystem::create_directories(dbFilePath.parent_path());
352  try
353  {
354  TableHelperPtr localTableHelper(new TableHelper(dbFilePath.string()));
355  localTableHelper->dropTables();
356  localTableHelper->createTablesIfNecessary();
357  for (PatternNameToObjectSet::iterator testSetIt = testSet.begin();
358  testSetIt != testSet.end();
359  ++testSetIt)
360  {
361  LogHelper::logMessage("Storing the " + type + " object sets for pattern "
362  + testSetIt->first + " to " + dbFilePath.string());
364  localTableHelper->insertRecordedPattern(testSetIt->first);
365  std::vector<ISM::ObjectSetPtr> objectSets = testSetIt->second;
366  for (unsigned int i = 0; i < objectSets.size(); ++i) {
367  LogHelper::displayProgress(((double) i + 1) / objectSets.size());
368  localTableHelper->insertRecordedObjectSet(objectSets[i], testSetIt->first);
369  }
371  std::cout << std::endl;
372 
373  }
374  } catch (soci::soci_error& e)
375  {
376  LogHelper::logMessage("Probablay the filepath " + dbFilePath.string()
377  + " used to store evaluation results in storeISMToDB in CombinatorialTrainer and"
378  + " setTestSets in Tester does not exist on your system", LOG_ERROR);
379  std::cerr << "soci error\n" << e.what() << std::endl;
380  }
382  }
383 
385  {
386  path dbfilename = fileName;
387  TableHelperPtr localTableHelper(new TableHelper(dbfilename.string()));
388  std::vector<std::string> patternNames = tableHelper->getRecordedPatternNames();
389  PatternNameToObjectSet testSet;
390 
391  for (unsigned int i = 0; i < patternNames.size(); ++i)
392  {
393  testSet[patternNames[i]] = localTableHelper->getRecordedPattern(patternNames[i])->objectSets;
394  LogHelper::logMessage("Loaded " + std::to_string(testSet[patternNames[i]].size()) +
395  " test sets for pattern " + patternNames[i] + " from DB " + fileName);
396  }
397 
398  return testSet;
399  }
400 
401  std::pair<PatternNameToObjectSet, PatternNameToObjectSet> CombinatorialTrainer::createTestSets(
402  double binSize, double maxAngleDeviation, double confidenceThreshold, unsigned int testSetCount)
403  {
404  PatternNameToObjectSet validTestSetsPerPattern;
405  PatternNameToObjectSet invalidTestSetsPerPattern;
406 
407  RecognizerPtr recognizer = ISM::RecognizerPtr(new ISM::Recognizer("", binSize, maxAngleDeviation, false));
408 
409  ObjectSetValidatorPtr objectSetValidator =
410  ObjectSetValidatorPtr(new ObjectSetValidator(recognizer, confidenceThreshold));
411 
412  TestSetGeneratorPtr testSetGenerator =
414 
415  for (std::map<std::string, ISM::TracksPtr>::iterator patternIt = this->objectTracksPerPattern.begin();
416  patternIt != this->objectTracksPerPattern.end();
417  ++patternIt)
418  {
419  std::pair<std::vector<ObjectSetPtr>, std::vector<ObjectSetPtr>> testSets =
420  testSetGenerator->generateTestSets(patternIt->first, patternIt->second,
421  mFullyMeshedTopologyPerPattern.at(patternIt->first), testSetCount);
422 
423  validTestSetsPerPattern[patternIt->first] = testSets.first;
424  invalidTestSetsPerPattern[patternIt->first] = testSets.second;
425  }
426 
427  return std::make_pair(validTestSetsPerPattern, invalidTestSetsPerPattern);
428  }
429 
430  void CombinatorialTrainer::initTestSets(double binSize, double maxAngleDeviation,
431  double confidenceThreshold, std::string loadValidTestSetsFrom,
432  std::string loadInvalidTestSetsFrom, unsigned int testSetCount)
433  {
434  PatternNameToObjectSet validTestSetsPerPattern;
435  PatternNameToObjectSet invalidTestSetsPerPattern;
436 
437  bool createInvalidTestSet = true;
438  bool createValidTestSet = true;
439 
440  if (loadValidTestSetsFrom.compare("") != 0) {
441  createValidTestSet = false;
442  validTestSetsPerPattern = loadTestSetsFromDB(loadValidTestSetsFrom);
443  }
444 
445  if (loadInvalidTestSetsFrom.compare("") != 0) {
446  createInvalidTestSet = false;
447  invalidTestSetsPerPattern = loadTestSetsFromDB(loadInvalidTestSetsFrom);
448  }
449 
450  if (createInvalidTestSet || createValidTestSet)
451  {
452  std::pair<PatternNameToObjectSet, PatternNameToObjectSet> testSets =
453  createTestSets(binSize, maxAngleDeviation, confidenceThreshold, testSetCount);
454 
455  if (createValidTestSet) validTestSetsPerPattern = testSets.first;
456  if (createInvalidTestSet) invalidTestSetsPerPattern = testSets.second;
457  }
458 
460 
461  mTestSets = std::make_pair(validTestSetsPerPattern, invalidTestSetsPerPattern);
462  }
463 
465  {
466  time_t rawtime;
467  struct tm * timeinfo;
468  char buffer[80];
469 
470  time(&rawtime);
471  timeinfo = localtime(&rawtime);
472 
473  strftime(buffer, 80,"%d-%b-%Y_%H:%M:%S", timeinfo);
474  return std::string(buffer);
475  }
476 
477  void CombinatorialTrainer::initStartTopologiesPerPattern(std::string loadStartTopologiesFrom)
478  {
479  if (loadStartTopologiesFrom.compare("") != 0)
480  {
481  path from = path(loadStartTopologiesFrom);
482  if (!boost::filesystem::exists(from))
483  {
484  std::stringstream ss;
485  ss << loadStartTopologiesFrom << " does not exist!";
486  throw std::runtime_error(ss.str());
487  }
488 
489  std::string line;
490  std::ifstream file;
491  file.open(from.string());
492 
493  while (std::getline(file, line))
494  {
495  std::vector<std::string> tokens;
496  boost::split(tokens, line, boost::is_any_of(":"));
497 
498  std::string patternName = tokens[0];
500  ObjectRelations objectRelations;
501 
502  std::vector<bool> bitvector(allObjectRelations.size(), 0) ;
503 
504  for (unsigned int i = 1; i < tokens.size(); ++i)
505  {
506  unsigned int index = boost::lexical_cast<unsigned int>(tokens[i]);
507  objectRelations[index] = allObjectRelations[index];
508  bitvector[index] = 1;
509  }
510 
511  std::string identifier = "";
512  for (unsigned int i = 0; i < bitvector.size(); ++i)
513  {
514  identifier += bitvector[i] ? "1" : "0";
515  }
516 
517  if (mStartTopologiesPerPattern.find(patternName) == mStartTopologiesPerPattern.end())
518  {
519  mStartTopologiesPerPattern[patternName] = std::vector<TopologyPtr>();
520  }
521  TopologyPtr topology = TopologyPtr(new Topology());
522  topology->objectRelations = objectRelations;
523  topology->identifier = identifier;
524  mStartTopologiesPerPattern[patternName].push_back(topology);
525  }
526  }
527  }
528 
529 }
boost::shared_ptr< Recognizer > RecognizerPtr
Definition: Recognizer.hpp:171
std::pair< PatternNameToObjectSet, PatternNameToObjectSet > mTestSets
CombinatorialTrainerParameters mParams
static void logLine(LogLevel logLevel=LOG_INFO)
Definition: LogHelper.cpp:101
std::map< std::string, ISM::ObjectRelations > allObjectRelationsPerPattern
struct TopologyGeneratorParameters topologyGenerator
boost::shared_ptr< ISM::ObjectRelation > ObjectRelationPtr
static void displayProgress(double progress)
Definition: LogHelper.cpp:147
struct OptimizationAlgorithmParameters optimizationAlgorithm
bool checkCorrectOrder(std::map< std::string, ISM::ObjectRelations > allObjectRelationsPerPattern)
std::string patternName
struct ISM::CombinatorialTrainerParameters::@0 general
bool containsAllObjects(const ISM::ObjectRelations &topology, const std::string &patternName) const
boost::shared_ptr< TableHelper > TableHelperPtr
std::pair< PatternNameToObjectSet, PatternNameToObjectSet > createTestSets(double binSize, double maxAngleDeviation, double confidenceThreshold, unsigned int testSetCount)
std::map< std::string, std::map< std::string, boost::filesystem::path > > objectModelsPerPattern
boost::shared_ptr< Topology > TopologyPtr
Definition: Topology.hpp:51
std::map< std::string, std::vector< ObjectSetPtr > > PatternNameToObjectSet
Definition: typedef.hpp:68
void initStartTopologiesPerPattern(std::string loadStartTopologiesFrom)
void storeTestSetsToDB(PatternNameToObjectSet testSet, const path dbFilePath, const std::string &type)
static void logMessage(const std::string &message, LogLevel logLevel=LOG_INFO, const char *logColor=LOG_COLOR_DEFAULT)
Definition: LogHelper.cpp:96
boost::shared_ptr< ObjectSetValidator > ObjectSetValidatorPtr
const std::map< std::string, ISM::TracksPtr > getRecordedObjectsTracks()
std::map< unsigned int, ISM::ObjectRelationPtr, std::less< unsigned > > ObjectRelations
std::map< std::string, std::vector< ISM::ObjectRelationPtr > > allSelfRelationsPerPattern
boost::shared_ptr< ImplicitShapeModel > IsmPtr
boost::shared_ptr< DocumentationHelper > DocumentationHelperPtr
boost::shared_ptr< OptimizationRunner > OptimizationRunnerPtr
CombinatorialTrainer(CombinatorialTrainerParameters params)
boost::shared_ptr< Tree > TreePtr
Definition: typedef.hpp:46
static OptimizationRunnerPtr createOptimizationRunner(EvaluatorParameters evaluatorParams, TopologyGeneratorParameters topologyGeneratorParams, OptimizationAlgorithmParameters optimizationAlgorithmParameters, CostFunctionParameters costFunctionParameters, TreeValidatorParameters treeValidatorParams, std::pair< PatternNameToObjectSet, PatternNameToObjectSet > testSets, std::map< std::string, ISM::ObjectRelations > allObjectRelationsPerPattern, DocumentationHelperPtr documentationHelper, bool storeFullyMeshedISM, bool storeStartTopologyISM)
std::map< std::string, std::pair< double, TreePtr > > learn()
void initTestSets(double binSize, double maxAngleDeviation, double confidenceThreshold, std::string loadValidTestSetsFrom, std::string loadInvalidTestSetsFrom, unsigned int testSetCount)
static void init(path logFilePath, LogLevel level)
Definition: LogHelper.cpp:106
boost::shared_ptr< Tracks > TracksPtr
Definition: Tracks.hpp:42
std::map< std::string, std::vector< TopologyPtr > > mStartTopologiesPerPattern
DocumentationHelperPtr mDocumentationHelper
PatternNameToObjectSet loadTestSetsFromDB(std::string fileName)
boost::shared_ptr< TestSetGenerator > TestSetGeneratorPtr
const std::map< std::string, ISM::ObjectRelations > calculateAllObjectRelations()
std::map< std::string, IsmPtr > mFullyMeshedTopologyPerPattern
static const char * LOG_COLOR_RED
Definition: LogHelper.hpp:45
std::vector< std::string > patternNames
boost::shared_ptr< Track > TrackPtr
Definition: Track.hpp:55
std::map< std::string, ISM::TracksPtr > objectTracksPerPattern
this namespace contains all generally usable classes.
std::map< std::string, IsmPtr > learnFullyMeshedTopologyPerPattern(bool naive=false)
boost::shared_ptr< Object > ObjectPtr
Definition: Object.hpp:82
static void close()
Definition: LogHelper.cpp:116


asr_lib_ism
Author(s): Hanselmann Fabian, Heller Florian, Heizmann Heinrich, Kübler Marcel, Mehlhaus Jonas, Meißner Pascal, Qattan Mohamad, Reckling Reno, Stroh Daniel
autogenerated on Wed Jan 8 2020 04:02:40