grid.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 
00004 
00005 import os, sys, traceback
00006 import getpass
00007 from threading import Thread
00008 from subprocess import *
00009 
00010 if(sys.hexversion < 0x03000000):
00011         import Queue
00012 else:
00013         import queue as Queue
00014 
00015 
00016 # svmtrain and gnuplot executable
00017 
00018 is_win32 = (sys.platform == 'win32')
00019 if not is_win32:
00020        svmtrain_exe = "../svm-train"
00021        gnuplot_exe = "/usr/bin/gnuplot"
00022 else:
00023        # example for windows
00024        svmtrain_exe = r"..\windows\svm-train.exe"
00025        gnuplot_exe = r"c:\tmp\gnuplot\bin\pgnuplot.exe"
00026 
00027 # global parameters and their default values
00028 
00029 fold = 5
00030 c_begin, c_end, c_step = -5,  15, 2
00031 g_begin, g_end, g_step =  3, -15, -2
00032 global dataset_pathname, dataset_title, pass_through_string
00033 global out_filename, png_filename
00034 
00035 # experimental
00036 
00037 telnet_workers = []
00038 ssh_workers = []
00039 nr_local_worker = 1
00040 
00041 # process command line options, set global parameters
00042 def process_options(argv=sys.argv):
00043 
00044     global fold
00045     global c_begin, c_end, c_step
00046     global g_begin, g_end, g_step
00047     global dataset_pathname, dataset_title, pass_through_string
00048     global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename
00049     
00050     usage = """\
00051 Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold] 
00052 [-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname]
00053 [additional parameters for svm-train] dataset"""
00054 
00055     if len(argv) < 2:
00056         print(usage)
00057         sys.exit(1)
00058 
00059     dataset_pathname = argv[-1]
00060     dataset_title = os.path.split(dataset_pathname)[1]
00061     out_filename = '%s.out' % dataset_title
00062     png_filename = '%s.png' % dataset_title
00063     pass_through_options = []
00064 
00065     i = 1
00066     while i < len(argv) - 1:
00067         if argv[i] == "-log2c":
00068             i = i + 1
00069             (c_begin,c_end,c_step) = map(float,argv[i].split(","))
00070         elif argv[i] == "-log2g":
00071             i = i + 1
00072             (g_begin,g_end,g_step) = map(float,argv[i].split(","))
00073         elif argv[i] == "-v":
00074             i = i + 1
00075             fold = argv[i]
00076         elif argv[i] in ('-c','-g'):
00077             print("Option -c and -g are renamed.")
00078             print(usage)
00079             sys.exit(1)
00080         elif argv[i] == '-svmtrain':
00081             i = i + 1
00082             svmtrain_exe = argv[i]
00083         elif argv[i] == '-gnuplot':
00084             i = i + 1
00085             gnuplot_exe = argv[i]
00086         elif argv[i] == '-out':
00087             i = i + 1
00088             out_filename = argv[i]
00089         elif argv[i] == '-png':
00090             i = i + 1
00091             png_filename = argv[i]
00092         else:
00093             pass_through_options.append(argv[i])
00094         i = i + 1
00095 
00096     pass_through_string = " ".join(pass_through_options)
00097     assert os.path.exists(svmtrain_exe),"svm-train executable not found"    
00098     assert os.path.exists(gnuplot_exe),"gnuplot executable not found"
00099     assert os.path.exists(dataset_pathname),"dataset not found"
00100     gnuplot = Popen(gnuplot_exe,stdin = PIPE).stdin
00101 
00102 
00103 def range_f(begin,end,step):
00104     # like range, but works on non-integer too
00105     seq = []
00106     while True:
00107         if step > 0 and begin > end: break
00108         if step < 0 and begin < end: break
00109         seq.append(begin)
00110         begin = begin + step
00111     return seq
00112 
00113 def permute_sequence(seq):
00114     n = len(seq)
00115     if n <= 1: return seq
00116 
00117     mid = int(n/2)
00118     left = permute_sequence(seq[:mid])
00119     right = permute_sequence(seq[mid+1:])
00120 
00121     ret = [seq[mid]]
00122     while left or right:
00123         if left: ret.append(left.pop(0))
00124         if right: ret.append(right.pop(0))
00125 
00126     return ret
00127 
00128 def redraw(db,best_param,tofile=False):
00129     if len(db) == 0: return
00130     begin_level = round(max(x[2] for x in db)) - 3
00131     step_size = 0.5
00132 
00133     best_log2c,best_log2g,best_rate = best_param
00134 
00135     if tofile:
00136         gnuplot.write( "set term png transparent small\n".encode())
00137         gnuplot.write( ("set output \"%s\"\n" % png_filename.replace('\\','\\\\')).encode())
00138         #gnuplot.write("set term postscript color solid\n".encode())
00139         #gnuplot.write(("set output \"%s.ps\"\n" % dataset_title).encode())
00140     elif is_win32:
00141         gnuplot.write("set term windows\n".encode())
00142     else:
00143         gnuplot.write( "set term x11\n".encode())
00144     gnuplot.write("set xlabel \"log2(C)\"\n".encode())
00145     gnuplot.write("set ylabel \"log2(gamma)\"\n".encode())
00146     gnuplot.write(("set xrange [%s:%s]\n" % (c_begin,c_end)).encode())
00147     gnuplot.write(("set yrange [%s:%s]\n" % (g_begin,g_end)).encode())
00148     gnuplot.write("set contour\n".encode())
00149     gnuplot.write(("set cntrparam levels incremental %s,%s,100\n" % (begin_level,step_size)).encode())
00150     gnuplot.write("unset surface\n".encode())
00151     gnuplot.write("unset ztics\n".encode())
00152     gnuplot.write("set view 0,0\n".encode())
00153     gnuplot.write(("set title \"%s\"\n" % dataset_title).encode())
00154     gnuplot.write("unset label\n".encode())
00155     gnuplot.write(("set label \"Best log2(C) = %s  log2(gamma) = %s  accuracy = %s%%\" \
00156                   at screen 0.5,0.85 center\n" % \
00157                   (best_log2c, best_log2g, best_rate)).encode())
00158     gnuplot.write(("set label \"C = %s  gamma = %s\""
00159                   " at screen 0.5,0.8 center\n" % (2**best_log2c, 2**best_log2g)).encode())
00160     gnuplot.write("splot \"-\" with lines\n".encode())
00161     
00162 
00163 
00164     
00165     db.sort(key = lambda x:(x[0], -x[1]))
00166 
00167     prevc = db[0][0]
00168     for line in db:
00169         if prevc != line[0]:
00170             gnuplot.write("\n".encode())
00171             prevc = line[0]
00172         gnuplot.write(("%s %s %s\n" % line).encode())
00173     gnuplot.write("e\n".encode())
00174     gnuplot.write("\n".encode()) # force gnuplot back to prompt when term set failure
00175     gnuplot.flush()
00176 
00177 
00178 def calculate_jobs():
00179     c_seq = permute_sequence(range_f(c_begin,c_end,c_step))
00180     g_seq = permute_sequence(range_f(g_begin,g_end,g_step))
00181     nr_c = float(len(c_seq))
00182     nr_g = float(len(g_seq))
00183     i = 0
00184     j = 0
00185     jobs = []
00186 
00187     while i < nr_c or j < nr_g:
00188         if i/nr_c < j/nr_g:
00189             # increase C resolution
00190             line = []
00191             for k in range(0,j):
00192                 line.append((c_seq[i],g_seq[k]))
00193             i = i + 1
00194             jobs.append(line)
00195         else:
00196             # increase g resolution
00197             line = []
00198             for k in range(0,i):
00199                 line.append((c_seq[k],g_seq[j]))
00200             j = j + 1
00201             jobs.append(line)
00202     return jobs
00203 
00204 class WorkerStopToken:  # used to notify the worker to stop
00205         pass
00206 
00207 class Worker(Thread):
00208     def __init__(self,name,job_queue,result_queue):
00209         Thread.__init__(self)
00210         self.name = name
00211         self.job_queue = job_queue
00212         self.result_queue = result_queue
00213     def run(self):
00214         while True:
00215             (cexp,gexp) = self.job_queue.get()
00216             if cexp is WorkerStopToken:
00217                 self.job_queue.put((cexp,gexp))
00218                 # print 'worker %s stop.' % self.name
00219                 break
00220             try:
00221                 rate = self.run_one(2.0**cexp,2.0**gexp)
00222                 if rate is None: raise RuntimeError("get no rate")
00223             except:
00224                 # we failed, let others do that and we just quit
00225             
00226                 traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
00227                 
00228                 self.job_queue.put((cexp,gexp))
00229                 print('worker %s quit.' % self.name)
00230                 break
00231             else:
00232                 self.result_queue.put((self.name,cexp,gexp,rate))
00233 
00234 class LocalWorker(Worker):
00235     def run_one(self,c,g):
00236         cmdline = '%s -c %s -g %s -v %s %s %s' % \
00237           (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
00238         result = Popen(cmdline,shell=True,stdout=PIPE).stdout
00239         for line in result.readlines():
00240             if str(line).find("Cross") != -1:
00241                 return float(line.split()[-1][0:-1])
00242 
00243 class SSHWorker(Worker):
00244     def __init__(self,name,job_queue,result_queue,host):
00245         Worker.__init__(self,name,job_queue,result_queue)
00246         self.host = host
00247         self.cwd = os.getcwd()
00248     def run_one(self,c,g):
00249         cmdline = 'ssh -x %s "cd %s; %s -c %s -g %s -v %s %s %s"' % \
00250           (self.host,self.cwd,
00251            svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
00252         result = Popen(cmdline,shell=True,stdout=PIPE).stdout
00253         for line in result.readlines():
00254             if str(line).find("Cross") != -1:
00255                 return float(line.split()[-1][0:-1])
00256 
00257 class TelnetWorker(Worker):
00258     def __init__(self,name,job_queue,result_queue,host,username,password):
00259         Worker.__init__(self,name,job_queue,result_queue)
00260         self.host = host
00261         self.username = username
00262         self.password = password        
00263     def run(self):
00264         import telnetlib
00265         self.tn = tn = telnetlib.Telnet(self.host)
00266         tn.read_until("login: ")
00267         tn.write(self.username + "\n")
00268         tn.read_until("Password: ")
00269         tn.write(self.password + "\n")
00270 
00271         # XXX: how to know whether login is successful?
00272         tn.read_until(self.username)
00273         # 
00274         print('login ok', self.host)
00275         tn.write("cd "+os.getcwd()+"\n")
00276         Worker.run(self)
00277         tn.write("exit\n")               
00278     def run_one(self,c,g):
00279         cmdline = '%s -c %s -g %s -v %s %s %s' % \
00280           (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
00281         result = self.tn.write(cmdline+'\n')
00282         (idx,matchm,output) = self.tn.expect(['Cross.*\n'])
00283         for line in output.split('\n'):
00284             if str(line).find("Cross") != -1:
00285                 return float(line.split()[-1][0:-1])
00286 
00287 def main():
00288 
00289     # set parameters
00290 
00291     process_options()
00292 
00293     # put jobs in queue
00294 
00295     jobs = calculate_jobs()
00296     job_queue = Queue.Queue(0)
00297     result_queue = Queue.Queue(0)
00298 
00299     for line in jobs:
00300         for (c,g) in line:
00301             job_queue.put((c,g))
00302 
00303     job_queue._put = job_queue.queue.appendleft
00304 
00305 
00306     # fire telnet workers
00307 
00308     if telnet_workers:
00309         nr_telnet_worker = len(telnet_workers)
00310         username = getpass.getuser()
00311         password = getpass.getpass()
00312         for host in telnet_workers:
00313             TelnetWorker(host,job_queue,result_queue,
00314                      host,username,password).start()
00315 
00316     # fire ssh workers
00317 
00318     if ssh_workers:
00319         for host in ssh_workers:
00320             SSHWorker(host,job_queue,result_queue,host).start()
00321 
00322     # fire local workers
00323 
00324     for i in range(nr_local_worker):
00325         LocalWorker('local',job_queue,result_queue).start()
00326 
00327     # gather results
00328 
00329     done_jobs = {}
00330 
00331 
00332     result_file = open(out_filename, 'w')
00333 
00334 
00335     db = []
00336     best_rate = -1
00337     best_c1,best_g1 = None,None
00338 
00339     for line in jobs:
00340         for (c,g) in line:
00341             while (c, g) not in done_jobs:
00342                 (worker,c1,g1,rate) = result_queue.get()
00343                 done_jobs[(c1,g1)] = rate
00344                 result_file.write('%s %s %s\n' %(c1,g1,rate))
00345                 result_file.flush()
00346                 if (rate > best_rate) or (rate==best_rate and g1==best_g1 and c1<best_c1):
00347                     best_rate = rate
00348                     best_c1,best_g1=c1,g1
00349                     best_c = 2.0**c1
00350                     best_g = 2.0**g1
00351                 print("[%s] %s %s %s (best c=%s, g=%s, rate=%s)" % \
00352                     (worker,c1,g1,rate, best_c, best_g, best_rate))
00353             db.append((c,g,done_jobs[(c,g)]))
00354         redraw(db,[best_c1, best_g1, best_rate])
00355         redraw(db,[best_c1, best_g1, best_rate],True)
00356 
00357     job_queue.put((WorkerStopToken,None))
00358     print("%s %s %s" % (best_c, best_g, best_rate))
00359 main()


libsvm3
Author(s): various
autogenerated on Wed Nov 27 2013 11:36:23