grid.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 
4 
5 import os, sys, traceback
6 import getpass
7 from threading import Thread
8 from subprocess import *
9 
10 if(sys.hexversion < 0x03000000):
11  import Queue
12 else:
13  import queue as Queue
14 
15 
16 # svmtrain and gnuplot executable
17 
18 is_win32 = (sys.platform == 'win32')
19 if not is_win32:
20  svmtrain_exe = "../svm-train"
21  gnuplot_exe = "/usr/bin/gnuplot"
22 else:
23  # example for windows
24  svmtrain_exe = r"..\windows\svm-train.exe"
25  # svmtrain_exe = r"c:\Program Files\libsvm\windows\svm-train.exe"
26  gnuplot_exe = r"c:\tmp\gnuplot\binary\pgnuplot.exe"
27 
28 # global parameters and their default values
29 
30 fold = 5
31 c_begin, c_end, c_step = -5, 15, 2
32 g_begin, g_end, g_step = 3, -15, -2
33 global dataset_pathname, dataset_title, pass_through_string
34 global out_filename, png_filename
35 
36 # experimental
37 
38 telnet_workers = []
39 ssh_workers = []
40 nr_local_worker = 1
41 
42 # process command line options, set global parameters
43 def process_options(argv=sys.argv):
44 
45  global fold
46  global c_begin, c_end, c_step
47  global g_begin, g_end, g_step
48  global dataset_pathname, dataset_title, pass_through_string
49  global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename
50 
51  usage = """\
52 Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold]
53 [-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname]
54 [additional parameters for svm-train] dataset"""
55 
56  if len(argv) < 2:
57  print(usage)
58  sys.exit(1)
59 
60  dataset_pathname = argv[-1]
61  dataset_title = os.path.split(dataset_pathname)[1]
62  out_filename = '{0}.out'.format(dataset_title)
63  png_filename = '{0}.png'.format(dataset_title)
64  pass_through_options = []
65 
66  i = 1
67  while i < len(argv) - 1:
68  if argv[i] == "-log2c":
69  i = i + 1
70  (c_begin,c_end,c_step) = map(float,argv[i].split(","))
71  elif argv[i] == "-log2g":
72  i = i + 1
73  (g_begin,g_end,g_step) = map(float,argv[i].split(","))
74  elif argv[i] == "-v":
75  i = i + 1
76  fold = argv[i]
77  elif argv[i] in ('-c','-g'):
78  print("Option -c and -g are renamed.")
79  print(usage)
80  sys.exit(1)
81  elif argv[i] == '-svmtrain':
82  i = i + 1
83  svmtrain_exe = argv[i]
84  elif argv[i] == '-gnuplot':
85  i = i + 1
86  gnuplot_exe = argv[i]
87  elif argv[i] == '-out':
88  i = i + 1
89  out_filename = argv[i]
90  elif argv[i] == '-png':
91  i = i + 1
92  png_filename = argv[i]
93  else:
94  pass_through_options.append(argv[i])
95  i = i + 1
96 
97  pass_through_string = " ".join(pass_through_options)
98  assert os.path.exists(svmtrain_exe),"svm-train executable not found"
99  assert os.path.exists(gnuplot_exe),"gnuplot executable not found"
100  assert os.path.exists(dataset_pathname),"dataset not found"
101  gnuplot = Popen(gnuplot_exe,stdin = PIPE).stdin
102 
103 
104 def range_f(begin,end,step):
105  # like range, but works on non-integer too
106  seq = []
107  while True:
108  if step > 0 and begin > end: break
109  if step < 0 and begin < end: break
110  seq.append(begin)
111  begin = begin + step
112  return seq
113 
115  n = len(seq)
116  if n <= 1: return seq
117 
118  mid = int(n/2)
119  left = permute_sequence(seq[:mid])
120  right = permute_sequence(seq[mid+1:])
121 
122  ret = [seq[mid]]
123  while left or right:
124  if left: ret.append(left.pop(0))
125  if right: ret.append(right.pop(0))
126 
127  return ret
128 
129 def redraw(db,best_param,tofile=False):
130  if len(db) == 0: return
131  begin_level = round(max(x[2] for x in db)) - 3
132  step_size = 0.5
133 
134  best_log2c,best_log2g,best_rate = best_param
135 
136  # if newly obtained c, g, or cv values are the same,
137  # then stop redrawing the contour.
138  if all(x[0] == db[0][0] for x in db): return
139  if all(x[1] == db[0][1] for x in db): return
140  if all(x[2] == db[0][2] for x in db): return
141 
142  if tofile:
143  gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n")
144  gnuplot.write("set output \"{0}\"\n".format(png_filename.replace('\\','\\\\')).encode())
145  #gnuplot.write(b"set term postscript color solid\n")
146  #gnuplot.write("set output \"{0}.ps\"\n".format(dataset_title).encode().encode())
147  elif is_win32:
148  gnuplot.write(b"set term windows\n")
149  else:
150  gnuplot.write( b"set term x11\n")
151  gnuplot.write(b"set xlabel \"log2(C)\"\n")
152  gnuplot.write(b"set ylabel \"log2(gamma)\"\n")
153  gnuplot.write("set xrange [{0}:{1}]\n".format(c_begin,c_end).encode())
154  gnuplot.write("set yrange [{0}:{1}]\n".format(g_begin,g_end).encode())
155  gnuplot.write(b"set contour\n")
156  gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode())
157  gnuplot.write(b"unset surface\n")
158  gnuplot.write(b"unset ztics\n")
159  gnuplot.write(b"set view 0,0\n")
160  gnuplot.write("set title \"{0}\"\n".format(dataset_title).encode())
161  gnuplot.write(b"unset label\n")
162  gnuplot.write("set label \"Best log2(C) = {0} log2(gamma) = {1} accuracy = {2}%\" \
163  at screen 0.5,0.85 center\n". \
164  format(best_log2c, best_log2g, best_rate).encode())
165  gnuplot.write("set label \"C = {0} gamma = {1}\""
166  " at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode())
167  gnuplot.write(b"set key at screen 0.9,0.9\n")
168  gnuplot.write(b"splot \"-\" with lines\n")
169 
170 
171 
172 
173  db.sort(key = lambda x:(x[0], -x[1]))
174 
175  prevc = db[0][0]
176  for line in db:
177  if prevc != line[0]:
178  gnuplot.write(b"\n")
179  prevc = line[0]
180  gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode())
181  gnuplot.write(b"e\n")
182  gnuplot.write(b"\n") # force gnuplot back to prompt when term set failure
183  gnuplot.flush()
184 
185 
187  c_seq = permute_sequence(range_f(c_begin,c_end,c_step))
188  g_seq = permute_sequence(range_f(g_begin,g_end,g_step))
189  nr_c = float(len(c_seq))
190  nr_g = float(len(g_seq))
191  i = 0
192  j = 0
193  jobs = []
194 
195  while i < nr_c or j < nr_g:
196  if i/nr_c < j/nr_g:
197  # increase C resolution
198  line = []
199  for k in range(0,j):
200  line.append((c_seq[i],g_seq[k]))
201  i = i + 1
202  jobs.append(line)
203  else:
204  # increase g resolution
205  line = []
206  for k in range(0,i):
207  line.append((c_seq[k],g_seq[j]))
208  j = j + 1
209  jobs.append(line)
210  return jobs
211 
212 class WorkerStopToken: # used to notify the worker to stop
213  pass
214 
215 class Worker(Thread):
216  def __init__(self,name,job_queue,result_queue):
217  Thread.__init__(self)
218  self.name = name
219  self.job_queue = job_queue
220  self.result_queue = result_queue
221  def run(self):
222  while True:
223  (cexp,gexp) = self.job_queue.get()
224  if cexp is WorkerStopToken:
225  self.job_queue.put((cexp,gexp))
226  # print('worker {0} stop.'.format(self.name))
227  break
228  try:
229  rate = self.run_one(2.0**cexp,2.0**gexp)
230  if rate is None: raise RuntimeError("get no rate")
231  except:
232  # we failed, let others do that and we just quit
233 
234  traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
235 
236  self.job_queue.put((cexp,gexp))
237  print('worker {0} quit.'.format(self.name))
238  break
239  else:
240  self.result_queue.put((self.name,cexp,gexp,rate))
241 
243  def run_one(self,c,g):
244  cmdline = '{0} -c {1} -g {2} -v {3} {4} {5}'.format \
245  (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
246  result = Popen(cmdline,shell=True,stdout=PIPE).stdout
247  for line in result.readlines():
248  if str(line).find("Cross") != -1:
249  return float(line.split()[-1][0:-1])
250 
252  def __init__(self,name,job_queue,result_queue,host):
253  Worker.__init__(self,name,job_queue,result_queue)
254  self.host = host
255  self.cwd = os.getcwd()
256  def run_one(self,c,g):
257  cmdline = 'ssh -x {0} "cd {1}; {2} -c {3} -g {4} -v {5} {6} {7}"'.format \
258  (self.host,self.cwd, \
259  svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
260  result = Popen(cmdline,shell=True,stdout=PIPE).stdout
261  for line in result.readlines():
262  if str(line).find("Cross") != -1:
263  return float(line.split()[-1][0:-1])
264 
266  def __init__(self,name,job_queue,result_queue,host,username,password):
267  Worker.__init__(self,name,job_queue,result_queue)
268  self.host = host
269  self.username = username
270  self.password = password
271  def run(self):
272  import telnetlib
273  self.tn = tn = telnetlib.Telnet(self.host)
274  tn.read_until("login: ")
275  tn.write(self.username + "\n")
276  tn.read_until("Password: ")
277  tn.write(self.password + "\n")
278 
279  # XXX: how to know whether login is successful?
280  tn.read_until(self.username)
281  #
282  print('login ok', self.host)
283  tn.write("cd "+os.getcwd()+"\n")
284  Worker.run(self)
285  tn.write("exit\n")
286  def run_one(self,c,g):
287  cmdline = '{0} -c {1} -g {2} -v {3} {4} {5}'.format \
288  (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
289  result = self.tn.write(cmdline+'\n')
290  (idx,matchm,output) = self.tn.expect(['Cross.*\n'])
291  for line in output.split('\n'):
292  if str(line).find("Cross") != -1:
293  return float(line.split()[-1][0:-1])
294 
295 def main():
296 
297  # set parameters
298 
300 
301  # put jobs in queue
302 
303  jobs = calculate_jobs()
304  job_queue = Queue.Queue(0)
305  result_queue = Queue.Queue(0)
306 
307  for line in jobs:
308  for (c,g) in line:
309  job_queue.put((c,g))
310 
311  # hack the queue to become a stack --
312  # this is important when some thread
313  # failed and re-put a job. It we still
314  # use FIFO, the job will be put
315  # into the end of the queue, and the graph
316  # will only be updated in the end
317 
318  job_queue._put = job_queue.queue.appendleft
319 
320 
321  # fire telnet workers
322 
323  if telnet_workers:
324  nr_telnet_worker = len(telnet_workers)
325  username = getpass.getuser()
326  password = getpass.getpass()
327  for host in telnet_workers:
328  TelnetWorker(host,job_queue,result_queue,
329  host,username,password).start()
330 
331  # fire ssh workers
332 
333  if ssh_workers:
334  for host in ssh_workers:
335  SSHWorker(host,job_queue,result_queue,host).start()
336 
337  # fire local workers
338 
339  for i in range(nr_local_worker):
340  LocalWorker('local',job_queue,result_queue).start()
341 
342  # gather results
343 
344  done_jobs = {}
345 
346 
347  result_file = open(out_filename, 'w')
348 
349 
350  db = []
351  best_rate = -1
352  best_c1,best_g1 = None,None
353 
354  for line in jobs:
355  for (c,g) in line:
356  while (c, g) not in done_jobs:
357  (worker,c1,g1,rate) = result_queue.get()
358  done_jobs[(c1,g1)] = rate
359  result_file.write('{0} {1} {2}\n'.format(c1,g1,rate))
360  result_file.flush()
361  if (rate > best_rate) or (rate==best_rate and g1==best_g1 and c1<best_c1):
362  best_rate = rate
363  best_c1,best_g1=c1,g1
364  best_c = 2.0**c1
365  best_g = 2.0**g1
366  print("[{0}] {1} {2} {3} (best c={4}, g={5}, rate={6})".format \
367  (worker,c1,g1,rate, best_c, best_g, best_rate))
368  db.append((c,g,done_jobs[(c,g)]))
369  redraw(db,[best_c1, best_g1, best_rate])
370  redraw(db,[best_c1, best_g1, best_rate],True)
371 
372  job_queue.put((WorkerStopToken,None))
373  print("{0} {1} {2}".format(best_c, best_g, best_rate))
374 main()
def run_one(self, c, g)
Definition: grid.py:243
#define max(x, y)
Definition: libsvmread.c:15
def __init__(self, name, job_queue, result_queue, host, username, password)
Definition: grid.py:266
result_queue
Definition: grid.py:220
def __init__(self, name, job_queue, result_queue)
Definition: grid.py:216
def calculate_jobs()
Definition: grid.py:186
def run(self)
Definition: grid.py:221
def process_options(argv=sys.argv)
Definition: grid.py:43
def run_one(self, c, g)
Definition: grid.py:256
def run(self)
Definition: grid.py:271
def main()
Definition: grid.py:295
def __init__(self, name, job_queue, result_queue, host)
Definition: grid.py:252
def redraw(db, best_param, tofile=False)
Definition: grid.py:129
def permute_sequence(seq)
Definition: grid.py:114
def range_f(begin, end, step)
Definition: grid.py:104
def run_one(self, c, g)
Definition: grid.py:286


ml_classifiers
Author(s): Scott Niekum , Joshua Whitley
autogenerated on Tue May 14 2019 02:28:35