00001
00002
00003
00004
00005 import os, sys, traceback
00006 import getpass
00007 from threading import Thread
00008 from subprocess import *
00009
00010 if(sys.hexversion < 0x03000000):
00011 import Queue
00012 else:
00013 import queue as Queue
00014
00015
00016
00017
00018 is_win32 = (sys.platform == 'win32')
00019 if not is_win32:
00020 svmtrain_exe = "../svm-train"
00021 gnuplot_exe = "/usr/bin/gnuplot"
00022 else:
00023
00024 svmtrain_exe = r"..\windows\svm-train.exe"
00025 gnuplot_exe = r"c:\tmp\gnuplot\bin\pgnuplot.exe"
00026
00027
00028
00029 fold = 5
00030 c_begin, c_end, c_step = -5, 15, 2
00031 g_begin, g_end, g_step = 3, -15, -2
00032 global dataset_pathname, dataset_title, pass_through_string
00033 global out_filename, png_filename
00034
00035
00036
00037 telnet_workers = []
00038 ssh_workers = []
00039 nr_local_worker = 1
00040
00041
00042 def process_options(argv=sys.argv):
00043
00044 global fold
00045 global c_begin, c_end, c_step
00046 global g_begin, g_end, g_step
00047 global dataset_pathname, dataset_title, pass_through_string
00048 global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename
00049
00050 usage = """\
00051 Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold]
00052 [-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname]
00053 [additional parameters for svm-train] dataset"""
00054
00055 if len(argv) < 2:
00056 print(usage)
00057 sys.exit(1)
00058
00059 dataset_pathname = argv[-1]
00060 dataset_title = os.path.split(dataset_pathname)[1]
00061 out_filename = '%s.out' % dataset_title
00062 png_filename = '%s.png' % dataset_title
00063 pass_through_options = []
00064
00065 i = 1
00066 while i < len(argv) - 1:
00067 if argv[i] == "-log2c":
00068 i = i + 1
00069 (c_begin,c_end,c_step) = map(float,argv[i].split(","))
00070 elif argv[i] == "-log2g":
00071 i = i + 1
00072 (g_begin,g_end,g_step) = map(float,argv[i].split(","))
00073 elif argv[i] == "-v":
00074 i = i + 1
00075 fold = argv[i]
00076 elif argv[i] in ('-c','-g'):
00077 print("Option -c and -g are renamed.")
00078 print(usage)
00079 sys.exit(1)
00080 elif argv[i] == '-svmtrain':
00081 i = i + 1
00082 svmtrain_exe = argv[i]
00083 elif argv[i] == '-gnuplot':
00084 i = i + 1
00085 gnuplot_exe = argv[i]
00086 elif argv[i] == '-out':
00087 i = i + 1
00088 out_filename = argv[i]
00089 elif argv[i] == '-png':
00090 i = i + 1
00091 png_filename = argv[i]
00092 else:
00093 pass_through_options.append(argv[i])
00094 i = i + 1
00095
00096 pass_through_string = " ".join(pass_through_options)
00097 assert os.path.exists(svmtrain_exe),"svm-train executable not found"
00098 assert os.path.exists(gnuplot_exe),"gnuplot executable not found"
00099 assert os.path.exists(dataset_pathname),"dataset not found"
00100 gnuplot = Popen(gnuplot_exe,stdin = PIPE).stdin
00101
00102
00103 def range_f(begin,end,step):
00104
00105 seq = []
00106 while True:
00107 if step > 0 and begin > end: break
00108 if step < 0 and begin < end: break
00109 seq.append(begin)
00110 begin = begin + step
00111 return seq
00112
00113 def permute_sequence(seq):
00114 n = len(seq)
00115 if n <= 1: return seq
00116
00117 mid = int(n/2)
00118 left = permute_sequence(seq[:mid])
00119 right = permute_sequence(seq[mid+1:])
00120
00121 ret = [seq[mid]]
00122 while left or right:
00123 if left: ret.append(left.pop(0))
00124 if right: ret.append(right.pop(0))
00125
00126 return ret
00127
00128 def redraw(db,best_param,tofile=False):
00129 if len(db) == 0: return
00130 begin_level = round(max(x[2] for x in db)) - 3
00131 step_size = 0.5
00132
00133 best_log2c,best_log2g,best_rate = best_param
00134
00135 if tofile:
00136 gnuplot.write( "set term png transparent small\n".encode())
00137 gnuplot.write( ("set output \"%s\"\n" % png_filename.replace('\\','\\\\')).encode())
00138
00139
00140 elif is_win32:
00141 gnuplot.write("set term windows\n".encode())
00142 else:
00143 gnuplot.write( "set term x11\n".encode())
00144 gnuplot.write("set xlabel \"log2(C)\"\n".encode())
00145 gnuplot.write("set ylabel \"log2(gamma)\"\n".encode())
00146 gnuplot.write(("set xrange [%s:%s]\n" % (c_begin,c_end)).encode())
00147 gnuplot.write(("set yrange [%s:%s]\n" % (g_begin,g_end)).encode())
00148 gnuplot.write("set contour\n".encode())
00149 gnuplot.write(("set cntrparam levels incremental %s,%s,100\n" % (begin_level,step_size)).encode())
00150 gnuplot.write("unset surface\n".encode())
00151 gnuplot.write("unset ztics\n".encode())
00152 gnuplot.write("set view 0,0\n".encode())
00153 gnuplot.write(("set title \"%s\"\n" % dataset_title).encode())
00154 gnuplot.write("unset label\n".encode())
00155 gnuplot.write(("set label \"Best log2(C) = %s log2(gamma) = %s accuracy = %s%%\" \
00156 at screen 0.5,0.85 center\n" % \
00157 (best_log2c, best_log2g, best_rate)).encode())
00158 gnuplot.write(("set label \"C = %s gamma = %s\""
00159 " at screen 0.5,0.8 center\n" % (2**best_log2c, 2**best_log2g)).encode())
00160 gnuplot.write("splot \"-\" with lines\n".encode())
00161
00162
00163
00164
00165 db.sort(key = lambda x:(x[0], -x[1]))
00166
00167 prevc = db[0][0]
00168 for line in db:
00169 if prevc != line[0]:
00170 gnuplot.write("\n".encode())
00171 prevc = line[0]
00172 gnuplot.write(("%s %s %s\n" % line).encode())
00173 gnuplot.write("e\n".encode())
00174 gnuplot.write("\n".encode())
00175 gnuplot.flush()
00176
00177
00178 def calculate_jobs():
00179 c_seq = permute_sequence(range_f(c_begin,c_end,c_step))
00180 g_seq = permute_sequence(range_f(g_begin,g_end,g_step))
00181 nr_c = float(len(c_seq))
00182 nr_g = float(len(g_seq))
00183 i = 0
00184 j = 0
00185 jobs = []
00186
00187 while i < nr_c or j < nr_g:
00188 if i/nr_c < j/nr_g:
00189
00190 line = []
00191 for k in range(0,j):
00192 line.append((c_seq[i],g_seq[k]))
00193 i = i + 1
00194 jobs.append(line)
00195 else:
00196
00197 line = []
00198 for k in range(0,i):
00199 line.append((c_seq[k],g_seq[j]))
00200 j = j + 1
00201 jobs.append(line)
00202 return jobs
00203
00204 class WorkerStopToken:
00205 pass
00206
00207 class Worker(Thread):
00208 def __init__(self,name,job_queue,result_queue):
00209 Thread.__init__(self)
00210 self.name = name
00211 self.job_queue = job_queue
00212 self.result_queue = result_queue
00213 def run(self):
00214 while True:
00215 (cexp,gexp) = self.job_queue.get()
00216 if cexp is WorkerStopToken:
00217 self.job_queue.put((cexp,gexp))
00218
00219 break
00220 try:
00221 rate = self.run_one(2.0**cexp,2.0**gexp)
00222 if rate is None: raise RuntimeError("get no rate")
00223 except:
00224
00225
00226 traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
00227
00228 self.job_queue.put((cexp,gexp))
00229 print('worker %s quit.' % self.name)
00230 break
00231 else:
00232 self.result_queue.put((self.name,cexp,gexp,rate))
00233
00234 class LocalWorker(Worker):
00235 def run_one(self,c,g):
00236 cmdline = '%s -c %s -g %s -v %s %s %s' % \
00237 (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
00238 result = Popen(cmdline,shell=True,stdout=PIPE).stdout
00239 for line in result.readlines():
00240 if str(line).find("Cross") != -1:
00241 return float(line.split()[-1][0:-1])
00242
00243 class SSHWorker(Worker):
00244 def __init__(self,name,job_queue,result_queue,host):
00245 Worker.__init__(self,name,job_queue,result_queue)
00246 self.host = host
00247 self.cwd = os.getcwd()
00248 def run_one(self,c,g):
00249 cmdline = 'ssh -x %s "cd %s; %s -c %s -g %s -v %s %s %s"' % \
00250 (self.host,self.cwd,
00251 svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
00252 result = Popen(cmdline,shell=True,stdout=PIPE).stdout
00253 for line in result.readlines():
00254 if str(line).find("Cross") != -1:
00255 return float(line.split()[-1][0:-1])
00256
00257 class TelnetWorker(Worker):
00258 def __init__(self,name,job_queue,result_queue,host,username,password):
00259 Worker.__init__(self,name,job_queue,result_queue)
00260 self.host = host
00261 self.username = username
00262 self.password = password
00263 def run(self):
00264 import telnetlib
00265 self.tn = tn = telnetlib.Telnet(self.host)
00266 tn.read_until("login: ")
00267 tn.write(self.username + "\n")
00268 tn.read_until("Password: ")
00269 tn.write(self.password + "\n")
00270
00271
00272 tn.read_until(self.username)
00273
00274 print('login ok', self.host)
00275 tn.write("cd "+os.getcwd()+"\n")
00276 Worker.run(self)
00277 tn.write("exit\n")
00278 def run_one(self,c,g):
00279 cmdline = '%s -c %s -g %s -v %s %s %s' % \
00280 (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
00281 result = self.tn.write(cmdline+'\n')
00282 (idx,matchm,output) = self.tn.expect(['Cross.*\n'])
00283 for line in output.split('\n'):
00284 if str(line).find("Cross") != -1:
00285 return float(line.split()[-1][0:-1])
00286
00287 def main():
00288
00289
00290
00291 process_options()
00292
00293
00294
00295 jobs = calculate_jobs()
00296 job_queue = Queue.Queue(0)
00297 result_queue = Queue.Queue(0)
00298
00299 for line in jobs:
00300 for (c,g) in line:
00301 job_queue.put((c,g))
00302
00303 job_queue._put = job_queue.queue.appendleft
00304
00305
00306
00307
00308 if telnet_workers:
00309 nr_telnet_worker = len(telnet_workers)
00310 username = getpass.getuser()
00311 password = getpass.getpass()
00312 for host in telnet_workers:
00313 TelnetWorker(host,job_queue,result_queue,
00314 host,username,password).start()
00315
00316
00317
00318 if ssh_workers:
00319 for host in ssh_workers:
00320 SSHWorker(host,job_queue,result_queue,host).start()
00321
00322
00323
00324 for i in range(nr_local_worker):
00325 LocalWorker('local',job_queue,result_queue).start()
00326
00327
00328
00329 done_jobs = {}
00330
00331
00332 result_file = open(out_filename, 'w')
00333
00334
00335 db = []
00336 best_rate = -1
00337 best_c1,best_g1 = None,None
00338
00339 for line in jobs:
00340 for (c,g) in line:
00341 while (c, g) not in done_jobs:
00342 (worker,c1,g1,rate) = result_queue.get()
00343 done_jobs[(c1,g1)] = rate
00344 result_file.write('%s %s %s\n' %(c1,g1,rate))
00345 result_file.flush()
00346 if (rate > best_rate) or (rate==best_rate and g1==best_g1 and c1<best_c1):
00347 best_rate = rate
00348 best_c1,best_g1=c1,g1
00349 best_c = 2.0**c1
00350 best_g = 2.0**g1
00351 print("[%s] %s %s %s (best c=%s, g=%s, rate=%s)" % \
00352 (worker,c1,g1,rate, best_c, best_g, best_rate))
00353 db.append((c,g,done_jobs[(c,g)]))
00354 redraw(db,[best_c1, best_g1, best_rate])
00355 redraw(db,[best_c1, best_g1, best_rate],True)
00356
00357 job_queue.put((WorkerStopToken,None))
00358 print("%s %s %s" % (best_c, best_g, best_rate))
00359 main()