grid.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 __all__ = ['find_parameters']
00003 
00004 import os, sys, traceback, getpass, time, re
00005 from threading import Thread
00006 from subprocess import *
00007 
00008 if sys.version_info[0] < 3:
00009         from Queue import Queue
00010 else:
00011         from queue import Queue
00012 
00013 telnet_workers = []
00014 ssh_workers = []
00015 nr_local_worker = 1
00016 
00017 class GridOption:
00018         def __init__(self, dataset_pathname, options):
00019                 dirname = os.path.dirname(__file__)
00020                 if sys.platform != 'win32':
00021                         self.svmtrain_pathname = os.path.join(dirname, '../svm-train')
00022                         self.gnuplot_pathname = '/usr/bin/gnuplot'
00023                 else:
00024                         # example for windows
00025                         self.svmtrain_pathname = os.path.join(dirname, r'..\windows\svm-train.exe')
00026                         # svmtrain_pathname = r'c:\Program Files\libsvm\windows\svm-train.exe'
00027                         self.gnuplot_pathname = r'c:\tmp\gnuplot\binary\pgnuplot.exe'
00028                 self.fold = 5
00029                 self.c_begin, self.c_end, self.c_step = -5,  15,  2
00030                 self.g_begin, self.g_end, self.g_step =  3, -15, -2
00031                 self.grid_with_c, self.grid_with_g = True, True
00032                 self.dataset_pathname = dataset_pathname
00033                 self.dataset_title = os.path.split(dataset_pathname)[1]
00034                 self.out_pathname = '{0}.out'.format(self.dataset_title)
00035                 self.png_pathname = '{0}.png'.format(self.dataset_title)
00036                 self.pass_through_string = ' '
00037                 self.resume_pathname = None
00038                 self.parse_options(options)
00039 
00040         def parse_options(self, options):
00041                 if type(options) == str:
00042                         options = options.split()
00043                 i = 0
00044                 pass_through_options = []
00045                 
00046                 while i < len(options):
00047                         if options[i] == '-log2c':
00048                                 i = i + 1
00049                                 if options[i] == 'null':
00050                                         self.grid_with_c = False
00051                                 else:
00052                                         self.c_begin, self.c_end, self.c_step = map(float,options[i].split(','))
00053                         elif options[i] == '-log2g':
00054                                 i = i + 1
00055                                 if options[i] == 'null':
00056                                         self.grid_with_g = False
00057                                 else:
00058                                         self.g_begin, self.g_end, self.g_step = map(float,options[i].split(','))
00059                         elif options[i] == '-v':
00060                                 i = i + 1
00061                                 self.fold = options[i]
00062                         elif options[i] in ('-c','-g'):
00063                                 raise ValueError('Use -log2c and -log2g.')
00064                         elif options[i] == '-svmtrain':
00065                                 i = i + 1
00066                                 self.svmtrain_pathname = options[i]
00067                         elif options[i] == '-gnuplot':
00068                                 i = i + 1
00069                                 if options[i] == 'null':
00070                                         self.gnuplot_pathname = None
00071                                 else:   
00072                                         self.gnuplot_pathname = options[i]
00073                         elif options[i] == '-out':
00074                                 i = i + 1
00075                                 if options[i] == 'null':
00076                                         self.out_pathname = None
00077                                 else:
00078                                         self.out_pathname = options[i]
00079                         elif options[i] == '-png':
00080                                 i = i + 1
00081                                 self.png_pathname = options[i]
00082                         elif options[i] == '-resume':
00083                                 if i == (len(options)-1) or options[i+1].startswith('-'):
00084                                         self.resume_pathname = self.dataset_title + '.out'
00085                                 else:
00086                                         i = i + 1
00087                                         self.resume_pathname = options[i]
00088                         else:
00089                                 pass_through_options.append(options[i])
00090                         i = i + 1
00091 
00092                 self.pass_through_string = ' '.join(pass_through_options)
00093                 if not os.path.exists(self.svmtrain_pathname):
00094                         raise IOError('svm-train executable not found')
00095                 if not os.path.exists(self.dataset_pathname):
00096                         raise IOError('dataset not found')
00097                 if self.resume_pathname and not os.path.exists(self.resume_pathname):
00098                         raise IOError('file for resumption not found')
00099                 if not self.grid_with_c and not self.grid_with_g:
00100                         raise ValueError('-log2c and -log2g should not be null simultaneously')
00101                 if self.gnuplot_pathname and not os.path.exists(self.gnuplot_pathname):
00102                         sys.stderr.write('gnuplot executable not found\n')
00103                         self.gnuplot_pathname = None
00104 
00105 def redraw(db,best_param,gnuplot,options,tofile=False):
00106         if len(db) == 0: return
00107         begin_level = round(max(x[2] for x in db)) - 3
00108         step_size = 0.5
00109 
00110         best_log2c,best_log2g,best_rate = best_param
00111 
00112         # if newly obtained c, g, or cv values are the same,
00113         # then stop redrawing the contour.
00114         if all(x[0] == db[0][0]  for x in db): return
00115         if all(x[1] == db[0][1]  for x in db): return
00116         if all(x[2] == db[0][2]  for x in db): return
00117 
00118         if tofile:
00119                 gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n")
00120                 gnuplot.write("set output \"{0}\"\n".format(options.png_pathname.replace('\\','\\\\')).encode())
00121                 #gnuplot.write(b"set term postscript color solid\n")
00122                 #gnuplot.write("set output \"{0}.ps\"\n".format(options.dataset_title).encode().encode())
00123         elif sys.platform == 'win32':
00124                 gnuplot.write(b"set term windows\n")
00125         else:
00126                 gnuplot.write( b"set term x11\n")
00127         gnuplot.write(b"set xlabel \"log2(C)\"\n")
00128         gnuplot.write(b"set ylabel \"log2(gamma)\"\n")
00129         gnuplot.write("set xrange [{0}:{1}]\n".format(options.c_begin,options.c_end).encode())
00130         gnuplot.write("set yrange [{0}:{1}]\n".format(options.g_begin,options.g_end).encode())
00131         gnuplot.write(b"set contour\n")
00132         gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode())
00133         gnuplot.write(b"unset surface\n")
00134         gnuplot.write(b"unset ztics\n")
00135         gnuplot.write(b"set view 0,0\n")
00136         gnuplot.write("set title \"{0}\"\n".format(options.dataset_title).encode())
00137         gnuplot.write(b"unset label\n")
00138         gnuplot.write("set label \"Best log2(C) = {0}  log2(gamma) = {1}  accuracy = {2}%\" \
00139                                   at screen 0.5,0.85 center\n". \
00140                                   format(best_log2c, best_log2g, best_rate).encode())
00141         gnuplot.write("set label \"C = {0}  gamma = {1}\""
00142                                   " at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode())
00143         gnuplot.write(b"set key at screen 0.9,0.9\n")
00144         gnuplot.write(b"splot \"-\" with lines\n")
00145         
00146         db.sort(key = lambda x:(x[0], -x[1]))
00147 
00148         prevc = db[0][0]
00149         for line in db:
00150                 if prevc != line[0]:
00151                         gnuplot.write(b"\n")
00152                         prevc = line[0]
00153                 gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode())
00154         gnuplot.write(b"e\n")
00155         gnuplot.write(b"\n") # force gnuplot back to prompt when term set failure
00156         gnuplot.flush()
00157 
00158 
00159 def calculate_jobs(options):
00160         
00161         def range_f(begin,end,step):
00162                 # like range, but works on non-integer too
00163                 seq = []
00164                 while True:
00165                         if step > 0 and begin > end: break
00166                         if step < 0 and begin < end: break
00167                         seq.append(begin)
00168                         begin = begin + step
00169                 return seq
00170         
00171         def permute_sequence(seq):
00172                 n = len(seq)
00173                 if n <= 1: return seq
00174         
00175                 mid = int(n/2)
00176                 left = permute_sequence(seq[:mid])
00177                 right = permute_sequence(seq[mid+1:])
00178         
00179                 ret = [seq[mid]]
00180                 while left or right:
00181                         if left: ret.append(left.pop(0))
00182                         if right: ret.append(right.pop(0))
00183                         
00184                 return ret      
00185 
00186         
00187         c_seq = permute_sequence(range_f(options.c_begin,options.c_end,options.c_step))
00188         g_seq = permute_sequence(range_f(options.g_begin,options.g_end,options.g_step))
00189 
00190         if not options.grid_with_c:
00191                 c_seq = [None]
00192         if not options.grid_with_g:
00193                 g_seq = [None] 
00194         
00195         nr_c = float(len(c_seq))
00196         nr_g = float(len(g_seq))
00197         i, j = 0, 0
00198         jobs = []
00199 
00200         while i < nr_c or j < nr_g:
00201                 if i/nr_c < j/nr_g:
00202                         # increase C resolution
00203                         line = []
00204                         for k in range(0,j):
00205                                 line.append((c_seq[i],g_seq[k]))
00206                         i = i + 1
00207                         jobs.append(line)
00208                 else:
00209                         # increase g resolution
00210                         line = []
00211                         for k in range(0,i):
00212                                 line.append((c_seq[k],g_seq[j]))
00213                         j = j + 1
00214                         jobs.append(line)
00215 
00216         resumed_jobs = {}
00217         
00218         if options.resume_pathname is None:
00219                 return jobs, resumed_jobs
00220 
00221         for line in open(options.resume_pathname, 'r'):
00222                 line = line.strip()
00223                 rst = re.findall(r'rate=([0-9.]+)',line)
00224                 if not rst: 
00225                         continue
00226                 rate = float(rst[0])
00227 
00228                 c, g = None, None 
00229                 rst = re.findall(r'log2c=([0-9.-]+)',line)
00230                 if rst: 
00231                         c = float(rst[0])
00232                 rst = re.findall(r'log2g=([0-9.-]+)',line)
00233                 if rst: 
00234                         g = float(rst[0])
00235 
00236                 resumed_jobs[(c,g)] = rate
00237 
00238         return jobs, resumed_jobs
00239 
00240         
00241 class WorkerStopToken:  # used to notify the worker to stop or if a worker is dead
00242         pass
00243 
00244 class Worker(Thread):
00245         def __init__(self,name,job_queue,result_queue,options):
00246                 Thread.__init__(self)
00247                 self.name = name
00248                 self.job_queue = job_queue
00249                 self.result_queue = result_queue
00250                 self.options = options
00251                 
00252         def run(self):
00253                 while True:
00254                         (cexp,gexp) = self.job_queue.get()
00255                         if cexp is WorkerStopToken:
00256                                 self.job_queue.put((cexp,gexp))
00257                                 # print('worker {0} stop.'.format(self.name))
00258                                 break
00259                         try:
00260                                 c, g = None, None
00261                                 if cexp != None:
00262                                         c = 2.0**cexp
00263                                 if gexp != None:
00264                                         g = 2.0**gexp
00265                                 rate = self.run_one(c,g)
00266                                 if rate is None: raise RuntimeError('get no rate')
00267                         except:
00268                                 # we failed, let others do that and we just quit
00269                         
00270                                 traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
00271                                 
00272                                 self.job_queue.put((cexp,gexp))
00273                                 sys.stderr.write('worker {0} quit.\n'.format(self.name))
00274                                 break
00275                         else:
00276                                 self.result_queue.put((self.name,cexp,gexp,rate))
00277 
00278         def get_cmd(self,c,g):
00279                 options=self.options
00280                 cmdline = '"' + options.svmtrain_pathname + '"'
00281                 if options.grid_with_c: 
00282                         cmdline += ' -c {0} '.format(c)
00283                 if options.grid_with_g: 
00284                         cmdline += ' -g {0} '.format(g)
00285                 cmdline += ' -v {0} {1} {2} '.format\
00286                         (options.fold,options.pass_through_string,options.dataset_pathname)
00287                 return cmdline
00288                 
00289 class LocalWorker(Worker):
00290         def run_one(self,c,g):
00291                 cmdline = self.get_cmd(c,g)
00292                 result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
00293                 for line in result.readlines():
00294                         if str(line).find('Cross') != -1:
00295                                 return float(line.split()[-1][0:-1])
00296 
00297 class SSHWorker(Worker):
00298         def __init__(self,name,job_queue,result_queue,host,options):
00299                 Worker.__init__(self,name,job_queue,result_queue,options)
00300                 self.host = host
00301                 self.cwd = os.getcwd()
00302         def run_one(self,c,g):
00303                 cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"'.format\
00304                         (self.host,self.cwd,self.get_cmd(c,g))
00305                 result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
00306                 for line in result.readlines():
00307                         if str(line).find('Cross') != -1:
00308                                 return float(line.split()[-1][0:-1])
00309 
00310 class TelnetWorker(Worker):
00311         def __init__(self,name,job_queue,result_queue,host,username,password,options):
00312                 Worker.__init__(self,name,job_queue,result_queue,options)
00313                 self.host = host
00314                 self.username = username
00315                 self.password = password                
00316         def run(self):
00317                 import telnetlib
00318                 self.tn = tn = telnetlib.Telnet(self.host)
00319                 tn.read_until('login: ')
00320                 tn.write(self.username + '\n')
00321                 tn.read_until('Password: ')
00322                 tn.write(self.password + '\n')
00323 
00324                 # XXX: how to know whether login is successful?
00325                 tn.read_until(self.username)
00326                 # 
00327                 print('login ok', self.host)
00328                 tn.write('cd '+os.getcwd()+'\n')
00329                 Worker.run(self)
00330                 tn.write('exit\n')                         
00331         def run_one(self,c,g):
00332                 cmdline = self.get_cmd(c,g)
00333                 result = self.tn.write(cmdline+'\n')
00334                 (idx,matchm,output) = self.tn.expect(['Cross.*\n'])
00335                 for line in output.split('\n'):
00336                         if str(line).find('Cross') != -1:
00337                                 return float(line.split()[-1][0:-1])
00338                         
00339 def find_parameters(dataset_pathname, options=''):
00340         
00341         def update_param(c,g,rate,best_c,best_g,best_rate,worker,resumed):
00342                 if (rate > best_rate) or (rate==best_rate and g==best_g and c<best_c):
00343                         best_rate,best_c,best_g = rate,c,g
00344                 stdout_str = '[{0}] {1} {2} (best '.format\
00345                         (worker,' '.join(str(x) for x in [c,g] if x is not None),rate)
00346                 output_str = ''
00347                 if c != None:
00348                         stdout_str += 'c={0}, '.format(2.0**best_c)
00349                         output_str += 'log2c={0} '.format(c)
00350                 if g != None:
00351                         stdout_str += 'g={0}, '.format(2.0**best_g)
00352                         output_str += 'log2g={0} '.format(g)
00353                 stdout_str += 'rate={0})'.format(best_rate)
00354                 print(stdout_str)
00355                 if options.out_pathname and not resumed:
00356                         output_str += 'rate={0}\n'.format(rate)
00357                         result_file.write(output_str)
00358                         result_file.flush()
00359                 
00360                 return best_c,best_g,best_rate
00361                 
00362         options = GridOption(dataset_pathname, options);
00363 
00364         if options.gnuplot_pathname:
00365                 gnuplot = Popen(options.gnuplot_pathname,stdin = PIPE,stdout=PIPE,stderr=PIPE).stdin
00366         else:
00367                 gnuplot = None
00368                 
00369         # put jobs in queue
00370 
00371         jobs,resumed_jobs = calculate_jobs(options)
00372         job_queue = Queue(0)
00373         result_queue = Queue(0)
00374 
00375         for (c,g) in resumed_jobs:
00376                 result_queue.put(('resumed',c,g,resumed_jobs[(c,g)]))
00377 
00378         for line in jobs:
00379                 for (c,g) in line:
00380                         if (c,g) not in resumed_jobs:
00381                                 job_queue.put((c,g))
00382 
00383         # hack the queue to become a stack --
00384         # this is important when some thread
00385         # failed and re-put a job. It we still
00386         # use FIFO, the job will be put
00387         # into the end of the queue, and the graph
00388         # will only be updated in the end
00389  
00390         job_queue._put = job_queue.queue.appendleft
00391 
00392         # fire telnet workers
00393 
00394         if telnet_workers:
00395                 nr_telnet_worker = len(telnet_workers)
00396                 username = getpass.getuser()
00397                 password = getpass.getpass()
00398                 for host in telnet_workers:
00399                         worker = TelnetWorker(host,job_queue,result_queue,
00400                                          host,username,password,options)
00401                         worker.start()
00402 
00403         # fire ssh workers
00404 
00405         if ssh_workers:
00406                 for host in ssh_workers:
00407                         worker = SSHWorker(host,job_queue,result_queue,host,options)
00408                         worker.start()
00409 
00410         # fire local workers
00411 
00412         for i in range(nr_local_worker):
00413                 worker = LocalWorker('local',job_queue,result_queue,options)
00414                 worker.start()
00415 
00416         # gather results
00417 
00418         done_jobs = {}
00419 
00420         if options.out_pathname:
00421                 if options.resume_pathname:
00422                         result_file = open(options.out_pathname, 'a')
00423                 else:
00424                         result_file = open(options.out_pathname, 'w')
00425 
00426 
00427         db = []
00428         best_rate = -1
00429         best_c,best_g = None,None  
00430 
00431         for (c,g) in resumed_jobs:
00432                 rate = resumed_jobs[(c,g)]
00433                 best_c,best_g,best_rate = update_param(c,g,rate,best_c,best_g,best_rate,'resumed',True)
00434 
00435         for line in jobs:
00436                 for (c,g) in line:
00437                         while (c,g) not in done_jobs:
00438                                 (worker,c1,g1,rate1) = result_queue.get()
00439                                 done_jobs[(c1,g1)] = rate1
00440                                 if (c1,g1) not in resumed_jobs:
00441                                         best_c,best_g,best_rate = update_param(c1,g1,rate1,best_c,best_g,best_rate,worker,False)
00442                         db.append((c,g,done_jobs[(c,g)]))
00443                 if gnuplot and options.grid_with_c and options.grid_with_g:
00444                         redraw(db,[best_c, best_g, best_rate],gnuplot,options)
00445                         redraw(db,[best_c, best_g, best_rate],gnuplot,options,True)
00446 
00447 
00448         if options.out_pathname:
00449                 result_file.close()
00450         job_queue.put((WorkerStopToken,None))
00451         best_param, best_cg  = {}, []
00452         if best_c != None:
00453                 best_param['c'] = 2.0**best_c
00454                 best_cg += [2.0**best_c]
00455         if best_g != None:
00456                 best_param['g'] = 2.0**best_g
00457                 best_cg += [2.0**best_g]
00458         print('{0} {1}'.format(' '.join(map(str,best_cg)), best_rate))
00459 
00460         return best_rate, best_param
00461 
00462 
00463 if __name__ == '__main__':
00464 
00465         def exit_with_help():
00466                 print("""\
00467 Usage: grid.py [grid_options] [svm_options] dataset
00468 
00469 grid_options :
00470 -log2c {begin,end,step | "null"} : set the range of c (default -5,15,2)
00471     begin,end,step -- c_range = 2^{begin,...,begin+k*step,...,end}
00472     "null"         -- do not grid with c
00473 -log2g {begin,end,step | "null"} : set the range of g (default 3,-15,-2)
00474     begin,end,step -- g_range = 2^{begin,...,begin+k*step,...,end}
00475     "null"         -- do not grid with g
00476 -v n : n-fold cross validation (default 5)
00477 -svmtrain pathname : set svm executable path and name
00478 -gnuplot {pathname | "null"} :
00479     pathname -- set gnuplot executable path and name
00480     "null"   -- do not plot 
00481 -out {pathname | "null"} : (default dataset.out)
00482     pathname -- set output file path and name
00483     "null"   -- do not output file
00484 -png pathname : set graphic output file path and name (default dataset.png)
00485 -resume [pathname] : resume the grid task using an existing output file (default pathname is dataset.out)
00486     This is experimental. Try this option only if some parameters have been checked for the SAME data.
00487 
00488 svm_options : additional options for svm-train""")
00489                 sys.exit(1)
00490         
00491         if len(sys.argv) < 2:
00492                 exit_with_help()
00493         dataset_pathname = sys.argv[-1]
00494         options = sys.argv[1:-1]
00495         try:
00496                 find_parameters(dataset_pathname, options)
00497         except (IOError,ValueError) as e:
00498                 sys.stderr.write(str(e) + '\n')
00499                 sys.stderr.write('Try "grid.py" for more information.\n')
00500                 sys.exit(1)


target_obejct_detector
Author(s): CIR-KIT
autogenerated on Thu Jun 6 2019 20:19:57