5 import os, sys, traceback
7 from threading
import Thread
8 from subprocess
import *
10 if(sys.hexversion < 0x03000000):
18 is_win32 = (sys.platform ==
'win32')
20 svmtrain_exe =
"../svm-train" 21 gnuplot_exe =
"/usr/bin/gnuplot" 24 svmtrain_exe =
r"..\windows\svm-train.exe" 26 gnuplot_exe =
r"c:\tmp\gnuplot\binary\pgnuplot.exe" 31 c_begin, c_end, c_step = -5, 15, 2
32 g_begin, g_end, g_step = 3, -15, -2
33 global dataset_pathname, dataset_title, pass_through_string
34 global out_filename, png_filename
46 global c_begin, c_end, c_step
47 global g_begin, g_end, g_step
48 global dataset_pathname, dataset_title, pass_through_string
49 global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename
52 Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold] 53 [-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname] 54 [additional parameters for svm-train] dataset""" 60 dataset_pathname = argv[-1]
61 dataset_title = os.path.split(dataset_pathname)[1]
62 out_filename =
'{0}.out'.format(dataset_title)
63 png_filename =
'{0}.png'.format(dataset_title)
64 pass_through_options = []
67 while i < len(argv) - 1:
68 if argv[i] ==
"-log2c":
70 (c_begin,c_end,c_step) = map(float,argv[i].split(
","))
71 elif argv[i] ==
"-log2g":
73 (g_begin,g_end,g_step) = map(float,argv[i].split(
","))
77 elif argv[i]
in (
'-c',
'-g'):
78 print(
"Option -c and -g are renamed.")
81 elif argv[i] ==
'-svmtrain':
83 svmtrain_exe = argv[i]
84 elif argv[i] ==
'-gnuplot':
87 elif argv[i] ==
'-out':
89 out_filename = argv[i]
90 elif argv[i] ==
'-png':
92 png_filename = argv[i]
94 pass_through_options.append(argv[i])
97 pass_through_string =
" ".join(pass_through_options)
98 assert os.path.exists(svmtrain_exe),
"svm-train executable not found" 99 assert os.path.exists(gnuplot_exe),
"gnuplot executable not found" 100 assert os.path.exists(dataset_pathname),
"dataset not found" 101 gnuplot = Popen(gnuplot_exe,stdin = PIPE).stdin
108 if step > 0
and begin > end:
break 109 if step < 0
and begin < end:
break 116 if n <= 1:
return seq
124 if left: ret.append(left.pop(0))
125 if right: ret.append(right.pop(0))
130 if len(db) == 0:
return 131 begin_level = round(
max(x[2]
for x
in db)) - 3
134 best_log2c,best_log2g,best_rate = best_param
138 if all(x[0] == db[0][0]
for x
in db):
return 139 if all(x[1] == db[0][1]
for x
in db):
return 140 if all(x[2] == db[0][2]
for x
in db):
return 143 gnuplot.write(b
"set term png transparent small linewidth 2 medium enhanced\n")
144 gnuplot.write(
"set output \"{0}\"\n".format(png_filename.replace(
'\\',
'\\\\')).encode())
148 gnuplot.write(b
"set term windows\n")
150 gnuplot.write( b
"set term x11\n")
151 gnuplot.write(b
"set xlabel \"log2(C)\"\n")
152 gnuplot.write(b
"set ylabel \"log2(gamma)\"\n")
153 gnuplot.write(
"set xrange [{0}:{1}]\n".format(c_begin,c_end).encode())
154 gnuplot.write(
"set yrange [{0}:{1}]\n".format(g_begin,g_end).encode())
155 gnuplot.write(b
"set contour\n")
156 gnuplot.write(
"set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode())
157 gnuplot.write(b
"unset surface\n")
158 gnuplot.write(b
"unset ztics\n")
159 gnuplot.write(b
"set view 0,0\n")
160 gnuplot.write(
"set title \"{0}\"\n".format(dataset_title).encode())
161 gnuplot.write(b
"unset label\n")
162 gnuplot.write(
"set label \"Best log2(C) = {0} log2(gamma) = {1} accuracy = {2}%\" \ 163 at screen 0.5,0.85 center\n". \
164 format(best_log2c, best_log2g, best_rate).encode())
165 gnuplot.write(
"set label \"C = {0} gamma = {1}\"" 166 " at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode())
167 gnuplot.write(b
"set key at screen 0.9,0.9\n")
168 gnuplot.write(b
"splot \"-\" with lines\n")
173 db.sort(key =
lambda x:(x[0], -x[1]))
180 gnuplot.write(
"{0[0]} {0[1]} {0[2]}\n".format(line).encode())
181 gnuplot.write(b
"e\n")
189 nr_c = float(len(c_seq))
190 nr_g = float(len(g_seq))
195 while i < nr_c
or j < nr_g:
200 line.append((c_seq[i],g_seq[k]))
207 line.append((c_seq[k],g_seq[j]))
215 class Worker(Thread):
217 Thread.__init__(self)
223 (cexp,gexp) = self.job_queue.get()
224 if cexp
is WorkerStopToken:
225 self.job_queue.put((cexp,gexp))
229 rate = self.run_one(2.0**cexp,2.0**gexp)
230 if rate
is None:
raise RuntimeError(
"get no rate")
234 traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
236 self.job_queue.put((cexp,gexp))
237 print(
'worker {0} quit.'.format(self.
name))
240 self.result_queue.put((self.
name,cexp,gexp,rate))
244 cmdline =
'{0} -c {1} -g {2} -v {3} {4} {5}'.format \
245 (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
246 result = Popen(cmdline,shell=
True,stdout=PIPE).stdout
247 for line
in result.readlines():
248 if str(line).find(
"Cross") != -1:
249 return float(line.split()[-1][0:-1])
252 def __init__(self,name,job_queue,result_queue,host):
253 Worker.__init__(self,name,job_queue,result_queue)
257 cmdline =
'ssh -x {0} "cd {1}; {2} -c {3} -g {4} -v {5} {6} {7}"'.format \
259 svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
260 result = Popen(cmdline,shell=
True,stdout=PIPE).stdout
261 for line
in result.readlines():
262 if str(line).find(
"Cross") != -1:
263 return float(line.split()[-1][0:-1])
266 def __init__(self,name,job_queue,result_queue,host,username,password):
267 Worker.__init__(self,name,job_queue,result_queue)
273 self.
tn = tn = telnetlib.Telnet(self.
host)
274 tn.read_until(
"login: ")
276 tn.read_until(
"Password: ")
282 print(
'login ok', self.
host)
283 tn.write(
"cd "+os.getcwd()+
"\n")
287 cmdline =
'{0} -c {1} -g {2} -v {3} {4} {5}'.format \
288 (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
289 result = self.tn.write(cmdline+
'\n')
290 (idx,matchm,output) = self.tn.expect([
'Cross.*\n'])
291 for line
in output.split(
'\n'):
292 if str(line).find(
"Cross") != -1:
293 return float(line.split()[-1][0:-1])
304 job_queue = Queue.Queue(0)
305 result_queue = Queue.Queue(0)
318 job_queue._put = job_queue.queue.appendleft
324 nr_telnet_worker = len(telnet_workers)
325 username = getpass.getuser()
326 password = getpass.getpass()
327 for host
in telnet_workers:
329 host,username,password).start()
334 for host
in ssh_workers:
335 SSHWorker(host,job_queue,result_queue,host).start()
339 for i
in range(nr_local_worker):
340 LocalWorker(
'local',job_queue,result_queue).start()
347 result_file = open(out_filename,
'w')
352 best_c1,best_g1 =
None,
None 356 while (c, g)
not in done_jobs:
357 (worker,c1,g1,rate) = result_queue.get()
358 done_jobs[(c1,g1)] = rate
359 result_file.write(
'{0} {1} {2}\n'.format(c1,g1,rate))
361 if (rate > best_rate)
or (rate==best_rate
and g1==best_g1
and c1<best_c1):
363 best_c1,best_g1=c1,g1
366 print(
"[{0}] {1} {2} {3} (best c={4}, g={5}, rate={6})".format \
367 (worker,c1,g1,rate, best_c, best_g, best_rate))
368 db.append((c,g,done_jobs[(c,g)]))
369 redraw(db,[best_c1, best_g1, best_rate])
370 redraw(db,[best_c1, best_g1, best_rate],
True)
372 job_queue.put((WorkerStopToken,
None))
373 print(
"{0} {1} {2}".format(best_c, best_g, best_rate))
def __init__(self, name, job_queue, result_queue, host, username, password)
def __init__(self, name, job_queue, result_queue)
def process_options(argv=sys.argv)
def __init__(self, name, job_queue, result_queue, host)
def redraw(db, best_param, tofile=False)
def permute_sequence(seq)
def range_f(begin, end, step)