svm.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 from ctypes import *
00004 from ctypes.util import find_library
00005 import sys
00006 import subprocess as sb
00007 
00008 # For unix the prefix 'lib' is not considered.
00009 if find_library('svm'):
00010         libsvm = CDLL(find_library('svm'))
00011 elif find_library('libsvm'):
00012         libsvm = CDLL(find_library('libsvm'))
00013 else:
00014         if sys.platform == 'win32':
00015                 libsvm = CDLL('../windows/libsvm.dll')
00016         else:
00017                 libsvm = CDLL(sb.Popen(['rospack', 'find', 'libsvm3'], stdout=sb.PIPE).communicate()[0].rstrip() + '/libsvm.so.2')
00018                 #libsvm = CDLL('../libsvm.so.2')
00019 
00020 # Construct constants
00021 SVM_TYPE = ['C_SVC', 'NU_SVC', 'ONE_CLASS', 'EPSILON_SVR', 'NU_SVR' ]
00022 KERNEL_TYPE = ['LINEAR', 'POLY', 'RBF', 'SIGMOID', 'PRECOMPUTED']
00023 for i, s in enumerate(SVM_TYPE): exec("%s = %d" % (s , i))
00024 for i, s in enumerate(KERNEL_TYPE): exec("%s = %d" % (s , i))
00025 
00026 PRINT_STRING_FUN = CFUNCTYPE(None, c_char_p)
00027 def print_null(s): 
00028         return 
00029 
00030 def genFields(names, types): 
00031         return list(zip(names, types))
00032 
00033 def fillprototype(f, restype, argtypes): 
00034         f.restype = restype
00035         f.argtypes = argtypes
00036 
00037 class svm_node(Structure):
00038         _names = ["index", "value"]
00039         _types = [c_int, c_double]
00040         _fields_ = genFields(_names, _types)
00041 
00042 def gen_svm_nodearray(xi, feature_max=None, issparse=None):
00043         if isinstance(xi, dict):
00044                 index_range = xi.keys()
00045         elif isinstance(xi, (list, tuple)):
00046                 index_range = range(len(xi))
00047         else:
00048                 raise TypeError('xi should be a dictionary, list or tuple')
00049 
00050         if feature_max:
00051                 assert(isinstance(feature_max, int))
00052                 index_range = filter(lambda j: j <= feature_max, index_range)
00053         if issparse: 
00054                 index_range = filter(lambda j:xi[j] != 0, index_range)
00055 
00056         index_range = sorted(index_range)
00057         ret = (svm_node * (len(index_range)+1))()
00058         ret[-1].index = -1
00059         for idx, j in enumerate(index_range):
00060                 ret[idx].index = j
00061                 ret[idx].value = xi[j]
00062         max_idx = 0
00063         if index_range: 
00064                 max_idx = index_range[-1]
00065         return ret, max_idx
00066 
00067 class svm_problem(Structure):
00068         _names = ["l", "y", "x"]
00069         _types = [c_int, POINTER(c_double), POINTER(POINTER(svm_node))]
00070         _fields_ = genFields(_names, _types)
00071 
00072         def __init__(self, y, x):
00073                 if len(y) != len(x):
00074                         raise ValueError("len(y) != len(x)")
00075                 self.l = l = len(y)
00076 
00077                 max_idx = 0
00078                 x_space = self.x_space = []
00079                 for i, xi in enumerate(x):
00080                         tmp_xi, tmp_idx = gen_svm_nodearray(xi)
00081                         x_space += [tmp_xi]
00082                         max_idx = max(max_idx, tmp_idx)
00083                 self.n = max_idx
00084 
00085                 self.y = (c_double * l)()
00086                 for i, yi in enumerate(y): self.y[i] = yi
00087 
00088                 self.x = (POINTER(svm_node) * l)() 
00089                 for i, xi in enumerate(self.x_space): self.x[i] = xi
00090 
00091 class svm_parameter(Structure):
00092         _names = ["svm_type", "kernel_type", "degree", "gamma", "coef0",
00093                         "cache_size", "eps", "C", "nr_weight", "weight_label", "weight", 
00094                         "nu", "p", "shrinking", "probability"]
00095         _types = [c_int, c_int, c_int, c_double, c_double, 
00096                         c_double, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double),
00097                         c_double, c_double, c_int, c_int]
00098         _fields_ = genFields(_names, _types)
00099 
00100         def __init__(self, options = None):
00101                 if options == None:
00102                         options = ''
00103                 self.parse_options(options)
00104 
00105         def show(self):
00106                 attrs = svm_parameter._names + self.__dict__.keys()
00107                 values = map(lambda attr: getattr(self, attr), attrs) 
00108                 for attr, val in zip(attrs, values):
00109                         print(' %s: %s' % (attr, val))
00110 
00111         def set_to_default_values(self):
00112                 self.svm_type = C_SVC;
00113                 self.kernel_type = RBF
00114                 self.degree = 3
00115                 self.gamma = 0
00116                 self.coef0 = 0
00117                 self.nu = 0.5
00118                 self.cache_size = 100
00119                 self.C = 1
00120                 self.eps = 0.001
00121                 self.p = 0.1
00122                 self.shrinking = 1
00123                 self.probability = 0
00124                 self.nr_weight = 0
00125                 self.weight_label = (c_int*0)()
00126                 self.weight = (c_double*0)()
00127                 self.cross_validation = False
00128                 self.nr_fold = 0
00129                 self.print_func = None
00130 
00131         def parse_options(self, options):
00132                 argv = options.split()
00133                 self.set_to_default_values()
00134                 self.print_func = cast(None, PRINT_STRING_FUN)
00135                 weight_label = []
00136                 weight = []
00137 
00138                 i = 0
00139                 while i < len(argv):
00140                         if argv[i] == "-s":
00141                                 i = i + 1
00142                                 self.svm_type = int(argv[i])
00143                         elif argv[i] == "-t":
00144                                 i = i + 1
00145                                 self.kernel_type = int(argv[i])
00146                         elif argv[i] == "-d":
00147                                 i = i + 1
00148                                 self.degree = int(argv[i])
00149                         elif argv[i] == "-g":
00150                                 i = i + 1
00151                                 self.gamma = float(argv[i])
00152                         elif argv[i] == "-r":
00153                                 i = i + 1
00154                                 self.coef0 = float(argv[i])
00155                         elif argv[i] == "-n":
00156                                 i = i + 1
00157                                 self.nu = float(argv[i])
00158                         elif argv[i] == "-m":
00159                                 i = i + 1
00160                                 self.cache_size = float(argv[i])
00161                         elif argv[i] == "-c":
00162                                 i = i + 1
00163                                 self.C = float(argv[i])
00164                         elif argv[i] == "-e":
00165                                 i = i + 1
00166                                 self.eps = float(argv[i])
00167                         elif argv[i] == "-p":
00168                                 i = i + 1
00169                                 self.p = float(argv[i])
00170                         elif argv[i] == "-h":
00171                                 i = i + 1
00172                                 self.shrinking = int(argv[i])
00173                         elif argv[i] == "-b":
00174                                 i = i + 1
00175                                 self.probability = int(argv[i])
00176                         elif argv[i] == "-q":
00177                                 self.print_func = PRINT_STRING_FUN(print_null)
00178                         elif argv[i] == "-v":
00179                                 i = i + 1
00180                                 self.cross_validation = 1
00181                                 self.nr_fold = int(argv[i])
00182                                 if self.nr_fold < 2:
00183                                         raise ValueError("n-fold cross validation: n must >= 2")
00184                         elif argv[i].startswith("-w"):
00185                                 i = i + 1
00186                                 self.nr_weight += 1
00187                                 nr_weight = self.nr_weight
00188                                 weight_label += [int(argv[i-1][2:])]
00189                                 weight += [float(argv[i])]
00190                         else:
00191                                 raise ValueError("Wrong options")
00192                         i += 1
00193 
00194                 libsvm.svm_set_print_string_function(self.print_func)
00195                 self.weight_label = (c_int*self.nr_weight)()
00196                 self.weight = (c_double*self.nr_weight)()
00197                 for i in range(self.nr_weight): 
00198                         self.weight[i] = weight[i]
00199                         self.weight_label[i] = weight_label[i]
00200 
00201 class svm_model(Structure):
00202         #_names = ['param', 'nr_class', 'l', 'SV', 'sv_coef', 
00203     #          'rho', 'probA', 'probB', 'label', 'nSV', 'free_sv', 'SVidx', 'nx']
00204         #_types = [POINTER(svm_parameter), c_int, c_int, POINTER(POINTER(svm_node)), POINTER(POINTER(c_double)), 
00205     #          POINTER(c_double), POINTER(c_double), POINTER(c_double), POINTER(c_int), POINTER(c_int), c_int, 
00206     #          POINTER(c_byte), c_int]
00207         #_fields_ = genFields(_names, _types)
00208 
00209         def __init__(self):
00210                 self.__createfrom__ = 'python'
00211 
00212         def __del__(self):
00213                 # free memory created by C to avoid memory leak
00214                 if hasattr(self, '__createfrom__') and self.__createfrom__ == 'C':
00215                         libsvm.svm_free_and_destroy_model(pointer(self))
00216 
00217         def get_svm_type(self):
00218                 return libsvm.svm_get_svm_type(self)
00219 
00220         def get_nr_class(self):
00221                 return libsvm.svm_get_nr_class(self)
00222 
00223         def get_svr_probability(self):
00224                 return libsvm.svm_get_svr_probability(self)
00225 
00226         def get_labels(self):
00227                 nr_class = self.get_nr_class()
00228                 labels = (c_int * nr_class)()
00229                 libsvm.svm_get_labels(self, labels)
00230                 return labels[:nr_class]
00231 
00232         def is_probability_model(self):
00233                 return (libsvm.svm_check_probability_model(self) == 1)
00234 
00235         #def predict(self, sample):
00236         #       return libsvm.svm_predict(self, gen_svm_nodearray(sample)[0])
00237 
00238         #def predict_values(self, sample):
00239         #       nc = self.get_nr_class()
00240         #       nvalues = (nc * (nc-1))/2
00241         #       dists = (c_double*nvalues)()
00242         #       p = libsvm.svm_predict_values(self, gen_svm_nodearray(sample)[0], dists)
00243         #       ddict = {}
00244         #       ii = 0
00245         #       for i in range(nc):
00246         #               for j in range(i+1, nc):
00247         #                       ddict[(i,j)] = dists[ii]
00248         #                       ii = ii+1
00249 
00250         #       nr_classifier = nc*(nc-1)//2
00251         #       dec_values = (c_double * nr_classifier)()
00252         #       xi = sample
00253         #       xi, idx = gen_svm_nodearray(xi)
00254         #       label = libsvm.svm_predict_values(self, xi, dec_values)
00255         #       values = dec_values[:nr_classifier]
00256         #       print 'got label', label
00257 
00258         #       return p, ddict
00259 
00260         def get_support_vectors(self, samples_len):
00261                 svs = (c_int * samples_len)()
00262                 libsvm.svm_get_support_vectors(self, samples_len, svs)
00263                 return svs[:samples_len]
00264 
00265 
00266 
00267 
00268 def toPyModel(model_ptr):
00269         """
00270         toPyModel(model_ptr) -> svm_model
00271 
00272         Convert a ctypes POINTER(svm_model) to a Python svm_model
00273         """
00274         if bool(model_ptr) == False:
00275                 raise ValueError("Null pointer")
00276         m = model_ptr.contents
00277         m.__createfrom__ = 'C'
00278         return m
00279 
00280 fillprototype(libsvm.svm_train, POINTER(svm_model), [POINTER(svm_problem), POINTER(svm_parameter)])
00281 fillprototype(libsvm.svm_cross_validation, None, [POINTER(svm_problem), POINTER(svm_parameter), c_int, POINTER(c_double)])
00282 
00283 fillprototype(libsvm.svm_save_model, c_int, [c_char_p, POINTER(svm_model)])
00284 fillprototype(libsvm.svm_load_model, POINTER(svm_model), [c_char_p])
00285 
00286 fillprototype(libsvm.svm_get_svm_type, c_int, [POINTER(svm_model)])
00287 fillprototype(libsvm.svm_get_nr_class, c_int, [POINTER(svm_model)])
00288 fillprototype(libsvm.svm_get_labels, None, [POINTER(svm_model), POINTER(c_int)])
00289 fillprototype(libsvm.svm_get_support_vectors, None, [POINTER(svm_model), c_int, POINTER(c_int)])
00290 fillprototype(libsvm.svm_get_svr_probability, c_double, [POINTER(svm_model)])
00291 
00292 fillprototype(libsvm.svm_predict_values, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
00293 fillprototype(libsvm.svm_predict, c_double, [POINTER(svm_model), POINTER(svm_node)])
00294 fillprototype(libsvm.svm_predict_probability, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
00295 
00296 fillprototype(libsvm.svm_free_model_content, None, [POINTER(svm_model)])
00297 fillprototype(libsvm.svm_free_and_destroy_model, None, [POINTER(POINTER(svm_model))])
00298 fillprototype(libsvm.svm_destroy_param, None, [POINTER(svm_parameter)])
00299 
00300 fillprototype(libsvm.svm_check_parameter, c_char_p, [POINTER(svm_problem), POINTER(svm_parameter)])
00301 fillprototype(libsvm.svm_check_probability_model, c_int, [POINTER(svm_model)])
00302 fillprototype(libsvm.svm_set_print_string_function, None, [PRINT_STRING_FUN])


libsvm3
Author(s): various
autogenerated on Wed Nov 27 2013 11:36:23