svm.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 from ctypes import *
00004 from ctypes.util import find_library
00005 import sys
00006 import os
00007 
00008 # For unix the prefix 'lib' is not considered.
00009 if find_library('svm'):
00010         libsvm = CDLL(find_library('svm'))
00011 elif find_library('libsvm'):
00012         libsvm = CDLL(find_library('libsvm'))
00013 else:
00014         if sys.platform == 'win32':
00015                 libsvm = CDLL(os.path.join(os.path.dirname(__file__),\
00016                                 '../windows/libsvm.dll'))
00017         else:
00018                 libsvm = CDLL(os.path.join(os.path.dirname(__file__),\
00019                                 '../libsvm.so.2'))
00020 
00021 # Construct constants
00022 SVM_TYPE = ['C_SVC', 'NU_SVC', 'ONE_CLASS', 'EPSILON_SVR', 'NU_SVR' ]
00023 KERNEL_TYPE = ['LINEAR', 'POLY', 'RBF', 'SIGMOID', 'PRECOMPUTED']
00024 for i, s in enumerate(SVM_TYPE): exec("%s = %d" % (s , i))
00025 for i, s in enumerate(KERNEL_TYPE): exec("%s = %d" % (s , i))
00026 
00027 PRINT_STRING_FUN = CFUNCTYPE(None, c_char_p)
00028 def print_null(s): 
00029         return 
00030 
00031 def genFields(names, types): 
00032         return list(zip(names, types))
00033 
00034 def fillprototype(f, restype, argtypes): 
00035         f.restype = restype
00036         f.argtypes = argtypes
00037 
00038 class svm_node(Structure):
00039         _names = ["index", "value"]
00040         _types = [c_int, c_double]
00041         _fields_ = genFields(_names, _types)
00042 
00043 def gen_svm_nodearray(xi, feature_max=None, isKernel=None):
00044         if isinstance(xi, dict):
00045                 index_range = xi.keys()
00046         elif isinstance(xi, (list, tuple)):
00047                 if not isKernel:
00048                         xi = [0] + xi  # idx should start from 1
00049                 index_range = range(len(xi))
00050         else:
00051                 raise TypeError('xi should be a dictionary, list or tuple')
00052 
00053         if feature_max:
00054                 assert(isinstance(feature_max, int))
00055                 index_range = filter(lambda j: j <= feature_max, index_range)
00056         if not isKernel: 
00057                 index_range = filter(lambda j:xi[j] != 0, index_range)
00058 
00059         index_range = sorted(index_range)
00060         ret = (svm_node * (len(index_range)+1))()
00061         ret[-1].index = -1
00062         for idx, j in enumerate(index_range):
00063                 ret[idx].index = j
00064                 ret[idx].value = xi[j]
00065         max_idx = 0
00066         if index_range: 
00067                 max_idx = index_range[-1]
00068         return ret, max_idx
00069 
00070 class svm_problem(Structure):
00071         _names = ["l", "y", "x"]
00072         _types = [c_int, POINTER(c_double), POINTER(POINTER(svm_node))]
00073         _fields_ = genFields(_names, _types)
00074 
00075         def __init__(self, y, x, isKernel=None):
00076                 if len(y) != len(x):
00077                         raise ValueError("len(y) != len(x)")
00078                 self.l = l = len(y)
00079 
00080                 max_idx = 0
00081                 x_space = self.x_space = []
00082                 for i, xi in enumerate(x):
00083                         tmp_xi, tmp_idx = gen_svm_nodearray(xi,isKernel=isKernel)
00084                         x_space += [tmp_xi]
00085                         max_idx = max(max_idx, tmp_idx)
00086                 self.n = max_idx
00087 
00088                 self.y = (c_double * l)()
00089                 for i, yi in enumerate(y): self.y[i] = yi
00090 
00091                 self.x = (POINTER(svm_node) * l)() 
00092                 for i, xi in enumerate(self.x_space): self.x[i] = xi
00093 
00094 class svm_parameter(Structure):
00095         _names = ["svm_type", "kernel_type", "degree", "gamma", "coef0",
00096                         "cache_size", "eps", "C", "nr_weight", "weight_label", "weight", 
00097                         "nu", "p", "shrinking", "probability"]
00098         _types = [c_int, c_int, c_int, c_double, c_double, 
00099                         c_double, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double),
00100                         c_double, c_double, c_int, c_int]
00101         _fields_ = genFields(_names, _types)
00102 
00103         def __init__(self, options = None):
00104                 if options == None:
00105                         options = ''
00106                 self.parse_options(options)
00107 
00108         def show(self):
00109                 attrs = svm_parameter._names + self.__dict__.keys()
00110                 values = map(lambda attr: getattr(self, attr), attrs) 
00111                 for attr, val in zip(attrs, values):
00112                         print(' %s: %s' % (attr, val))
00113 
00114         def set_to_default_values(self):
00115                 self.svm_type = C_SVC;
00116                 self.kernel_type = RBF
00117                 self.degree = 3
00118                 self.gamma = 0
00119                 self.coef0 = 0
00120                 self.nu = 0.5
00121                 self.cache_size = 100
00122                 self.C = 1
00123                 self.eps = 0.001
00124                 self.p = 0.1
00125                 self.shrinking = 1
00126                 self.probability = 0
00127                 self.nr_weight = 0
00128                 self.weight_label = (c_int*0)()
00129                 self.weight = (c_double*0)()
00130                 self.cross_validation = False
00131                 self.nr_fold = 0
00132                 self.print_func = None
00133 
00134         def parse_options(self, options):
00135                 argv = options.split()
00136                 self.set_to_default_values()
00137                 self.print_func = cast(None, PRINT_STRING_FUN)
00138                 weight_label = []
00139                 weight = []
00140 
00141                 i = 0
00142                 while i < len(argv):
00143                         if argv[i] == "-s":
00144                                 i = i + 1
00145                                 self.svm_type = int(argv[i])
00146                         elif argv[i] == "-t":
00147                                 i = i + 1
00148                                 self.kernel_type = int(argv[i])
00149                         elif argv[i] == "-d":
00150                                 i = i + 1
00151                                 self.degree = int(argv[i])
00152                         elif argv[i] == "-g":
00153                                 i = i + 1
00154                                 self.gamma = float(argv[i])
00155                         elif argv[i] == "-r":
00156                                 i = i + 1
00157                                 self.coef0 = float(argv[i])
00158                         elif argv[i] == "-n":
00159                                 i = i + 1
00160                                 self.nu = float(argv[i])
00161                         elif argv[i] == "-m":
00162                                 i = i + 1
00163                                 self.cache_size = float(argv[i])
00164                         elif argv[i] == "-c":
00165                                 i = i + 1
00166                                 self.C = float(argv[i])
00167                         elif argv[i] == "-e":
00168                                 i = i + 1
00169                                 self.eps = float(argv[i])
00170                         elif argv[i] == "-p":
00171                                 i = i + 1
00172                                 self.p = float(argv[i])
00173                         elif argv[i] == "-h":
00174                                 i = i + 1
00175                                 self.shrinking = int(argv[i])
00176                         elif argv[i] == "-b":
00177                                 i = i + 1
00178                                 self.probability = int(argv[i])
00179                         elif argv[i] == "-q":
00180                                 self.print_func = PRINT_STRING_FUN(print_null)
00181                         elif argv[i] == "-v":
00182                                 i = i + 1
00183                                 self.cross_validation = 1
00184                                 self.nr_fold = int(argv[i])
00185                                 if self.nr_fold < 2:
00186                                         raise ValueError("n-fold cross validation: n must >= 2")
00187                         elif argv[i].startswith("-w"):
00188                                 i = i + 1
00189                                 self.nr_weight += 1
00190                                 nr_weight = self.nr_weight
00191                                 weight_label += [int(argv[i-1][2:])]
00192                                 weight += [float(argv[i])]
00193                         else:
00194                                 raise ValueError("Wrong options")
00195                         i += 1
00196 
00197                 libsvm.svm_set_print_string_function(self.print_func)
00198                 self.weight_label = (c_int*self.nr_weight)()
00199                 self.weight = (c_double*self.nr_weight)()
00200                 for i in range(self.nr_weight): 
00201                         self.weight[i] = weight[i]
00202                         self.weight_label[i] = weight_label[i]
00203 
00204 class svm_model(Structure):
00205         _names = ['param', 'nr_class', 'l', 'SV', 'sv_coef', 'rho',
00206                         'probA', 'probB', 'sv_indices', 'label', 'nSV', 'free_sv']
00207         _types = [svm_parameter, c_int, c_int, POINTER(POINTER(svm_node)),
00208                         POINTER(POINTER(c_double)), POINTER(c_double),
00209                         POINTER(c_double), POINTER(c_double), POINTER(c_int),
00210                         POINTER(c_int), POINTER(c_int), c_int]
00211         _fields_ = genFields(_names, _types)
00212 
00213         def __init__(self):
00214                 self.__createfrom__ = 'python'
00215 
00216         def __del__(self):
00217                 # free memory created by C to avoid memory leak
00218                 if hasattr(self, '__createfrom__') and self.__createfrom__ == 'C':
00219                         libsvm.svm_free_and_destroy_model(pointer(self))
00220 
00221         def get_svm_type(self):
00222                 return libsvm.svm_get_svm_type(self)
00223 
00224         def get_nr_class(self):
00225                 return libsvm.svm_get_nr_class(self)
00226 
00227         def get_svr_probability(self):
00228                 return libsvm.svm_get_svr_probability(self)
00229 
00230         def get_labels(self):
00231                 nr_class = self.get_nr_class()
00232                 labels = (c_int * nr_class)()
00233                 libsvm.svm_get_labels(self, labels)
00234                 return labels[:nr_class]
00235 
00236         def get_sv_indices(self):
00237                 total_sv = self.get_nr_sv()
00238                 sv_indices = (c_int * total_sv)()
00239                 libsvm.svm_get_sv_indices(self, sv_indices)
00240                 return sv_indices[:total_sv]
00241 
00242         def get_nr_sv(self):
00243                 return libsvm.svm_get_nr_sv(self)
00244 
00245         def is_probability_model(self):
00246                 return (libsvm.svm_check_probability_model(self) == 1)
00247 
00248         def get_sv_coef(self):
00249                 return [tuple(self.sv_coef[j][i] for j in xrange(self.nr_class - 1))
00250                                 for i in xrange(self.l)]
00251 
00252         def get_SV(self):
00253                 result = []
00254                 for sparse_sv in self.SV[:self.l]:
00255                         row = dict()
00256                         
00257                         i = 0
00258                         while True:
00259                                 row[sparse_sv[i].index] = sparse_sv[i].value
00260                                 if sparse_sv[i].index == -1:
00261                                         break
00262                                 i += 1
00263 
00264                         result.append(row)
00265                 return result
00266 
00267 def toPyModel(model_ptr):
00268         """
00269         toPyModel(model_ptr) -> svm_model
00270 
00271         Convert a ctypes POINTER(svm_model) to a Python svm_model
00272         """
00273         if bool(model_ptr) == False:
00274                 raise ValueError("Null pointer")
00275         m = model_ptr.contents
00276         m.__createfrom__ = 'C'
00277         return m
00278 
00279 fillprototype(libsvm.svm_train, POINTER(svm_model), [POINTER(svm_problem), POINTER(svm_parameter)])
00280 fillprototype(libsvm.svm_cross_validation, None, [POINTER(svm_problem), POINTER(svm_parameter), c_int, POINTER(c_double)])
00281 
00282 fillprototype(libsvm.svm_save_model, c_int, [c_char_p, POINTER(svm_model)])
00283 fillprototype(libsvm.svm_load_model, POINTER(svm_model), [c_char_p])
00284 
00285 fillprototype(libsvm.svm_get_svm_type, c_int, [POINTER(svm_model)])
00286 fillprototype(libsvm.svm_get_nr_class, c_int, [POINTER(svm_model)])
00287 fillprototype(libsvm.svm_get_labels, None, [POINTER(svm_model), POINTER(c_int)])
00288 fillprototype(libsvm.svm_get_sv_indices, None, [POINTER(svm_model), POINTER(c_int)])
00289 fillprototype(libsvm.svm_get_nr_sv, c_int, [POINTER(svm_model)])
00290 fillprototype(libsvm.svm_get_svr_probability, c_double, [POINTER(svm_model)])
00291 
00292 fillprototype(libsvm.svm_predict_values, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
00293 fillprototype(libsvm.svm_predict, c_double, [POINTER(svm_model), POINTER(svm_node)])
00294 fillprototype(libsvm.svm_predict_probability, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
00295 
00296 fillprototype(libsvm.svm_free_model_content, None, [POINTER(svm_model)])
00297 fillprototype(libsvm.svm_free_and_destroy_model, None, [POINTER(POINTER(svm_model))])
00298 fillprototype(libsvm.svm_destroy_param, None, [POINTER(svm_parameter)])
00299 
00300 fillprototype(libsvm.svm_check_parameter, c_char_p, [POINTER(svm_problem), POINTER(svm_parameter)])
00301 fillprototype(libsvm.svm_check_probability_model, c_int, [POINTER(svm_model)])
00302 fillprototype(libsvm.svm_set_print_string_function, None, [PRINT_STRING_FUN])


ml_classifiers
Author(s): Scott Niekum
autogenerated on Fri Jan 3 2014 11:30:23