00001
00002
00003 from ctypes import *
00004 from ctypes.util import find_library
00005 from os import path
00006 import sys
00007
00008 if sys.version_info[0] >= 3:
00009 xrange = range
00010
00011 __all__ = ['libsvm', 'svm_problem', 'svm_parameter',
00012 'toPyModel', 'gen_svm_nodearray', 'print_null', 'svm_node', 'C_SVC',
00013 'EPSILON_SVR', 'LINEAR', 'NU_SVC', 'NU_SVR', 'ONE_CLASS',
00014 'POLY', 'PRECOMPUTED', 'PRINT_STRING_FUN', 'RBF',
00015 'SIGMOID', 'c_double', 'svm_model']
00016
00017 try:
00018 dirname = path.dirname(path.abspath(__file__))
00019 if sys.platform == 'win32':
00020 libsvm = CDLL(path.join(dirname, r'..\windows\libsvm.dll'))
00021 else:
00022 libsvm = CDLL(path.join(dirname, '../libsvm.so.2'))
00023 except:
00024
00025 if find_library('svm'):
00026 libsvm = CDLL(find_library('svm'))
00027 elif find_library('libsvm'):
00028 libsvm = CDLL(find_library('libsvm'))
00029 else:
00030 raise Exception('LIBSVM library not found.')
00031
00032 C_SVC = 0
00033 NU_SVC = 1
00034 ONE_CLASS = 2
00035 EPSILON_SVR = 3
00036 NU_SVR = 4
00037
00038 LINEAR = 0
00039 POLY = 1
00040 RBF = 2
00041 SIGMOID = 3
00042 PRECOMPUTED = 4
00043
00044 PRINT_STRING_FUN = CFUNCTYPE(None, c_char_p)
00045 def print_null(s):
00046 return
00047
00048 def genFields(names, types):
00049 return list(zip(names, types))
00050
00051 def fillprototype(f, restype, argtypes):
00052 f.restype = restype
00053 f.argtypes = argtypes
00054
00055 class svm_node(Structure):
00056 _names = ["index", "value"]
00057 _types = [c_int, c_double]
00058 _fields_ = genFields(_names, _types)
00059
00060 def __str__(self):
00061 return '%d:%g' % (self.index, self.value)
00062
00063 def gen_svm_nodearray(xi, feature_max=None, isKernel=None):
00064 if isinstance(xi, dict):
00065 index_range = xi.keys()
00066 elif isinstance(xi, (list, tuple)):
00067 if not isKernel:
00068 xi = [0] + xi
00069 index_range = range(len(xi))
00070 else:
00071 raise TypeError('xi should be a dictionary, list or tuple')
00072
00073 if feature_max:
00074 assert(isinstance(feature_max, int))
00075 index_range = filter(lambda j: j <= feature_max, index_range)
00076 if not isKernel:
00077 index_range = filter(lambda j:xi[j] != 0, index_range)
00078
00079 index_range = sorted(index_range)
00080 ret = (svm_node * (len(index_range)+1))()
00081 ret[-1].index = -1
00082 for idx, j in enumerate(index_range):
00083 ret[idx].index = j
00084 ret[idx].value = xi[j]
00085 max_idx = 0
00086 if index_range:
00087 max_idx = index_range[-1]
00088 return ret, max_idx
00089
00090 class svm_problem(Structure):
00091 _names = ["l", "y", "x"]
00092 _types = [c_int, POINTER(c_double), POINTER(POINTER(svm_node))]
00093 _fields_ = genFields(_names, _types)
00094
00095 def __init__(self, y, x, isKernel=None):
00096 if len(y) != len(x):
00097 raise ValueError("len(y) != len(x)")
00098 self.l = l = len(y)
00099
00100 max_idx = 0
00101 x_space = self.x_space = []
00102 for i, xi in enumerate(x):
00103 tmp_xi, tmp_idx = gen_svm_nodearray(xi,isKernel=isKernel)
00104 x_space += [tmp_xi]
00105 max_idx = max(max_idx, tmp_idx)
00106 self.n = max_idx
00107
00108 self.y = (c_double * l)()
00109 for i, yi in enumerate(y): self.y[i] = yi
00110
00111 self.x = (POINTER(svm_node) * l)()
00112 for i, xi in enumerate(self.x_space): self.x[i] = xi
00113
00114 class svm_parameter(Structure):
00115 _names = ["svm_type", "kernel_type", "degree", "gamma", "coef0",
00116 "cache_size", "eps", "C", "nr_weight", "weight_label", "weight",
00117 "nu", "p", "shrinking", "probability"]
00118 _types = [c_int, c_int, c_int, c_double, c_double,
00119 c_double, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double),
00120 c_double, c_double, c_int, c_int]
00121 _fields_ = genFields(_names, _types)
00122
00123 def __init__(self, options = None):
00124 if options == None:
00125 options = ''
00126 self.parse_options(options)
00127
00128 def __str__(self):
00129 s = ''
00130 attrs = svm_parameter._names + list(self.__dict__.keys())
00131 values = map(lambda attr: getattr(self, attr), attrs)
00132 for attr, val in zip(attrs, values):
00133 s += (' %s: %s\n' % (attr, val))
00134 s = s.strip()
00135
00136 return s
00137
00138 def set_to_default_values(self):
00139 self.svm_type = C_SVC;
00140 self.kernel_type = RBF
00141 self.degree = 3
00142 self.gamma = 0
00143 self.coef0 = 0
00144 self.nu = 0.5
00145 self.cache_size = 100
00146 self.C = 1
00147 self.eps = 0.001
00148 self.p = 0.1
00149 self.shrinking = 1
00150 self.probability = 0
00151 self.nr_weight = 0
00152 self.weight_label = None
00153 self.weight = None
00154 self.cross_validation = False
00155 self.nr_fold = 0
00156 self.print_func = cast(None, PRINT_STRING_FUN)
00157
00158 def parse_options(self, options):
00159 if isinstance(options, list):
00160 argv = options
00161 elif isinstance(options, str):
00162 argv = options.split()
00163 else:
00164 raise TypeError("arg 1 should be a list or a str.")
00165 self.set_to_default_values()
00166 self.print_func = cast(None, PRINT_STRING_FUN)
00167 weight_label = []
00168 weight = []
00169
00170 i = 0
00171 while i < len(argv):
00172 if argv[i] == "-s":
00173 i = i + 1
00174 self.svm_type = int(argv[i])
00175 elif argv[i] == "-t":
00176 i = i + 1
00177 self.kernel_type = int(argv[i])
00178 elif argv[i] == "-d":
00179 i = i + 1
00180 self.degree = int(argv[i])
00181 elif argv[i] == "-g":
00182 i = i + 1
00183 self.gamma = float(argv[i])
00184 elif argv[i] == "-r":
00185 i = i + 1
00186 self.coef0 = float(argv[i])
00187 elif argv[i] == "-n":
00188 i = i + 1
00189 self.nu = float(argv[i])
00190 elif argv[i] == "-m":
00191 i = i + 1
00192 self.cache_size = float(argv[i])
00193 elif argv[i] == "-c":
00194 i = i + 1
00195 self.C = float(argv[i])
00196 elif argv[i] == "-e":
00197 i = i + 1
00198 self.eps = float(argv[i])
00199 elif argv[i] == "-p":
00200 i = i + 1
00201 self.p = float(argv[i])
00202 elif argv[i] == "-h":
00203 i = i + 1
00204 self.shrinking = int(argv[i])
00205 elif argv[i] == "-b":
00206 i = i + 1
00207 self.probability = int(argv[i])
00208 elif argv[i] == "-q":
00209 self.print_func = PRINT_STRING_FUN(print_null)
00210 elif argv[i] == "-v":
00211 i = i + 1
00212 self.cross_validation = 1
00213 self.nr_fold = int(argv[i])
00214 if self.nr_fold < 2:
00215 raise ValueError("n-fold cross validation: n must >= 2")
00216 elif argv[i].startswith("-w"):
00217 i = i + 1
00218 self.nr_weight += 1
00219 weight_label += [int(argv[i-1][2:])]
00220 weight += [float(argv[i])]
00221 else:
00222 raise ValueError("Wrong options")
00223 i += 1
00224
00225 libsvm.svm_set_print_string_function(self.print_func)
00226 self.weight_label = (c_int*self.nr_weight)()
00227 self.weight = (c_double*self.nr_weight)()
00228 for i in range(self.nr_weight):
00229 self.weight[i] = weight[i]
00230 self.weight_label[i] = weight_label[i]
00231
00232 class svm_model(Structure):
00233 _names = ['param', 'nr_class', 'l', 'SV', 'sv_coef', 'rho',
00234 'probA', 'probB', 'sv_indices', 'label', 'nSV', 'free_sv']
00235 _types = [svm_parameter, c_int, c_int, POINTER(POINTER(svm_node)),
00236 POINTER(POINTER(c_double)), POINTER(c_double),
00237 POINTER(c_double), POINTER(c_double), POINTER(c_int),
00238 POINTER(c_int), POINTER(c_int), c_int]
00239 _fields_ = genFields(_names, _types)
00240
00241 def __init__(self):
00242 self.__createfrom__ = 'python'
00243
00244 def __del__(self):
00245
00246 if hasattr(self, '__createfrom__') and self.__createfrom__ == 'C':
00247 libsvm.svm_free_and_destroy_model(pointer(self))
00248
00249 def get_svm_type(self):
00250 return libsvm.svm_get_svm_type(self)
00251
00252 def get_nr_class(self):
00253 return libsvm.svm_get_nr_class(self)
00254
00255 def get_svr_probability(self):
00256 return libsvm.svm_get_svr_probability(self)
00257
00258 def get_labels(self):
00259 nr_class = self.get_nr_class()
00260 labels = (c_int * nr_class)()
00261 libsvm.svm_get_labels(self, labels)
00262 return labels[:nr_class]
00263
00264 def get_sv_indices(self):
00265 total_sv = self.get_nr_sv()
00266 sv_indices = (c_int * total_sv)()
00267 libsvm.svm_get_sv_indices(self, sv_indices)
00268 return sv_indices[:total_sv]
00269
00270 def get_nr_sv(self):
00271 return libsvm.svm_get_nr_sv(self)
00272
00273 def is_probability_model(self):
00274 return (libsvm.svm_check_probability_model(self) == 1)
00275
00276 def get_sv_coef(self):
00277 return [tuple(self.sv_coef[j][i] for j in xrange(self.nr_class - 1))
00278 for i in xrange(self.l)]
00279
00280 def get_SV(self):
00281 result = []
00282 for sparse_sv in self.SV[:self.l]:
00283 row = dict()
00284
00285 i = 0
00286 while True:
00287 row[sparse_sv[i].index] = sparse_sv[i].value
00288 if sparse_sv[i].index == -1:
00289 break
00290 i += 1
00291
00292 result.append(row)
00293 return result
00294
00295 def toPyModel(model_ptr):
00296 """
00297 toPyModel(model_ptr) -> svm_model
00298
00299 Convert a ctypes POINTER(svm_model) to a Python svm_model
00300 """
00301 if bool(model_ptr) == False:
00302 raise ValueError("Null pointer")
00303 m = model_ptr.contents
00304 m.__createfrom__ = 'C'
00305 return m
00306
00307 fillprototype(libsvm.svm_train, POINTER(svm_model), [POINTER(svm_problem), POINTER(svm_parameter)])
00308 fillprototype(libsvm.svm_cross_validation, None, [POINTER(svm_problem), POINTER(svm_parameter), c_int, POINTER(c_double)])
00309
00310 fillprototype(libsvm.svm_save_model, c_int, [c_char_p, POINTER(svm_model)])
00311 fillprototype(libsvm.svm_load_model, POINTER(svm_model), [c_char_p])
00312
00313 fillprototype(libsvm.svm_get_svm_type, c_int, [POINTER(svm_model)])
00314 fillprototype(libsvm.svm_get_nr_class, c_int, [POINTER(svm_model)])
00315 fillprototype(libsvm.svm_get_labels, None, [POINTER(svm_model), POINTER(c_int)])
00316 fillprototype(libsvm.svm_get_sv_indices, None, [POINTER(svm_model), POINTER(c_int)])
00317 fillprototype(libsvm.svm_get_nr_sv, c_int, [POINTER(svm_model)])
00318 fillprototype(libsvm.svm_get_svr_probability, c_double, [POINTER(svm_model)])
00319
00320 fillprototype(libsvm.svm_predict_values, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
00321 fillprototype(libsvm.svm_predict, c_double, [POINTER(svm_model), POINTER(svm_node)])
00322 fillprototype(libsvm.svm_predict_probability, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
00323
00324 fillprototype(libsvm.svm_free_model_content, None, [POINTER(svm_model)])
00325 fillprototype(libsvm.svm_free_and_destroy_model, None, [POINTER(POINTER(svm_model))])
00326 fillprototype(libsvm.svm_destroy_param, None, [POINTER(svm_parameter)])
00327
00328 fillprototype(libsvm.svm_check_parameter, c_char_p, [POINTER(svm_problem), POINTER(svm_parameter)])
00329 fillprototype(libsvm.svm_check_probability_model, c_int, [POINTER(svm_model)])
00330 fillprototype(libsvm.svm_set_print_string_function, None, [PRINT_STRING_FUN])