00001
00002 __all__ = ['find_parameters']
00003
00004 import os, sys, traceback, getpass, time, re
00005 from threading import Thread
00006 from subprocess import *
00007
00008 if sys.version_info[0] < 3:
00009 from Queue import Queue
00010 else:
00011 from queue import Queue
00012
00013 telnet_workers = []
00014 ssh_workers = []
00015 nr_local_worker = 1
00016
00017 class GridOption:
00018 def __init__(self, dataset_pathname, options):
00019 dirname = os.path.dirname(__file__)
00020 if sys.platform != 'win32':
00021 self.svmtrain_pathname = os.path.join(dirname, '../svm-train')
00022 self.gnuplot_pathname = '/usr/bin/gnuplot'
00023 else:
00024
00025 self.svmtrain_pathname = os.path.join(dirname, r'..\windows\svm-train.exe')
00026
00027 self.gnuplot_pathname = r'c:\tmp\gnuplot\binary\pgnuplot.exe'
00028 self.fold = 5
00029 self.c_begin, self.c_end, self.c_step = -5, 15, 2
00030 self.g_begin, self.g_end, self.g_step = 3, -15, -2
00031 self.grid_with_c, self.grid_with_g = True, True
00032 self.dataset_pathname = dataset_pathname
00033 self.dataset_title = os.path.split(dataset_pathname)[1]
00034 self.out_pathname = '{0}.out'.format(self.dataset_title)
00035 self.png_pathname = '{0}.png'.format(self.dataset_title)
00036 self.pass_through_string = ' '
00037 self.resume_pathname = None
00038 self.parse_options(options)
00039
00040 def parse_options(self, options):
00041 if type(options) == str:
00042 options = options.split()
00043 i = 0
00044 pass_through_options = []
00045
00046 while i < len(options):
00047 if options[i] == '-log2c':
00048 i = i + 1
00049 if options[i] == 'null':
00050 self.grid_with_c = False
00051 else:
00052 self.c_begin, self.c_end, self.c_step = map(float,options[i].split(','))
00053 elif options[i] == '-log2g':
00054 i = i + 1
00055 if options[i] == 'null':
00056 self.grid_with_g = False
00057 else:
00058 self.g_begin, self.g_end, self.g_step = map(float,options[i].split(','))
00059 elif options[i] == '-v':
00060 i = i + 1
00061 self.fold = options[i]
00062 elif options[i] in ('-c','-g'):
00063 raise ValueError('Use -log2c and -log2g.')
00064 elif options[i] == '-svmtrain':
00065 i = i + 1
00066 self.svmtrain_pathname = options[i]
00067 elif options[i] == '-gnuplot':
00068 i = i + 1
00069 if options[i] == 'null':
00070 self.gnuplot_pathname = None
00071 else:
00072 self.gnuplot_pathname = options[i]
00073 elif options[i] == '-out':
00074 i = i + 1
00075 if options[i] == 'null':
00076 self.out_pathname = None
00077 else:
00078 self.out_pathname = options[i]
00079 elif options[i] == '-png':
00080 i = i + 1
00081 self.png_pathname = options[i]
00082 elif options[i] == '-resume':
00083 if i == (len(options)-1) or options[i+1].startswith('-'):
00084 self.resume_pathname = self.dataset_title + '.out'
00085 else:
00086 i = i + 1
00087 self.resume_pathname = options[i]
00088 else:
00089 pass_through_options.append(options[i])
00090 i = i + 1
00091
00092 self.pass_through_string = ' '.join(pass_through_options)
00093 if not os.path.exists(self.svmtrain_pathname):
00094 raise IOError('svm-train executable not found')
00095 if not os.path.exists(self.dataset_pathname):
00096 raise IOError('dataset not found')
00097 if self.resume_pathname and not os.path.exists(self.resume_pathname):
00098 raise IOError('file for resumption not found')
00099 if not self.grid_with_c and not self.grid_with_g:
00100 raise ValueError('-log2c and -log2g should not be null simultaneously')
00101 if self.gnuplot_pathname and not os.path.exists(self.gnuplot_pathname):
00102 sys.stderr.write('gnuplot executable not found\n')
00103 self.gnuplot_pathname = None
00104
00105 def redraw(db,best_param,gnuplot,options,tofile=False):
00106 if len(db) == 0: return
00107 begin_level = round(max(x[2] for x in db)) - 3
00108 step_size = 0.5
00109
00110 best_log2c,best_log2g,best_rate = best_param
00111
00112
00113
00114 if all(x[0] == db[0][0] for x in db): return
00115 if all(x[1] == db[0][1] for x in db): return
00116 if all(x[2] == db[0][2] for x in db): return
00117
00118 if tofile:
00119 gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n")
00120 gnuplot.write("set output \"{0}\"\n".format(options.png_pathname.replace('\\','\\\\')).encode())
00121
00122
00123 elif sys.platform == 'win32':
00124 gnuplot.write(b"set term windows\n")
00125 else:
00126 gnuplot.write( b"set term x11\n")
00127 gnuplot.write(b"set xlabel \"log2(C)\"\n")
00128 gnuplot.write(b"set ylabel \"log2(gamma)\"\n")
00129 gnuplot.write("set xrange [{0}:{1}]\n".format(options.c_begin,options.c_end).encode())
00130 gnuplot.write("set yrange [{0}:{1}]\n".format(options.g_begin,options.g_end).encode())
00131 gnuplot.write(b"set contour\n")
00132 gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode())
00133 gnuplot.write(b"unset surface\n")
00134 gnuplot.write(b"unset ztics\n")
00135 gnuplot.write(b"set view 0,0\n")
00136 gnuplot.write("set title \"{0}\"\n".format(options.dataset_title).encode())
00137 gnuplot.write(b"unset label\n")
00138 gnuplot.write("set label \"Best log2(C) = {0} log2(gamma) = {1} accuracy = {2}%\" \
00139 at screen 0.5,0.85 center\n". \
00140 format(best_log2c, best_log2g, best_rate).encode())
00141 gnuplot.write("set label \"C = {0} gamma = {1}\""
00142 " at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode())
00143 gnuplot.write(b"set key at screen 0.9,0.9\n")
00144 gnuplot.write(b"splot \"-\" with lines\n")
00145
00146 db.sort(key = lambda x:(x[0], -x[1]))
00147
00148 prevc = db[0][0]
00149 for line in db:
00150 if prevc != line[0]:
00151 gnuplot.write(b"\n")
00152 prevc = line[0]
00153 gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode())
00154 gnuplot.write(b"e\n")
00155 gnuplot.write(b"\n")
00156 gnuplot.flush()
00157
00158
00159 def calculate_jobs(options):
00160
00161 def range_f(begin,end,step):
00162
00163 seq = []
00164 while True:
00165 if step > 0 and begin > end: break
00166 if step < 0 and begin < end: break
00167 seq.append(begin)
00168 begin = begin + step
00169 return seq
00170
00171 def permute_sequence(seq):
00172 n = len(seq)
00173 if n <= 1: return seq
00174
00175 mid = int(n/2)
00176 left = permute_sequence(seq[:mid])
00177 right = permute_sequence(seq[mid+1:])
00178
00179 ret = [seq[mid]]
00180 while left or right:
00181 if left: ret.append(left.pop(0))
00182 if right: ret.append(right.pop(0))
00183
00184 return ret
00185
00186
00187 c_seq = permute_sequence(range_f(options.c_begin,options.c_end,options.c_step))
00188 g_seq = permute_sequence(range_f(options.g_begin,options.g_end,options.g_step))
00189
00190 if not options.grid_with_c:
00191 c_seq = [None]
00192 if not options.grid_with_g:
00193 g_seq = [None]
00194
00195 nr_c = float(len(c_seq))
00196 nr_g = float(len(g_seq))
00197 i, j = 0, 0
00198 jobs = []
00199
00200 while i < nr_c or j < nr_g:
00201 if i/nr_c < j/nr_g:
00202
00203 line = []
00204 for k in range(0,j):
00205 line.append((c_seq[i],g_seq[k]))
00206 i = i + 1
00207 jobs.append(line)
00208 else:
00209
00210 line = []
00211 for k in range(0,i):
00212 line.append((c_seq[k],g_seq[j]))
00213 j = j + 1
00214 jobs.append(line)
00215
00216 resumed_jobs = {}
00217
00218 if options.resume_pathname is None:
00219 return jobs, resumed_jobs
00220
00221 for line in open(options.resume_pathname, 'r'):
00222 line = line.strip()
00223 rst = re.findall(r'rate=([0-9.]+)',line)
00224 if not rst:
00225 continue
00226 rate = float(rst[0])
00227
00228 c, g = None, None
00229 rst = re.findall(r'log2c=([0-9.-]+)',line)
00230 if rst:
00231 c = float(rst[0])
00232 rst = re.findall(r'log2g=([0-9.-]+)',line)
00233 if rst:
00234 g = float(rst[0])
00235
00236 resumed_jobs[(c,g)] = rate
00237
00238 return jobs, resumed_jobs
00239
00240
00241 class WorkerStopToken:
00242 pass
00243
00244 class Worker(Thread):
00245 def __init__(self,name,job_queue,result_queue,options):
00246 Thread.__init__(self)
00247 self.name = name
00248 self.job_queue = job_queue
00249 self.result_queue = result_queue
00250 self.options = options
00251
00252 def run(self):
00253 while True:
00254 (cexp,gexp) = self.job_queue.get()
00255 if cexp is WorkerStopToken:
00256 self.job_queue.put((cexp,gexp))
00257
00258 break
00259 try:
00260 c, g = None, None
00261 if cexp != None:
00262 c = 2.0**cexp
00263 if gexp != None:
00264 g = 2.0**gexp
00265 rate = self.run_one(c,g)
00266 if rate is None: raise RuntimeError('get no rate')
00267 except:
00268
00269
00270 traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
00271
00272 self.job_queue.put((cexp,gexp))
00273 sys.stderr.write('worker {0} quit.\n'.format(self.name))
00274 break
00275 else:
00276 self.result_queue.put((self.name,cexp,gexp,rate))
00277
00278 def get_cmd(self,c,g):
00279 options=self.options
00280 cmdline = '"' + options.svmtrain_pathname + '"'
00281 if options.grid_with_c:
00282 cmdline += ' -c {0} '.format(c)
00283 if options.grid_with_g:
00284 cmdline += ' -g {0} '.format(g)
00285 cmdline += ' -v {0} {1} {2} '.format\
00286 (options.fold,options.pass_through_string,options.dataset_pathname)
00287 return cmdline
00288
00289 class LocalWorker(Worker):
00290 def run_one(self,c,g):
00291 cmdline = self.get_cmd(c,g)
00292 result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
00293 for line in result.readlines():
00294 if str(line).find('Cross') != -1:
00295 return float(line.split()[-1][0:-1])
00296
00297 class SSHWorker(Worker):
00298 def __init__(self,name,job_queue,result_queue,host,options):
00299 Worker.__init__(self,name,job_queue,result_queue,options)
00300 self.host = host
00301 self.cwd = os.getcwd()
00302 def run_one(self,c,g):
00303 cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"'.format\
00304 (self.host,self.cwd,self.get_cmd(c,g))
00305 result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
00306 for line in result.readlines():
00307 if str(line).find('Cross') != -1:
00308 return float(line.split()[-1][0:-1])
00309
00310 class TelnetWorker(Worker):
00311 def __init__(self,name,job_queue,result_queue,host,username,password,options):
00312 Worker.__init__(self,name,job_queue,result_queue,options)
00313 self.host = host
00314 self.username = username
00315 self.password = password
00316 def run(self):
00317 import telnetlib
00318 self.tn = tn = telnetlib.Telnet(self.host)
00319 tn.read_until('login: ')
00320 tn.write(self.username + '\n')
00321 tn.read_until('Password: ')
00322 tn.write(self.password + '\n')
00323
00324
00325 tn.read_until(self.username)
00326
00327 print('login ok', self.host)
00328 tn.write('cd '+os.getcwd()+'\n')
00329 Worker.run(self)
00330 tn.write('exit\n')
00331 def run_one(self,c,g):
00332 cmdline = self.get_cmd(c,g)
00333 result = self.tn.write(cmdline+'\n')
00334 (idx,matchm,output) = self.tn.expect(['Cross.*\n'])
00335 for line in output.split('\n'):
00336 if str(line).find('Cross') != -1:
00337 return float(line.split()[-1][0:-1])
00338
00339 def find_parameters(dataset_pathname, options=''):
00340
00341 def update_param(c,g,rate,best_c,best_g,best_rate,worker,resumed):
00342 if (rate > best_rate) or (rate==best_rate and g==best_g and c<best_c):
00343 best_rate,best_c,best_g = rate,c,g
00344 stdout_str = '[{0}] {1} {2} (best '.format\
00345 (worker,' '.join(str(x) for x in [c,g] if x is not None),rate)
00346 output_str = ''
00347 if c != None:
00348 stdout_str += 'c={0}, '.format(2.0**best_c)
00349 output_str += 'log2c={0} '.format(c)
00350 if g != None:
00351 stdout_str += 'g={0}, '.format(2.0**best_g)
00352 output_str += 'log2g={0} '.format(g)
00353 stdout_str += 'rate={0})'.format(best_rate)
00354 print(stdout_str)
00355 if options.out_pathname and not resumed:
00356 output_str += 'rate={0}\n'.format(rate)
00357 result_file.write(output_str)
00358 result_file.flush()
00359
00360 return best_c,best_g,best_rate
00361
00362 options = GridOption(dataset_pathname, options);
00363
00364 if options.gnuplot_pathname:
00365 gnuplot = Popen(options.gnuplot_pathname,stdin = PIPE,stdout=PIPE,stderr=PIPE).stdin
00366 else:
00367 gnuplot = None
00368
00369
00370
00371 jobs,resumed_jobs = calculate_jobs(options)
00372 job_queue = Queue(0)
00373 result_queue = Queue(0)
00374
00375 for (c,g) in resumed_jobs:
00376 result_queue.put(('resumed',c,g,resumed_jobs[(c,g)]))
00377
00378 for line in jobs:
00379 for (c,g) in line:
00380 if (c,g) not in resumed_jobs:
00381 job_queue.put((c,g))
00382
00383
00384
00385
00386
00387
00388
00389
00390 job_queue._put = job_queue.queue.appendleft
00391
00392
00393
00394 if telnet_workers:
00395 nr_telnet_worker = len(telnet_workers)
00396 username = getpass.getuser()
00397 password = getpass.getpass()
00398 for host in telnet_workers:
00399 worker = TelnetWorker(host,job_queue,result_queue,
00400 host,username,password,options)
00401 worker.start()
00402
00403
00404
00405 if ssh_workers:
00406 for host in ssh_workers:
00407 worker = SSHWorker(host,job_queue,result_queue,host,options)
00408 worker.start()
00409
00410
00411
00412 for i in range(nr_local_worker):
00413 worker = LocalWorker('local',job_queue,result_queue,options)
00414 worker.start()
00415
00416
00417
00418 done_jobs = {}
00419
00420 if options.out_pathname:
00421 if options.resume_pathname:
00422 result_file = open(options.out_pathname, 'a')
00423 else:
00424 result_file = open(options.out_pathname, 'w')
00425
00426
00427 db = []
00428 best_rate = -1
00429 best_c,best_g = None,None
00430
00431 for (c,g) in resumed_jobs:
00432 rate = resumed_jobs[(c,g)]
00433 best_c,best_g,best_rate = update_param(c,g,rate,best_c,best_g,best_rate,'resumed',True)
00434
00435 for line in jobs:
00436 for (c,g) in line:
00437 while (c,g) not in done_jobs:
00438 (worker,c1,g1,rate1) = result_queue.get()
00439 done_jobs[(c1,g1)] = rate1
00440 if (c1,g1) not in resumed_jobs:
00441 best_c,best_g,best_rate = update_param(c1,g1,rate1,best_c,best_g,best_rate,worker,False)
00442 db.append((c,g,done_jobs[(c,g)]))
00443 if gnuplot and options.grid_with_c and options.grid_with_g:
00444 redraw(db,[best_c, best_g, best_rate],gnuplot,options)
00445 redraw(db,[best_c, best_g, best_rate],gnuplot,options,True)
00446
00447
00448 if options.out_pathname:
00449 result_file.close()
00450 job_queue.put((WorkerStopToken,None))
00451 best_param, best_cg = {}, []
00452 if best_c != None:
00453 best_param['c'] = 2.0**best_c
00454 best_cg += [2.0**best_c]
00455 if best_g != None:
00456 best_param['g'] = 2.0**best_g
00457 best_cg += [2.0**best_g]
00458 print('{0} {1}'.format(' '.join(map(str,best_cg)), best_rate))
00459
00460 return best_rate, best_param
00461
00462
00463 if __name__ == '__main__':
00464
00465 def exit_with_help():
00466 print("""\
00467 Usage: grid.py [grid_options] [svm_options] dataset
00468
00469 grid_options :
00470 -log2c {begin,end,step | "null"} : set the range of c (default -5,15,2)
00471 begin,end,step -- c_range = 2^{begin,...,begin+k*step,...,end}
00472 "null" -- do not grid with c
00473 -log2g {begin,end,step | "null"} : set the range of g (default 3,-15,-2)
00474 begin,end,step -- g_range = 2^{begin,...,begin+k*step,...,end}
00475 "null" -- do not grid with g
00476 -v n : n-fold cross validation (default 5)
00477 -svmtrain pathname : set svm executable path and name
00478 -gnuplot {pathname | "null"} :
00479 pathname -- set gnuplot executable path and name
00480 "null" -- do not plot
00481 -out {pathname | "null"} : (default dataset.out)
00482 pathname -- set output file path and name
00483 "null" -- do not output file
00484 -png pathname : set graphic output file path and name (default dataset.png)
00485 -resume [pathname] : resume the grid task using an existing output file (default pathname is dataset.out)
00486 This is experimental. Try this option only if some parameters have been checked for the SAME data.
00487
00488 svm_options : additional options for svm-train""")
00489 sys.exit(1)
00490
00491 if len(sys.argv) < 2:
00492 exit_with_help()
00493 dataset_pathname = sys.argv[-1]
00494 options = sys.argv[1:-1]
00495 try:
00496 find_parameters(dataset_pathname, options)
00497 except (IOError,ValueError) as e:
00498 sys.stderr.write(str(e) + '\n')
00499 sys.stderr.write('Try "grid.py" for more information.\n')
00500 sys.exit(1)