Trainer.cpp
Go to the documentation of this file.
1 
18 #include <boost/shared_ptr.hpp>
19 #include <boost/foreach.hpp>
20 #include <vector>
21 #include <iostream>
22 #include <set>
23 #include <math.h>
24 
25 #include "Trainer.hpp"
29 #include "common_type/Pose.hpp"
30 #include "common_type/Tracks.hpp"
34 #include "DataCollector.hpp"
35 
36 namespace ISM {
37  Trainer::Trainer(std::string dbfilename, bool dropOldModelTables) : mUseManualDefHeuristic(false),
38  mUsePredefinedRefs(false),
39  staticBreakRatio(0.0), togetherRatio(0.0), maxAngleDeviation(0.0)
40  {
41  this->tableHelper.reset(new TableHelper(dbfilename));
42  if(dropOldModelTables)
43  {
44  tableHelper->createTablesIfNecessary();
45  tableHelper->dropModelTables();
46  tableHelper->createTablesIfNecessary();
47  }
48  this->skips = 0;
49  this->useClustering = true;
50  }
51 
53  this->skips = skips;
54  }
55 
57  this->useClustering = useClustering;
58  }
59 
61  std::vector<std::string> patternNames = this->tableHelper->getRecordedPatternNames();
62  std::cerr<<"found "<<patternNames.size()<<" patterns"<<std::endl;
63  for (std::string& name : patternNames) {
64  this->trainPattern(name);
65  }
66  }
67 
68  void Trainer::trainPattern(const std::string& patternName) {
69  boost::shared_ptr<RecordedPattern> r = this->tableHelper->getRecordedPattern(patternName);
70  if (!r) {
71  std::cerr<<"no pattern records found for pattern "<<patternName<<std::endl;
72  } else {
73  std::cerr<<"training "<<patternName<<std::endl;
74  this->recordedPattern = r;
75  this->learn();
76  }
77  }
78 
79  void Trainer::learn() {
80  std::vector<ObjectSetPtr> sets = this->recordedPattern->objectSets;
81  for (size_t i = 0; i < sets.size(); i++) {
82  if (sets[i]->objects.size() == 0) {
83  sets.erase(sets.begin() + i);
84  i--;
85  }
86  }
87 
88  int clusterId = 0;
89  TracksPtr tracks(new Tracks(sets));
90 
91  while (this->useClustering) {
92  HeuristicPtr heuristic = this->findHeuristicMatch(tracks);
93  if (!heuristic || (int) tracks->tracks.size() <= 2
94  || heuristic->cluster->tracks.size() == tracks->tracks.size()) {
95  break;
96  }
97 
98  TracksPtr cluster = heuristic->cluster;
99  std::stringstream subPatternNameStream;
100  if(heuristic->clusterId < 0)
101  {
102  subPatternNameStream<<this->recordedPattern->name<<"_sub"<<clusterId;
103  } else
104  {
105  subPatternNameStream<<this->recordedPattern->name<<"_sub"<<heuristic->clusterId;
106  clusterId = std::max(clusterId, heuristic->clusterId);
107  }
108  clusterId++;
109  std::string subPatternName = subPatternNameStream.str();
110 
111  TrackPtr refTrack = this->doTraining(cluster->toObjectSetVector(), subPatternName);
113  DataCollector::getData()->tracksWithRef.push_back(TracksWithRef(cluster, refTrack));
114  }
115  tracks->replace(cluster->tracks, refTrack);
116  }
117 
118  //train remaining sets
119  TrackPtr refTrack = this->doTraining(tracks->toObjectSetVector(), this->recordedPattern->name);
121  DataCollector::getData()->tracksWithRef.push_back(TracksWithRef(tracks, refTrack));
122  }
123  }
124 
125  void Trainer::setClusterForManualDefHeuristic(std::vector<std::pair<std::vector<
126  ManuallyDefPseudoHeuristic::ClusterObject>, uint16_t>> cluster)
127  {
129  mUseManualDefHeuristic = true;
130  }
131  void Trainer::setPredefinedRefs(std::map<std::string, std::string>& refs)
132  {
134  mUsePredefinedRefs = true;
135  }
137  HeuristicPtr bestHeuristic;
138 
139  std::vector<HeuristicPtr> heuristics;
141  {
143  }
145  //heuristics.push_back(HeuristicPtr(new DirectionOrientationRelationHeuristic(staticBreakRatio, togetherRatio, maxAngleDeviation)));
146  for (HeuristicPtr& heuristic : heuristics) {
147 
148  heuristic->applyHeuristic(tracks);
149 
150  if (!heuristic->cluster) {
151  continue;
152  }
153  std::cerr<<"heuristic results of "<<heuristic->name<<std::endl;
154  std::cerr<<heuristic->cluster->tracks.size()<<" tracks, confidence: "<<heuristic->confidence<<std::endl;
155  if (heuristic->confidence > 0.7 && (!bestHeuristic || heuristic->confidence > bestHeuristic->confidence)) {
156  bestHeuristic = heuristic;
157  }
158  }
159 
160  return bestHeuristic ? bestHeuristic : HeuristicPtr();
161  }
162 
163  TrackPtr Trainer::doTraining(std::vector<ObjectSetPtr> sets, std::string patternName) {
164  int toSkip = 0;
165  int setCount = 0;
166  double objectsWeightSum = 0;
167  std::string refType = "";
168  std::string refId = "";
169 
170  TrackPtr refTrack(new Track(patternName));
171 
172  TracksPtr tracks(new Tracks(sets));
173 
174  bool refFound = false;
176  {
178  {
179  for (std::vector<TrackPtr>::iterator track = tracks->tracks.begin(); track != tracks->tracks.end();
180  ++track)
181  {
182  for (ObjectPtr& obj : (*track)->objects) {
183  if (obj) {
184  if(obj->type == mPatternToTypesOfPredefinedRefs[patternName])
185  {
186  refType = obj->type;
187  refId = obj->observedId;
188  refFound = true;
189  } else
190  {
191  break;
192  }
193  }
194  }
195  }
196  }
197  }
198  if(!refFound)
199  {
200  double bestViewRatio = 0;
201  double bestMovement = 0;
202  for (TrackPtr& track : tracks->tracks) {
203  int views = 0;
204  std::string refT;
205  std::string refI;
206  ObjectPtr lastObj;
207  double movement = 0;
208  for (ObjectPtr& obj : track->objects) {
209  if (obj) {
210  refT = obj->type;
211  refI = obj->observedId;
212  views++;
213  if (lastObj) {
214  movement += GeometryHelper::getDistanceBetweenPoints(obj->pose->point, lastObj->pose->point);
215  }
216  lastObj = obj;
217  }
218  }
219 
220  double ratio = (double)views / (double)track->objects.size();
221  if (ratio > bestViewRatio || (ratio == bestViewRatio && movement < bestMovement)) {
222  refType = refT;
223  refId = refI;
224  bestViewRatio = ratio;
225  bestMovement = movement;
226  }
227  }
228  }
229 
230  std::cerr<<"choose ref "<<refType<<" : "<<refId<<std::endl;
231  std::cerr<<"training "<<patternName<<" "<<std::endl;
232 
233  for (size_t setIdx = 0; setIdx < sets.size(); setIdx++)
234  {
235  double setWeightSum = 0;
236  if (toSkip == 0) {
237  toSkip = this->skips;
238  std::cerr<<".";
239  std::cerr.flush();
240  } else {
241  std::cerr<<"_";
242  std::cerr.flush();
243  toSkip--;
244  continue;
245  }
246 
247  PosePtr referencePose;
248  setCount++;
249  std::vector<ObjectPtr> objects = sets[setIdx]->objects;
250 
251  for (ObjectPtr& o : objects) {
252  if (o->type == refType && o->observedId == refId) {
253  referencePose.reset(new Pose(*(o->pose)));
254  break;
255  }
256  }
257  if (!referencePose && objects.size() > 0) {
258  referencePose.reset(new Pose(*(objects[0]->pose)));
259  } else if (!referencePose) {
260  refTrack->objects.push_back(ObjectPtr());
261  continue;
262  }
263 
264  for (ObjectPtr& o : objects) {
265  VoteSpecifierPtr vote = GeometryHelper::createVoteSpecifier(o->pose, referencePose);
266  vote->patternName = patternName;
267  vote->observedId = o->observedId;
268  vote->objectType = o->type;
269  vote->trackIndex = setIdx;
270  this->tableHelper->insertModelVoteSpecifier(vote);
271  setWeightSum += o->weight;
272  }
273 
274  ObjectPtr refObj = ObjectPtr(
275  new Object(
276  patternName,
277  referencePose
278  )
279  );
280 
281  refObj->weight = setWeightSum;
282 
283  refTrack->objects.push_back(refObj);
284 
285  objectsWeightSum += setWeightSum;
286 
287  }
288 
289  this->tableHelper->upsertModelPattern(
290  patternName,
291  floor(((float)objectsWeightSum / (float)setCount) + 0.5)
292  );
293 
294  refTrack->calculateWeight();
295 
296  std::cerr<<"done ("<<refTrack->weight<<")"<<std::endl;
297 
298  return refTrack;
299  }
300 }
static bool shouldCollect()
static VoteSpecifierPtr createVoteSpecifier(const PosePtr &sourcePose, const PosePtr &refPose)
double maxAngleDeviation
Definition: Trainer.hpp:51
std::map< std::string, std::string > mPatternToTypesOfPredefinedRefs
Definition: Trainer.hpp:46
void setPredefinedRefs(std::map< std::string, std::string > &refs)
Definition: Trainer.cpp:131
Trainer(std::string dbfilename="record.sqlite", bool dropOldModelTables=false)
Definition: Trainer.cpp:37
void setUseClustering(const bool useClustering)
Definition: Trainer.cpp:56
static CollectedDataPtr getData()
bool mUsePredefinedRefs
Definition: Trainer.hpp:43
boost::shared_ptr< Heuristic > HeuristicPtr
Definition: Heuristic.hpp:44
std::string patternName
void trainPattern()
Definition: Trainer.cpp:60
TableHelperPtr tableHelper
Definition: Trainer.hpp:37
static double getDistanceBetweenPoints(const PointPtr &p1, const PointPtr &p2)
HeuristicPtr findHeuristicMatch(const TracksPtr &tracks)
Definition: Trainer.cpp:136
double togetherRatio
Definition: Trainer.hpp:51
bool useClustering
Definition: Trainer.hpp:41
std::vector< std::pair< std::vector< ManuallyDefPseudoHeuristic::ClusterObject >, uint16_t > > mClusterForManualDefHeuristic
Definition: Trainer.hpp:45
boost::shared_ptr< VoteSpecifier > VoteSpecifierPtr
bool mUseManualDefHeuristic
Definition: Trainer.hpp:42
double staticBreakRatio
Definition: Trainer.hpp:51
boost::shared_ptr< Tracks > TracksPtr
Definition: Tracks.hpp:42
boost::shared_ptr< Pose > PosePtr
Definition: Pose.hpp:79
void setClusterForManualDefHeuristic(std::vector< std::pair< std::vector< ManuallyDefPseudoHeuristic::ClusterObject >, uint16_t >>)
Definition: Trainer.cpp:125
RecordedPatternPtr recordedPattern
Definition: Trainer.hpp:38
TrackPtr doTraining(const std::vector< ObjectSetPtr > sets, std::string patternName)
Definition: Trainer.cpp:163
void setSkipsPerCycle(const int skips)
Definition: Trainer.cpp:52
boost::shared_ptr< Track > TrackPtr
Definition: Track.hpp:55
this namespace contains all generally usable classes.
void learn()
Definition: Trainer.cpp:79
boost::shared_ptr< Object > ObjectPtr
Definition: Object.hpp:82


asr_lib_ism
Author(s): Hanselmann Fabian, Heller Florian, Heizmann Heinrich, Kübler Marcel, Mehlhaus Jonas, Meißner Pascal, Qattan Mohamad, Reckling Reno, Stroh Daniel
autogenerated on Wed Jan 8 2020 04:02:41