4 from __future__
import division
7 if sys.version_info.major <= 2:
8 import cPickle
as pickle
10 import _pickle
as pickle
13 from sklearn.preprocessing
import normalize
16 from jsk_topic_tools
import ConnectionBasedTransport
17 from jsk_recognition_msgs.msg
import VectorArray, ClassificationResult
22 super(ScikitLearnClassifier, self).
__init__()
24 self.
_pub = self.advertise(
'~output', ClassificationResult,
28 clf_path = rospy.get_param(
'~clf_path')
29 with gzip.open(clf_path)
as f:
30 if sys.version_info.major <= 2:
31 self.
clf = pickle.load(f)
33 self.
clf = pickle.load(f, encoding=
'latin1')
42 X = np.array(msg.data).reshape((-1, msg.vector_dim))
43 normalize(X, copy=
False)
44 y_proba = self.
clf.predict_proba(X)
45 y_pred = np.argmax(y_proba, axis=-1)
46 target_names = np.array(self.
clf.target_names_)
47 label_proba = [p[i]
for p, i
in zip(y_proba, y_pred)]
48 out_msg = ClassificationResult(
51 label_names=target_names[y_pred],
52 label_proba=label_proba,
53 probabilities=y_proba.reshape(-1),
54 classifier=self.
clf.__str__(),
55 target_names=target_names,
57 self.
_pub.publish(out_msg)
60 if __name__ ==
'__main__':
61 rospy.init_node(
'sklearn_classifier')