svm_toy.java
Go to the documentation of this file.
00001 import libsvm.*;
00002 import java.applet.*;
00003 import java.awt.*;
00004 import java.util.*;
00005 import java.awt.event.*;
00006 import java.io.*;
00007 
00008 public class svm_toy extends Applet {
00009 
00010         static final String DEFAULT_PARAM="-t 2 -c 100";
00011         int XLEN;
00012         int YLEN;
00013 
00014         // off-screen buffer
00015 
00016         Image buffer;
00017         Graphics buffer_gc;
00018 
00019         // pre-allocated colors
00020 
00021         final static Color colors[] =
00022         {
00023           new Color(0,0,0),
00024           new Color(0,120,120),
00025           new Color(120,120,0),
00026           new Color(120,0,120),
00027           new Color(0,200,200),
00028           new Color(200,200,0),
00029           new Color(200,0,200)
00030         };
00031 
00032         class point {
00033                 point(double x, double y, byte value)
00034                 {
00035                         this.x = x;
00036                         this.y = y;
00037                         this.value = value;
00038                 }
00039                 double x, y;
00040                 byte value;
00041         }
00042 
00043         Vector<point> point_list = new Vector<point>();
00044         byte current_value = 1;
00045 
00046         public void init()
00047         {
00048                 setSize(getSize());
00049 
00050                 final Button button_change = new Button("Change");
00051                 Button button_run = new Button("Run");
00052                 Button button_clear = new Button("Clear");
00053                 Button button_save = new Button("Save");
00054                 Button button_load = new Button("Load");
00055                 final TextField input_line = new TextField(DEFAULT_PARAM);
00056 
00057                 BorderLayout layout = new BorderLayout();
00058                 this.setLayout(layout);
00059 
00060                 Panel p = new Panel();
00061                 GridBagLayout gridbag = new GridBagLayout();
00062                 p.setLayout(gridbag);
00063 
00064                 GridBagConstraints c = new GridBagConstraints();
00065                 c.fill = GridBagConstraints.HORIZONTAL;
00066                 c.weightx = 1;
00067                 c.gridwidth = 1;
00068                 gridbag.setConstraints(button_change,c);
00069                 gridbag.setConstraints(button_run,c);
00070                 gridbag.setConstraints(button_clear,c);
00071                 gridbag.setConstraints(button_save,c);
00072                 gridbag.setConstraints(button_load,c);
00073                 c.weightx = 5;
00074                 c.gridwidth = 5;
00075                 gridbag.setConstraints(input_line,c);
00076 
00077                 button_change.setBackground(colors[current_value]);
00078 
00079                 p.add(button_change);
00080                 p.add(button_run);
00081                 p.add(button_clear);
00082                 p.add(button_save);
00083                 p.add(button_load);
00084                 p.add(input_line);
00085                 this.add(p,BorderLayout.SOUTH);
00086 
00087                 button_change.addActionListener(new ActionListener()
00088                 { public void actionPerformed (ActionEvent e)
00089                   { button_change_clicked(); button_change.setBackground(colors[current_value]); }});
00090 
00091                 button_run.addActionListener(new ActionListener()
00092                 { public void actionPerformed (ActionEvent e)
00093                   { button_run_clicked(input_line.getText()); }});
00094 
00095                 button_clear.addActionListener(new ActionListener()
00096                 { public void actionPerformed (ActionEvent e)
00097                   { button_clear_clicked(); }});
00098 
00099                 button_save.addActionListener(new ActionListener()
00100                 { public void actionPerformed (ActionEvent e)
00101                   { button_save_clicked(input_line.getText()); }});
00102 
00103                 button_load.addActionListener(new ActionListener()
00104                 { public void actionPerformed (ActionEvent e)
00105                   { button_load_clicked(); }});
00106 
00107                 input_line.addActionListener(new ActionListener()
00108                 { public void actionPerformed (ActionEvent e)
00109                   { button_run_clicked(input_line.getText()); }});
00110 
00111                 this.enableEvents(AWTEvent.MOUSE_EVENT_MASK);
00112         }
00113 
00114         void draw_point(point p)
00115         {
00116                 Color c = colors[p.value+3];
00117 
00118                 Graphics window_gc = getGraphics();
00119                 buffer_gc.setColor(c);
00120                 buffer_gc.fillRect((int)(p.x*XLEN),(int)(p.y*YLEN),4,4);
00121                 window_gc.setColor(c);
00122                 window_gc.fillRect((int)(p.x*XLEN),(int)(p.y*YLEN),4,4);
00123         }
00124 
00125         void clear_all()
00126         {
00127                 point_list.removeAllElements();
00128                 if(buffer != null)
00129                 {
00130                         buffer_gc.setColor(colors[0]);
00131                         buffer_gc.fillRect(0,0,XLEN,YLEN);
00132                 }
00133                 repaint();
00134         }
00135 
00136         void draw_all_points()
00137         {
00138                 int n = point_list.size();
00139                 for(int i=0;i<n;i++)
00140                         draw_point(point_list.elementAt(i));
00141         }
00142 
00143         void button_change_clicked()
00144         {
00145                 ++current_value;
00146                 if(current_value > 3) current_value = 1;
00147         }
00148 
00149         private static double atof(String s)
00150         {
00151                 return Double.valueOf(s).doubleValue();
00152         }
00153 
00154         private static int atoi(String s)
00155         {
00156                 return Integer.parseInt(s);
00157         }
00158 
00159         void button_run_clicked(String args)
00160         {
00161                 // guard
00162                 if(point_list.isEmpty()) return;
00163 
00164                 svm_parameter param = new svm_parameter();
00165 
00166                 // default values
00167                 param.svm_type = svm_parameter.C_SVC;
00168                 param.kernel_type = svm_parameter.RBF;
00169                 param.degree = 3;
00170                 param.gamma = 0;
00171                 param.coef0 = 0;
00172                 param.nu = 0.5;
00173                 param.cache_size = 40;
00174                 param.C = 1;
00175                 param.eps = 1e-3;
00176                 param.p = 0.1;
00177                 param.shrinking = 1;
00178                 param.probability = 0;
00179                 param.nr_weight = 0;
00180                 param.weight_label = new int[0];
00181                 param.weight = new double[0];
00182 
00183                 // parse options
00184                 StringTokenizer st = new StringTokenizer(args);
00185                 String[] argv = new String[st.countTokens()];
00186                 for(int i=0;i<argv.length;i++)
00187                         argv[i] = st.nextToken();
00188 
00189                 for(int i=0;i<argv.length;i++)
00190                 {
00191                         if(argv[i].charAt(0) != '-') break;
00192                         if(++i>=argv.length)
00193                         {
00194                                 System.err.print("unknown option\n");
00195                                 break;
00196                         }
00197                         switch(argv[i-1].charAt(1))
00198                         {
00199                                 case 's':
00200                                         param.svm_type = atoi(argv[i]);
00201                                         break;
00202                                 case 't':
00203                                         param.kernel_type = atoi(argv[i]);
00204                                         break;
00205                                 case 'd':
00206                                         param.degree = atoi(argv[i]);
00207                                         break;
00208                                 case 'g':
00209                                         param.gamma = atof(argv[i]);
00210                                         break;
00211                                 case 'r':
00212                                         param.coef0 = atof(argv[i]);
00213                                         break;
00214                                 case 'n':
00215                                         param.nu = atof(argv[i]);
00216                                         break;
00217                                 case 'm':
00218                                         param.cache_size = atof(argv[i]);
00219                                         break;
00220                                 case 'c':
00221                                         param.C = atof(argv[i]);
00222                                         break;
00223                                 case 'e':
00224                                         param.eps = atof(argv[i]);
00225                                         break;
00226                                 case 'p':
00227                                         param.p = atof(argv[i]);
00228                                         break;
00229                                 case 'h':
00230                                         param.shrinking = atoi(argv[i]);
00231                                         break;
00232                                 case 'b':
00233                                         param.probability = atoi(argv[i]);
00234                                         break;
00235                                 case 'w':
00236                                         ++param.nr_weight;
00237                                         {
00238                                                 int[] old = param.weight_label;
00239                                                 param.weight_label = new int[param.nr_weight];
00240                                                 System.arraycopy(old,0,param.weight_label,0,param.nr_weight-1);
00241                                         }
00242 
00243                                         {
00244                                                 double[] old = param.weight;
00245                                                 param.weight = new double[param.nr_weight];
00246                                                 System.arraycopy(old,0,param.weight,0,param.nr_weight-1);
00247                                         }
00248 
00249                                         param.weight_label[param.nr_weight-1] = atoi(argv[i-1].substring(2));
00250                                         param.weight[param.nr_weight-1] = atof(argv[i]);
00251                                         break;
00252                                 default:
00253                                         System.err.print("unknown option\n");
00254                         }
00255                 }
00256 
00257                 // build problem
00258                 svm_problem prob = new svm_problem();
00259                 prob.l = point_list.size();
00260                 prob.y = new double[prob.l];
00261 
00262                 if(param.kernel_type == svm_parameter.PRECOMPUTED)
00263                 {
00264                 }
00265                 else if(param.svm_type == svm_parameter.EPSILON_SVR ||
00266                         param.svm_type == svm_parameter.NU_SVR)
00267                 {
00268                         if(param.gamma == 0) param.gamma = 1;
00269                         prob.x = new svm_node[prob.l][1];
00270                         for(int i=0;i<prob.l;i++)
00271                         {
00272                                 point p = point_list.elementAt(i);
00273                                 prob.x[i][0] = new svm_node();
00274                                 prob.x[i][0].index = 1;
00275                                 prob.x[i][0].value = p.x;
00276                                 prob.y[i] = p.y;
00277                         }
00278 
00279                         // build model & classify
00280                         svm_model model = svm.svm_train(prob, param);
00281                         svm_node[] x = new svm_node[1];
00282                         x[0] = new svm_node();
00283                         x[0].index = 1;
00284                         int[] j = new int[XLEN];
00285 
00286                         Graphics window_gc = getGraphics();
00287                         for (int i = 0; i < XLEN; i++)
00288                         {
00289                                 x[0].value = (double) i / XLEN;
00290                                 j[i] = (int)(YLEN*svm.svm_predict(model, x));
00291                         }
00292                         
00293                         buffer_gc.setColor(colors[0]);
00294                         buffer_gc.drawLine(0,0,0,YLEN-1);
00295                         window_gc.setColor(colors[0]);
00296                         window_gc.drawLine(0,0,0,YLEN-1);
00297                         
00298                         int p = (int)(param.p * YLEN);
00299                         for(int i=1;i<XLEN;i++)
00300                         {
00301                                 buffer_gc.setColor(colors[0]);
00302                                 buffer_gc.drawLine(i,0,i,YLEN-1);
00303                                 window_gc.setColor(colors[0]);
00304                                 window_gc.drawLine(i,0,i,YLEN-1);
00305 
00306                                 buffer_gc.setColor(colors[5]);
00307                                 window_gc.setColor(colors[5]);
00308                                 buffer_gc.drawLine(i-1,j[i-1],i,j[i]);
00309                                 window_gc.drawLine(i-1,j[i-1],i,j[i]);
00310 
00311                                 if(param.svm_type == svm_parameter.EPSILON_SVR)
00312                                 {
00313                                         buffer_gc.setColor(colors[2]);
00314                                         window_gc.setColor(colors[2]);
00315                                         buffer_gc.drawLine(i-1,j[i-1]+p,i,j[i]+p);
00316                                         window_gc.drawLine(i-1,j[i-1]+p,i,j[i]+p);
00317 
00318                                         buffer_gc.setColor(colors[2]);
00319                                         window_gc.setColor(colors[2]);
00320                                         buffer_gc.drawLine(i-1,j[i-1]-p,i,j[i]-p);
00321                                         window_gc.drawLine(i-1,j[i-1]-p,i,j[i]-p);
00322                                 }
00323                         }
00324                 }
00325                 else
00326                 {
00327                         if(param.gamma == 0) param.gamma = 0.5;
00328                         prob.x = new svm_node [prob.l][2];
00329                         for(int i=0;i<prob.l;i++)
00330                         {
00331                                 point p = point_list.elementAt(i);
00332                                 prob.x[i][0] = new svm_node();
00333                                 prob.x[i][0].index = 1;
00334                                 prob.x[i][0].value = p.x;
00335                                 prob.x[i][1] = new svm_node();
00336                                 prob.x[i][1].index = 2;
00337                                 prob.x[i][1].value = p.y;
00338                                 prob.y[i] = p.value;
00339                         }
00340 
00341                         // build model & classify
00342                         svm_model model = svm.svm_train(prob, param);
00343                         svm_node[] x = new svm_node[2];
00344                         x[0] = new svm_node();
00345                         x[1] = new svm_node();
00346                         x[0].index = 1;
00347                         x[1].index = 2;
00348 
00349                         Graphics window_gc = getGraphics();
00350                         for (int i = 0; i < XLEN; i++)
00351                                 for (int j = 0; j < YLEN ; j++) {
00352                                         x[0].value = (double) i / XLEN;
00353                                         x[1].value = (double) j / YLEN;
00354                                         double d = svm.svm_predict(model, x);
00355                                         if (param.svm_type == svm_parameter.ONE_CLASS && d<0) d=2;
00356                                         buffer_gc.setColor(colors[(int)d]);
00357                                         window_gc.setColor(colors[(int)d]);
00358                                         buffer_gc.drawLine(i,j,i,j);
00359                                         window_gc.drawLine(i,j,i,j);
00360                         }
00361                 }
00362 
00363                 draw_all_points();
00364         }
00365 
00366         void button_clear_clicked()
00367         {
00368                 clear_all();
00369         }
00370 
00371         void button_save_clicked(String args)
00372         {
00373                 FileDialog dialog = new FileDialog(new Frame(),"Save",FileDialog.SAVE);
00374                 dialog.setVisible(true);
00375                 String filename = dialog.getDirectory() + dialog.getFile();
00376                 if (filename == null) return;
00377                 try {
00378                         DataOutputStream fp = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(filename)));
00379 
00380                         int svm_type = svm_parameter.C_SVC;
00381                         int svm_type_idx = args.indexOf("-s ");
00382                         if(svm_type_idx != -1)
00383                         {
00384                                 StringTokenizer svm_str_st = new StringTokenizer(args.substring(svm_type_idx+2).trim());
00385                                 svm_type = atoi(svm_str_st.nextToken());
00386                         }
00387 
00388                         int n = point_list.size();
00389                         if(svm_type == svm_parameter.EPSILON_SVR || svm_type == svm_parameter.NU_SVR)
00390                         {
00391                                 for(int i=0;i<n;i++)
00392                                 {
00393                                         point p = point_list.elementAt(i);
00394                                         fp.writeBytes(p.y+" 1:"+p.x+"\n");
00395                                 }
00396                         }
00397                         else
00398                         {
00399                                 for(int i=0;i<n;i++)
00400                                 {
00401                                         point p = point_list.elementAt(i);
00402                                         fp.writeBytes(p.value+" 1:"+p.x+" 2:"+p.y+"\n");
00403                                 }
00404                         }
00405                         fp.close();
00406                 } catch (IOException e) { System.err.print(e); }
00407         }
00408 
00409         void button_load_clicked()
00410         {
00411                 FileDialog dialog = new FileDialog(new Frame(),"Load",FileDialog.LOAD);
00412                 dialog.setVisible(true);
00413                 String filename = dialog.getDirectory() + dialog.getFile();
00414                 if (filename == null) return;
00415                 clear_all();
00416                 try {
00417                         BufferedReader fp = new BufferedReader(new FileReader(filename));
00418                         String line;
00419                         while((line = fp.readLine()) != null)
00420                         {
00421                                 StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
00422                                 if(st.countTokens() == 5)
00423                                 {
00424                                         byte value = (byte)atoi(st.nextToken());
00425                                         st.nextToken();
00426                                         double x = atof(st.nextToken());
00427                                         st.nextToken();
00428                                         double y = atof(st.nextToken());
00429                                         point_list.addElement(new point(x,y,value));
00430                                 }
00431                                 else if(st.countTokens() == 3)
00432                                 {
00433                                         double y = atof(st.nextToken());
00434                                         st.nextToken();
00435                                         double x = atof(st.nextToken());
00436                                         point_list.addElement(new point(x,y,current_value));
00437                                 }else
00438                                         break;
00439                         }
00440                         fp.close();
00441                 } catch (IOException e) { System.err.print(e); }
00442                 draw_all_points();
00443         }
00444         
00445         protected void processMouseEvent(MouseEvent e)
00446         {
00447                 if(e.getID() == MouseEvent.MOUSE_PRESSED)
00448                 {
00449                         if(e.getX() >= XLEN || e.getY() >= YLEN) return;
00450                         point p = new point((double)e.getX()/XLEN,
00451                                             (double)e.getY()/YLEN,
00452                                             current_value);
00453                         point_list.addElement(p);
00454                         draw_point(p);
00455                 }
00456         }
00457 
00458         public void paint(Graphics g)
00459         {
00460                 // create buffer first time
00461                 if(buffer == null) {
00462                         buffer = this.createImage(XLEN,YLEN);
00463                         buffer_gc = buffer.getGraphics();
00464                         buffer_gc.setColor(colors[0]);
00465                         buffer_gc.fillRect(0,0,XLEN,YLEN);
00466                 }
00467                 g.drawImage(buffer,0,0,this);
00468         }
00469 
00470         public Dimension getPreferredSize() { return new Dimension(XLEN,YLEN+50); }
00471 
00472         public void setSize(Dimension d) { setSize(d.width,d.height); }
00473         public void setSize(int w,int h) {
00474                 super.setSize(w,h);
00475                 XLEN = w;
00476                 YLEN = h-50;
00477                 clear_all();
00478         }
00479 
00480         public static void main(String[] argv)
00481         {
00482                 new AppletFrame("svm_toy",new svm_toy(),500,500+50);
00483         }
00484 }
00485 
00486 class AppletFrame extends Frame {
00487         AppletFrame(String title, Applet applet, int width, int height)
00488         {
00489                 super(title);
00490                 this.addWindowListener(new WindowAdapter() {
00491                         public void windowClosing(WindowEvent e) {
00492                                 System.exit(0);
00493                         }
00494                 });
00495                 applet.init();
00496                 applet.setSize(width,height);
00497                 applet.start();
00498                 this.add(applet);
00499                 this.pack();
00500                 this.setVisible(true);
00501         }
00502 }


ml_classifiers
Author(s): Scott Niekum
autogenerated on Fri Jan 3 2014 11:30:23