00001
00002
00003 from ctypes import *
00004 from ctypes.util import find_library
00005 import sys
00006 import os
00007
00008
00009 if find_library('svm'):
00010 libsvm = CDLL(find_library('svm'))
00011 elif find_library('libsvm'):
00012 libsvm = CDLL(find_library('libsvm'))
00013 else:
00014 if sys.platform == 'win32':
00015 libsvm = CDLL(os.path.join(os.path.dirname(__file__),\
00016 '../windows/libsvm.dll'))
00017 else:
00018 libsvm = CDLL(os.path.join(os.path.dirname(__file__),\
00019 '../libsvm.so.2'))
00020
00021
00022 SVM_TYPE = ['C_SVC', 'NU_SVC', 'ONE_CLASS', 'EPSILON_SVR', 'NU_SVR' ]
00023 KERNEL_TYPE = ['LINEAR', 'POLY', 'RBF', 'SIGMOID', 'PRECOMPUTED']
00024 for i, s in enumerate(SVM_TYPE): exec("%s = %d" % (s , i))
00025 for i, s in enumerate(KERNEL_TYPE): exec("%s = %d" % (s , i))
00026
00027 PRINT_STRING_FUN = CFUNCTYPE(None, c_char_p)
00028 def print_null(s):
00029 return
00030
00031 def genFields(names, types):
00032 return list(zip(names, types))
00033
00034 def fillprototype(f, restype, argtypes):
00035 f.restype = restype
00036 f.argtypes = argtypes
00037
00038 class svm_node(Structure):
00039 _names = ["index", "value"]
00040 _types = [c_int, c_double]
00041 _fields_ = genFields(_names, _types)
00042
00043 def gen_svm_nodearray(xi, feature_max=None, isKernel=None):
00044 if isinstance(xi, dict):
00045 index_range = xi.keys()
00046 elif isinstance(xi, (list, tuple)):
00047 if not isKernel:
00048 xi = [0] + xi
00049 index_range = range(len(xi))
00050 else:
00051 raise TypeError('xi should be a dictionary, list or tuple')
00052
00053 if feature_max:
00054 assert(isinstance(feature_max, int))
00055 index_range = filter(lambda j: j <= feature_max, index_range)
00056 if not isKernel:
00057 index_range = filter(lambda j:xi[j] != 0, index_range)
00058
00059 index_range = sorted(index_range)
00060 ret = (svm_node * (len(index_range)+1))()
00061 ret[-1].index = -1
00062 for idx, j in enumerate(index_range):
00063 ret[idx].index = j
00064 ret[idx].value = xi[j]
00065 max_idx = 0
00066 if index_range:
00067 max_idx = index_range[-1]
00068 return ret, max_idx
00069
00070 class svm_problem(Structure):
00071 _names = ["l", "y", "x"]
00072 _types = [c_int, POINTER(c_double), POINTER(POINTER(svm_node))]
00073 _fields_ = genFields(_names, _types)
00074
00075 def __init__(self, y, x, isKernel=None):
00076 if len(y) != len(x):
00077 raise ValueError("len(y) != len(x)")
00078 self.l = l = len(y)
00079
00080 max_idx = 0
00081 x_space = self.x_space = []
00082 for i, xi in enumerate(x):
00083 tmp_xi, tmp_idx = gen_svm_nodearray(xi,isKernel=isKernel)
00084 x_space += [tmp_xi]
00085 max_idx = max(max_idx, tmp_idx)
00086 self.n = max_idx
00087
00088 self.y = (c_double * l)()
00089 for i, yi in enumerate(y): self.y[i] = yi
00090
00091 self.x = (POINTER(svm_node) * l)()
00092 for i, xi in enumerate(self.x_space): self.x[i] = xi
00093
00094 class svm_parameter(Structure):
00095 _names = ["svm_type", "kernel_type", "degree", "gamma", "coef0",
00096 "cache_size", "eps", "C", "nr_weight", "weight_label", "weight",
00097 "nu", "p", "shrinking", "probability"]
00098 _types = [c_int, c_int, c_int, c_double, c_double,
00099 c_double, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double),
00100 c_double, c_double, c_int, c_int]
00101 _fields_ = genFields(_names, _types)
00102
00103 def __init__(self, options = None):
00104 if options == None:
00105 options = ''
00106 self.parse_options(options)
00107
00108 def show(self):
00109 attrs = svm_parameter._names + self.__dict__.keys()
00110 values = map(lambda attr: getattr(self, attr), attrs)
00111 for attr, val in zip(attrs, values):
00112 print(' %s: %s' % (attr, val))
00113
00114 def set_to_default_values(self):
00115 self.svm_type = C_SVC;
00116 self.kernel_type = RBF
00117 self.degree = 3
00118 self.gamma = 0
00119 self.coef0 = 0
00120 self.nu = 0.5
00121 self.cache_size = 100
00122 self.C = 1
00123 self.eps = 0.001
00124 self.p = 0.1
00125 self.shrinking = 1
00126 self.probability = 0
00127 self.nr_weight = 0
00128 self.weight_label = (c_int*0)()
00129 self.weight = (c_double*0)()
00130 self.cross_validation = False
00131 self.nr_fold = 0
00132 self.print_func = None
00133
00134 def parse_options(self, options):
00135 argv = options.split()
00136 self.set_to_default_values()
00137 self.print_func = cast(None, PRINT_STRING_FUN)
00138 weight_label = []
00139 weight = []
00140
00141 i = 0
00142 while i < len(argv):
00143 if argv[i] == "-s":
00144 i = i + 1
00145 self.svm_type = int(argv[i])
00146 elif argv[i] == "-t":
00147 i = i + 1
00148 self.kernel_type = int(argv[i])
00149 elif argv[i] == "-d":
00150 i = i + 1
00151 self.degree = int(argv[i])
00152 elif argv[i] == "-g":
00153 i = i + 1
00154 self.gamma = float(argv[i])
00155 elif argv[i] == "-r":
00156 i = i + 1
00157 self.coef0 = float(argv[i])
00158 elif argv[i] == "-n":
00159 i = i + 1
00160 self.nu = float(argv[i])
00161 elif argv[i] == "-m":
00162 i = i + 1
00163 self.cache_size = float(argv[i])
00164 elif argv[i] == "-c":
00165 i = i + 1
00166 self.C = float(argv[i])
00167 elif argv[i] == "-e":
00168 i = i + 1
00169 self.eps = float(argv[i])
00170 elif argv[i] == "-p":
00171 i = i + 1
00172 self.p = float(argv[i])
00173 elif argv[i] == "-h":
00174 i = i + 1
00175 self.shrinking = int(argv[i])
00176 elif argv[i] == "-b":
00177 i = i + 1
00178 self.probability = int(argv[i])
00179 elif argv[i] == "-q":
00180 self.print_func = PRINT_STRING_FUN(print_null)
00181 elif argv[i] == "-v":
00182 i = i + 1
00183 self.cross_validation = 1
00184 self.nr_fold = int(argv[i])
00185 if self.nr_fold < 2:
00186 raise ValueError("n-fold cross validation: n must >= 2")
00187 elif argv[i].startswith("-w"):
00188 i = i + 1
00189 self.nr_weight += 1
00190 nr_weight = self.nr_weight
00191 weight_label += [int(argv[i-1][2:])]
00192 weight += [float(argv[i])]
00193 else:
00194 raise ValueError("Wrong options")
00195 i += 1
00196
00197 libsvm.svm_set_print_string_function(self.print_func)
00198 self.weight_label = (c_int*self.nr_weight)()
00199 self.weight = (c_double*self.nr_weight)()
00200 for i in range(self.nr_weight):
00201 self.weight[i] = weight[i]
00202 self.weight_label[i] = weight_label[i]
00203
00204 class svm_model(Structure):
00205 _names = ['param', 'nr_class', 'l', 'SV', 'sv_coef', 'rho',
00206 'probA', 'probB', 'label', 'nSV', 'free_sv']
00207 _types = [svm_parameter, c_int, c_int, POINTER(POINTER(svm_node)),
00208 POINTER(POINTER(c_double)), POINTER(c_double),
00209 POINTER(c_double), POINTER(c_double), POINTER(c_int),
00210 POINTER(c_int), c_int]
00211 _fields_ = genFields(_names, _types)
00212
00213 def __init__(self):
00214 self.__createfrom__ = 'python'
00215
00216 def __del__(self):
00217
00218 if hasattr(self, '__createfrom__') and self.__createfrom__ == 'C':
00219 libsvm.svm_free_and_destroy_model(pointer(self))
00220
00221 def get_svm_type(self):
00222 return libsvm.svm_get_svm_type(self)
00223
00224 def get_nr_class(self):
00225 return libsvm.svm_get_nr_class(self)
00226
00227 def get_svr_probability(self):
00228 return libsvm.svm_get_svr_probability(self)
00229
00230 def get_labels(self):
00231 nr_class = self.get_nr_class()
00232 labels = (c_int * nr_class)()
00233 libsvm.svm_get_labels(self, labels)
00234 return labels[:nr_class]
00235
00236 def is_probability_model(self):
00237 return (libsvm.svm_check_probability_model(self) == 1)
00238
00239 def get_sv_coef(self):
00240 return [tuple(self.sv_coef[j][i] for j in xrange(self.nr_class - 1))
00241 for i in xrange(self.l)]
00242
00243 def get_SV(self):
00244 result = []
00245 for sparse_sv in self.SV[:self.l]:
00246 row = dict()
00247
00248 i = 0
00249 while True:
00250 row[sparse_sv[i].index] = sparse_sv[i].value
00251 if sparse_sv[i].index == -1:
00252 break
00253 i += 1
00254
00255 result.append(row)
00256 return result
00257
00258 def toPyModel(model_ptr):
00259 """
00260 toPyModel(model_ptr) -> svm_model
00261
00262 Convert a ctypes POINTER(svm_model) to a Python svm_model
00263 """
00264 if bool(model_ptr) == False:
00265 raise ValueError("Null pointer")
00266 m = model_ptr.contents
00267 m.__createfrom__ = 'C'
00268 return m
00269
00270 fillprototype(libsvm.svm_train, POINTER(svm_model), [POINTER(svm_problem), POINTER(svm_parameter)])
00271 fillprototype(libsvm.svm_cross_validation, None, [POINTER(svm_problem), POINTER(svm_parameter), c_int, POINTER(c_double)])
00272
00273 fillprototype(libsvm.svm_save_model, c_int, [c_char_p, POINTER(svm_model)])
00274 fillprototype(libsvm.svm_load_model, POINTER(svm_model), [c_char_p])
00275
00276 fillprototype(libsvm.svm_get_svm_type, c_int, [POINTER(svm_model)])
00277 fillprototype(libsvm.svm_get_nr_class, c_int, [POINTER(svm_model)])
00278 fillprototype(libsvm.svm_get_labels, None, [POINTER(svm_model), POINTER(c_int)])
00279 fillprototype(libsvm.svm_get_svr_probability, c_double, [POINTER(svm_model)])
00280
00281 fillprototype(libsvm.svm_predict_values, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
00282 fillprototype(libsvm.svm_predict, c_double, [POINTER(svm_model), POINTER(svm_node)])
00283 fillprototype(libsvm.svm_predict_probability, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
00284
00285 fillprototype(libsvm.svm_free_model_content, None, [POINTER(svm_model)])
00286 fillprototype(libsvm.svm_free_and_destroy_model, None, [POINTER(POINTER(svm_model))])
00287 fillprototype(libsvm.svm_destroy_param, None, [POINTER(svm_parameter)])
00288
00289 fillprototype(libsvm.svm_check_parameter, c_char_p, [POINTER(svm_problem), POINTER(svm_parameter)])
00290 fillprototype(libsvm.svm_check_probability_model, c_int, [POINTER(svm_model)])
00291 fillprototype(libsvm.svm_set_print_string_function, None, [PRINT_STRING_FUN])