sklearn_classifier_trainer.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 # -*- coding: utf-8 -*-
00003 #
00004 import os
00005 import sys
00006 import gzip
00007 import cPickle as pickle
00008 import argparse
00009 
00010 import numpy as np
00011 from sklearn.preprocessing import normalize
00012 from sklearn.linear_model import LogisticRegression
00013 from sklearn.cross_validation import train_test_split
00014 from sklearn.metrics import classification_report, accuracy_score
00015 
00016 
00017 def main():
00018     parser = argparse.ArgumentParser()
00019     parser.add_argument('dataset',
00020         help='dataset must have data, target, target_names attributes')
00021     parser.add_argument('-c', '--classifier', default='logistic_regression',
00022                         help='now supports logistic_regression only')
00023     parser.add_argument('-O', '--output', default='clf.pkl.gz',
00024                         help='saving clf filename')
00025     args = parser.parse_args(sys.argv[1:])
00026 
00027     print('loading dataset')
00028     with gzip.open(args.dataset, 'rb') as f:
00029         dataset = pickle.load(f)
00030 
00031     X = dataset.data
00032     y = dataset.target
00033     target_names = dataset.target_names
00034 
00035     # create train and test data
00036     X_train, X_test, y_train, y_test = train_test_split(X, y,
00037                                         random_state=np.random.randint(1234))
00038 
00039     # train and test
00040     if args.classifier == 'logistic_regression':
00041         clf = LogisticRegression()
00042     else:
00043         raise ValueError('unsupported classifier')
00044     print('fitting {0}'.format(args.classifier))
00045     clf.fit(X_train, y_train)
00046     clf.target_names_ = target_names
00047     with gzip.open(args.output, 'wb') as f:
00048         pickle.dump(clf, f)
00049     y_pred = clf.predict(X_test)
00050     print('score of classifier: {}'.format(accuracy_score(y_test, y_pred)))
00051     print(classification_report(y_test, y_pred, target_names=target_names))
00052 
00053 
00054 if __name__ == '__main__':
00055     main()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Sun Oct 8 2017 02:43:23