svmutil.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 from svm import *
00004 
00005 def svm_read_problem(data_file_name):
00006         """
00007         svm_read_problem(data_file_name) -> [y, x]
00008 
00009         Read LIBSVM-format data from data_file_name and return labels y
00010         and data instances x.
00011         """
00012         prob_y = []
00013         prob_x = []
00014         for line in open(data_file_name):
00015                 line = line.split(None, 1)
00016                 # In case an instance with all zero features
00017                 if len(line) == 1: line += ['']
00018                 label, features = line
00019                 xi = {}
00020                 for e in features.split():
00021                         ind, val = e.split(":")
00022                         xi[int(ind)] = float(val)
00023                 prob_y += [float(label)]
00024                 prob_x += [xi]
00025         return (prob_y, prob_x)
00026 
00027 def svm_load_model(model_file_name):
00028         """
00029         svm_load_model(model_file_name) -> model
00030         
00031         Load a LIBSVM model from model_file_name and return.
00032         """
00033         model = libsvm.svm_load_model(model_file_name)
00034         if not model: 
00035                 print("can't open model file %s" % model_file_name)
00036                 return None
00037         model = toPyModel(model)
00038         return model
00039 
00040 def svm_save_model(model_file_name, model):
00041         """
00042         svm_save_model(model_file_name, model) -> None
00043 
00044         Save a LIBSVM model to the file model_file_name.
00045         """
00046         libsvm.svm_save_model(model_file_name, model)
00047 
00048 def evaluations(ty, pv):
00049         """
00050         evaluations(ty, pv) -> (ACC, MSE, SCC)
00051 
00052         Calculate accuracy, mean squared error and squared correlation coefficient
00053         using the true values (ty) and predicted values (pv).
00054         """
00055         if len(ty) != len(pv):
00056                 raise ValueError("len(ty) must equal to len(pv)")
00057         total_correct = total_error = 0
00058         sumv = sumy = sumvv = sumyy = sumvy = 0
00059         for v, y in zip(pv, ty):
00060                 if y == v: 
00061                         total_correct += 1
00062                 total_error += (v-y)*(v-y)
00063                 sumv += v
00064                 sumy += y
00065                 sumvv += v*v
00066                 sumyy += y*y
00067                 sumvy += v*y 
00068         l = len(ty)
00069         ACC = 100.0*total_correct/l
00070         MSE = total_error/l
00071         try:
00072                 SCC = ((l*sumvy-sumv*sumy)*(l*sumvy-sumv*sumy))/((l*sumvv-sumv*sumv)*(l*sumyy-sumy*sumy))
00073         except:
00074                 SCC = float('nan')
00075         return (ACC, MSE, SCC)
00076 
00077 def svm_train(arg1, arg2=None, arg3=None):
00078         """
00079         svm_train(y, x [, 'options']) -> model | ACC | MSE 
00080         svm_train(prob, [, 'options']) -> model | ACC | MSE 
00081         svm_train(prob, param) -> model | ACC| MSE 
00082 
00083         Train an SVM model from data (y, x) or an svm_problem prob using
00084         'options' or an svm_parameter param. 
00085         If '-v' is specified in 'options' (i.e., cross validation)
00086         either accuracy (ACC) or mean-squared error (MSE) is returned.
00087         'options':
00088             -s svm_type : set type of SVM (default 0)
00089                 0 -- C-SVC
00090                 1 -- nu-SVC
00091                 2 -- one-class SVM
00092                 3 -- epsilon-SVR
00093                 4 -- nu-SVR
00094             -t kernel_type : set type of kernel function (default 2)
00095                 0 -- linear: u'*v
00096                 1 -- polynomial: (gamma*u'*v + coef0)^degree
00097                 2 -- radial basis function: exp(-gamma*|u-v|^2)
00098                 3 -- sigmoid: tanh(gamma*u'*v + coef0)
00099                 4 -- precomputed kernel (kernel values in training_set_file)
00100             -d degree : set degree in kernel function (default 3)
00101             -g gamma : set gamma in kernel function (default 1/num_features)
00102             -r coef0 : set coef0 in kernel function (default 0)
00103             -c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)
00104             -n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)
00105             -p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)
00106             -m cachesize : set cache memory size in MB (default 100)
00107             -e epsilon : set tolerance of termination criterion (default 0.001)
00108             -h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)
00109             -b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)
00110             -wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)
00111             -v n: n-fold cross validation mode
00112             -q : quiet mode (no outputs)
00113         """
00114         prob, param = None, None
00115         if isinstance(arg1, (list, tuple)):
00116                 assert isinstance(arg2, (list, tuple))
00117                 y, x, options = arg1, arg2, arg3
00118                 param = svm_parameter(options)
00119                 prob = svm_problem(y, x, isKernel=(param.kernel_type == PRECOMPUTED))
00120         elif isinstance(arg1, svm_problem):
00121                 prob = arg1
00122                 if isinstance(arg2, svm_parameter):
00123                         param = arg2
00124                 else:
00125                         param = svm_parameter(arg2)
00126         if prob == None or param == None:
00127                 raise TypeError("Wrong types for the arguments")
00128 
00129         if param.kernel_type == PRECOMPUTED:
00130                 for xi in prob.x_space:
00131                         idx, val = xi[0].index, xi[0].value
00132                         if xi[0].index != 0:
00133                                 raise ValueError('Wrong input format: first column must be 0:sample_serial_number')
00134                         if val <= 0 or val > prob.n:
00135                                 raise ValueError('Wrong input format: sample_serial_number out of range')
00136 
00137         if param.gamma == 0 and prob.n > 0: 
00138                 param.gamma = 1.0 / prob.n
00139         libsvm.svm_set_print_string_function(param.print_func)
00140         err_msg = libsvm.svm_check_parameter(prob, param)
00141         if err_msg:
00142                 raise ValueError('Error: %s' % err_msg)
00143 
00144         if param.cross_validation:
00145                 l, nr_fold = prob.l, param.nr_fold
00146                 target = (c_double * l)()
00147                 libsvm.svm_cross_validation(prob, param, nr_fold, target)       
00148                 ACC, MSE, SCC = evaluations(prob.y[:l], target[:l])
00149                 if param.svm_type in [EPSILON_SVR, NU_SVR]:
00150                         print("Cross Validation Mean squared error = %g" % MSE)
00151                         print("Cross Validation Squared correlation coefficient = %g" % SCC)
00152                         return MSE
00153                 else:
00154                         print("Cross Validation Accuracy = %g%%" % ACC)
00155                         return ACC
00156         else:
00157                 m = libsvm.svm_train(prob, param)
00158                 m = toPyModel(m)
00159 
00160                 # If prob is destroyed, data including SVs pointed by m can remain.
00161                 m.x_space = prob.x_space
00162                 return m
00163 
00164 def svm_predict(y, x, m, options=""):
00165         """
00166         svm_predict(y, x, m [, "options"]) -> (p_labels, p_acc, p_vals)
00167 
00168         Predict data (y, x) with the SVM model m. 
00169         "options": 
00170             -b probability_estimates: whether to predict probability estimates, 
00171                 0 or 1 (default 0); for one-class SVM only 0 is supported.
00172 
00173         The return tuple contains
00174         p_labels: a list of predicted labels
00175         p_acc: a tuple including  accuracy (for classification), mean-squared 
00176                error, and squared correlation coefficient (for regression).
00177         p_vals: a list of decision values or probability estimates (if '-b 1' 
00178                 is specified). If k is the number of classes, for decision values,
00179                 each element includes results of predicting k(k-1)/2 binary-class
00180                 SVMs. For probabilities, each element contains k values indicating
00181                 the probability that the testing instance is in each class.
00182                 Note that the order of classes here is the same as 'model.label'
00183                 field in the model structure.
00184         """
00185         predict_probability = 0
00186         argv = options.split()
00187         i = 0
00188         while i < len(argv):
00189                 if argv[i] == '-b':
00190                         i += 1
00191                         predict_probability = int(argv[i])
00192                 else:
00193                         raise ValueError("Wrong options")
00194                 i+=1
00195 
00196         svm_type = m.get_svm_type()
00197         is_prob_model = m.is_probability_model()
00198         nr_class = m.get_nr_class()
00199         pred_labels = []
00200         pred_values = []
00201 
00202         if predict_probability:
00203                 if not is_prob_model:
00204                         raise ValueError("Model does not support probabiliy estimates")
00205 
00206                 if svm_type in [NU_SVR, EPSILON_SVR]:
00207                         print("Prob. model for test data: target value = predicted value + z,\n"
00208                         "z: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g" % m.get_svr_probability());
00209                         nr_class = 0
00210 
00211                 prob_estimates = (c_double * nr_class)()
00212                 for xi in x:
00213                         xi, idx = gen_svm_nodearray(xi, isKernel=(m.param.kernel_type == PRECOMPUTED))
00214                         label = libsvm.svm_predict_probability(m, xi, prob_estimates)
00215                         values = prob_estimates[:nr_class]
00216                         pred_labels += [label]
00217                         pred_values += [values]
00218         else:
00219                 if is_prob_model:
00220                         print("Model supports probability estimates, but disabled in predicton.")
00221                 if svm_type in (ONE_CLASS, EPSILON_SVR, NU_SVC):
00222                         nr_classifier = 1
00223                 else:
00224                         nr_classifier = nr_class*(nr_class-1)//2
00225                 dec_values = (c_double * nr_classifier)()
00226                 for xi in x:
00227                         xi, idx = gen_svm_nodearray(xi, isKernel=(m.param.kernel_type == PRECOMPUTED))
00228                         label = libsvm.svm_predict_values(m, xi, dec_values)
00229                         if(nr_class == 1): 
00230                                 values = [1]
00231                         else: 
00232                                 values = dec_values[:nr_classifier]
00233                         pred_labels += [label]
00234                         pred_values += [values]
00235 
00236         ACC, MSE, SCC = evaluations(y, pred_labels)
00237         l = len(y)
00238         if svm_type in [EPSILON_SVR, NU_SVR]:
00239                 print("Mean squared error = %g (regression)" % MSE)
00240                 print("Squared correlation coefficient = %g (regression)" % SCC)
00241         else:
00242                 print("Accuracy = %g%% (%d/%d) (classification)" % (ACC, int(l*ACC/100), l))
00243 
00244         return pred_labels, (ACC, MSE, SCC), pred_values
00245 


haf_grasping
Author(s): David Fischinger
autogenerated on Wed Jan 11 2017 03:48:49