svm.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 from ctypes import *
00004 from ctypes.util import find_library
00005 from os import path
00006 import sys
00007 
00008 if sys.version_info[0] >= 3:
00009         xrange = range
00010 
00011 __all__ = ['libsvm', 'svm_problem', 'svm_parameter',
00012            'toPyModel', 'gen_svm_nodearray', 'print_null', 'svm_node', 'C_SVC',
00013            'EPSILON_SVR', 'LINEAR', 'NU_SVC', 'NU_SVR', 'ONE_CLASS',
00014            'POLY', 'PRECOMPUTED', 'PRINT_STRING_FUN', 'RBF',
00015            'SIGMOID', 'c_double', 'svm_model']
00016 
00017 try:
00018         dirname = path.dirname(path.abspath(__file__))
00019         if sys.platform == 'win32':
00020                 libsvm = CDLL(path.join(dirname, r'..\windows\libsvm.dll'))
00021         else:
00022                 libsvm = CDLL(path.join(dirname, '../libsvm.so.2'))
00023 except:
00024 # For unix the prefix 'lib' is not considered.
00025         if find_library('svm'):
00026                 libsvm = CDLL(find_library('svm'))
00027         elif find_library('libsvm'):
00028                 libsvm = CDLL(find_library('libsvm'))
00029         else:
00030                 raise Exception('LIBSVM library not found.')
00031 
00032 C_SVC = 0
00033 NU_SVC = 1
00034 ONE_CLASS = 2
00035 EPSILON_SVR = 3
00036 NU_SVR = 4
00037 
00038 LINEAR = 0
00039 POLY = 1
00040 RBF = 2
00041 SIGMOID = 3
00042 PRECOMPUTED = 4
00043 
00044 PRINT_STRING_FUN = CFUNCTYPE(None, c_char_p)
00045 def print_null(s):
00046         return
00047 
00048 def genFields(names, types):
00049         return list(zip(names, types))
00050 
00051 def fillprototype(f, restype, argtypes):
00052         f.restype = restype
00053         f.argtypes = argtypes
00054 
00055 class svm_node(Structure):
00056         _names = ["index", "value"]
00057         _types = [c_int, c_double]
00058         _fields_ = genFields(_names, _types)
00059 
00060         def __str__(self):
00061                 return '%d:%g' % (self.index, self.value)
00062 
00063 def gen_svm_nodearray(xi, feature_max=None, isKernel=None):
00064         if isinstance(xi, dict):
00065                 index_range = xi.keys()
00066         elif isinstance(xi, (list, tuple)):
00067                 if not isKernel:
00068                         xi = [0] + xi  # idx should start from 1
00069                 index_range = range(len(xi))
00070         else:
00071                 raise TypeError('xi should be a dictionary, list or tuple')
00072 
00073         if feature_max:
00074                 assert(isinstance(feature_max, int))
00075                 index_range = filter(lambda j: j <= feature_max, index_range)
00076         if not isKernel:
00077                 index_range = filter(lambda j:xi[j] != 0, index_range)
00078 
00079         index_range = sorted(index_range)
00080         ret = (svm_node * (len(index_range)+1))()
00081         ret[-1].index = -1
00082         for idx, j in enumerate(index_range):
00083                 ret[idx].index = j
00084                 ret[idx].value = xi[j]
00085         max_idx = 0
00086         if index_range:
00087                 max_idx = index_range[-1]
00088         return ret, max_idx
00089 
00090 class svm_problem(Structure):
00091         _names = ["l", "y", "x"]
00092         _types = [c_int, POINTER(c_double), POINTER(POINTER(svm_node))]
00093         _fields_ = genFields(_names, _types)
00094 
00095         def __init__(self, y, x, isKernel=None):
00096                 if len(y) != len(x):
00097                         raise ValueError("len(y) != len(x)")
00098                 self.l = l = len(y)
00099 
00100                 max_idx = 0
00101                 x_space = self.x_space = []
00102                 for i, xi in enumerate(x):
00103                         tmp_xi, tmp_idx = gen_svm_nodearray(xi,isKernel=isKernel)
00104                         x_space += [tmp_xi]
00105                         max_idx = max(max_idx, tmp_idx)
00106                 self.n = max_idx
00107 
00108                 self.y = (c_double * l)()
00109                 for i, yi in enumerate(y): self.y[i] = yi
00110 
00111                 self.x = (POINTER(svm_node) * l)()
00112                 for i, xi in enumerate(self.x_space): self.x[i] = xi
00113 
00114 class svm_parameter(Structure):
00115         _names = ["svm_type", "kernel_type", "degree", "gamma", "coef0",
00116                         "cache_size", "eps", "C", "nr_weight", "weight_label", "weight",
00117                         "nu", "p", "shrinking", "probability"]
00118         _types = [c_int, c_int, c_int, c_double, c_double,
00119                         c_double, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double),
00120                         c_double, c_double, c_int, c_int]
00121         _fields_ = genFields(_names, _types)
00122 
00123         def __init__(self, options = None):
00124                 if options == None:
00125                         options = ''
00126                 self.parse_options(options)
00127 
00128         def __str__(self):
00129                 s = ''
00130                 attrs = svm_parameter._names + list(self.__dict__.keys())
00131                 values = map(lambda attr: getattr(self, attr), attrs)
00132                 for attr, val in zip(attrs, values):
00133                         s += (' %s: %s\n' % (attr, val))
00134                 s = s.strip()
00135 
00136                 return s
00137 
00138         def set_to_default_values(self):
00139                 self.svm_type = C_SVC;
00140                 self.kernel_type = RBF
00141                 self.degree = 3
00142                 self.gamma = 0
00143                 self.coef0 = 0
00144                 self.nu = 0.5
00145                 self.cache_size = 100
00146                 self.C = 1
00147                 self.eps = 0.001
00148                 self.p = 0.1
00149                 self.shrinking = 1
00150                 self.probability = 0
00151                 self.nr_weight = 0
00152                 self.weight_label = None
00153                 self.weight = None
00154                 self.cross_validation = False
00155                 self.nr_fold = 0
00156                 self.print_func = cast(None, PRINT_STRING_FUN)
00157 
00158         def parse_options(self, options):
00159                 if isinstance(options, list):
00160                         argv = options
00161                 elif isinstance(options, str):
00162                         argv = options.split()
00163                 else:
00164                         raise TypeError("arg 1 should be a list or a str.")
00165                 self.set_to_default_values()
00166                 self.print_func = cast(None, PRINT_STRING_FUN)
00167                 weight_label = []
00168                 weight = []
00169 
00170                 i = 0
00171                 while i < len(argv):
00172                         if argv[i] == "-s":
00173                                 i = i + 1
00174                                 self.svm_type = int(argv[i])
00175                         elif argv[i] == "-t":
00176                                 i = i + 1
00177                                 self.kernel_type = int(argv[i])
00178                         elif argv[i] == "-d":
00179                                 i = i + 1
00180                                 self.degree = int(argv[i])
00181                         elif argv[i] == "-g":
00182                                 i = i + 1
00183                                 self.gamma = float(argv[i])
00184                         elif argv[i] == "-r":
00185                                 i = i + 1
00186                                 self.coef0 = float(argv[i])
00187                         elif argv[i] == "-n":
00188                                 i = i + 1
00189                                 self.nu = float(argv[i])
00190                         elif argv[i] == "-m":
00191                                 i = i + 1
00192                                 self.cache_size = float(argv[i])
00193                         elif argv[i] == "-c":
00194                                 i = i + 1
00195                                 self.C = float(argv[i])
00196                         elif argv[i] == "-e":
00197                                 i = i + 1
00198                                 self.eps = float(argv[i])
00199                         elif argv[i] == "-p":
00200                                 i = i + 1
00201                                 self.p = float(argv[i])
00202                         elif argv[i] == "-h":
00203                                 i = i + 1
00204                                 self.shrinking = int(argv[i])
00205                         elif argv[i] == "-b":
00206                                 i = i + 1
00207                                 self.probability = int(argv[i])
00208                         elif argv[i] == "-q":
00209                                 self.print_func = PRINT_STRING_FUN(print_null)
00210                         elif argv[i] == "-v":
00211                                 i = i + 1
00212                                 self.cross_validation = 1
00213                                 self.nr_fold = int(argv[i])
00214                                 if self.nr_fold < 2:
00215                                         raise ValueError("n-fold cross validation: n must >= 2")
00216                         elif argv[i].startswith("-w"):
00217                                 i = i + 1
00218                                 self.nr_weight += 1
00219                                 weight_label += [int(argv[i-1][2:])]
00220                                 weight += [float(argv[i])]
00221                         else:
00222                                 raise ValueError("Wrong options")
00223                         i += 1
00224 
00225                 libsvm.svm_set_print_string_function(self.print_func)
00226                 self.weight_label = (c_int*self.nr_weight)()
00227                 self.weight = (c_double*self.nr_weight)()
00228                 for i in range(self.nr_weight):
00229                         self.weight[i] = weight[i]
00230                         self.weight_label[i] = weight_label[i]
00231 
00232 class svm_model(Structure):
00233         _names = ['param', 'nr_class', 'l', 'SV', 'sv_coef', 'rho',
00234                         'probA', 'probB', 'sv_indices', 'label', 'nSV', 'free_sv']
00235         _types = [svm_parameter, c_int, c_int, POINTER(POINTER(svm_node)),
00236                         POINTER(POINTER(c_double)), POINTER(c_double),
00237                         POINTER(c_double), POINTER(c_double), POINTER(c_int),
00238                         POINTER(c_int), POINTER(c_int), c_int]
00239         _fields_ = genFields(_names, _types)
00240 
00241         def __init__(self):
00242                 self.__createfrom__ = 'python'
00243 
00244         def __del__(self):
00245                 # free memory created by C to avoid memory leak
00246                 if hasattr(self, '__createfrom__') and self.__createfrom__ == 'C':
00247                         libsvm.svm_free_and_destroy_model(pointer(self))
00248 
00249         def get_svm_type(self):
00250                 return libsvm.svm_get_svm_type(self)
00251 
00252         def get_nr_class(self):
00253                 return libsvm.svm_get_nr_class(self)
00254 
00255         def get_svr_probability(self):
00256                 return libsvm.svm_get_svr_probability(self)
00257 
00258         def get_labels(self):
00259                 nr_class = self.get_nr_class()
00260                 labels = (c_int * nr_class)()
00261                 libsvm.svm_get_labels(self, labels)
00262                 return labels[:nr_class]
00263 
00264         def get_sv_indices(self):
00265                 total_sv = self.get_nr_sv()
00266                 sv_indices = (c_int * total_sv)()
00267                 libsvm.svm_get_sv_indices(self, sv_indices)
00268                 return sv_indices[:total_sv]
00269 
00270         def get_nr_sv(self):
00271                 return libsvm.svm_get_nr_sv(self)
00272 
00273         def is_probability_model(self):
00274                 return (libsvm.svm_check_probability_model(self) == 1)
00275 
00276         def get_sv_coef(self):
00277                 return [tuple(self.sv_coef[j][i] for j in xrange(self.nr_class - 1))
00278                                 for i in xrange(self.l)]
00279 
00280         def get_SV(self):
00281                 result = []
00282                 for sparse_sv in self.SV[:self.l]:
00283                         row = dict()
00284 
00285                         i = 0
00286                         while True:
00287                                 row[sparse_sv[i].index] = sparse_sv[i].value
00288                                 if sparse_sv[i].index == -1:
00289                                         break
00290                                 i += 1
00291 
00292                         result.append(row)
00293                 return result
00294 
00295 def toPyModel(model_ptr):
00296         """
00297         toPyModel(model_ptr) -> svm_model
00298 
00299         Convert a ctypes POINTER(svm_model) to a Python svm_model
00300         """
00301         if bool(model_ptr) == False:
00302                 raise ValueError("Null pointer")
00303         m = model_ptr.contents
00304         m.__createfrom__ = 'C'
00305         return m
00306 
00307 fillprototype(libsvm.svm_train, POINTER(svm_model), [POINTER(svm_problem), POINTER(svm_parameter)])
00308 fillprototype(libsvm.svm_cross_validation, None, [POINTER(svm_problem), POINTER(svm_parameter), c_int, POINTER(c_double)])
00309 
00310 fillprototype(libsvm.svm_save_model, c_int, [c_char_p, POINTER(svm_model)])
00311 fillprototype(libsvm.svm_load_model, POINTER(svm_model), [c_char_p])
00312 
00313 fillprototype(libsvm.svm_get_svm_type, c_int, [POINTER(svm_model)])
00314 fillprototype(libsvm.svm_get_nr_class, c_int, [POINTER(svm_model)])
00315 fillprototype(libsvm.svm_get_labels, None, [POINTER(svm_model), POINTER(c_int)])
00316 fillprototype(libsvm.svm_get_sv_indices, None, [POINTER(svm_model), POINTER(c_int)])
00317 fillprototype(libsvm.svm_get_nr_sv, c_int, [POINTER(svm_model)])
00318 fillprototype(libsvm.svm_get_svr_probability, c_double, [POINTER(svm_model)])
00319 
00320 fillprototype(libsvm.svm_predict_values, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
00321 fillprototype(libsvm.svm_predict, c_double, [POINTER(svm_model), POINTER(svm_node)])
00322 fillprototype(libsvm.svm_predict_probability, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
00323 
00324 fillprototype(libsvm.svm_free_model_content, None, [POINTER(svm_model)])
00325 fillprototype(libsvm.svm_free_and_destroy_model, None, [POINTER(POINTER(svm_model))])
00326 fillprototype(libsvm.svm_destroy_param, None, [POINTER(svm_parameter)])
00327 
00328 fillprototype(libsvm.svm_check_parameter, c_char_p, [POINTER(svm_problem), POINTER(svm_parameter)])
00329 fillprototype(libsvm.svm_check_probability_model, c_int, [POINTER(svm_model)])
00330 fillprototype(libsvm.svm_set_print_string_function, None, [PRINT_STRING_FUN])


target_obejct_detector
Author(s): CIR-KIT
autogenerated on Thu Jun 6 2019 20:19:57