00001
00002
00003
00004
00005 import os, sys, traceback
00006 import getpass
00007 from threading import Thread
00008 from subprocess import *
00009
00010 if(sys.hexversion < 0x03000000):
00011 import Queue
00012 else:
00013 import queue as Queue
00014
00015
00016
00017
00018 is_win32 = (sys.platform == 'win32')
00019 if not is_win32:
00020 svmtrain_exe = "../svm-train"
00021 gnuplot_exe = "/usr/bin/gnuplot"
00022 else:
00023
00024 svmtrain_exe = r"..\windows\svm-train.exe"
00025
00026 gnuplot_exe = r"c:\tmp\gnuplot\binary\pgnuplot.exe"
00027
00028
00029
00030 fold = 5
00031 c_begin, c_end, c_step = -5, 15, 2
00032 g_begin, g_end, g_step = 3, -15, -2
00033 global dataset_pathname, dataset_title, pass_through_string
00034 global out_filename, png_filename
00035
00036
00037
00038 telnet_workers = []
00039 ssh_workers = []
00040 nr_local_worker = 1
00041
00042
00043 def process_options(argv=sys.argv):
00044
00045 global fold
00046 global c_begin, c_end, c_step
00047 global g_begin, g_end, g_step
00048 global dataset_pathname, dataset_title, pass_through_string
00049 global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename
00050
00051 usage = """\
00052 Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold]
00053 [-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname]
00054 [additional parameters for svm-train] dataset"""
00055
00056 if len(argv) < 2:
00057 print(usage)
00058 sys.exit(1)
00059
00060 dataset_pathname = argv[-1]
00061 dataset_title = os.path.split(dataset_pathname)[1]
00062 out_filename = '{0}.out'.format(dataset_title)
00063 png_filename = '{0}.png'.format(dataset_title)
00064 pass_through_options = []
00065
00066 i = 1
00067 while i < len(argv) - 1:
00068 if argv[i] == "-log2c":
00069 i = i + 1
00070 (c_begin,c_end,c_step) = map(float,argv[i].split(","))
00071 elif argv[i] == "-log2g":
00072 i = i + 1
00073 (g_begin,g_end,g_step) = map(float,argv[i].split(","))
00074 elif argv[i] == "-v":
00075 i = i + 1
00076 fold = argv[i]
00077 elif argv[i] in ('-c','-g'):
00078 print("Option -c and -g are renamed.")
00079 print(usage)
00080 sys.exit(1)
00081 elif argv[i] == '-svmtrain':
00082 i = i + 1
00083 svmtrain_exe = argv[i]
00084 elif argv[i] == '-gnuplot':
00085 i = i + 1
00086 gnuplot_exe = argv[i]
00087 elif argv[i] == '-out':
00088 i = i + 1
00089 out_filename = argv[i]
00090 elif argv[i] == '-png':
00091 i = i + 1
00092 png_filename = argv[i]
00093 else:
00094 pass_through_options.append(argv[i])
00095 i = i + 1
00096
00097 pass_through_string = " ".join(pass_through_options)
00098 assert os.path.exists(svmtrain_exe),"svm-train executable not found"
00099 assert os.path.exists(gnuplot_exe),"gnuplot executable not found"
00100 assert os.path.exists(dataset_pathname),"dataset not found"
00101 gnuplot = Popen(gnuplot_exe,stdin = PIPE).stdin
00102
00103
00104 def range_f(begin,end,step):
00105
00106 seq = []
00107 while True:
00108 if step > 0 and begin > end: break
00109 if step < 0 and begin < end: break
00110 seq.append(begin)
00111 begin = begin + step
00112 return seq
00113
00114 def permute_sequence(seq):
00115 n = len(seq)
00116 if n <= 1: return seq
00117
00118 mid = int(n/2)
00119 left = permute_sequence(seq[:mid])
00120 right = permute_sequence(seq[mid+1:])
00121
00122 ret = [seq[mid]]
00123 while left or right:
00124 if left: ret.append(left.pop(0))
00125 if right: ret.append(right.pop(0))
00126
00127 return ret
00128
00129 def redraw(db,best_param,tofile=False):
00130 if len(db) == 0: return
00131 begin_level = round(max(x[2] for x in db)) - 3
00132 step_size = 0.5
00133
00134 best_log2c,best_log2g,best_rate = best_param
00135
00136
00137
00138 if all(x[0] == db[0][0] for x in db): return
00139 if all(x[1] == db[0][1] for x in db): return
00140 if all(x[2] == db[0][2] for x in db): return
00141
00142 if tofile:
00143 gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n")
00144 gnuplot.write("set output \"{0}\"\n".format(png_filename.replace('\\','\\\\')).encode())
00145
00146
00147 elif is_win32:
00148 gnuplot.write(b"set term windows\n")
00149 else:
00150 gnuplot.write( b"set term x11\n")
00151 gnuplot.write(b"set xlabel \"log2(C)\"\n")
00152 gnuplot.write(b"set ylabel \"log2(gamma)\"\n")
00153 gnuplot.write("set xrange [{0}:{1}]\n".format(c_begin,c_end).encode())
00154 gnuplot.write("set yrange [{0}:{1}]\n".format(g_begin,g_end).encode())
00155 gnuplot.write(b"set contour\n")
00156 gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode())
00157 gnuplot.write(b"unset surface\n")
00158 gnuplot.write(b"unset ztics\n")
00159 gnuplot.write(b"set view 0,0\n")
00160 gnuplot.write("set title \"{0}\"\n".format(dataset_title).encode())
00161 gnuplot.write(b"unset label\n")
00162 gnuplot.write("set label \"Best log2(C) = {0} log2(gamma) = {1} accuracy = {2}%\" \
00163 at screen 0.5,0.85 center\n". \
00164 format(best_log2c, best_log2g, best_rate).encode())
00165 gnuplot.write("set label \"C = {0} gamma = {1}\""
00166 " at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode())
00167 gnuplot.write(b"set key at screen 0.9,0.9\n")
00168 gnuplot.write(b"splot \"-\" with lines\n")
00169
00170
00171
00172
00173 db.sort(key = lambda x:(x[0], -x[1]))
00174
00175 prevc = db[0][0]
00176 for line in db:
00177 if prevc != line[0]:
00178 gnuplot.write(b"\n")
00179 prevc = line[0]
00180 gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode())
00181 gnuplot.write(b"e\n")
00182 gnuplot.write(b"\n")
00183 gnuplot.flush()
00184
00185
00186 def calculate_jobs():
00187 c_seq = permute_sequence(range_f(c_begin,c_end,c_step))
00188 g_seq = permute_sequence(range_f(g_begin,g_end,g_step))
00189 nr_c = float(len(c_seq))
00190 nr_g = float(len(g_seq))
00191 i = 0
00192 j = 0
00193 jobs = []
00194
00195 while i < nr_c or j < nr_g:
00196 if i/nr_c < j/nr_g:
00197
00198 line = []
00199 for k in range(0,j):
00200 line.append((c_seq[i],g_seq[k]))
00201 i = i + 1
00202 jobs.append(line)
00203 else:
00204
00205 line = []
00206 for k in range(0,i):
00207 line.append((c_seq[k],g_seq[j]))
00208 j = j + 1
00209 jobs.append(line)
00210 return jobs
00211
00212 class WorkerStopToken:
00213 pass
00214
00215 class Worker(Thread):
00216 def __init__(self,name,job_queue,result_queue):
00217 Thread.__init__(self)
00218 self.name = name
00219 self.job_queue = job_queue
00220 self.result_queue = result_queue
00221 def run(self):
00222 while True:
00223 (cexp,gexp) = self.job_queue.get()
00224 if cexp is WorkerStopToken:
00225 self.job_queue.put((cexp,gexp))
00226
00227 break
00228 try:
00229 rate = self.run_one(2.0**cexp,2.0**gexp)
00230 if rate is None: raise RuntimeError("get no rate")
00231 except:
00232
00233
00234 traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
00235
00236 self.job_queue.put((cexp,gexp))
00237 print('worker {0} quit.'.format(self.name))
00238 break
00239 else:
00240 self.result_queue.put((self.name,cexp,gexp,rate))
00241
00242 class LocalWorker(Worker):
00243 def run_one(self,c,g):
00244 cmdline = '{0} -c {1} -g {2} -v {3} {4} {5}'.format \
00245 (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
00246 result = Popen(cmdline,shell=True,stdout=PIPE).stdout
00247 for line in result.readlines():
00248 if str(line).find("Cross") != -1:
00249 return float(line.split()[-1][0:-1])
00250
00251 class SSHWorker(Worker):
00252 def __init__(self,name,job_queue,result_queue,host):
00253 Worker.__init__(self,name,job_queue,result_queue)
00254 self.host = host
00255 self.cwd = os.getcwd()
00256 def run_one(self,c,g):
00257 cmdline = 'ssh -x {0} "cd {1}; {2} -c {3} -g {4} -v {5} {6} {7}"'.format \
00258 (self.host,self.cwd, \
00259 svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
00260 result = Popen(cmdline,shell=True,stdout=PIPE).stdout
00261 for line in result.readlines():
00262 if str(line).find("Cross") != -1:
00263 return float(line.split()[-1][0:-1])
00264
00265 class TelnetWorker(Worker):
00266 def __init__(self,name,job_queue,result_queue,host,username,password):
00267 Worker.__init__(self,name,job_queue,result_queue)
00268 self.host = host
00269 self.username = username
00270 self.password = password
00271 def run(self):
00272 import telnetlib
00273 self.tn = tn = telnetlib.Telnet(self.host)
00274 tn.read_until("login: ")
00275 tn.write(self.username + "\n")
00276 tn.read_until("Password: ")
00277 tn.write(self.password + "\n")
00278
00279
00280 tn.read_until(self.username)
00281
00282 print('login ok', self.host)
00283 tn.write("cd "+os.getcwd()+"\n")
00284 Worker.run(self)
00285 tn.write("exit\n")
00286 def run_one(self,c,g):
00287 cmdline = '{0} -c {1} -g {2} -v {3} {4} {5}'.format \
00288 (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
00289 result = self.tn.write(cmdline+'\n')
00290 (idx,matchm,output) = self.tn.expect(['Cross.*\n'])
00291 for line in output.split('\n'):
00292 if str(line).find("Cross") != -1:
00293 return float(line.split()[-1][0:-1])
00294
00295 def main():
00296
00297
00298
00299 process_options()
00300
00301
00302
00303 jobs = calculate_jobs()
00304 job_queue = Queue.Queue(0)
00305 result_queue = Queue.Queue(0)
00306
00307 for line in jobs:
00308 for (c,g) in line:
00309 job_queue.put((c,g))
00310
00311
00312
00313
00314
00315
00316
00317
00318 job_queue._put = job_queue.queue.appendleft
00319
00320
00321
00322
00323 if telnet_workers:
00324 nr_telnet_worker = len(telnet_workers)
00325 username = getpass.getuser()
00326 password = getpass.getpass()
00327 for host in telnet_workers:
00328 TelnetWorker(host,job_queue,result_queue,
00329 host,username,password).start()
00330
00331
00332
00333 if ssh_workers:
00334 for host in ssh_workers:
00335 SSHWorker(host,job_queue,result_queue,host).start()
00336
00337
00338
00339 for i in range(nr_local_worker):
00340 LocalWorker('local',job_queue,result_queue).start()
00341
00342
00343
00344 done_jobs = {}
00345
00346
00347 result_file = open(out_filename, 'w')
00348
00349
00350 db = []
00351 best_rate = -1
00352 best_c1,best_g1 = None,None
00353
00354 for line in jobs:
00355 for (c,g) in line:
00356 while (c, g) not in done_jobs:
00357 (worker,c1,g1,rate) = result_queue.get()
00358 done_jobs[(c1,g1)] = rate
00359 result_file.write('{0} {1} {2}\n'.format(c1,g1,rate))
00360 result_file.flush()
00361 if (rate > best_rate) or (rate==best_rate and g1==best_g1 and c1<best_c1):
00362 best_rate = rate
00363 best_c1,best_g1=c1,g1
00364 best_c = 2.0**c1
00365 best_g = 2.0**g1
00366 print("[{0}] {1} {2} {3} (best c={4}, g={5}, rate={6})".format \
00367 (worker,c1,g1,rate, best_c, best_g, best_rate))
00368 db.append((c,g,done_jobs[(c,g)]))
00369 redraw(db,[best_c1, best_g1, best_rate])
00370 redraw(db,[best_c1, best_g1, best_rate],True)
00371
00372 job_queue.put((WorkerStopToken,None))
00373 print("{0} {1} {2}".format(best_c, best_g, best_rate))
00374 main()