4 from ctypes.util
import find_library
9 if find_library(
'svm'):
10 libsvm = CDLL(find_library(
'svm'))
11 elif find_library(
'libsvm'):
12 libsvm = CDLL(find_library(
'libsvm'))
14 if sys.platform ==
'win32':
15 libsvm = CDLL(os.path.join(os.path.dirname(__file__),\
16 '../windows/libsvm.dll'))
18 libsvm = CDLL(os.path.join(os.path.dirname(__file__),\
22 SVM_TYPE = [
'C_SVC',
'NU_SVC',
'ONE_CLASS',
'EPSILON_SVR',
'NU_SVR' ]
23 KERNEL_TYPE = [
'LINEAR',
'POLY',
'RBF',
'SIGMOID',
'PRECOMPUTED']
24 for i, s
in enumerate(SVM_TYPE): exec(
"%s = %d" % (s , i))
25 for i, s
in enumerate(KERNEL_TYPE): exec(
"%s = %d" % (s , i))
27 PRINT_STRING_FUN = CFUNCTYPE(
None, c_char_p)
32 return list(zip(names, types))
39 _names = [
"index",
"value"]
40 _types = [c_int, c_double]
44 if isinstance(xi, dict):
45 index_range = xi.keys()
46 elif isinstance(xi, (list, tuple)):
49 index_range = range(len(xi))
51 raise TypeError(
'xi should be a dictionary, list or tuple')
54 assert(isinstance(feature_max, int))
55 index_range = filter(
lambda j: j <= feature_max, index_range)
57 index_range = filter(
lambda j:xi[j] != 0, index_range)
59 index_range = sorted(index_range)
60 ret = (svm_node * (len(index_range)+1))()
62 for idx, j
in enumerate(index_range):
64 ret[idx].value = xi[j]
67 max_idx = index_range[-1]
71 _names = [
"l",
"y",
"x"]
72 _types = [c_int, POINTER(c_double), POINTER(POINTER(svm_node))]
77 raise ValueError(
"len(y) != len(x)")
82 for i, xi
in enumerate(x):
85 max_idx =
max(max_idx, tmp_idx)
88 self.
y = (c_double * l)()
89 for i, yi
in enumerate(y): self.
y[i] = yi
91 self.
x = (POINTER(svm_node) * l)()
92 for i, xi
in enumerate(self.
x_space): self.
x[i] = xi
95 _names = [
"svm_type",
"kernel_type",
"degree",
"gamma",
"coef0",
96 "cache_size",
"eps",
"C",
"nr_weight",
"weight_label",
"weight",
97 "nu",
"p",
"shrinking",
"probability"]
98 _types = [c_int, c_int, c_int, c_double, c_double,
99 c_double, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double),
100 c_double, c_double, c_int, c_int]
109 attrs = svm_parameter._names + self.__dict__.keys()
110 values = map(
lambda attr: getattr(self, attr), attrs)
111 for attr, val
in zip(attrs, values):
112 print(
' %s: %s' % (attr, val))
135 argv = options.split()
137 self.
print_func = cast(
None, PRINT_STRING_FUN)
146 elif argv[i] ==
"-t":
149 elif argv[i] ==
"-d":
151 self.
degree = int(argv[i])
152 elif argv[i] ==
"-g":
154 self.
gamma = float(argv[i])
155 elif argv[i] ==
"-r":
157 self.
coef0 = float(argv[i])
158 elif argv[i] ==
"-n":
160 self.
nu = float(argv[i])
161 elif argv[i] ==
"-m":
164 elif argv[i] ==
"-c":
166 self.
C = float(argv[i])
167 elif argv[i] ==
"-e":
169 self.
eps = float(argv[i])
170 elif argv[i] ==
"-p":
172 self.
p = float(argv[i])
173 elif argv[i] ==
"-h":
176 elif argv[i] ==
"-b":
179 elif argv[i] ==
"-q":
181 elif argv[i] ==
"-v":
186 raise ValueError(
"n-fold cross validation: n must >= 2")
187 elif argv[i].startswith(
"-w"):
191 weight_label += [int(argv[i-1][2:])]
192 weight += [float(argv[i])]
194 raise ValueError(
"Wrong options")
197 libsvm.svm_set_print_string_function(self.
print_func)
201 self.
weight[i] = weight[i]
205 _names = [
'param',
'nr_class',
'l',
'SV',
'sv_coef',
'rho',
206 'probA',
'probB',
'label',
'nSV',
'free_sv']
207 _types = [svm_parameter, c_int, c_int, POINTER(POINTER(svm_node)),
208 POINTER(POINTER(c_double)), POINTER(c_double),
209 POINTER(c_double), POINTER(c_double), POINTER(c_int),
210 POINTER(c_int), c_int]
218 if hasattr(self,
'__createfrom__')
and self.
__createfrom__ ==
'C':
219 libsvm.svm_free_and_destroy_model(pointer(self))
222 return libsvm.svm_get_svm_type(self)
225 return libsvm.svm_get_nr_class(self)
228 return libsvm.svm_get_svr_probability(self)
232 labels = (c_int * nr_class)()
233 libsvm.svm_get_labels(self, labels)
234 return labels[:nr_class]
237 return (libsvm.svm_check_probability_model(self) == 1)
240 return [tuple(self.sv_coef[j][i]
for j
in xrange(self.nr_class - 1))
241 for i
in xrange(self.l)]
245 for sparse_sv
in self.SV[:self.l]:
250 row[sparse_sv[i].index] = sparse_sv[i].value
251 if sparse_sv[i].index == -1:
260 toPyModel(model_ptr) -> svm_model 262 Convert a ctypes POINTER(svm_model) to a Python svm_model 264 if bool(model_ptr) ==
False:
265 raise ValueError(
"Null pointer")
266 m = model_ptr.contents
267 m.__createfrom__ =
'C' 270 fillprototype(libsvm.svm_train, POINTER(svm_model), [POINTER(svm_problem), POINTER(svm_parameter)])
271 fillprototype(libsvm.svm_cross_validation,
None, [POINTER(svm_problem), POINTER(svm_parameter), c_int, POINTER(c_double)])
273 fillprototype(libsvm.svm_save_model, c_int, [c_char_p, POINTER(svm_model)])
274 fillprototype(libsvm.svm_load_model, POINTER(svm_model), [c_char_p])
276 fillprototype(libsvm.svm_get_svm_type, c_int, [POINTER(svm_model)])
277 fillprototype(libsvm.svm_get_nr_class, c_int, [POINTER(svm_model)])
278 fillprototype(libsvm.svm_get_labels,
None, [POINTER(svm_model), POINTER(c_int)])
279 fillprototype(libsvm.svm_get_svr_probability, c_double, [POINTER(svm_model)])
281 fillprototype(libsvm.svm_predict_values, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
282 fillprototype(libsvm.svm_predict, c_double, [POINTER(svm_model), POINTER(svm_node)])
283 fillprototype(libsvm.svm_predict_probability, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
285 fillprototype(libsvm.svm_free_model_content,
None, [POINTER(svm_model)])
286 fillprototype(libsvm.svm_free_and_destroy_model,
None, [POINTER(POINTER(svm_model))])
287 fillprototype(libsvm.svm_destroy_param,
None, [POINTER(svm_parameter)])
289 fillprototype(libsvm.svm_check_parameter, c_char_p, [POINTER(svm_problem), POINTER(svm_parameter)])
290 fillprototype(libsvm.svm_check_probability_model, c_int, [POINTER(svm_model)])
291 fillprototype(libsvm.svm_set_print_string_function,
None, [PRINT_STRING_FUN])
def gen_svm_nodearray(xi, feature_max=None, isKernel=None)
def is_probability_model(self)
def get_svr_probability(self)
def parse_options(self, options)
def set_to_default_values(self)
def __init__(self, options=None)
def fillprototype(f, restype, argtypes)
def __init__(self, y, x, isKernel=None)
def genFields(names, types)