18 #include <boost/shared_ptr.hpp> 19 #include <boost/foreach.hpp> 37 Trainer::Trainer(std::string dbfilename,
bool dropOldModelTables) : mUseManualDefHeuristic(false),
38 mUsePredefinedRefs(false),
39 staticBreakRatio(0.0), togetherRatio(0.0), maxAngleDeviation(0.0)
42 if(dropOldModelTables)
61 std::vector<std::string> patternNames = this->
tableHelper->getRecordedPatternNames();
62 std::cerr<<
"found "<<patternNames.size()<<
" patterns"<<std::endl;
63 for (std::string& name : patternNames) {
69 boost::shared_ptr<RecordedPattern> r = this->
tableHelper->getRecordedPattern(patternName);
71 std::cerr<<
"no pattern records found for pattern "<<patternName<<std::endl;
73 std::cerr<<
"training "<<patternName<<std::endl;
81 for (
size_t i = 0; i < sets.size(); i++) {
82 if (sets[i]->objects.size() == 0) {
83 sets.erase(sets.begin() + i);
93 if (!heuristic || (
int) tracks->tracks.size() <= 2
94 || heuristic->cluster->tracks.size() == tracks->tracks.size()) {
99 std::stringstream subPatternNameStream;
100 if(heuristic->clusterId < 0)
105 subPatternNameStream<<this->
recordedPattern->name<<
"_sub"<<heuristic->clusterId;
106 clusterId = std::max(clusterId, heuristic->clusterId);
109 std::string subPatternName = subPatternNameStream.str();
115 tracks->replace(cluster->tracks, refTrack);
139 std::vector<HeuristicPtr> heuristics;
148 heuristic->applyHeuristic(tracks);
150 if (!heuristic->cluster) {
153 std::cerr<<
"heuristic results of "<<heuristic->name<<std::endl;
154 std::cerr<<heuristic->cluster->tracks.size()<<
" tracks, confidence: "<<heuristic->confidence<<std::endl;
155 if (heuristic->confidence > 0.7 && (!bestHeuristic || heuristic->confidence > bestHeuristic->confidence)) {
156 bestHeuristic = heuristic;
166 double objectsWeightSum = 0;
167 std::string refType =
"";
168 std::string refId =
"";
174 bool refFound =
false;
179 for (std::vector<TrackPtr>::iterator track = tracks->tracks.begin(); track != tracks->tracks.end();
182 for (
ObjectPtr& obj : (*track)->objects) {
187 refId = obj->observedId;
200 double bestViewRatio = 0;
201 double bestMovement = 0;
202 for (
TrackPtr& track : tracks->tracks) {
211 refI = obj->observedId;
220 double ratio = (double)views / (
double)track->objects.size();
221 if (ratio > bestViewRatio || (ratio == bestViewRatio && movement < bestMovement)) {
224 bestViewRatio = ratio;
225 bestMovement = movement;
230 std::cerr<<
"choose ref "<<refType<<
" : "<<refId<<std::endl;
231 std::cerr<<
"training "<<patternName<<
" "<<std::endl;
233 for (
size_t setIdx = 0; setIdx < sets.size(); setIdx++)
235 double setWeightSum = 0;
237 toSkip = this->
skips;
249 std::vector<ObjectPtr> objects = sets[setIdx]->objects;
252 if (o->type == refType && o->observedId == refId) {
253 referencePose.reset(
new Pose(*(o->pose)));
257 if (!referencePose && objects.size() > 0) {
258 referencePose.reset(
new Pose(*(objects[0]->pose)));
259 }
else if (!referencePose) {
260 refTrack->objects.push_back(
ObjectPtr());
267 vote->observedId = o->observedId;
268 vote->objectType = o->type;
269 vote->trackIndex = setIdx;
271 setWeightSum += o->weight;
281 refObj->weight = setWeightSum;
283 refTrack->objects.push_back(refObj);
285 objectsWeightSum += setWeightSum;
291 floor(((
float)objectsWeightSum / (
float)setCount) + 0.5)
294 refTrack->calculateWeight();
296 std::cerr<<
"done ("<<refTrack->weight<<
")"<<std::endl;
static bool shouldCollect()
static VoteSpecifierPtr createVoteSpecifier(const PosePtr &sourcePose, const PosePtr &refPose)
std::map< std::string, std::string > mPatternToTypesOfPredefinedRefs
void setPredefinedRefs(std::map< std::string, std::string > &refs)
Trainer(std::string dbfilename="record.sqlite", bool dropOldModelTables=false)
void setUseClustering(const bool useClustering)
static CollectedDataPtr getData()
boost::shared_ptr< Heuristic > HeuristicPtr
TableHelperPtr tableHelper
static double getDistanceBetweenPoints(const PointPtr &p1, const PointPtr &p2)
HeuristicPtr findHeuristicMatch(const TracksPtr &tracks)
std::vector< std::pair< std::vector< ManuallyDefPseudoHeuristic::ClusterObject >, uint16_t > > mClusterForManualDefHeuristic
boost::shared_ptr< VoteSpecifier > VoteSpecifierPtr
bool mUseManualDefHeuristic
boost::shared_ptr< Tracks > TracksPtr
boost::shared_ptr< Pose > PosePtr
void setClusterForManualDefHeuristic(std::vector< std::pair< std::vector< ManuallyDefPseudoHeuristic::ClusterObject >, uint16_t >>)
RecordedPatternPtr recordedPattern
TrackPtr doTraining(const std::vector< ObjectSetPtr > sets, std::string patternName)
void setSkipsPerCycle(const int skips)
boost::shared_ptr< Track > TrackPtr
this namespace contains all generally usable classes.
boost::shared_ptr< Object > ObjectPtr