7 import cPickle
as pickle
11 from sklearn.preprocessing
import normalize
12 from sklearn.linear_model
import LogisticRegression
13 from sklearn.cross_validation
import train_test_split
14 from sklearn.metrics
import classification_report, accuracy_score
18 parser = argparse.ArgumentParser()
19 parser.add_argument(
'dataset',
20 help=
'dataset must have data, target, target_names attributes')
21 parser.add_argument(
'-c',
'--classifier', default=
'logistic_regression',
22 help=
'now supports logistic_regression only')
23 parser.add_argument(
'-O',
'--output', default=
'clf.pkl.gz',
24 help=
'saving clf filename')
25 args = parser.parse_args(sys.argv[1:])
27 print(
'loading dataset')
28 with gzip.open(args.dataset,
'rb')
as f:
29 dataset = pickle.load(f)
33 target_names = dataset.target_names
36 X_train, X_test, y_train, y_test = train_test_split(X, y,
37 random_state=np.random.randint(1234))
40 if args.classifier ==
'logistic_regression':
41 clf = LogisticRegression()
43 raise ValueError(
'unsupported classifier')
44 print((
'fitting {0}'.format(args.classifier)))
45 clf.fit(X_train, y_train)
46 clf.target_names_ = target_names
47 with gzip.open(args.output,
'wb')
as f:
49 y_pred = clf.predict(X_test)
50 print((
'score of classifier: {}'.format(accuracy_score(y_test, y_pred))))
51 print((classification_report(y_test, y_pred, target_names=target_names)))
54 if __name__ ==
'__main__':