grid.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 
00004 
00005 import os, sys, traceback
00006 import getpass
00007 from threading import Thread
00008 from subprocess import *
00009 
00010 if(sys.hexversion < 0x03000000):
00011         import Queue
00012 else:
00013         import queue as Queue
00014 
00015 
00016 # svmtrain and gnuplot executable
00017 
00018 is_win32 = (sys.platform == 'win32')
00019 if not is_win32:
00020        svmtrain_exe = "../svm-train"
00021        gnuplot_exe = "/usr/bin/gnuplot"
00022 else:
00023        # example for windows
00024        svmtrain_exe = r"..\windows\svm-train.exe"
00025        # svmtrain_exe = r"c:\Program Files\libsvm\windows\svm-train.exe" 
00026        gnuplot_exe = r"c:\tmp\gnuplot\binary\pgnuplot.exe"
00027 
00028 # global parameters and their default values
00029 
00030 fold = 5
00031 c_begin, c_end, c_step = -5,  15, 2
00032 g_begin, g_end, g_step =  3, -15, -2
00033 global dataset_pathname, dataset_title, pass_through_string
00034 global out_filename, png_filename
00035 
00036 # experimental
00037 
00038 telnet_workers = []
00039 ssh_workers = []
00040 nr_local_worker = 1
00041 
00042 # process command line options, set global parameters
00043 def process_options(argv=sys.argv):
00044 
00045     global fold
00046     global c_begin, c_end, c_step
00047     global g_begin, g_end, g_step
00048     global dataset_pathname, dataset_title, pass_through_string
00049     global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename
00050     
00051     usage = """\
00052 Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold] 
00053 [-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname]
00054 [additional parameters for svm-train] dataset"""
00055 
00056     if len(argv) < 2:
00057         print(usage)
00058         sys.exit(1)
00059 
00060     dataset_pathname = argv[-1]
00061     dataset_title = os.path.split(dataset_pathname)[1]
00062     out_filename = '{0}.out'.format(dataset_title)
00063     png_filename = '{0}.png'.format(dataset_title)
00064     pass_through_options = []
00065 
00066     i = 1
00067     while i < len(argv) - 1:
00068         if argv[i] == "-log2c":
00069             i = i + 1
00070             (c_begin,c_end,c_step) = map(float,argv[i].split(","))
00071         elif argv[i] == "-log2g":
00072             i = i + 1
00073             (g_begin,g_end,g_step) = map(float,argv[i].split(","))
00074         elif argv[i] == "-v":
00075             i = i + 1
00076             fold = argv[i]
00077         elif argv[i] in ('-c','-g'):
00078             print("Option -c and -g are renamed.")
00079             print(usage)
00080             sys.exit(1)
00081         elif argv[i] == '-svmtrain':
00082             i = i + 1
00083             svmtrain_exe = argv[i]
00084         elif argv[i] == '-gnuplot':
00085             i = i + 1
00086             gnuplot_exe = argv[i]
00087         elif argv[i] == '-out':
00088             i = i + 1
00089             out_filename = argv[i]
00090         elif argv[i] == '-png':
00091             i = i + 1
00092             png_filename = argv[i]
00093         else:
00094             pass_through_options.append(argv[i])
00095         i = i + 1
00096 
00097     pass_through_string = " ".join(pass_through_options)
00098     assert os.path.exists(svmtrain_exe),"svm-train executable not found"    
00099     assert os.path.exists(gnuplot_exe),"gnuplot executable not found"
00100     assert os.path.exists(dataset_pathname),"dataset not found"
00101     gnuplot = Popen(gnuplot_exe,stdin = PIPE).stdin
00102 
00103 
00104 def range_f(begin,end,step):
00105     # like range, but works on non-integer too
00106     seq = []
00107     while True:
00108         if step > 0 and begin > end: break
00109         if step < 0 and begin < end: break
00110         seq.append(begin)
00111         begin = begin + step
00112     return seq
00113 
00114 def permute_sequence(seq):
00115     n = len(seq)
00116     if n <= 1: return seq
00117 
00118     mid = int(n/2)
00119     left = permute_sequence(seq[:mid])
00120     right = permute_sequence(seq[mid+1:])
00121 
00122     ret = [seq[mid]]
00123     while left or right:
00124         if left: ret.append(left.pop(0))
00125         if right: ret.append(right.pop(0))
00126 
00127     return ret
00128 
00129 def redraw(db,best_param,tofile=False):
00130     if len(db) == 0: return
00131     begin_level = round(max(x[2] for x in db)) - 3
00132     step_size = 0.5
00133 
00134     best_log2c,best_log2g,best_rate = best_param
00135 
00136     # if newly obtained c, g, or cv values are the same,
00137     # then stop redrawing the contour.
00138     if all(x[0] == db[0][0]  for x in db): return
00139     if all(x[1] == db[0][1]  for x in db): return
00140     if all(x[2] == db[0][2]  for x in db): return
00141 
00142     if tofile:
00143         gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n")
00144         gnuplot.write("set output \"{0}\"\n".format(png_filename.replace('\\','\\\\')).encode())
00145         #gnuplot.write(b"set term postscript color solid\n")
00146         #gnuplot.write("set output \"{0}.ps\"\n".format(dataset_title).encode().encode())
00147     elif is_win32:
00148         gnuplot.write(b"set term windows\n")
00149     else:
00150         gnuplot.write( b"set term x11\n")
00151     gnuplot.write(b"set xlabel \"log2(C)\"\n")
00152     gnuplot.write(b"set ylabel \"log2(gamma)\"\n")
00153     gnuplot.write("set xrange [{0}:{1}]\n".format(c_begin,c_end).encode())
00154     gnuplot.write("set yrange [{0}:{1}]\n".format(g_begin,g_end).encode())
00155     gnuplot.write(b"set contour\n")
00156     gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode())
00157     gnuplot.write(b"unset surface\n")
00158     gnuplot.write(b"unset ztics\n")
00159     gnuplot.write(b"set view 0,0\n")
00160     gnuplot.write("set title \"{0}\"\n".format(dataset_title).encode())
00161     gnuplot.write(b"unset label\n")
00162     gnuplot.write("set label \"Best log2(C) = {0}  log2(gamma) = {1}  accuracy = {2}%\" \
00163                   at screen 0.5,0.85 center\n". \
00164                   format(best_log2c, best_log2g, best_rate).encode())
00165     gnuplot.write("set label \"C = {0}  gamma = {1}\""
00166                   " at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode())
00167     gnuplot.write(b"set key at screen 0.9,0.9\n")
00168     gnuplot.write(b"splot \"-\" with lines\n")
00169     
00170 
00171 
00172     
00173     db.sort(key = lambda x:(x[0], -x[1]))
00174 
00175     prevc = db[0][0]
00176     for line in db:
00177         if prevc != line[0]:
00178             gnuplot.write(b"\n")
00179             prevc = line[0]
00180         gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode())
00181     gnuplot.write(b"e\n")
00182     gnuplot.write(b"\n") # force gnuplot back to prompt when term set failure
00183     gnuplot.flush()
00184 
00185 
00186 def calculate_jobs():
00187     c_seq = permute_sequence(range_f(c_begin,c_end,c_step))
00188     g_seq = permute_sequence(range_f(g_begin,g_end,g_step))
00189     nr_c = float(len(c_seq))
00190     nr_g = float(len(g_seq))
00191     i = 0
00192     j = 0
00193     jobs = []
00194 
00195     while i < nr_c or j < nr_g:
00196         if i/nr_c < j/nr_g:
00197             # increase C resolution
00198             line = []
00199             for k in range(0,j):
00200                 line.append((c_seq[i],g_seq[k]))
00201             i = i + 1
00202             jobs.append(line)
00203         else:
00204             # increase g resolution
00205             line = []
00206             for k in range(0,i):
00207                 line.append((c_seq[k],g_seq[j]))
00208             j = j + 1
00209             jobs.append(line)
00210     return jobs
00211 
00212 class WorkerStopToken:  # used to notify the worker to stop
00213         pass
00214 
00215 class Worker(Thread):
00216     def __init__(self,name,job_queue,result_queue):
00217         Thread.__init__(self)
00218         self.name = name
00219         self.job_queue = job_queue
00220         self.result_queue = result_queue
00221     def run(self):
00222         while True:
00223             (cexp,gexp) = self.job_queue.get()
00224             if cexp is WorkerStopToken:
00225                 self.job_queue.put((cexp,gexp))
00226                 # print('worker {0} stop.'.format(self.name))
00227                 break
00228             try:
00229                 rate = self.run_one(2.0**cexp,2.0**gexp)
00230                 if rate is None: raise RuntimeError("get no rate")
00231             except:
00232                 # we failed, let others do that and we just quit
00233             
00234                 traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
00235                 
00236                 self.job_queue.put((cexp,gexp))
00237                 print('worker {0} quit.'.format(self.name))
00238                 break
00239             else:
00240                 self.result_queue.put((self.name,cexp,gexp,rate))
00241 
00242 class LocalWorker(Worker):
00243     def run_one(self,c,g):
00244         cmdline = '{0} -c {1} -g {2} -v {3} {4} {5}'.format \
00245           (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
00246         result = Popen(cmdline,shell=True,stdout=PIPE).stdout
00247         for line in result.readlines():
00248             if str(line).find("Cross") != -1:
00249                 return float(line.split()[-1][0:-1])
00250 
00251 class SSHWorker(Worker):
00252     def __init__(self,name,job_queue,result_queue,host):
00253         Worker.__init__(self,name,job_queue,result_queue)
00254         self.host = host
00255         self.cwd = os.getcwd()
00256     def run_one(self,c,g):
00257         cmdline = 'ssh -x {0} "cd {1}; {2} -c {3} -g {4} -v {5} {6} {7}"'.format \
00258           (self.host,self.cwd, \
00259            svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
00260         result = Popen(cmdline,shell=True,stdout=PIPE).stdout
00261         for line in result.readlines():
00262             if str(line).find("Cross") != -1:
00263                 return float(line.split()[-1][0:-1])
00264 
00265 class TelnetWorker(Worker):
00266     def __init__(self,name,job_queue,result_queue,host,username,password):
00267         Worker.__init__(self,name,job_queue,result_queue)
00268         self.host = host
00269         self.username = username
00270         self.password = password        
00271     def run(self):
00272         import telnetlib
00273         self.tn = tn = telnetlib.Telnet(self.host)
00274         tn.read_until("login: ")
00275         tn.write(self.username + "\n")
00276         tn.read_until("Password: ")
00277         tn.write(self.password + "\n")
00278 
00279         # XXX: how to know whether login is successful?
00280         tn.read_until(self.username)
00281         # 
00282         print('login ok', self.host)
00283         tn.write("cd "+os.getcwd()+"\n")
00284         Worker.run(self)
00285         tn.write("exit\n")               
00286     def run_one(self,c,g):
00287         cmdline = '{0} -c {1} -g {2} -v {3} {4} {5}'.format \
00288           (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
00289         result = self.tn.write(cmdline+'\n')
00290         (idx,matchm,output) = self.tn.expect(['Cross.*\n'])
00291         for line in output.split('\n'):
00292             if str(line).find("Cross") != -1:
00293                 return float(line.split()[-1][0:-1])
00294 
00295 def main():
00296 
00297     # set parameters
00298 
00299     process_options()
00300 
00301     # put jobs in queue
00302 
00303     jobs = calculate_jobs()
00304     job_queue = Queue.Queue(0)
00305     result_queue = Queue.Queue(0)
00306 
00307     for line in jobs:
00308         for (c,g) in line:
00309             job_queue.put((c,g))
00310 
00311     # hack the queue to become a stack --
00312     # this is important when some thread
00313     # failed and re-put a job. It we still
00314     # use FIFO, the job will be put
00315     # into the end of the queue, and the graph
00316     # will only be updated in the end
00317  
00318     job_queue._put = job_queue.queue.appendleft
00319 
00320 
00321     # fire telnet workers
00322 
00323     if telnet_workers:
00324         nr_telnet_worker = len(telnet_workers)
00325         username = getpass.getuser()
00326         password = getpass.getpass()
00327         for host in telnet_workers:
00328             TelnetWorker(host,job_queue,result_queue,
00329                      host,username,password).start()
00330 
00331     # fire ssh workers
00332 
00333     if ssh_workers:
00334         for host in ssh_workers:
00335             SSHWorker(host,job_queue,result_queue,host).start()
00336 
00337     # fire local workers
00338 
00339     for i in range(nr_local_worker):
00340         LocalWorker('local',job_queue,result_queue).start()
00341 
00342     # gather results
00343 
00344     done_jobs = {}
00345 
00346 
00347     result_file = open(out_filename, 'w')
00348 
00349 
00350     db = []
00351     best_rate = -1
00352     best_c1,best_g1 = None,None
00353 
00354     for line in jobs:
00355         for (c,g) in line:
00356             while (c, g) not in done_jobs:
00357                 (worker,c1,g1,rate) = result_queue.get()
00358                 done_jobs[(c1,g1)] = rate
00359                 result_file.write('{0} {1} {2}\n'.format(c1,g1,rate))
00360                 result_file.flush()
00361                 if (rate > best_rate) or (rate==best_rate and g1==best_g1 and c1<best_c1):
00362                     best_rate = rate
00363                     best_c1,best_g1=c1,g1
00364                     best_c = 2.0**c1
00365                     best_g = 2.0**g1
00366                 print("[{0}] {1} {2} {3} (best c={4}, g={5}, rate={6})".format \
00367                     (worker,c1,g1,rate, best_c, best_g, best_rate))
00368             db.append((c,g,done_jobs[(c,g)]))
00369         redraw(db,[best_c1, best_g1, best_rate])
00370         redraw(db,[best_c1, best_g1, best_rate],True)
00371 
00372     job_queue.put((WorkerStopToken,None))
00373     print("{0} {1} {2}".format(best_c, best_g, best_rate))
00374 main()


ml_classifiers
Author(s): Scott Niekum
autogenerated on Fri Jan 3 2014 11:30:23