sklearn_classifier_trainer.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 #
4 import os
5 import sys
6 import gzip
7 import cPickle as pickle
8 import argparse
9 
10 import numpy as np
11 from sklearn.preprocessing import normalize
12 from sklearn.linear_model import LogisticRegression
13 from sklearn.cross_validation import train_test_split
14 from sklearn.metrics import classification_report, accuracy_score
15 
16 
17 def main():
18  parser = argparse.ArgumentParser()
19  parser.add_argument('dataset',
20  help='dataset must have data, target, target_names attributes')
21  parser.add_argument('-c', '--classifier', default='logistic_regression',
22  help='now supports logistic_regression only')
23  parser.add_argument('-O', '--output', default='clf.pkl.gz',
24  help='saving clf filename')
25  args = parser.parse_args(sys.argv[1:])
26 
27  print('loading dataset')
28  with gzip.open(args.dataset, 'rb') as f:
29  dataset = pickle.load(f)
30 
31  X = dataset.data
32  y = dataset.target
33  target_names = dataset.target_names
34 
35  # create train and test data
36  X_train, X_test, y_train, y_test = train_test_split(X, y,
37  random_state=np.random.randint(1234))
38 
39  # train and test
40  if args.classifier == 'logistic_regression':
41  clf = LogisticRegression()
42  else:
43  raise ValueError('unsupported classifier')
44  print(('fitting {0}'.format(args.classifier)))
45  clf.fit(X_train, y_train)
46  clf.target_names_ = target_names
47  with gzip.open(args.output, 'wb') as f:
48  pickle.dump(clf, f)
49  y_pred = clf.predict(X_test)
50  print(('score of classifier: {}'.format(accuracy_score(y_test, y_pred))))
51  print((classification_report(y_test, y_pred, target_names=target_names)))
52 
53 
54 if __name__ == '__main__':
55  main()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27