00001
00002
00003 import sys
00004 import csv
00005 import numpy as np
00006 import math
00007 import matplotlib.pyplot as plt
00008 import pprint
00009 import pickle
00010 import argparse
00011
00012 DEGREES = {'WEAK' : 0.33,
00013 'AVERAGE': 0.66,
00014 'STRONG' : 1.0}
00015
00016 ACTIONS = {'WINCE' : [0,0,0],
00017 'SMILE' : [0.5,0,0] ,
00018 'FROWN' : [0,0.5,0],
00019 'LAUGH' : [0,0,0.5],
00020 'GLARE' : [0.5,0.5,0],
00021 'NOD' : [0.5,0,0.5],
00022 'SHAKE' : [0,0.5,0.5],
00023 'REQUEST FOR BOARD': [0.5,0.5,0.5],
00024 'EYE-ROLL':[1,0,0],
00025 'JOY' : [0,1,0],
00026 'SUPRISE': [0,0,1],
00027 'FEAR' : [1,1,0],
00028 'ANGER' : [0,1,1],
00029 'DISGUST': [1,0,1],
00030 'SADNESS': [0.5,0,0]}
00031 ACT_LIST = ['WINCE', 'NOD', 'SHAKE', 'JOY', "FEAR", "SUPRISE", "ANGER", "DISGUST", "SADNESS"]
00032
00033
00034 def extract_data(files):
00035 data = []
00036 for data_file in files.split():
00037 with open(data_file, 'rb') as f:
00038 reader = csv.reader(f)
00039 for row in reader:
00040 data.append(row)
00041 return data
00042
00043 def process(files, SVM_DATA_FILE, WINDOW_DUR, MAG_THRESH, plot):
00044 data = extract_data(files)
00045 npdata = np.array(data)
00046 txt = npdata[:,:2]
00047 nums=np.array(npdata[:,2:], dtype=float)
00048 x2 = np.square(nums[:,1])
00049
00050 y2 = np.square(nums[:,2])
00051
00052 mags=np.sqrt(np.add(x2, y2))
00053
00054 dirs=np.arctan2(nums[:,2], nums[:,1])
00055 nums = np.hstack((nums, np.column_stack((mags,dirs))))
00056 window = []
00057 o_type_cnt={}.fromkeys(ACT_LIST,0)
00058 f_type_cnt={}.fromkeys(ACT_LIST,0)
00059 legend_labels = []
00060 svm_label = []
00061 svm_data = []
00062 for dat in data:
00063
00064 o_type_cnt[dat[1]] += 1
00065 dat[2]=float(dat[2])
00066 dat[3]=float(dat[3])
00067 dat[4]=float(dat[4])
00068 dat.append((dat[3]**2. + dat[4]**2.)**(1./2))
00069 dat.append(math.atan2(dat[4], dat[3]))
00070 color = tuple(ACTIONS[dat[1]]+[DEGREES[dat[0]]])
00071 if plot:
00072 if dat[1] not in legend_labels:
00073 legend_labels.append(dat[1])
00074 plt.figure(1)
00075 plt.polar(dat[-1], dat[-2], '.', color=color, label=dat[1])
00076 plt.figure(2)
00077
00078 plt.figure(3)
00079
00080 else:
00081 plt.figure(1)
00082 plt.polar(dat[-1], dat[-2], '.', color=color)
00083 plt.figure(2)
00084
00085 plt.figure(3)
00086
00087 if (dat[5]<MAG_THRESH):
00088 continue
00089
00090 f_type_cnt[dat[1]] += 1
00091 window.append(dat)
00092 while (window[-1][2] - window[0][2]) > WINDOW_DUR:
00093 window.pop(0)
00094 dat.append(len(window))
00095 movement = [0.,0.]
00096 for datum in window:
00097 movement[0] += datum[3]
00098 movement[1] += datum[4]
00099 dat.append((movement[0]**2+movement[1]**2)**(1./2))
00100 dat.append(dat[-1]/dat[-2])
00101 dat.append(math.atan2(movement[1],movement[0]))
00102 if SVM_DATA_FILE is not None:
00103 if dat[1] == 'WINCE':
00104 svm_label.append(1)
00105 else:
00106 svm_label.append(0)
00107 svm_data.append([dat[5],dat[6],dat[7],dat[9],dat[10]])
00108
00109 mean_std = np.empty((4, len(ACT_LIST)))
00110 print " \r\n"*5
00111 print "Total Datapoints: ", len(data)
00112 print " \r\n"
00113 for i,act in enumerate(ACT_LIST):
00114 indices = np.nonzero(txt[:,1]==act)
00115 mean_std[0,i] = np.mean(nums[indices,3])
00116 mean_std[1,i] = np.std(nums[indices,3])
00117 mean_std[2,i] = np.mean(nums[indices,4])
00118 mean_std[3,i] = np.std(nums[indices,4])
00119 print "%s:" %act
00120 print "%s raw events" %indices[0].size
00121 print "Mag: %3.2f (%3.2f) \r\nDir: %3.2f (%3.2f)" %(mean_std[0,i], mean_std[1,i], mean_std[2,i], mean_std[3,i])
00122
00123
00124 print " \r\n"
00125 print "Impact of Filtering:"
00126 total_features=0
00127 for type_ in o_type_cnt.keys():
00128 total_features += f_type_cnt[type_]
00129 for type_ in o_type_cnt.keys():
00130 print "%s: \r\n %s (%2.2f%%) --> \r\n %s (%2.2f%%)" %(type_,
00131 o_type_cnt[type_],
00132 (100.*o_type_cnt[type_])/len(data),
00133 f_type_cnt[type_],
00134 (100.*f_type_cnt[type_])/total_features)
00135
00136 print " \r\n"
00137 print "Total Features: ", total_features
00138 print " \r\n"*2
00139
00140 if plot:
00141 plt.figure(1)
00142 plt.legend(loc=2,bbox_to_anchor=(1,1))
00143 plt.figure(3)
00144 ind = np.arange(len(ACT_LIST))
00145 width = 0.5
00146 p1 = plt.bar(ind, mean_std[0,:], width, yerr=mean_std[1,:])
00147 plt.ylabel('Mean Mag. +/- std.')
00148 plt.xlabel('Event Type')
00149 plt.title('Mean Magnitude per Event Type')
00150 plt.xticks(ind+width/2., tuple(ACT_LIST))
00151
00152
00153 plt.figure(4)
00154 ind = np.arange(len(ACT_LIST))
00155 width = 0.5
00156 p1 = plt.bar(ind, mean_std[2,:], width, yerr=mean_std[3,:])
00157 plt.ylabel('Mean Dir. +/- std.')
00158 plt.xlabel('Event Type')
00159 plt.title('Mean Direction per Event Type')
00160 plt.xticks(ind+width/2., tuple(ACT_LIST))
00161
00162
00163 if SVM_DATA_FILE is not None:
00164 svm_output = {'labels':svm_label,
00165 'data':svm_data}
00166 with open(SVM_DATA_FILE+'.pkl','wb+') as f_pkl:
00167 pickle.dump(svm_output, f_pkl)
00168
00169
00170 def create_ROC(filename):
00171 from scipy import interp
00172 from sklearn import preprocessing as pps, svm
00173 from sklearn.metrics import roc_curve, auc
00174 from sklearn.cross_validation import StratifiedKFold, LeaveOneOut
00175
00176 filepath=filename+'.pkl'
00177 with open(filepath, 'rb') as f:
00178 svm_data = pickle.load(f)
00179 labels = svm_data['labels']
00180 data = svm_data['data']
00181
00182 scaler = pps.Scaler().fit(data)
00183 print "Mean: ", scaler.mean_
00184 print "Std: ", scaler.std_
00185 data_scaled = scaler.transform(data)
00186
00187 classifier = svm.SVC(probability=True)
00188 classifier.fit(data_scaled, labels)
00189
00190
00191 print "SV's per class: \r\n", classifier.n_support_
00192
00193
00194
00195
00196 X, y = data_scaled, np.array(labels)
00197 n_samples, n_features = X.shape
00198 print n_samples, n_features
00199
00200
00201
00202
00203 cv = StratifiedKFold(y, k=9)
00204
00205 mean_tpr = 0.0
00206 mean_fpr = np.linspace(0, 1, n_samples)
00207 all_tpr = []
00208 plt.figure(2)
00209 for i, (train, test) in enumerate(cv):
00210 probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test])
00211
00212 fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1])
00213 mean_tpr += interp(mean_fpr, fpr, tpr)
00214 mean_tpr[0] = 0.0
00215 roc_auc = auc(fpr, tpr)
00216 plt.plot(fpr, tpr, '--', lw=1, label='ROC fold %d (area = %0.2f)' % (i, roc_auc))
00217
00218 plt.plot([0, 1], [0, 1], '--', color=(0.6, 0.6, 0.6), label='Luck')
00219
00220 mean_tpr /= len(cv)
00221 mean_tpr[-1] = 1.0
00222 mean_auc = auc(mean_fpr, mean_tpr)
00223 plt.plot(mean_fpr, mean_tpr, 'k-', lw=3,
00224 label='Mean ROC (area = %0.2f)' % mean_auc)
00225
00226 plt.xlim([0, 1])
00227 plt.ylim([0, 1])
00228 plt.xlabel('False Positive Rate')
00229 plt.ylabel('True Positive Rate')
00230 plt.title('Receiver Operating Characteristic')
00231 plt.legend(loc="lower right")
00232 plt.show()
00233 print "Finished!"
00234
00235 if __name__=='__main__':
00236 parser = argparse.ArgumentParser(
00237 description="Process raw wouse training data to output plots,"
00238 "statistics, and SVM-ready formatted data",
00239 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
00240 parser.add_argument('filename',
00241 help="One or more training data files to process")
00242 parser.add_argument('-o','--output', dest="SVM_DATA_FILE",
00243 help="Output file for SVM-formatted training data")
00244 parser.add_argument('-w','--window', default=0.250, type=float,
00245 help="Length of time window in seconds")
00246 parser.add_argument('-t','--threshold', default=2.5, type=float,
00247 help="Minimum activity threshold")
00248 parser.add_argument('-p','--plot', action='store_true',
00249 help="Produce plots regarding the data")
00250 parser.add_argument('-r','--ROC', action='store_true',
00251 help="Produce ROC Curve using stratified k-fold crossvalidation")
00252 args = parser.parse_args()
00253
00254 print "Parsing data from the following files: \r\n ", args.filename
00255
00256 process(args.filename, args.SVM_DATA_FILE, args.window, args.threshold, args.plot)
00257 if args.SVM_DATA_FILE is not None and args.ROC:
00258 print "Creating ROC"
00259 create_ROC(args.SVM_DATA_FILE)