Go to the documentation of this file.00001
00002
00003
00004 import os
00005 import sys
00006 import gzip
00007 import cPickle as pickle
00008 import argparse
00009
00010 import numpy as np
00011 from sklearn.preprocessing import normalize
00012 from sklearn.linear_model import LogisticRegression
00013 from sklearn.cross_validation import train_test_split
00014 from sklearn.metrics import classification_report, accuracy_score
00015
00016
00017 def main():
00018 parser = argparse.ArgumentParser()
00019 parser.add_argument('dataset',
00020 help='dataset must have data, target, target_names attributes')
00021 parser.add_argument('-c', '--classifier', default='logistic_regression',
00022 help='now supports logistic_regression only')
00023 parser.add_argument('-O', '--output', default='clf.pkl.gz',
00024 help='saving clf filename')
00025 args = parser.parse_args(sys.argv[1:])
00026
00027 print('loading dataset')
00028 with gzip.open(args.dataset, 'rb') as f:
00029 dataset = pickle.load(f)
00030
00031 X = dataset.data
00032 y = dataset.target
00033 target_names = dataset.target_names
00034
00035
00036 X_train, X_test, y_train, y_test = train_test_split(X, y,
00037 random_state=np.random.randint(1234))
00038
00039
00040 if args.classifier == 'logistic_regression':
00041 clf = LogisticRegression()
00042 else:
00043 raise ValueError('unsupported classifier')
00044 print('fitting {0}'.format(args.classifier))
00045 clf.fit(X_train, y_train)
00046 clf.target_names_ = target_names
00047 with gzip.open(args.output, 'wb') as f:
00048 pickle.dump(clf, f)
00049 y_pred = clf.predict(X_test)
00050 print('score of classifier: {}'.format(accuracy_score(y_test, y_pred)))
00051 print(classification_report(y_test, y_pred, target_names=target_names))
00052
00053
00054 if __name__ == '__main__':
00055 main()