00001
00002
00003 from ctypes import *
00004 from ctypes.util import find_library
00005 import sys
00006 import subprocess as sb
00007
00008
00009 if find_library('svm'):
00010 libsvm = CDLL(find_library('svm'))
00011 elif find_library('libsvm'):
00012 libsvm = CDLL(find_library('libsvm'))
00013 else:
00014 if sys.platform == 'win32':
00015 libsvm = CDLL('../windows/libsvm.dll')
00016 else:
00017 libsvm = CDLL(sb.Popen(['rospack', 'find', 'libsvm3'], stdout=sb.PIPE).communicate()[0].rstrip() + '/libsvm.so.2')
00018
00019
00020
00021 SVM_TYPE = ['C_SVC', 'NU_SVC', 'ONE_CLASS', 'EPSILON_SVR', 'NU_SVR' ]
00022 KERNEL_TYPE = ['LINEAR', 'POLY', 'RBF', 'SIGMOID', 'PRECOMPUTED']
00023 for i, s in enumerate(SVM_TYPE): exec("%s = %d" % (s , i))
00024 for i, s in enumerate(KERNEL_TYPE): exec("%s = %d" % (s , i))
00025
00026 PRINT_STRING_FUN = CFUNCTYPE(None, c_char_p)
00027 def print_null(s):
00028 return
00029
00030 def genFields(names, types):
00031 return list(zip(names, types))
00032
00033 def fillprototype(f, restype, argtypes):
00034 f.restype = restype
00035 f.argtypes = argtypes
00036
00037 class svm_node(Structure):
00038 _names = ["index", "value"]
00039 _types = [c_int, c_double]
00040 _fields_ = genFields(_names, _types)
00041
00042 def gen_svm_nodearray(xi, feature_max=None, issparse=None):
00043 if isinstance(xi, dict):
00044 index_range = xi.keys()
00045 elif isinstance(xi, (list, tuple)):
00046 index_range = range(len(xi))
00047 else:
00048 raise TypeError('xi should be a dictionary, list or tuple')
00049
00050 if feature_max:
00051 assert(isinstance(feature_max, int))
00052 index_range = filter(lambda j: j <= feature_max, index_range)
00053 if issparse:
00054 index_range = filter(lambda j:xi[j] != 0, index_range)
00055
00056 index_range = sorted(index_range)
00057 ret = (svm_node * (len(index_range)+1))()
00058 ret[-1].index = -1
00059 for idx, j in enumerate(index_range):
00060 ret[idx].index = j
00061 ret[idx].value = xi[j]
00062 max_idx = 0
00063 if index_range:
00064 max_idx = index_range[-1]
00065 return ret, max_idx
00066
00067 class svm_problem(Structure):
00068 _names = ["l", "y", "x"]
00069 _types = [c_int, POINTER(c_double), POINTER(POINTER(svm_node))]
00070 _fields_ = genFields(_names, _types)
00071
00072 def __init__(self, y, x):
00073 if len(y) != len(x):
00074 raise ValueError("len(y) != len(x)")
00075 self.l = l = len(y)
00076
00077 max_idx = 0
00078 x_space = self.x_space = []
00079 for i, xi in enumerate(x):
00080 tmp_xi, tmp_idx = gen_svm_nodearray(xi)
00081 x_space += [tmp_xi]
00082 max_idx = max(max_idx, tmp_idx)
00083 self.n = max_idx
00084
00085 self.y = (c_double * l)()
00086 for i, yi in enumerate(y): self.y[i] = yi
00087
00088 self.x = (POINTER(svm_node) * l)()
00089 for i, xi in enumerate(self.x_space): self.x[i] = xi
00090
00091 class svm_parameter(Structure):
00092 _names = ["svm_type", "kernel_type", "degree", "gamma", "coef0",
00093 "cache_size", "eps", "C", "nr_weight", "weight_label", "weight",
00094 "nu", "p", "shrinking", "probability"]
00095 _types = [c_int, c_int, c_int, c_double, c_double,
00096 c_double, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double),
00097 c_double, c_double, c_int, c_int]
00098 _fields_ = genFields(_names, _types)
00099
00100 def __init__(self, options = None):
00101 if options == None:
00102 options = ''
00103 self.parse_options(options)
00104
00105 def show(self):
00106 attrs = svm_parameter._names + self.__dict__.keys()
00107 values = map(lambda attr: getattr(self, attr), attrs)
00108 for attr, val in zip(attrs, values):
00109 print(' %s: %s' % (attr, val))
00110
00111 def set_to_default_values(self):
00112 self.svm_type = C_SVC;
00113 self.kernel_type = RBF
00114 self.degree = 3
00115 self.gamma = 0
00116 self.coef0 = 0
00117 self.nu = 0.5
00118 self.cache_size = 100
00119 self.C = 1
00120 self.eps = 0.001
00121 self.p = 0.1
00122 self.shrinking = 1
00123 self.probability = 0
00124 self.nr_weight = 0
00125 self.weight_label = (c_int*0)()
00126 self.weight = (c_double*0)()
00127 self.cross_validation = False
00128 self.nr_fold = 0
00129 self.print_func = None
00130
00131 def parse_options(self, options):
00132 argv = options.split()
00133 self.set_to_default_values()
00134 self.print_func = cast(None, PRINT_STRING_FUN)
00135 weight_label = []
00136 weight = []
00137
00138 i = 0
00139 while i < len(argv):
00140 if argv[i] == "-s":
00141 i = i + 1
00142 self.svm_type = int(argv[i])
00143 elif argv[i] == "-t":
00144 i = i + 1
00145 self.kernel_type = int(argv[i])
00146 elif argv[i] == "-d":
00147 i = i + 1
00148 self.degree = int(argv[i])
00149 elif argv[i] == "-g":
00150 i = i + 1
00151 self.gamma = float(argv[i])
00152 elif argv[i] == "-r":
00153 i = i + 1
00154 self.coef0 = float(argv[i])
00155 elif argv[i] == "-n":
00156 i = i + 1
00157 self.nu = float(argv[i])
00158 elif argv[i] == "-m":
00159 i = i + 1
00160 self.cache_size = float(argv[i])
00161 elif argv[i] == "-c":
00162 i = i + 1
00163 self.C = float(argv[i])
00164 elif argv[i] == "-e":
00165 i = i + 1
00166 self.eps = float(argv[i])
00167 elif argv[i] == "-p":
00168 i = i + 1
00169 self.p = float(argv[i])
00170 elif argv[i] == "-h":
00171 i = i + 1
00172 self.shrinking = int(argv[i])
00173 elif argv[i] == "-b":
00174 i = i + 1
00175 self.probability = int(argv[i])
00176 elif argv[i] == "-q":
00177 self.print_func = PRINT_STRING_FUN(print_null)
00178 elif argv[i] == "-v":
00179 i = i + 1
00180 self.cross_validation = 1
00181 self.nr_fold = int(argv[i])
00182 if self.nr_fold < 2:
00183 raise ValueError("n-fold cross validation: n must >= 2")
00184 elif argv[i].startswith("-w"):
00185 i = i + 1
00186 self.nr_weight += 1
00187 nr_weight = self.nr_weight
00188 weight_label += [int(argv[i-1][2:])]
00189 weight += [float(argv[i])]
00190 else:
00191 raise ValueError("Wrong options")
00192 i += 1
00193
00194 libsvm.svm_set_print_string_function(self.print_func)
00195 self.weight_label = (c_int*self.nr_weight)()
00196 self.weight = (c_double*self.nr_weight)()
00197 for i in range(self.nr_weight):
00198 self.weight[i] = weight[i]
00199 self.weight_label[i] = weight_label[i]
00200
00201 class svm_model(Structure):
00202
00203
00204
00205
00206
00207
00208
00209 def __init__(self):
00210 self.__createfrom__ = 'python'
00211
00212 def __del__(self):
00213
00214 if hasattr(self, '__createfrom__') and self.__createfrom__ == 'C':
00215 libsvm.svm_free_and_destroy_model(pointer(self))
00216
00217 def get_svm_type(self):
00218 return libsvm.svm_get_svm_type(self)
00219
00220 def get_nr_class(self):
00221 return libsvm.svm_get_nr_class(self)
00222
00223 def get_svr_probability(self):
00224 return libsvm.svm_get_svr_probability(self)
00225
00226 def get_labels(self):
00227 nr_class = self.get_nr_class()
00228 labels = (c_int * nr_class)()
00229 libsvm.svm_get_labels(self, labels)
00230 return labels[:nr_class]
00231
00232 def is_probability_model(self):
00233 return (libsvm.svm_check_probability_model(self) == 1)
00234
00235
00236
00237
00238
00239
00240
00241
00242
00243
00244
00245
00246
00247
00248
00249
00250
00251
00252
00253
00254
00255
00256
00257
00258
00259
00260 def get_support_vectors(self, samples_len):
00261 svs = (c_int * samples_len)()
00262 libsvm.svm_get_support_vectors(self, samples_len, svs)
00263 return svs[:samples_len]
00264
00265
00266
00267
00268 def toPyModel(model_ptr):
00269 """
00270 toPyModel(model_ptr) -> svm_model
00271
00272 Convert a ctypes POINTER(svm_model) to a Python svm_model
00273 """
00274 if bool(model_ptr) == False:
00275 raise ValueError("Null pointer")
00276 m = model_ptr.contents
00277 m.__createfrom__ = 'C'
00278 return m
00279
00280 fillprototype(libsvm.svm_train, POINTER(svm_model), [POINTER(svm_problem), POINTER(svm_parameter)])
00281 fillprototype(libsvm.svm_cross_validation, None, [POINTER(svm_problem), POINTER(svm_parameter), c_int, POINTER(c_double)])
00282
00283 fillprototype(libsvm.svm_save_model, c_int, [c_char_p, POINTER(svm_model)])
00284 fillprototype(libsvm.svm_load_model, POINTER(svm_model), [c_char_p])
00285
00286 fillprototype(libsvm.svm_get_svm_type, c_int, [POINTER(svm_model)])
00287 fillprototype(libsvm.svm_get_nr_class, c_int, [POINTER(svm_model)])
00288 fillprototype(libsvm.svm_get_labels, None, [POINTER(svm_model), POINTER(c_int)])
00289 fillprototype(libsvm.svm_get_support_vectors, None, [POINTER(svm_model), c_int, POINTER(c_int)])
00290 fillprototype(libsvm.svm_get_svr_probability, c_double, [POINTER(svm_model)])
00291
00292 fillprototype(libsvm.svm_predict_values, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
00293 fillprototype(libsvm.svm_predict, c_double, [POINTER(svm_model), POINTER(svm_node)])
00294 fillprototype(libsvm.svm_predict_probability, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
00295
00296 fillprototype(libsvm.svm_free_model_content, None, [POINTER(svm_model)])
00297 fillprototype(libsvm.svm_free_and_destroy_model, None, [POINTER(POINTER(svm_model))])
00298 fillprototype(libsvm.svm_destroy_param, None, [POINTER(svm_parameter)])
00299
00300 fillprototype(libsvm.svm_check_parameter, c_char_p, [POINTER(svm_problem), POINTER(svm_parameter)])
00301 fillprototype(libsvm.svm_check_probability_model, c_int, [POINTER(svm_model)])
00302 fillprototype(libsvm.svm_set_print_string_function, None, [PRINT_STRING_FUN])