4 from __future__
import division
6 import cPickle
as pickle
9 from sklearn.preprocessing
import normalize
12 from jsk_topic_tools
import ConnectionBasedTransport
13 from jsk_recognition_msgs.msg
import VectorArray, ClassificationResult
18 super(ScikitLearnClassifier, self).
__init__()
20 self.
_pub = self.advertise(
'~output', ClassificationResult,
24 clf_path = rospy.get_param(
'~clf_path')
25 with gzip.open(clf_path)
as f:
26 self.
clf = pickle.load(f)
32 self.sub_hist.unregister()
35 X = np.array(msg.data).reshape((-1, msg.vector_dim))
36 normalize(X, copy=
False)
37 y_proba = self.clf.predict_proba(X)
38 y_pred = np.argmax(y_proba, axis=-1)
39 target_names = np.array(self.clf.target_names_)
40 label_proba = [p[i]
for p, i
in zip(y_proba, y_pred)]
41 out_msg = ClassificationResult(
44 label_names=target_names[y_pred],
45 label_proba=label_proba,
46 probabilities=y_proba.reshape(-1),
47 classifier=self.clf.__str__(),
48 target_names=target_names,
50 self._pub.publish(out_msg)
53 if __name__ ==
'__main__':
54 rospy.init_node(
'sklearn_classifier')
def _init_classifier(self)