svm.java
Go to the documentation of this file.
00001 
00002 
00003 
00004 
00005 
00006 package libsvm;
00007 import java.io.*;
00008 import java.util.*;
00009 
00010 //
00011 // Kernel Cache
00012 //
00013 // l is the number of total data items
00014 // size is the cache size limit in bytes
00015 //
00016 class Cache {
00017         private final int l;
00018         private long size;
00019         private final class head_t
00020         {
00021                 head_t prev, next;      // a cicular list
00022                 float[] data;
00023                 int len;                // data[0,len) is cached in this entry
00024         }
00025         private final head_t[] head;
00026         private head_t lru_head;
00027 
00028         Cache(int l_, long size_)
00029         {
00030                 l = l_;
00031                 size = size_;
00032                 head = new head_t[l];
00033                 for(int i=0;i<l;i++) head[i] = new head_t();
00034                 size /= 4;
00035                 size -= l * (16/4);     // sizeof(head_t) == 16
00036                 size = Math.max(size, 2* (long) l);  // cache must be large enough for two columns
00037                 lru_head = new head_t();
00038                 lru_head.next = lru_head.prev = lru_head;
00039         }
00040 
00041         private void lru_delete(head_t h)
00042         {
00043                 // delete from current location
00044                 h.prev.next = h.next;
00045                 h.next.prev = h.prev;
00046         }
00047 
00048         private void lru_insert(head_t h)
00049         {
00050                 // insert to last position
00051                 h.next = lru_head;
00052                 h.prev = lru_head.prev;
00053                 h.prev.next = h;
00054                 h.next.prev = h;
00055         }
00056 
00057         // request data [0,len)
00058         // return some position p where [p,len) need to be filled
00059         // (p >= len if nothing needs to be filled)
00060         // java: simulate pointer using single-element array
00061         int get_data(int index, float[][] data, int len)
00062         {
00063                 head_t h = head[index];
00064                 if(h.len > 0) lru_delete(h);
00065                 int more = len - h.len;
00066 
00067                 if(more > 0)
00068                 {
00069                         // free old space
00070                         while(size < more)
00071                         {
00072                                 head_t old = lru_head.next;
00073                                 lru_delete(old);
00074                                 size += old.len;
00075                                 old.data = null;
00076                                 old.len = 0;
00077                         }
00078 
00079                         // allocate new space
00080                         float[] new_data = new float[len];
00081                         if(h.data != null) System.arraycopy(h.data,0,new_data,0,h.len);
00082                         h.data = new_data;
00083                         size -= more;
00084                         do {int _=h.len; h.len=len; len=_;} while(false);
00085                 }
00086 
00087                 lru_insert(h);
00088                 data[0] = h.data;
00089                 return len;
00090         }
00091 
00092         void swap_index(int i, int j)
00093         {
00094                 if(i==j) return;
00095                 
00096                 if(head[i].len > 0) lru_delete(head[i]);
00097                 if(head[j].len > 0) lru_delete(head[j]);
00098                 do {float[] _=head[i].data; head[i].data=head[j].data; head[j].data=_;} while(false);
00099                 do {int _=head[i].len; head[i].len=head[j].len; head[j].len=_;} while(false);
00100                 if(head[i].len > 0) lru_insert(head[i]);
00101                 if(head[j].len > 0) lru_insert(head[j]);
00102 
00103                 if(i>j) do {int _=i; i=j; j=_;} while(false);
00104                 for(head_t h = lru_head.next; h!=lru_head; h=h.next)
00105                 {
00106                         if(h.len > i)
00107                         {
00108                                 if(h.len > j)
00109                                         do {float _=h.data[i]; h.data[i]=h.data[j]; h.data[j]=_;} while(false);
00110                                 else
00111                                 {
00112                                         // give up
00113                                         lru_delete(h);
00114                                         size += h.len;
00115                                         h.data = null;
00116                                         h.len = 0;
00117                                 }
00118                         }
00119                 }
00120         }
00121 }
00122 
00123 //
00124 // Kernel evaluation
00125 //
00126 // the static method k_function is for doing single kernel evaluation
00127 // the constructor of Kernel prepares to calculate the l*l kernel matrix
00128 // the member function get_Q is for getting one column from the Q Matrix
00129 //
00130 abstract class QMatrix {
00131         abstract float[] get_Q(int column, int len);
00132         abstract double[] get_QD();
00133         abstract void swap_index(int i, int j);
00134 };
00135 
00136 abstract class Kernel extends QMatrix {
00137         private svm_node[][] x;
00138         private final double[] x_square;
00139 
00140         // svm_parameter
00141         private final int kernel_type;
00142         private final int degree;
00143         private final double gamma;
00144         private final double coef0;
00145 
00146         abstract float[] get_Q(int column, int len);
00147         abstract double[] get_QD();
00148 
00149         void swap_index(int i, int j)
00150         {
00151                 do {svm_node[] _=x[i]; x[i]=x[j]; x[j]=_;} while(false);
00152                 if(x_square != null) do {double _=x_square[i]; x_square[i]=x_square[j]; x_square[j]=_;} while(false);
00153         }
00154 
00155         private static double powi(double base, int times)
00156         {
00157                 double tmp = base, ret = 1.0;
00158 
00159                 for(int t=times; t>0; t/=2)
00160                 {
00161                         if(t%2==1) ret*=tmp;
00162                         tmp = tmp * tmp;
00163                 }
00164                 return ret;
00165         }
00166 
00167         double kernel_function(int i, int j)
00168         {
00169                 switch(kernel_type)
00170                 {
00171                         case svm_parameter.LINEAR:
00172                                 return dot(x[i],x[j]);
00173                         case svm_parameter.POLY:
00174                                 return powi(gamma*dot(x[i],x[j])+coef0,degree);
00175                         case svm_parameter.RBF:
00176                                 return Math.exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
00177                         case svm_parameter.SIGMOID:
00178                                 return Math.tanh(gamma*dot(x[i],x[j])+coef0);
00179                         case svm_parameter.PRECOMPUTED:
00180                                 return x[i][(int)(x[j][0].value)].value;
00181                         default:
00182                                 return 0;       // java
00183                 }
00184         }
00185 
00186         Kernel(int l, svm_node[][] x_, svm_parameter param)
00187         {
00188                 this.kernel_type = param.kernel_type;
00189                 this.degree = param.degree;
00190                 this.gamma = param.gamma;
00191                 this.coef0 = param.coef0;
00192 
00193                 x = (svm_node[][])x_.clone();
00194 
00195                 if(kernel_type == svm_parameter.RBF)
00196                 {
00197                         x_square = new double[l];
00198                         for(int i=0;i<l;i++)
00199                                 x_square[i] = dot(x[i],x[i]);
00200                 }
00201                 else x_square = null;
00202         }
00203 
00204         static double dot(svm_node[] x, svm_node[] y)
00205         {
00206                 double sum = 0;
00207                 int xlen = x.length;
00208                 int ylen = y.length;
00209                 int i = 0;
00210                 int j = 0;
00211                 while(i < xlen && j < ylen)
00212                 {
00213                         if(x[i].index == y[j].index)
00214                                 sum += x[i++].value * y[j++].value;
00215                         else
00216                         {
00217                                 if(x[i].index > y[j].index)
00218                                         ++j;
00219                                 else
00220                                         ++i;
00221                         }
00222                 }
00223                 return sum;
00224         }
00225 
00226         static double k_function(svm_node[] x, svm_node[] y,
00227                                         svm_parameter param)
00228         {
00229                 switch(param.kernel_type)
00230                 {
00231                         case svm_parameter.LINEAR:
00232                                 return dot(x,y);
00233                         case svm_parameter.POLY:
00234                                 return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
00235                         case svm_parameter.RBF:
00236                         {
00237                                 double sum = 0;
00238                                 int xlen = x.length;
00239                                 int ylen = y.length;
00240                                 int i = 0;
00241                                 int j = 0;
00242                                 while(i < xlen && j < ylen)
00243                                 {
00244                                         if(x[i].index == y[j].index)
00245                                         {
00246                                                 double d = x[i++].value - y[j++].value;
00247                                                 sum += d*d;
00248                                         }
00249                                         else if(x[i].index > y[j].index)
00250                                         {
00251                                                 sum += y[j].value * y[j].value;
00252                                                 ++j;
00253                                         }
00254                                         else
00255                                         {
00256                                                 sum += x[i].value * x[i].value;
00257                                                 ++i;
00258                                         }
00259                                 }
00260 
00261                                 while(i < xlen)
00262                                 {
00263                                         sum += x[i].value * x[i].value;
00264                                         ++i;
00265                                 }
00266 
00267                                 while(j < ylen)
00268                                 {
00269                                         sum += y[j].value * y[j].value;
00270                                         ++j;
00271                                 }
00272 
00273                                 return Math.exp(-param.gamma*sum);
00274                         }
00275                         case svm_parameter.SIGMOID:
00276                                 return Math.tanh(param.gamma*dot(x,y)+param.coef0);
00277                         case svm_parameter.PRECOMPUTED:
00278                                 return  x[(int)(y[0].value)].value;
00279                         default:
00280                                 return 0;       // java
00281                 }
00282         }
00283 }
00284 
00285 // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
00286 // Solves:
00287 //
00288 //      min 0.5(\alpha^T Q \alpha) + p^T \alpha
00289 //
00290 //              y^T \alpha = \delta
00291 //              y_i = +1 or -1
00292 //              0 <= alpha_i <= Cp for y_i = 1
00293 //              0 <= alpha_i <= Cn for y_i = -1
00294 //
00295 // Given:
00296 //
00297 //      Q, p, y, Cp, Cn, and an initial feasible point \alpha
00298 //      l is the size of vectors and matrices
00299 //      eps is the stopping tolerance
00300 //
00301 // solution will be put in \alpha, objective value will be put in obj
00302 //
00303 class Solver {
00304         int active_size;
00305         byte[] y;
00306         double[] G;             // gradient of objective function
00307         static final byte LOWER_BOUND = 0;
00308         static final byte UPPER_BOUND = 1;
00309         static final byte FREE = 2;
00310         byte[] alpha_status;    // LOWER_BOUND, UPPER_BOUND, FREE
00311         double[] alpha;
00312         QMatrix Q;
00313         double[] QD;
00314         double eps;
00315         double Cp,Cn;
00316         double[] p;
00317         int[] active_set;
00318         double[] G_bar;         // gradient, if we treat free variables as 0
00319         int l;
00320         boolean unshrink;       // XXX
00321         
00322         static final double INF = java.lang.Double.POSITIVE_INFINITY;
00323 
00324         double get_C(int i)
00325         {
00326                 return (y[i] > 0)? Cp : Cn;
00327         }
00328         void update_alpha_status(int i)
00329         {
00330                 if(alpha[i] >= get_C(i))
00331                         alpha_status[i] = UPPER_BOUND;
00332                 else if(alpha[i] <= 0)
00333                         alpha_status[i] = LOWER_BOUND;
00334                 else alpha_status[i] = FREE;
00335         }
00336         boolean is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
00337         boolean is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
00338         boolean is_free(int i) {  return alpha_status[i] == FREE; }
00339 
00340         // java: information about solution except alpha,
00341         // because we cannot return multiple values otherwise...
00342         static class SolutionInfo {
00343                 double obj;
00344                 double rho;
00345                 double upper_bound_p;
00346                 double upper_bound_n;
00347                 double r;       // for Solver_NU
00348         }
00349 
00350         void swap_index(int i, int j)
00351         {
00352                 Q.swap_index(i,j);
00353                 do {byte _=y[i]; y[i]=y[j]; y[j]=_;} while(false);
00354                 do {double _=G[i]; G[i]=G[j]; G[j]=_;} while(false);
00355                 do {byte _=alpha_status[i]; alpha_status[i]=alpha_status[j]; alpha_status[j]=_;} while(false);
00356                 do {double _=alpha[i]; alpha[i]=alpha[j]; alpha[j]=_;} while(false);
00357                 do {double _=p[i]; p[i]=p[j]; p[j]=_;} while(false);
00358                 do {int _=active_set[i]; active_set[i]=active_set[j]; active_set[j]=_;} while(false);
00359                 do {double _=G_bar[i]; G_bar[i]=G_bar[j]; G_bar[j]=_;} while(false);
00360         }
00361 
00362         void reconstruct_gradient()
00363         {
00364                 // reconstruct inactive elements of G from G_bar and free variables
00365 
00366                 if(active_size == l) return;
00367 
00368                 int i,j;
00369                 int nr_free = 0;
00370 
00371                 for(j=active_size;j<l;j++)
00372                         G[j] = G_bar[j] + p[j];
00373 
00374                 for(j=0;j<active_size;j++)
00375                         if(is_free(j))
00376                                 nr_free++;
00377 
00378                 if(2*nr_free < active_size)
00379                         svm.info("\nWARNING: using -h 0 may be faster\n");
00380 
00381                 if (nr_free*l > 2*active_size*(l-active_size))
00382                 {
00383                         for(i=active_size;i<l;i++)
00384                         {
00385                                 float[] Q_i = Q.get_Q(i,active_size);
00386                                 for(j=0;j<active_size;j++)
00387                                         if(is_free(j))
00388                                                 G[i] += alpha[j] * Q_i[j];
00389                         }       
00390                 }
00391                 else
00392                 {
00393                         for(i=0;i<active_size;i++)
00394                                 if(is_free(i))
00395                                 {
00396                                         float[] Q_i = Q.get_Q(i,l);
00397                                         double alpha_i = alpha[i];
00398                                         for(j=active_size;j<l;j++)
00399                                                 G[j] += alpha_i * Q_i[j];
00400                                 }
00401                 }
00402         }
00403 
00404         void Solve(int l, QMatrix Q, double[] p_, byte[] y_,
00405                    double[] alpha_, double Cp, double Cn, double eps, SolutionInfo si, int shrinking)
00406         {
00407                 this.l = l;
00408                 this.Q = Q;
00409                 QD = Q.get_QD();
00410                 p = (double[])p_.clone();
00411                 y = (byte[])y_.clone();
00412                 alpha = (double[])alpha_.clone();
00413                 this.Cp = Cp;
00414                 this.Cn = Cn;
00415                 this.eps = eps;
00416                 this.unshrink = false;
00417 
00418                 // initialize alpha_status
00419                 {
00420                         alpha_status = new byte[l];
00421                         for(int i=0;i<l;i++)
00422                                 update_alpha_status(i);
00423                 }
00424 
00425                 // initialize active set (for shrinking)
00426                 {
00427                         active_set = new int[l];
00428                         for(int i=0;i<l;i++)
00429                                 active_set[i] = i;
00430                         active_size = l;
00431                 }
00432 
00433                 // initialize gradient
00434                 {
00435                         G = new double[l];
00436                         G_bar = new double[l];
00437                         int i;
00438                         for(i=0;i<l;i++)
00439                         {
00440                                 G[i] = p[i];
00441                                 G_bar[i] = 0;
00442                         }
00443                         for(i=0;i<l;i++)
00444                                 if(!is_lower_bound(i))
00445                                 {
00446                                         float[] Q_i = Q.get_Q(i,l);
00447                                         double alpha_i = alpha[i];
00448                                         int j;
00449                                         for(j=0;j<l;j++)
00450                                                 G[j] += alpha_i*Q_i[j];
00451                                         if(is_upper_bound(i))
00452                                                 for(j=0;j<l;j++)
00453                                                         G_bar[j] += get_C(i) * Q_i[j];
00454                                 }
00455                 }
00456 
00457                 // optimization step
00458 
00459                 int iter = 0;
00460                 int max_iter = Math.max(10000000, l>Integer.MAX_VALUE/100 ? Integer.MAX_VALUE : 100*l);
00461                 int counter = Math.min(l,1000)+1;
00462                 int[] working_set = new int[2];
00463 
00464                 while(iter < max_iter)
00465                 {
00466                         // show progress and do shrinking
00467 
00468                         if(--counter == 0)
00469                         {
00470                                 counter = Math.min(l,1000);
00471                                 if(shrinking!=0) do_shrinking();
00472                                 svm.info(".");
00473                         }
00474 
00475                         if(select_working_set(working_set)!=0)
00476                         {
00477                                 // reconstruct the whole gradient
00478                                 reconstruct_gradient();
00479                                 // reset active set size and check
00480                                 active_size = l;
00481                                 svm.info("*");
00482                                 if(select_working_set(working_set)!=0)
00483                                         break;
00484                                 else
00485                                         counter = 1;    // do shrinking next iteration
00486                         }
00487                         
00488                         int i = working_set[0];
00489                         int j = working_set[1];
00490 
00491                         ++iter;
00492 
00493                         // update alpha[i] and alpha[j], handle bounds carefully
00494 
00495                         float[] Q_i = Q.get_Q(i,active_size);
00496                         float[] Q_j = Q.get_Q(j,active_size);
00497 
00498                         double C_i = get_C(i);
00499                         double C_j = get_C(j);
00500 
00501                         double old_alpha_i = alpha[i];
00502                         double old_alpha_j = alpha[j];
00503 
00504                         if(y[i]!=y[j])
00505                         {
00506                                 double quad_coef = QD[i]+QD[j]+2*Q_i[j];
00507                                 if (quad_coef <= 0)
00508                                         quad_coef = 1e-12;
00509                                 double delta = (-G[i]-G[j])/quad_coef;
00510                                 double diff = alpha[i] - alpha[j];
00511                                 alpha[i] += delta;
00512                                 alpha[j] += delta;
00513                         
00514                                 if(diff > 0)
00515                                 {
00516                                         if(alpha[j] < 0)
00517                                         {
00518                                                 alpha[j] = 0;
00519                                                 alpha[i] = diff;
00520                                         }
00521                                 }
00522                                 else
00523                                 {
00524                                         if(alpha[i] < 0)
00525                                         {
00526                                                 alpha[i] = 0;
00527                                                 alpha[j] = -diff;
00528                                         }
00529                                 }
00530                                 if(diff > C_i - C_j)
00531                                 {
00532                                         if(alpha[i] > C_i)
00533                                         {
00534                                                 alpha[i] = C_i;
00535                                                 alpha[j] = C_i - diff;
00536                                         }
00537                                 }
00538                                 else
00539                                 {
00540                                         if(alpha[j] > C_j)
00541                                         {
00542                                                 alpha[j] = C_j;
00543                                                 alpha[i] = C_j + diff;
00544                                         }
00545                                 }
00546                         }
00547                         else
00548                         {
00549                                 double quad_coef = QD[i]+QD[j]-2*Q_i[j];
00550                                 if (quad_coef <= 0)
00551                                         quad_coef = 1e-12;
00552                                 double delta = (G[i]-G[j])/quad_coef;
00553                                 double sum = alpha[i] + alpha[j];
00554                                 alpha[i] -= delta;
00555                                 alpha[j] += delta;
00556 
00557                                 if(sum > C_i)
00558                                 {
00559                                         if(alpha[i] > C_i)
00560                                         {
00561                                                 alpha[i] = C_i;
00562                                                 alpha[j] = sum - C_i;
00563                                         }
00564                                 }
00565                                 else
00566                                 {
00567                                         if(alpha[j] < 0)
00568                                         {
00569                                                 alpha[j] = 0;
00570                                                 alpha[i] = sum;
00571                                         }
00572                                 }
00573                                 if(sum > C_j)
00574                                 {
00575                                         if(alpha[j] > C_j)
00576                                         {
00577                                                 alpha[j] = C_j;
00578                                                 alpha[i] = sum - C_j;
00579                                         }
00580                                 }
00581                                 else
00582                                 {
00583                                         if(alpha[i] < 0)
00584                                         {
00585                                                 alpha[i] = 0;
00586                                                 alpha[j] = sum;
00587                                         }
00588                                 }
00589                         }
00590 
00591                         // update G
00592 
00593                         double delta_alpha_i = alpha[i] - old_alpha_i;
00594                         double delta_alpha_j = alpha[j] - old_alpha_j;
00595 
00596                         for(int k=0;k<active_size;k++)
00597                         {
00598                                 G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
00599                         }
00600 
00601                         // update alpha_status and G_bar
00602 
00603                         {
00604                                 boolean ui = is_upper_bound(i);
00605                                 boolean uj = is_upper_bound(j);
00606                                 update_alpha_status(i);
00607                                 update_alpha_status(j);
00608                                 int k;
00609                                 if(ui != is_upper_bound(i))
00610                                 {
00611                                         Q_i = Q.get_Q(i,l);
00612                                         if(ui)
00613                                                 for(k=0;k<l;k++)
00614                                                         G_bar[k] -= C_i * Q_i[k];
00615                                         else
00616                                                 for(k=0;k<l;k++)
00617                                                         G_bar[k] += C_i * Q_i[k];
00618                                 }
00619 
00620                                 if(uj != is_upper_bound(j))
00621                                 {
00622                                         Q_j = Q.get_Q(j,l);
00623                                         if(uj)
00624                                                 for(k=0;k<l;k++)
00625                                                         G_bar[k] -= C_j * Q_j[k];
00626                                         else
00627                                                 for(k=0;k<l;k++)
00628                                                         G_bar[k] += C_j * Q_j[k];
00629                                 }
00630                         }
00631 
00632                 }
00633                 
00634                 if(iter >= max_iter)
00635                 {
00636                         if(active_size < l)
00637                         {
00638                                 // reconstruct the whole gradient to calculate objective value
00639                                 reconstruct_gradient();
00640                                 active_size = l;
00641                                 svm.info("*");
00642                         }
00643                         System.err.print("\nWARNING: reaching max number of iterations\n");
00644                 }
00645 
00646                 // calculate rho
00647 
00648                 si.rho = calculate_rho();
00649 
00650                 // calculate objective value
00651                 {
00652                         double v = 0;
00653                         int i;
00654                         for(i=0;i<l;i++)
00655                                 v += alpha[i] * (G[i] + p[i]);
00656 
00657                         si.obj = v/2;
00658                 }
00659 
00660                 // put back the solution
00661                 {
00662                         for(int i=0;i<l;i++)
00663                                 alpha_[active_set[i]] = alpha[i];
00664                 }
00665 
00666                 si.upper_bound_p = Cp;
00667                 si.upper_bound_n = Cn;
00668 
00669                 svm.info("\noptimization finished, #iter = "+iter+"\n");
00670         }
00671 
00672         // return 1 if already optimal, return 0 otherwise
00673         int select_working_set(int[] working_set)
00674         {
00675                 // return i,j such that
00676                 // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
00677                 // j: mimimizes the decrease of obj value
00678                 //    (if quadratic coefficeint <= 0, replace it with tau)
00679                 //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
00680                 
00681                 double Gmax = -INF;
00682                 double Gmax2 = -INF;
00683                 int Gmax_idx = -1;
00684                 int Gmin_idx = -1;
00685                 double obj_diff_min = INF;
00686         
00687                 for(int t=0;t<active_size;t++)
00688                         if(y[t]==+1)    
00689                         {
00690                                 if(!is_upper_bound(t))
00691                                         if(-G[t] >= Gmax)
00692                                         {
00693                                                 Gmax = -G[t];
00694                                                 Gmax_idx = t;
00695                                         }
00696                         }
00697                         else
00698                         {
00699                                 if(!is_lower_bound(t))
00700                                         if(G[t] >= Gmax)
00701                                         {
00702                                                 Gmax = G[t];
00703                                                 Gmax_idx = t;
00704                                         }
00705                         }
00706         
00707                 int i = Gmax_idx;
00708                 float[] Q_i = null;
00709                 if(i != -1) // null Q_i not accessed: Gmax=-INF if i=-1
00710                         Q_i = Q.get_Q(i,active_size);
00711         
00712                 for(int j=0;j<active_size;j++)
00713                 {
00714                         if(y[j]==+1)
00715                         {
00716                                 if (!is_lower_bound(j))
00717                                 {
00718                                         double grad_diff=Gmax+G[j];
00719                                         if (G[j] >= Gmax2)
00720                                                 Gmax2 = G[j];
00721                                         if (grad_diff > 0)
00722                                         {
00723                                                 double obj_diff; 
00724                                                 double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];
00725                                                 if (quad_coef > 0)
00726                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00727                                                 else
00728                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00729         
00730                                                 if (obj_diff <= obj_diff_min)
00731                                                 {
00732                                                         Gmin_idx=j;
00733                                                         obj_diff_min = obj_diff;
00734                                                 }
00735                                         }
00736                                 }
00737                         }
00738                         else
00739                         {
00740                                 if (!is_upper_bound(j))
00741                                 {
00742                                         double grad_diff= Gmax-G[j];
00743                                         if (-G[j] >= Gmax2)
00744                                                 Gmax2 = -G[j];
00745                                         if (grad_diff > 0)
00746                                         {
00747                                                 double obj_diff; 
00748                                                 double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];
00749                                                 if (quad_coef > 0)
00750                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00751                                                 else
00752                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00753         
00754                                                 if (obj_diff <= obj_diff_min)
00755                                                 {
00756                                                         Gmin_idx=j;
00757                                                         obj_diff_min = obj_diff;
00758                                                 }
00759                                         }
00760                                 }
00761                         }
00762                 }
00763 
00764                 if(Gmax+Gmax2 < eps || Gmin_idx == -1)
00765                         return 1;
00766 
00767                 working_set[0] = Gmax_idx;
00768                 working_set[1] = Gmin_idx;
00769                 return 0;
00770         }
00771 
00772         private boolean be_shrunk(int i, double Gmax1, double Gmax2)
00773         {       
00774                 if(is_upper_bound(i))
00775                 {
00776                         if(y[i]==+1)
00777                                 return(-G[i] > Gmax1);
00778                         else
00779                                 return(-G[i] > Gmax2);
00780                 }
00781                 else if(is_lower_bound(i))
00782                 {
00783                         if(y[i]==+1)
00784                                 return(G[i] > Gmax2);
00785                         else    
00786                                 return(G[i] > Gmax1);
00787                 }
00788                 else
00789                         return(false);
00790         }
00791 
00792         void do_shrinking()
00793         {
00794                 int i;
00795                 double Gmax1 = -INF;            // max { -y_i * grad(f)_i | i in I_up(\alpha) }
00796                 double Gmax2 = -INF;            // max { y_i * grad(f)_i | i in I_low(\alpha) }
00797 
00798                 // find maximal violating pair first
00799                 for(i=0;i<active_size;i++)
00800                 {
00801                         if(y[i]==+1)
00802                         {
00803                                 if(!is_upper_bound(i))  
00804                                 {
00805                                         if(-G[i] >= Gmax1)
00806                                                 Gmax1 = -G[i];
00807                                 }
00808                                 if(!is_lower_bound(i))
00809                                 {
00810                                         if(G[i] >= Gmax2)
00811                                                 Gmax2 = G[i];
00812                                 }
00813                         }
00814                         else            
00815                         {
00816                                 if(!is_upper_bound(i))  
00817                                 {
00818                                         if(-G[i] >= Gmax2)
00819                                                 Gmax2 = -G[i];
00820                                 }
00821                                 if(!is_lower_bound(i))  
00822                                 {
00823                                         if(G[i] >= Gmax1)
00824                                                 Gmax1 = G[i];
00825                                 }
00826                         }
00827                 }
00828 
00829                 if(unshrink == false && Gmax1 + Gmax2 <= eps*10) 
00830                 {
00831                         unshrink = true;
00832                         reconstruct_gradient();
00833                         active_size = l;
00834                 }
00835 
00836                 for(i=0;i<active_size;i++)
00837                         if (be_shrunk(i, Gmax1, Gmax2))
00838                         {
00839                                 active_size--;
00840                                 while (active_size > i)
00841                                 {
00842                                         if (!be_shrunk(active_size, Gmax1, Gmax2))
00843                                         {
00844                                                 swap_index(i,active_size);
00845                                                 break;
00846                                         }
00847                                         active_size--;
00848                                 }
00849                         }
00850         }
00851 
00852         double calculate_rho()
00853         {
00854                 double r;
00855                 int nr_free = 0;
00856                 double ub = INF, lb = -INF, sum_free = 0;
00857                 for(int i=0;i<active_size;i++)
00858                 {
00859                         double yG = y[i]*G[i];
00860 
00861                         if(is_lower_bound(i))
00862                         {
00863                                 if(y[i] > 0)
00864                                         ub = Math.min(ub,yG);
00865                                 else
00866                                         lb = Math.max(lb,yG);
00867                         }
00868                         else if(is_upper_bound(i))
00869                         {
00870                                 if(y[i] < 0)
00871                                         ub = Math.min(ub,yG);
00872                                 else
00873                                         lb = Math.max(lb,yG);
00874                         }
00875                         else
00876                         {
00877                                 ++nr_free;
00878                                 sum_free += yG;
00879                         }
00880                 }
00881 
00882                 if(nr_free>0)
00883                         r = sum_free/nr_free;
00884                 else
00885                         r = (ub+lb)/2;
00886 
00887                 return r;
00888         }
00889 
00890 }
00891 
00892 //
00893 // Solver for nu-svm classification and regression
00894 //
00895 // additional constraint: e^T \alpha = constant
00896 //
00897 final class Solver_NU extends Solver
00898 {
00899         private SolutionInfo si;
00900 
00901         void Solve(int l, QMatrix Q, double[] p, byte[] y,
00902                    double[] alpha, double Cp, double Cn, double eps,
00903                    SolutionInfo si, int shrinking)
00904         {
00905                 this.si = si;
00906                 super.Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking);
00907         }
00908 
00909         // return 1 if already optimal, return 0 otherwise
00910         int select_working_set(int[] working_set)
00911         {
00912                 // return i,j such that y_i = y_j and
00913                 // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
00914                 // j: minimizes the decrease of obj value
00915                 //    (if quadratic coefficeint <= 0, replace it with tau)
00916                 //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
00917         
00918                 double Gmaxp = -INF;
00919                 double Gmaxp2 = -INF;
00920                 int Gmaxp_idx = -1;
00921         
00922                 double Gmaxn = -INF;
00923                 double Gmaxn2 = -INF;
00924                 int Gmaxn_idx = -1;
00925         
00926                 int Gmin_idx = -1;
00927                 double obj_diff_min = INF;
00928         
00929                 for(int t=0;t<active_size;t++)
00930                         if(y[t]==+1)
00931                         {
00932                                 if(!is_upper_bound(t))
00933                                         if(-G[t] >= Gmaxp)
00934                                         {
00935                                                 Gmaxp = -G[t];
00936                                                 Gmaxp_idx = t;
00937                                         }
00938                         }
00939                         else
00940                         {
00941                                 if(!is_lower_bound(t))
00942                                         if(G[t] >= Gmaxn)
00943                                         {
00944                                                 Gmaxn = G[t];
00945                                                 Gmaxn_idx = t;
00946                                         }
00947                         }
00948         
00949                 int ip = Gmaxp_idx;
00950                 int in = Gmaxn_idx;
00951                 float[] Q_ip = null;
00952                 float[] Q_in = null;
00953                 if(ip != -1) // null Q_ip not accessed: Gmaxp=-INF if ip=-1
00954                         Q_ip = Q.get_Q(ip,active_size);
00955                 if(in != -1)
00956                         Q_in = Q.get_Q(in,active_size);
00957         
00958                 for(int j=0;j<active_size;j++)
00959                 {
00960                         if(y[j]==+1)
00961                         {
00962                                 if (!is_lower_bound(j)) 
00963                                 {
00964                                         double grad_diff=Gmaxp+G[j];
00965                                         if (G[j] >= Gmaxp2)
00966                                                 Gmaxp2 = G[j];
00967                                         if (grad_diff > 0)
00968                                         {
00969                                                 double obj_diff; 
00970                                                 double quad_coef = QD[ip]+QD[j]-2*Q_ip[j];
00971                                                 if (quad_coef > 0)
00972                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00973                                                 else
00974                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00975         
00976                                                 if (obj_diff <= obj_diff_min)
00977                                                 {
00978                                                         Gmin_idx=j;
00979                                                         obj_diff_min = obj_diff;
00980                                                 }
00981                                         }
00982                                 }
00983                         }
00984                         else
00985                         {
00986                                 if (!is_upper_bound(j))
00987                                 {
00988                                         double grad_diff=Gmaxn-G[j];
00989                                         if (-G[j] >= Gmaxn2)
00990                                                 Gmaxn2 = -G[j];
00991                                         if (grad_diff > 0)
00992                                         {
00993                                                 double obj_diff; 
00994                                                 double quad_coef = QD[in]+QD[j]-2*Q_in[j];
00995                                                 if (quad_coef > 0)
00996                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00997                                                 else
00998                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00999         
01000                                                 if (obj_diff <= obj_diff_min)
01001                                                 {
01002                                                         Gmin_idx=j;
01003                                                         obj_diff_min = obj_diff;
01004                                                 }
01005                                         }
01006                                 }
01007                         }
01008                 }
01009 
01010                 if(Math.max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps || Gmin_idx == -1)
01011                         return 1;
01012         
01013                 if(y[Gmin_idx] == +1)
01014                         working_set[0] = Gmaxp_idx;
01015                 else
01016                         working_set[0] = Gmaxn_idx;
01017                 working_set[1] = Gmin_idx;
01018         
01019                 return 0;
01020         }
01021 
01022         private boolean be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4)
01023         {
01024                 if(is_upper_bound(i))
01025                 {
01026                         if(y[i]==+1)
01027                                 return(-G[i] > Gmax1);
01028                         else    
01029                                 return(-G[i] > Gmax4);
01030                 }
01031                 else if(is_lower_bound(i))
01032                 {
01033                         if(y[i]==+1)
01034                                 return(G[i] > Gmax2);
01035                         else    
01036                                 return(G[i] > Gmax3);
01037                 }
01038                 else
01039                         return(false);
01040         }
01041 
01042         void do_shrinking()
01043         {
01044                 double Gmax1 = -INF;    // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) }
01045                 double Gmax2 = -INF;    // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) }
01046                 double Gmax3 = -INF;    // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) }
01047                 double Gmax4 = -INF;    // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) }
01048  
01049                 // find maximal violating pair first
01050                 int i;
01051                 for(i=0;i<active_size;i++)
01052                 {
01053                         if(!is_upper_bound(i))
01054                         {
01055                                 if(y[i]==+1)
01056                                 {
01057                                         if(-G[i] > Gmax1) Gmax1 = -G[i];
01058                                 }
01059                                 else    if(-G[i] > Gmax4) Gmax4 = -G[i];
01060                         }
01061                         if(!is_lower_bound(i))
01062                         {
01063                                 if(y[i]==+1)
01064                                 {       
01065                                         if(G[i] > Gmax2) Gmax2 = G[i];
01066                                 }
01067                                 else    if(G[i] > Gmax3) Gmax3 = G[i];
01068                         }
01069                 }
01070 
01071                 if(unshrink == false && Math.max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) 
01072                 {
01073                         unshrink = true;
01074                         reconstruct_gradient();
01075                         active_size = l;
01076                 }
01077 
01078                 for(i=0;i<active_size;i++)
01079                         if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4))
01080                         {
01081                                 active_size--;
01082                                 while (active_size > i)
01083                                 {
01084                                         if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
01085                                         {
01086                                                 swap_index(i,active_size);
01087                                                 break;
01088                                         }
01089                                         active_size--;
01090                                 }
01091                         }
01092         }
01093         
01094         double calculate_rho()
01095         {
01096                 int nr_free1 = 0,nr_free2 = 0;
01097                 double ub1 = INF, ub2 = INF;
01098                 double lb1 = -INF, lb2 = -INF;
01099                 double sum_free1 = 0, sum_free2 = 0;
01100 
01101                 for(int i=0;i<active_size;i++)
01102                 {
01103                         if(y[i]==+1)
01104                         {
01105                                 if(is_lower_bound(i))
01106                                         ub1 = Math.min(ub1,G[i]);
01107                                 else if(is_upper_bound(i))
01108                                         lb1 = Math.max(lb1,G[i]);
01109                                 else
01110                                 {
01111                                         ++nr_free1;
01112                                         sum_free1 += G[i];
01113                                 }
01114                         }
01115                         else
01116                         {
01117                                 if(is_lower_bound(i))
01118                                         ub2 = Math.min(ub2,G[i]);
01119                                 else if(is_upper_bound(i))
01120                                         lb2 = Math.max(lb2,G[i]);
01121                                 else
01122                                 {
01123                                         ++nr_free2;
01124                                         sum_free2 += G[i];
01125                                 }
01126                         }
01127                 }
01128 
01129                 double r1,r2;
01130                 if(nr_free1 > 0)
01131                         r1 = sum_free1/nr_free1;
01132                 else
01133                         r1 = (ub1+lb1)/2;
01134 
01135                 if(nr_free2 > 0)
01136                         r2 = sum_free2/nr_free2;
01137                 else
01138                         r2 = (ub2+lb2)/2;
01139 
01140                 si.r = (r1+r2)/2;
01141                 return (r1-r2)/2;
01142         }
01143 }
01144 
01145 //
01146 // Q matrices for various formulations
01147 //
01148 class SVC_Q extends Kernel
01149 {
01150         private final byte[] y;
01151         private final Cache cache;
01152         private final double[] QD;
01153 
01154         SVC_Q(svm_problem prob, svm_parameter param, byte[] y_)
01155         {
01156                 super(prob.l, prob.x, param);
01157                 y = (byte[])y_.clone();
01158                 cache = new Cache(prob.l,(long)(param.cache_size*(1<<20)));
01159                 QD = new double[prob.l];
01160                 for(int i=0;i<prob.l;i++)
01161                         QD[i] = kernel_function(i,i);
01162         }
01163 
01164         float[] get_Q(int i, int len)
01165         {
01166                 float[][] data = new float[1][];
01167                 int start, j;
01168                 if((start = cache.get_data(i,data,len)) < len)
01169                 {
01170                         for(j=start;j<len;j++)
01171                                 data[0][j] = (float)(y[i]*y[j]*kernel_function(i,j));
01172                 }
01173                 return data[0];
01174         }
01175 
01176         double[] get_QD()
01177         {
01178                 return QD;
01179         }
01180 
01181         void swap_index(int i, int j)
01182         {
01183                 cache.swap_index(i,j);
01184                 super.swap_index(i,j);
01185                 do {byte _=y[i]; y[i]=y[j]; y[j]=_;} while(false);
01186                 do {double _=QD[i]; QD[i]=QD[j]; QD[j]=_;} while(false);
01187         }
01188 }
01189 
01190 class ONE_CLASS_Q extends Kernel
01191 {
01192         private final Cache cache;
01193         private final double[] QD;
01194 
01195         ONE_CLASS_Q(svm_problem prob, svm_parameter param)
01196         {
01197                 super(prob.l, prob.x, param);
01198                 cache = new Cache(prob.l,(long)(param.cache_size*(1<<20)));
01199                 QD = new double[prob.l];
01200                 for(int i=0;i<prob.l;i++)
01201                         QD[i] = kernel_function(i,i);
01202         }
01203 
01204         float[] get_Q(int i, int len)
01205         {
01206                 float[][] data = new float[1][];
01207                 int start, j;
01208                 if((start = cache.get_data(i,data,len)) < len)
01209                 {
01210                         for(j=start;j<len;j++)
01211                                 data[0][j] = (float)kernel_function(i,j);
01212                 }
01213                 return data[0];
01214         }
01215 
01216         double[] get_QD()
01217         {
01218                 return QD;
01219         }
01220 
01221         void swap_index(int i, int j)
01222         {
01223                 cache.swap_index(i,j);
01224                 super.swap_index(i,j);
01225                 do {double _=QD[i]; QD[i]=QD[j]; QD[j]=_;} while(false);
01226         }
01227 }
01228 
01229 class SVR_Q extends Kernel
01230 {
01231         private final int l;
01232         private final Cache cache;
01233         private final byte[] sign;
01234         private final int[] index;
01235         private int next_buffer;
01236         private float[][] buffer;
01237         private final double[] QD;
01238 
01239         SVR_Q(svm_problem prob, svm_parameter param)
01240         {
01241                 super(prob.l, prob.x, param);
01242                 l = prob.l;
01243                 cache = new Cache(l,(long)(param.cache_size*(1<<20)));
01244                 QD = new double[2*l];
01245                 sign = new byte[2*l];
01246                 index = new int[2*l];
01247                 for(int k=0;k<l;k++)
01248                 {
01249                         sign[k] = 1;
01250                         sign[k+l] = -1;
01251                         index[k] = k;
01252                         index[k+l] = k;
01253                         QD[k] = kernel_function(k,k);
01254                         QD[k+l] = QD[k];
01255                 }
01256                 buffer = new float[2][2*l];
01257                 next_buffer = 0;
01258         }
01259 
01260         void swap_index(int i, int j)
01261         {
01262                 do {byte _=sign[i]; sign[i]=sign[j]; sign[j]=_;} while(false);
01263                 do {int _=index[i]; index[i]=index[j]; index[j]=_;} while(false);
01264                 do {double _=QD[i]; QD[i]=QD[j]; QD[j]=_;} while(false);
01265         }
01266 
01267         float[] get_Q(int i, int len)
01268         {
01269                 float[][] data = new float[1][];
01270                 int j, real_i = index[i];
01271                 if(cache.get_data(real_i,data,l) < l)
01272                 {
01273                         for(j=0;j<l;j++)
01274                                 data[0][j] = (float)kernel_function(real_i,j);
01275                 }
01276 
01277                 // reorder and copy
01278                 float buf[] = buffer[next_buffer];
01279                 next_buffer = 1 - next_buffer;
01280                 byte si = sign[i];
01281                 for(j=0;j<len;j++)
01282                         buf[j] = (float) si * sign[j] * data[0][index[j]];
01283                 return buf;
01284         }
01285 
01286         double[] get_QD()
01287         {
01288                 return QD;
01289         }
01290 }
01291 
01292 public class svm {
01293         //
01294         // construct and solve various formulations
01295         //
01296         public static final int LIBSVM_VERSION=321; 
01297         public static final Random rand = new Random();
01298 
01299         private static svm_print_interface svm_print_stdout = new svm_print_interface()
01300         {
01301                 public void print(String s)
01302                 {
01303                         System.out.print(s);
01304                         System.out.flush();
01305                 }
01306         };
01307 
01308         private static svm_print_interface svm_print_string = svm_print_stdout;
01309 
01310         static void info(String s) 
01311         {
01312                 svm_print_string.print(s);
01313         }
01314 
01315         private static void solve_c_svc(svm_problem prob, svm_parameter param,
01316                                         double[] alpha, Solver.SolutionInfo si,
01317                                         double Cp, double Cn)
01318         {
01319                 int l = prob.l;
01320                 double[] minus_ones = new double[l];
01321                 byte[] y = new byte[l];
01322 
01323                 int i;
01324 
01325                 for(i=0;i<l;i++)
01326                 {
01327                         alpha[i] = 0;
01328                         minus_ones[i] = -1;
01329                         if(prob.y[i] > 0) y[i] = +1; else y[i] = -1;
01330                 }
01331 
01332                 Solver s = new Solver();
01333                 s.Solve(l, new SVC_Q(prob,param,y), minus_ones, y,
01334                         alpha, Cp, Cn, param.eps, si, param.shrinking);
01335 
01336                 double sum_alpha=0;
01337                 for(i=0;i<l;i++)
01338                         sum_alpha += alpha[i];
01339 
01340                 if (Cp==Cn)
01341                         svm.info("nu = "+sum_alpha/(Cp*prob.l)+"\n");
01342 
01343                 for(i=0;i<l;i++)
01344                         alpha[i] *= y[i];
01345         }
01346 
01347         private static void solve_nu_svc(svm_problem prob, svm_parameter param,
01348                                         double[] alpha, Solver.SolutionInfo si)
01349         {
01350                 int i;
01351                 int l = prob.l;
01352                 double nu = param.nu;
01353 
01354                 byte[] y = new byte[l];
01355 
01356                 for(i=0;i<l;i++)
01357                         if(prob.y[i]>0)
01358                                 y[i] = +1;
01359                         else
01360                                 y[i] = -1;
01361 
01362                 double sum_pos = nu*l/2;
01363                 double sum_neg = nu*l/2;
01364 
01365                 for(i=0;i<l;i++)
01366                         if(y[i] == +1)
01367                         {
01368                                 alpha[i] = Math.min(1.0,sum_pos);
01369                                 sum_pos -= alpha[i];
01370                         }
01371                         else
01372                         {
01373                                 alpha[i] = Math.min(1.0,sum_neg);
01374                                 sum_neg -= alpha[i];
01375                         }
01376 
01377                 double[] zeros = new double[l];
01378 
01379                 for(i=0;i<l;i++)
01380                         zeros[i] = 0;
01381 
01382                 Solver_NU s = new Solver_NU();
01383                 s.Solve(l, new SVC_Q(prob,param,y), zeros, y,
01384                         alpha, 1.0, 1.0, param.eps, si, param.shrinking);
01385                 double r = si.r;
01386 
01387                 svm.info("C = "+1/r+"\n");
01388 
01389                 for(i=0;i<l;i++)
01390                         alpha[i] *= y[i]/r;
01391 
01392                 si.rho /= r;
01393                 si.obj /= (r*r);
01394                 si.upper_bound_p = 1/r;
01395                 si.upper_bound_n = 1/r;
01396         }
01397 
01398         private static void solve_one_class(svm_problem prob, svm_parameter param,
01399                                         double[] alpha, Solver.SolutionInfo si)
01400         {
01401                 int l = prob.l;
01402                 double[] zeros = new double[l];
01403                 byte[] ones = new byte[l];
01404                 int i;
01405 
01406                 int n = (int)(param.nu*prob.l); // # of alpha's at upper bound
01407 
01408                 for(i=0;i<n;i++)
01409                         alpha[i] = 1;
01410                 if(n<prob.l)
01411                         alpha[n] = param.nu * prob.l - n;
01412                 for(i=n+1;i<l;i++)
01413                         alpha[i] = 0;
01414 
01415                 for(i=0;i<l;i++)
01416                 {
01417                         zeros[i] = 0;
01418                         ones[i] = 1;
01419                 }
01420 
01421                 Solver s = new Solver();
01422                 s.Solve(l, new ONE_CLASS_Q(prob,param), zeros, ones,
01423                         alpha, 1.0, 1.0, param.eps, si, param.shrinking);
01424         }
01425 
01426         private static void solve_epsilon_svr(svm_problem prob, svm_parameter param,
01427                                         double[] alpha, Solver.SolutionInfo si)
01428         {
01429                 int l = prob.l;
01430                 double[] alpha2 = new double[2*l];
01431                 double[] linear_term = new double[2*l];
01432                 byte[] y = new byte[2*l];
01433                 int i;
01434 
01435                 for(i=0;i<l;i++)
01436                 {
01437                         alpha2[i] = 0;
01438                         linear_term[i] = param.p - prob.y[i];
01439                         y[i] = 1;
01440 
01441                         alpha2[i+l] = 0;
01442                         linear_term[i+l] = param.p + prob.y[i];
01443                         y[i+l] = -1;
01444                 }
01445 
01446                 Solver s = new Solver();
01447                 s.Solve(2*l, new SVR_Q(prob,param), linear_term, y,
01448                         alpha2, param.C, param.C, param.eps, si, param.shrinking);
01449 
01450                 double sum_alpha = 0;
01451                 for(i=0;i<l;i++)
01452                 {
01453                         alpha[i] = alpha2[i] - alpha2[i+l];
01454                         sum_alpha += Math.abs(alpha[i]);
01455                 }
01456                 svm.info("nu = "+sum_alpha/(param.C*l)+"\n");
01457         }
01458 
01459         private static void solve_nu_svr(svm_problem prob, svm_parameter param,
01460                                         double[] alpha, Solver.SolutionInfo si)
01461         {
01462                 int l = prob.l;
01463                 double C = param.C;
01464                 double[] alpha2 = new double[2*l];
01465                 double[] linear_term = new double[2*l];
01466                 byte[] y = new byte[2*l];
01467                 int i;
01468 
01469                 double sum = C * param.nu * l / 2;
01470                 for(i=0;i<l;i++)
01471                 {
01472                         alpha2[i] = alpha2[i+l] = Math.min(sum,C);
01473                         sum -= alpha2[i];
01474                         
01475                         linear_term[i] = - prob.y[i];
01476                         y[i] = 1;
01477 
01478                         linear_term[i+l] = prob.y[i];
01479                         y[i+l] = -1;
01480                 }
01481 
01482                 Solver_NU s = new Solver_NU();
01483                 s.Solve(2*l, new SVR_Q(prob,param), linear_term, y,
01484                         alpha2, C, C, param.eps, si, param.shrinking);
01485 
01486                 svm.info("epsilon = "+(-si.r)+"\n");
01487                 
01488                 for(i=0;i<l;i++)
01489                         alpha[i] = alpha2[i] - alpha2[i+l];
01490         }
01491 
01492         //
01493         // decision_function
01494         //
01495         static class decision_function
01496         {
01497                 double[] alpha;
01498                 double rho;     
01499         };
01500 
01501         static decision_function svm_train_one(
01502                 svm_problem prob, svm_parameter param,
01503                 double Cp, double Cn)
01504         {
01505                 double[] alpha = new double[prob.l];
01506                 Solver.SolutionInfo si = new Solver.SolutionInfo();
01507                 switch(param.svm_type)
01508                 {
01509                         case svm_parameter.C_SVC:
01510                                 solve_c_svc(prob,param,alpha,si,Cp,Cn);
01511                                 break;
01512                         case svm_parameter.NU_SVC:
01513                                 solve_nu_svc(prob,param,alpha,si);
01514                                 break;
01515                         case svm_parameter.ONE_CLASS:
01516                                 solve_one_class(prob,param,alpha,si);
01517                                 break;
01518                         case svm_parameter.EPSILON_SVR:
01519                                 solve_epsilon_svr(prob,param,alpha,si);
01520                                 break;
01521                         case svm_parameter.NU_SVR:
01522                                 solve_nu_svr(prob,param,alpha,si);
01523                                 break;
01524                 }
01525 
01526                 svm.info("obj = "+si.obj+", rho = "+si.rho+"\n");
01527 
01528                 // output SVs
01529 
01530                 int nSV = 0;
01531                 int nBSV = 0;
01532                 for(int i=0;i<prob.l;i++)
01533                 {
01534                         if(Math.abs(alpha[i]) > 0)
01535                         {
01536                                 ++nSV;
01537                                 if(prob.y[i] > 0)
01538                                 {
01539                                         if(Math.abs(alpha[i]) >= si.upper_bound_p)
01540                                         ++nBSV;
01541                                 }
01542                                 else
01543                                 {
01544                                         if(Math.abs(alpha[i]) >= si.upper_bound_n)
01545                                                 ++nBSV;
01546                                 }
01547                         }
01548                 }
01549 
01550                 svm.info("nSV = "+nSV+", nBSV = "+nBSV+"\n");
01551 
01552                 decision_function f = new decision_function();
01553                 f.alpha = alpha;
01554                 f.rho = si.rho;
01555                 return f;
01556         }
01557 
01558         // Platt's binary SVM Probablistic Output: an improvement from Lin et al.
01559         private static void sigmoid_train(int l, double[] dec_values, double[] labels, 
01560                                   double[] probAB)
01561         {
01562                 double A, B;
01563                 double prior1=0, prior0 = 0;
01564                 int i;
01565 
01566                 for (i=0;i<l;i++)
01567                         if (labels[i] > 0) prior1+=1;
01568                         else prior0+=1;
01569         
01570                 int max_iter=100;       // Maximal number of iterations
01571                 double min_step=1e-10;  // Minimal step taken in line search
01572                 double sigma=1e-12;     // For numerically strict PD of Hessian
01573                 double eps=1e-5;
01574                 double hiTarget=(prior1+1.0)/(prior1+2.0);
01575                 double loTarget=1/(prior0+2.0);
01576                 double[] t= new double[l];
01577                 double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
01578                 double newA,newB,newf,d1,d2;
01579                 int iter; 
01580         
01581                 // Initial Point and Initial Fun Value
01582                 A=0.0; B=Math.log((prior0+1.0)/(prior1+1.0));
01583                 double fval = 0.0;
01584 
01585                 for (i=0;i<l;i++)
01586                 {
01587                         if (labels[i]>0) t[i]=hiTarget;
01588                         else t[i]=loTarget;
01589                         fApB = dec_values[i]*A+B;
01590                         if (fApB>=0)
01591                                 fval += t[i]*fApB + Math.log(1+Math.exp(-fApB));
01592                         else
01593                                 fval += (t[i] - 1)*fApB +Math.log(1+Math.exp(fApB));
01594                 }
01595                 for (iter=0;iter<max_iter;iter++)
01596                 {
01597                         // Update Gradient and Hessian (use H' = H + sigma I)
01598                         h11=sigma; // numerically ensures strict PD
01599                         h22=sigma;
01600                         h21=0.0;g1=0.0;g2=0.0;
01601                         for (i=0;i<l;i++)
01602                         {
01603                                 fApB = dec_values[i]*A+B;
01604                                 if (fApB >= 0)
01605                                 {
01606                                         p=Math.exp(-fApB)/(1.0+Math.exp(-fApB));
01607                                         q=1.0/(1.0+Math.exp(-fApB));
01608                                 }
01609                                 else
01610                                 {
01611                                         p=1.0/(1.0+Math.exp(fApB));
01612                                         q=Math.exp(fApB)/(1.0+Math.exp(fApB));
01613                                 }
01614                                 d2=p*q;
01615                                 h11+=dec_values[i]*dec_values[i]*d2;
01616                                 h22+=d2;
01617                                 h21+=dec_values[i]*d2;
01618                                 d1=t[i]-p;
01619                                 g1+=dec_values[i]*d1;
01620                                 g2+=d1;
01621                         }
01622 
01623                         // Stopping Criteria
01624                         if (Math.abs(g1)<eps && Math.abs(g2)<eps)
01625                                 break;
01626                         
01627                         // Finding Newton direction: -inv(H') * g
01628                         det=h11*h22-h21*h21;
01629                         dA=-(h22*g1 - h21 * g2) / det;
01630                         dB=-(-h21*g1+ h11 * g2) / det;
01631                         gd=g1*dA+g2*dB;
01632 
01633 
01634                         stepsize = 1;           // Line Search
01635                         while (stepsize >= min_step)
01636                         {
01637                                 newA = A + stepsize * dA;
01638                                 newB = B + stepsize * dB;
01639 
01640                                 // New function value
01641                                 newf = 0.0;
01642                                 for (i=0;i<l;i++)
01643                                 {
01644                                         fApB = dec_values[i]*newA+newB;
01645                                         if (fApB >= 0)
01646                                                 newf += t[i]*fApB + Math.log(1+Math.exp(-fApB));
01647                                         else
01648                                                 newf += (t[i] - 1)*fApB +Math.log(1+Math.exp(fApB));
01649                                 }
01650                                 // Check sufficient decrease
01651                                 if (newf<fval+0.0001*stepsize*gd)
01652                                 {
01653                                         A=newA;B=newB;fval=newf;
01654                                         break;
01655                                 }
01656                                 else
01657                                         stepsize = stepsize / 2.0;
01658                         }
01659                         
01660                         if (stepsize < min_step)
01661                         {
01662                                 svm.info("Line search fails in two-class probability estimates\n");
01663                                 break;
01664                         }
01665                 }
01666                 
01667                 if (iter>=max_iter)
01668                         svm.info("Reaching maximal iterations in two-class probability estimates\n");
01669                 probAB[0]=A;probAB[1]=B;
01670         }
01671 
01672         private static double sigmoid_predict(double decision_value, double A, double B)
01673         {
01674                 double fApB = decision_value*A+B;
01675                 if (fApB >= 0)
01676                         return Math.exp(-fApB)/(1.0+Math.exp(-fApB));
01677                 else
01678                         return 1.0/(1+Math.exp(fApB)) ;
01679         }
01680 
01681         // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
01682         private static void multiclass_probability(int k, double[][] r, double[] p)
01683         {
01684                 int t,j;
01685                 int iter = 0, max_iter=Math.max(100,k);
01686                 double[][] Q=new double[k][k];
01687                 double[] Qp=new double[k];
01688                 double pQp, eps=0.005/k;
01689         
01690                 for (t=0;t<k;t++)
01691                 {
01692                         p[t]=1.0/k;  // Valid if k = 1
01693                         Q[t][t]=0;
01694                         for (j=0;j<t;j++)
01695                         {
01696                                 Q[t][t]+=r[j][t]*r[j][t];
01697                                 Q[t][j]=Q[j][t];
01698                         }
01699                         for (j=t+1;j<k;j++)
01700                         {
01701                                 Q[t][t]+=r[j][t]*r[j][t];
01702                                 Q[t][j]=-r[j][t]*r[t][j];
01703                         }
01704                 }
01705                 for (iter=0;iter<max_iter;iter++)
01706                 {
01707                         // stopping condition, recalculate QP,pQP for numerical accuracy
01708                         pQp=0;
01709                         for (t=0;t<k;t++)
01710                         {
01711                                 Qp[t]=0;
01712                                 for (j=0;j<k;j++)
01713                                         Qp[t]+=Q[t][j]*p[j];
01714                                 pQp+=p[t]*Qp[t];
01715                         }
01716                         double max_error=0;
01717                         for (t=0;t<k;t++)
01718                         {
01719                                 double error=Math.abs(Qp[t]-pQp);
01720                                 if (error>max_error)
01721                                         max_error=error;
01722                         }
01723                         if (max_error<eps) break;
01724                 
01725                         for (t=0;t<k;t++)
01726                         {
01727                                 double diff=(-Qp[t]+pQp)/Q[t][t];
01728                                 p[t]+=diff;
01729                                 pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff);
01730                                 for (j=0;j<k;j++)
01731                                 {
01732                                         Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff);
01733                                         p[j]/=(1+diff);
01734                                 }
01735                         }
01736                 }
01737                 if (iter>=max_iter)
01738                         svm.info("Exceeds max_iter in multiclass_prob\n");
01739         }
01740 
01741         // Cross-validation decision values for probability estimates
01742         private static void svm_binary_svc_probability(svm_problem prob, svm_parameter param, double Cp, double Cn, double[] probAB)
01743         {
01744                 int i;
01745                 int nr_fold = 5;
01746                 int[] perm = new int[prob.l];
01747                 double[] dec_values = new double[prob.l];
01748 
01749                 // random shuffle
01750                 for(i=0;i<prob.l;i++) perm[i]=i;
01751                 for(i=0;i<prob.l;i++)
01752                 {
01753                         int j = i+rand.nextInt(prob.l-i);
01754                         do {int _=perm[i]; perm[i]=perm[j]; perm[j]=_;} while(false);
01755                 }
01756                 for(i=0;i<nr_fold;i++)
01757                 {
01758                         int begin = i*prob.l/nr_fold;
01759                         int end = (i+1)*prob.l/nr_fold;
01760                         int j,k;
01761                         svm_problem subprob = new svm_problem();
01762 
01763                         subprob.l = prob.l-(end-begin);
01764                         subprob.x = new svm_node[subprob.l][];
01765                         subprob.y = new double[subprob.l];
01766                         
01767                         k=0;
01768                         for(j=0;j<begin;j++)
01769                         {
01770                                 subprob.x[k] = prob.x[perm[j]];
01771                                 subprob.y[k] = prob.y[perm[j]];
01772                                 ++k;
01773                         }
01774                         for(j=end;j<prob.l;j++)
01775                         {
01776                                 subprob.x[k] = prob.x[perm[j]];
01777                                 subprob.y[k] = prob.y[perm[j]];
01778                                 ++k;
01779                         }
01780                         int p_count=0,n_count=0;
01781                         for(j=0;j<k;j++)
01782                                 if(subprob.y[j]>0)
01783                                         p_count++;
01784                                 else
01785                                         n_count++;
01786                         
01787                         if(p_count==0 && n_count==0)
01788                                 for(j=begin;j<end;j++)
01789                                         dec_values[perm[j]] = 0;
01790                         else if(p_count > 0 && n_count == 0)
01791                                 for(j=begin;j<end;j++)
01792                                         dec_values[perm[j]] = 1;
01793                         else if(p_count == 0 && n_count > 0)
01794                                 for(j=begin;j<end;j++)
01795                                         dec_values[perm[j]] = -1;
01796                         else
01797                         {
01798                                 svm_parameter subparam = (svm_parameter)param.clone();
01799                                 subparam.probability=0;
01800                                 subparam.C=1.0;
01801                                 subparam.nr_weight=2;
01802                                 subparam.weight_label = new int[2];
01803                                 subparam.weight = new double[2];
01804                                 subparam.weight_label[0]=+1;
01805                                 subparam.weight_label[1]=-1;
01806                                 subparam.weight[0]=Cp;
01807                                 subparam.weight[1]=Cn;
01808                                 svm_model submodel = svm_train(subprob,subparam);
01809                                 for(j=begin;j<end;j++)
01810                                 {
01811                                         double[] dec_value=new double[1];
01812                                         svm_predict_values(submodel,prob.x[perm[j]],dec_value);
01813                                         dec_values[perm[j]]=dec_value[0];
01814                                         // ensure +1 -1 order; reason not using CV subroutine
01815                                         dec_values[perm[j]] *= submodel.label[0];
01816                                 }               
01817                         }
01818                 }               
01819                 sigmoid_train(prob.l,dec_values,prob.y,probAB);
01820         }
01821 
01822         // Return parameter of a Laplace distribution 
01823         private static double svm_svr_probability(svm_problem prob, svm_parameter param)
01824         {
01825                 int i;
01826                 int nr_fold = 5;
01827                 double[] ymv = new double[prob.l];
01828                 double mae = 0;
01829 
01830                 svm_parameter newparam = (svm_parameter)param.clone();
01831                 newparam.probability = 0;
01832                 svm_cross_validation(prob,newparam,nr_fold,ymv);
01833                 for(i=0;i<prob.l;i++)
01834                 {
01835                         ymv[i]=prob.y[i]-ymv[i];
01836                         mae += Math.abs(ymv[i]);
01837                 }               
01838                 mae /= prob.l;
01839                 double std=Math.sqrt(2*mae*mae);
01840                 int count=0;
01841                 mae=0;
01842                 for(i=0;i<prob.l;i++)
01843                         if (Math.abs(ymv[i]) > 5*std) 
01844                                 count=count+1;
01845                         else 
01846                                 mae+=Math.abs(ymv[i]);
01847                 mae /= (prob.l-count);
01848                 svm.info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma="+mae+"\n");
01849                 return mae;
01850         }
01851 
01852         // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
01853         // perm, length l, must be allocated before calling this subroutine
01854         private static void svm_group_classes(svm_problem prob, int[] nr_class_ret, int[][] label_ret, int[][] start_ret, int[][] count_ret, int[] perm)
01855         {
01856                 int l = prob.l;
01857                 int max_nr_class = 16;
01858                 int nr_class = 0;
01859                 int[] label = new int[max_nr_class];
01860                 int[] count = new int[max_nr_class];
01861                 int[] data_label = new int[l];
01862                 int i;
01863 
01864                 for(i=0;i<l;i++)
01865                 {
01866                         int this_label = (int)(prob.y[i]);
01867                         int j;
01868                         for(j=0;j<nr_class;j++)
01869                         {
01870                                 if(this_label == label[j])
01871                                 {
01872                                         ++count[j];
01873                                         break;
01874                                 }
01875                         }
01876                         data_label[i] = j;
01877                         if(j == nr_class)
01878                         {
01879                                 if(nr_class == max_nr_class)
01880                                 {
01881                                         max_nr_class *= 2;
01882                                         int[] new_data = new int[max_nr_class];
01883                                         System.arraycopy(label,0,new_data,0,label.length);
01884                                         label = new_data;
01885                                         new_data = new int[max_nr_class];
01886                                         System.arraycopy(count,0,new_data,0,count.length);
01887                                         count = new_data;                                       
01888                                 }
01889                                 label[nr_class] = this_label;
01890                                 count[nr_class] = 1;
01891                                 ++nr_class;
01892                         }
01893                 }
01894 
01895                 //
01896                 // Labels are ordered by their first occurrence in the training set. 
01897                 // However, for two-class sets with -1/+1 labels and -1 appears first, 
01898                 // we swap labels to ensure that internally the binary SVM has positive data corresponding to the +1 instances.
01899                 //
01900                 if (nr_class == 2 && label[0] == -1 && label[1] == +1)
01901                 {
01902                         do {int _=label[0]; label[0]=label[1]; label[1]=_;} while(false);
01903                         do {int _=count[0]; count[0]=count[1]; count[1]=_;} while(false);
01904                         for(i=0;i<l;i++)
01905                         {
01906                                 if(data_label[i] == 0)
01907                                         data_label[i] = 1;
01908                                 else
01909                                         data_label[i] = 0;
01910                         }
01911                 }
01912 
01913                 int[] start = new int[nr_class];
01914                 start[0] = 0;
01915                 for(i=1;i<nr_class;i++)
01916                         start[i] = start[i-1]+count[i-1];
01917                 for(i=0;i<l;i++)
01918                 {
01919                         perm[start[data_label[i]]] = i;
01920                         ++start[data_label[i]];
01921                 }
01922                 start[0] = 0;
01923                 for(i=1;i<nr_class;i++)
01924                         start[i] = start[i-1]+count[i-1];
01925 
01926                 nr_class_ret[0] = nr_class;
01927                 label_ret[0] = label;
01928                 start_ret[0] = start;
01929                 count_ret[0] = count;
01930         }
01931 
01932         //
01933         // Interface functions
01934         //
01935         public static svm_model svm_train(svm_problem prob, svm_parameter param)
01936         {
01937                 svm_model model = new svm_model();
01938                 model.param = param;
01939 
01940                 if(param.svm_type == svm_parameter.ONE_CLASS ||
01941                    param.svm_type == svm_parameter.EPSILON_SVR ||
01942                    param.svm_type == svm_parameter.NU_SVR)
01943                 {
01944                         // regression or one-class-svm
01945                         model.nr_class = 2;
01946                         model.label = null;
01947                         model.nSV = null;
01948                         model.probA = null; model.probB = null;
01949                         model.sv_coef = new double[1][];
01950 
01951                         if(param.probability == 1 &&
01952                            (param.svm_type == svm_parameter.EPSILON_SVR ||
01953                             param.svm_type == svm_parameter.NU_SVR))
01954                         {
01955                                 model.probA = new double[1];
01956                                 model.probA[0] = svm_svr_probability(prob,param);
01957                         }
01958 
01959                         decision_function f = svm_train_one(prob,param,0,0);
01960                         model.rho = new double[1];
01961                         model.rho[0] = f.rho;
01962 
01963                         int nSV = 0;
01964                         int i;
01965                         for(i=0;i<prob.l;i++)
01966                                 if(Math.abs(f.alpha[i]) > 0) ++nSV;
01967                         model.l = nSV;
01968                         model.SV = new svm_node[nSV][];
01969                         model.sv_coef[0] = new double[nSV];
01970                         model.sv_indices = new int[nSV];
01971                         int j = 0;
01972                         for(i=0;i<prob.l;i++)
01973                                 if(Math.abs(f.alpha[i]) > 0)
01974                                 {
01975                                         model.SV[j] = prob.x[i];
01976                                         model.sv_coef[0][j] = f.alpha[i];
01977                                         model.sv_indices[j] = i+1;
01978                                         ++j;
01979                                 }
01980                 }
01981                 else
01982                 {
01983                         // classification
01984                         int l = prob.l;
01985                         int[] tmp_nr_class = new int[1];
01986                         int[][] tmp_label = new int[1][];
01987                         int[][] tmp_start = new int[1][];
01988                         int[][] tmp_count = new int[1][];                       
01989                         int[] perm = new int[l];
01990 
01991                         // group training data of the same class
01992                         svm_group_classes(prob,tmp_nr_class,tmp_label,tmp_start,tmp_count,perm);
01993                         int nr_class = tmp_nr_class[0];                 
01994                         int[] label = tmp_label[0];
01995                         int[] start = tmp_start[0];
01996                         int[] count = tmp_count[0];
01997                         
01998                         if(nr_class == 1) 
01999                                 svm.info("WARNING: training data in only one class. See README for details.\n");
02000                         
02001                         svm_node[][] x = new svm_node[l][];
02002                         int i;
02003                         for(i=0;i<l;i++)
02004                                 x[i] = prob.x[perm[i]];
02005 
02006                         // calculate weighted C
02007 
02008                         double[] weighted_C = new double[nr_class];
02009                         for(i=0;i<nr_class;i++)
02010                                 weighted_C[i] = param.C;
02011                         for(i=0;i<param.nr_weight;i++)
02012                         {
02013                                 int j;
02014                                 for(j=0;j<nr_class;j++)
02015                                         if(param.weight_label[i] == label[j])
02016                                                 break;
02017                                 if(j == nr_class)
02018                                         System.err.print("WARNING: class label "+param.weight_label[i]+" specified in weight is not found\n");
02019                                 else
02020                                         weighted_C[j] *= param.weight[i];
02021                         }
02022 
02023                         // train k*(k-1)/2 models
02024 
02025                         boolean[] nonzero = new boolean[l];
02026                         for(i=0;i<l;i++)
02027                                 nonzero[i] = false;
02028                         decision_function[] f = new decision_function[nr_class*(nr_class-1)/2];
02029 
02030                         double[] probA=null,probB=null;
02031                         if (param.probability == 1)
02032                         {
02033                                 probA=new double[nr_class*(nr_class-1)/2];
02034                                 probB=new double[nr_class*(nr_class-1)/2];
02035                         }
02036 
02037                         int p = 0;
02038                         for(i=0;i<nr_class;i++)
02039                                 for(int j=i+1;j<nr_class;j++)
02040                                 {
02041                                         svm_problem sub_prob = new svm_problem();
02042                                         int si = start[i], sj = start[j];
02043                                         int ci = count[i], cj = count[j];
02044                                         sub_prob.l = ci+cj;
02045                                         sub_prob.x = new svm_node[sub_prob.l][];
02046                                         sub_prob.y = new double[sub_prob.l];
02047                                         int k;
02048                                         for(k=0;k<ci;k++)
02049                                         {
02050                                                 sub_prob.x[k] = x[si+k];
02051                                                 sub_prob.y[k] = +1;
02052                                         }
02053                                         for(k=0;k<cj;k++)
02054                                         {
02055                                                 sub_prob.x[ci+k] = x[sj+k];
02056                                                 sub_prob.y[ci+k] = -1;
02057                                         }
02058 
02059                                         if(param.probability == 1)
02060                                         {
02061                                                 double[] probAB=new double[2];
02062                                                 svm_binary_svc_probability(sub_prob,param,weighted_C[i],weighted_C[j],probAB);
02063                                                 probA[p]=probAB[0];
02064                                                 probB[p]=probAB[1];
02065                                         }
02066 
02067                                         f[p] = svm_train_one(sub_prob,param,weighted_C[i],weighted_C[j]);
02068                                         for(k=0;k<ci;k++)
02069                                                 if(!nonzero[si+k] && Math.abs(f[p].alpha[k]) > 0)
02070                                                         nonzero[si+k] = true;
02071                                         for(k=0;k<cj;k++)
02072                                                 if(!nonzero[sj+k] && Math.abs(f[p].alpha[ci+k]) > 0)
02073                                                         nonzero[sj+k] = true;
02074                                         ++p;
02075                                 }
02076 
02077                         // build output
02078 
02079                         model.nr_class = nr_class;
02080 
02081                         model.label = new int[nr_class];
02082                         for(i=0;i<nr_class;i++)
02083                                 model.label[i] = label[i];
02084 
02085                         model.rho = new double[nr_class*(nr_class-1)/2];
02086                         for(i=0;i<nr_class*(nr_class-1)/2;i++)
02087                                 model.rho[i] = f[i].rho;
02088 
02089                         if(param.probability == 1)
02090                         {
02091                                 model.probA = new double[nr_class*(nr_class-1)/2];
02092                                 model.probB = new double[nr_class*(nr_class-1)/2];
02093                                 for(i=0;i<nr_class*(nr_class-1)/2;i++)
02094                                 {
02095                                         model.probA[i] = probA[i];
02096                                         model.probB[i] = probB[i];
02097                                 }
02098                         }
02099                         else
02100                         {
02101                                 model.probA=null;
02102                                 model.probB=null;
02103                         }
02104 
02105                         int nnz = 0;
02106                         int[] nz_count = new int[nr_class];
02107                         model.nSV = new int[nr_class];
02108                         for(i=0;i<nr_class;i++)
02109                         {
02110                                 int nSV = 0;
02111                                 for(int j=0;j<count[i];j++)
02112                                         if(nonzero[start[i]+j])
02113                                         {
02114                                                 ++nSV;
02115                                                 ++nnz;
02116                                         }
02117                                 model.nSV[i] = nSV;
02118                                 nz_count[i] = nSV;
02119                         }
02120 
02121                         svm.info("Total nSV = "+nnz+"\n");
02122 
02123                         model.l = nnz;
02124                         model.SV = new svm_node[nnz][];
02125                         model.sv_indices = new int[nnz];
02126                         p = 0;
02127                         for(i=0;i<l;i++)
02128                                 if(nonzero[i])
02129                                 {
02130                                         model.SV[p] = x[i];
02131                                         model.sv_indices[p++] = perm[i] + 1;
02132                                 }
02133 
02134                         int[] nz_start = new int[nr_class];
02135                         nz_start[0] = 0;
02136                         for(i=1;i<nr_class;i++)
02137                                 nz_start[i] = nz_start[i-1]+nz_count[i-1];
02138 
02139                         model.sv_coef = new double[nr_class-1][];
02140                         for(i=0;i<nr_class-1;i++)
02141                                 model.sv_coef[i] = new double[nnz];
02142 
02143                         p = 0;
02144                         for(i=0;i<nr_class;i++)
02145                                 for(int j=i+1;j<nr_class;j++)
02146                                 {
02147                                         // classifier (i,j): coefficients with
02148                                         // i are in sv_coef[j-1][nz_start[i]...],
02149                                         // j are in sv_coef[i][nz_start[j]...]
02150 
02151                                         int si = start[i];
02152                                         int sj = start[j];
02153                                         int ci = count[i];
02154                                         int cj = count[j];
02155 
02156                                         int q = nz_start[i];
02157                                         int k;
02158                                         for(k=0;k<ci;k++)
02159                                                 if(nonzero[si+k])
02160                                                         model.sv_coef[j-1][q++] = f[p].alpha[k];
02161                                         q = nz_start[j];
02162                                         for(k=0;k<cj;k++)
02163                                                 if(nonzero[sj+k])
02164                                                         model.sv_coef[i][q++] = f[p].alpha[ci+k];
02165                                         ++p;
02166                                 }
02167                 }
02168                 return model;
02169         }
02170         
02171         // Stratified cross validation
02172         public static void svm_cross_validation(svm_problem prob, svm_parameter param, int nr_fold, double[] target)
02173         {
02174                 int i;
02175                 int[] fold_start = new int[nr_fold+1];
02176                 int l = prob.l;
02177                 int[] perm = new int[l];
02178                 
02179                 // stratified cv may not give leave-one-out rate
02180                 // Each class to l folds -> some folds may have zero elements
02181                 if((param.svm_type == svm_parameter.C_SVC ||
02182                     param.svm_type == svm_parameter.NU_SVC) && nr_fold < l)
02183                 {
02184                         int[] tmp_nr_class = new int[1];
02185                         int[][] tmp_label = new int[1][];
02186                         int[][] tmp_start = new int[1][];
02187                         int[][] tmp_count = new int[1][];
02188 
02189                         svm_group_classes(prob,tmp_nr_class,tmp_label,tmp_start,tmp_count,perm);
02190 
02191                         int nr_class = tmp_nr_class[0];
02192                         int[] start = tmp_start[0];
02193                         int[] count = tmp_count[0];             
02194 
02195                         // random shuffle and then data grouped by fold using the array perm
02196                         int[] fold_count = new int[nr_fold];
02197                         int c;
02198                         int[] index = new int[l];
02199                         for(i=0;i<l;i++)
02200                                 index[i]=perm[i];
02201                         for (c=0; c<nr_class; c++)
02202                                 for(i=0;i<count[c];i++)
02203                                 {
02204                                         int j = i+rand.nextInt(count[c]-i);
02205                                         do {int _=index[start[c]+j]; index[start[c]+j]=index[start[c]+i]; index[start[c]+i]=_;} while(false);
02206                                 }
02207                         for(i=0;i<nr_fold;i++)
02208                         {
02209                                 fold_count[i] = 0;
02210                                 for (c=0; c<nr_class;c++)
02211                                         fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold;
02212                         }
02213                         fold_start[0]=0;
02214                         for (i=1;i<=nr_fold;i++)
02215                                 fold_start[i] = fold_start[i-1]+fold_count[i-1];
02216                         for (c=0; c<nr_class;c++)
02217                                 for(i=0;i<nr_fold;i++)
02218                                 {
02219                                         int begin = start[c]+i*count[c]/nr_fold;
02220                                         int end = start[c]+(i+1)*count[c]/nr_fold;
02221                                         for(int j=begin;j<end;j++)
02222                                         {
02223                                                 perm[fold_start[i]] = index[j];
02224                                                 fold_start[i]++;
02225                                         }
02226                                 }
02227                         fold_start[0]=0;
02228                         for (i=1;i<=nr_fold;i++)
02229                                 fold_start[i] = fold_start[i-1]+fold_count[i-1];
02230                 }
02231                 else
02232                 {
02233                         for(i=0;i<l;i++) perm[i]=i;
02234                         for(i=0;i<l;i++)
02235                         {
02236                                 int j = i+rand.nextInt(l-i);
02237                                 do {int _=perm[i]; perm[i]=perm[j]; perm[j]=_;} while(false);
02238                         }
02239                         for(i=0;i<=nr_fold;i++)
02240                                 fold_start[i]=i*l/nr_fold;
02241                 }
02242 
02243                 for(i=0;i<nr_fold;i++)
02244                 {
02245                         int begin = fold_start[i];
02246                         int end = fold_start[i+1];
02247                         int j,k;
02248                         svm_problem subprob = new svm_problem();
02249 
02250                         subprob.l = l-(end-begin);
02251                         subprob.x = new svm_node[subprob.l][];
02252                         subprob.y = new double[subprob.l];
02253 
02254                         k=0;
02255                         for(j=0;j<begin;j++)
02256                         {
02257                                 subprob.x[k] = prob.x[perm[j]];
02258                                 subprob.y[k] = prob.y[perm[j]];
02259                                 ++k;
02260                         }
02261                         for(j=end;j<l;j++)
02262                         {
02263                                 subprob.x[k] = prob.x[perm[j]];
02264                                 subprob.y[k] = prob.y[perm[j]];
02265                                 ++k;
02266                         }
02267                         svm_model submodel = svm_train(subprob,param);
02268                         if(param.probability==1 &&
02269                            (param.svm_type == svm_parameter.C_SVC ||
02270                             param.svm_type == svm_parameter.NU_SVC))
02271                         {
02272                                 double[] prob_estimates= new double[svm_get_nr_class(submodel)];
02273                                 for(j=begin;j<end;j++)
02274                                         target[perm[j]] = svm_predict_probability(submodel,prob.x[perm[j]],prob_estimates);
02275                         }
02276                         else
02277                                 for(j=begin;j<end;j++)
02278                                         target[perm[j]] = svm_predict(submodel,prob.x[perm[j]]);
02279                 }
02280         }
02281 
02282         public static int svm_get_svm_type(svm_model model)
02283         {
02284                 return model.param.svm_type;
02285         }
02286 
02287         public static int svm_get_nr_class(svm_model model)
02288         {
02289                 return model.nr_class;
02290         }
02291 
02292         public static void svm_get_labels(svm_model model, int[] label)
02293         {
02294                 if (model.label != null)
02295                         for(int i=0;i<model.nr_class;i++)
02296                                 label[i] = model.label[i];
02297         }
02298 
02299         public static void svm_get_sv_indices(svm_model model, int[] indices)
02300         {
02301                 if (model.sv_indices != null)
02302                         for(int i=0;i<model.l;i++)
02303                                 indices[i] = model.sv_indices[i];
02304         }
02305 
02306         public static int svm_get_nr_sv(svm_model model)
02307         {
02308                 return model.l;
02309         }
02310 
02311         public static double svm_get_svr_probability(svm_model model)
02312         {
02313                 if ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) &&
02314                     model.probA!=null)
02315                 return model.probA[0];
02316                 else
02317                 {
02318                         System.err.print("Model doesn't contain information for SVR probability inference\n");
02319                         return 0;
02320                 }
02321         }
02322 
02323         public static double svm_predict_values(svm_model model, svm_node[] x, double[] dec_values)
02324         {
02325                 int i;
02326                 if(model.param.svm_type == svm_parameter.ONE_CLASS ||
02327                    model.param.svm_type == svm_parameter.EPSILON_SVR ||
02328                    model.param.svm_type == svm_parameter.NU_SVR)
02329                 {
02330                         double[] sv_coef = model.sv_coef[0];
02331                         double sum = 0;
02332                         for(i=0;i<model.l;i++)
02333                                 sum += sv_coef[i] * Kernel.k_function(x,model.SV[i],model.param);
02334                         sum -= model.rho[0];
02335                         dec_values[0] = sum;
02336 
02337                         if(model.param.svm_type == svm_parameter.ONE_CLASS)
02338                                 return (sum>0)?1:-1;
02339                         else
02340                                 return sum;
02341                 }
02342                 else
02343                 {
02344                         int nr_class = model.nr_class;
02345                         int l = model.l;
02346                 
02347                         double[] kvalue = new double[l];
02348                         for(i=0;i<l;i++)
02349                                 kvalue[i] = Kernel.k_function(x,model.SV[i],model.param);
02350 
02351                         int[] start = new int[nr_class];
02352                         start[0] = 0;
02353                         for(i=1;i<nr_class;i++)
02354                                 start[i] = start[i-1]+model.nSV[i-1];
02355 
02356                         int[] vote = new int[nr_class];
02357                         for(i=0;i<nr_class;i++)
02358                                 vote[i] = 0;
02359 
02360                         int p=0;
02361                         for(i=0;i<nr_class;i++)
02362                                 for(int j=i+1;j<nr_class;j++)
02363                                 {
02364                                         double sum = 0;
02365                                         int si = start[i];
02366                                         int sj = start[j];
02367                                         int ci = model.nSV[i];
02368                                         int cj = model.nSV[j];
02369                                 
02370                                         int k;
02371                                         double[] coef1 = model.sv_coef[j-1];
02372                                         double[] coef2 = model.sv_coef[i];
02373                                         for(k=0;k<ci;k++)
02374                                                 sum += coef1[si+k] * kvalue[si+k];
02375                                         for(k=0;k<cj;k++)
02376                                                 sum += coef2[sj+k] * kvalue[sj+k];
02377                                         sum -= model.rho[p];
02378                                         dec_values[p] = sum;                                    
02379 
02380                                         if(dec_values[p] > 0)
02381                                                 ++vote[i];
02382                                         else
02383                                                 ++vote[j];
02384                                         p++;
02385                                 }
02386 
02387                         int vote_max_idx = 0;
02388                         for(i=1;i<nr_class;i++)
02389                                 if(vote[i] > vote[vote_max_idx])
02390                                         vote_max_idx = i;
02391 
02392                         return model.label[vote_max_idx];
02393                 }
02394         }
02395 
02396         public static double svm_predict(svm_model model, svm_node[] x)
02397         {
02398                 int nr_class = model.nr_class;
02399                 double[] dec_values;
02400                 if(model.param.svm_type == svm_parameter.ONE_CLASS ||
02401                                 model.param.svm_type == svm_parameter.EPSILON_SVR ||
02402                                 model.param.svm_type == svm_parameter.NU_SVR)
02403                         dec_values = new double[1];
02404                 else
02405                         dec_values = new double[nr_class*(nr_class-1)/2];
02406                 double pred_result = svm_predict_values(model, x, dec_values);
02407                 return pred_result;
02408         }
02409 
02410         public static double svm_predict_probability(svm_model model, svm_node[] x, double[] prob_estimates)
02411         {
02412                 if ((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) &&
02413                     model.probA!=null && model.probB!=null)
02414                 {
02415                         int i;
02416                         int nr_class = model.nr_class;
02417                         double[] dec_values = new double[nr_class*(nr_class-1)/2];
02418                         svm_predict_values(model, x, dec_values);
02419 
02420                         double min_prob=1e-7;
02421                         double[][] pairwise_prob=new double[nr_class][nr_class];
02422                         
02423                         int k=0;
02424                         for(i=0;i<nr_class;i++)
02425                                 for(int j=i+1;j<nr_class;j++)
02426                                 {
02427                                         pairwise_prob[i][j]=Math.min(Math.max(sigmoid_predict(dec_values[k],model.probA[k],model.probB[k]),min_prob),1-min_prob);
02428                                         pairwise_prob[j][i]=1-pairwise_prob[i][j];
02429                                         k++;
02430                                 }
02431                         multiclass_probability(nr_class,pairwise_prob,prob_estimates);
02432 
02433                         int prob_max_idx = 0;
02434                         for(i=1;i<nr_class;i++)
02435                                 if(prob_estimates[i] > prob_estimates[prob_max_idx])
02436                                         prob_max_idx = i;
02437                         return model.label[prob_max_idx];
02438                 }
02439                 else 
02440                         return svm_predict(model, x);
02441         }
02442 
02443         static final String svm_type_table[] =
02444         {
02445                 "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",
02446         };
02447 
02448         static final String kernel_type_table[]=
02449         {
02450                 "linear","polynomial","rbf","sigmoid","precomputed"
02451         };
02452 
02453         public static void svm_save_model(String model_file_name, svm_model model) throws IOException
02454         {
02455                 DataOutputStream fp = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(model_file_name)));
02456 
02457                 svm_parameter param = model.param;
02458 
02459                 fp.writeBytes("svm_type "+svm_type_table[param.svm_type]+"\n");
02460                 fp.writeBytes("kernel_type "+kernel_type_table[param.kernel_type]+"\n");
02461 
02462                 if(param.kernel_type == svm_parameter.POLY)
02463                         fp.writeBytes("degree "+param.degree+"\n");
02464 
02465                 if(param.kernel_type == svm_parameter.POLY ||
02466                    param.kernel_type == svm_parameter.RBF ||
02467                    param.kernel_type == svm_parameter.SIGMOID)
02468                         fp.writeBytes("gamma "+param.gamma+"\n");
02469 
02470                 if(param.kernel_type == svm_parameter.POLY ||
02471                    param.kernel_type == svm_parameter.SIGMOID)
02472                         fp.writeBytes("coef0 "+param.coef0+"\n");
02473 
02474                 int nr_class = model.nr_class;
02475                 int l = model.l;
02476                 fp.writeBytes("nr_class "+nr_class+"\n");
02477                 fp.writeBytes("total_sv "+l+"\n");
02478         
02479                 {
02480                         fp.writeBytes("rho");
02481                         for(int i=0;i<nr_class*(nr_class-1)/2;i++)
02482                                 fp.writeBytes(" "+model.rho[i]);
02483                         fp.writeBytes("\n");
02484                 }
02485         
02486                 if(model.label != null)
02487                 {
02488                         fp.writeBytes("label");
02489                         for(int i=0;i<nr_class;i++)
02490                                 fp.writeBytes(" "+model.label[i]);
02491                         fp.writeBytes("\n");
02492                 }
02493 
02494                 if(model.probA != null) // regression has probA only
02495                 {
02496                         fp.writeBytes("probA");
02497                         for(int i=0;i<nr_class*(nr_class-1)/2;i++)
02498                                 fp.writeBytes(" "+model.probA[i]);
02499                         fp.writeBytes("\n");
02500                 }
02501                 if(model.probB != null) 
02502                 {
02503                         fp.writeBytes("probB");
02504                         for(int i=0;i<nr_class*(nr_class-1)/2;i++)
02505                                 fp.writeBytes(" "+model.probB[i]);
02506                         fp.writeBytes("\n");
02507                 }
02508 
02509                 if(model.nSV != null)
02510                 {
02511                         fp.writeBytes("nr_sv");
02512                         for(int i=0;i<nr_class;i++)
02513                                 fp.writeBytes(" "+model.nSV[i]);
02514                         fp.writeBytes("\n");
02515                 }
02516 
02517                 fp.writeBytes("SV\n");
02518                 double[][] sv_coef = model.sv_coef;
02519                 svm_node[][] SV = model.SV;
02520 
02521                 for(int i=0;i<l;i++)
02522                 {
02523                         for(int j=0;j<nr_class-1;j++)
02524                                 fp.writeBytes(sv_coef[j][i]+" ");
02525 
02526                         svm_node[] p = SV[i];
02527                         if(param.kernel_type == svm_parameter.PRECOMPUTED)
02528                                 fp.writeBytes("0:"+(int)(p[0].value));
02529                         else    
02530                                 for(int j=0;j<p.length;j++)
02531                                         fp.writeBytes(p[j].index+":"+p[j].value+" ");
02532                         fp.writeBytes("\n");
02533                 }
02534 
02535                 fp.close();
02536         }
02537 
02538         private static double atof(String s)
02539         {
02540                 return Double.valueOf(s).doubleValue();
02541         }
02542 
02543         private static int atoi(String s)
02544         {
02545                 return Integer.parseInt(s);
02546         }
02547 
02548         private static boolean read_model_header(BufferedReader fp, svm_model model)
02549         {
02550                 svm_parameter param = new svm_parameter();
02551                 model.param = param;
02552                 try
02553                 {
02554                         while(true)
02555                         {
02556                                 String cmd = fp.readLine();
02557                                 String arg = cmd.substring(cmd.indexOf(' ')+1);
02558 
02559                                 if(cmd.startsWith("svm_type"))
02560                                 {
02561                                         int i;
02562                                         for(i=0;i<svm_type_table.length;i++)
02563                                         {
02564                                                 if(arg.indexOf(svm_type_table[i])!=-1)
02565                                                 {
02566                                                         param.svm_type=i;
02567                                                         break;
02568                                                 }
02569                                         }
02570                                         if(i == svm_type_table.length)
02571                                         {
02572                                                 System.err.print("unknown svm type.\n");
02573                                                 return false;
02574                                         }
02575                                 }
02576                                 else if(cmd.startsWith("kernel_type"))
02577                                 {
02578                                         int i;
02579                                         for(i=0;i<kernel_type_table.length;i++)
02580                                         {
02581                                                 if(arg.indexOf(kernel_type_table[i])!=-1)
02582                                                 {
02583                                                         param.kernel_type=i;
02584                                                         break;
02585                                                 }
02586                                         }
02587                                         if(i == kernel_type_table.length)
02588                                         {
02589                                                 System.err.print("unknown kernel function.\n");
02590                                                 return false;
02591                                         }
02592                                 }
02593                                 else if(cmd.startsWith("degree"))
02594                                         param.degree = atoi(arg);
02595                                 else if(cmd.startsWith("gamma"))
02596                                         param.gamma = atof(arg);
02597                                 else if(cmd.startsWith("coef0"))
02598                                         param.coef0 = atof(arg);
02599                                 else if(cmd.startsWith("nr_class"))
02600                                         model.nr_class = atoi(arg);
02601                                 else if(cmd.startsWith("total_sv"))
02602                                         model.l = atoi(arg);
02603                                 else if(cmd.startsWith("rho"))
02604                                 {
02605                                         int n = model.nr_class * (model.nr_class-1)/2;
02606                                         model.rho = new double[n];
02607                                         StringTokenizer st = new StringTokenizer(arg);
02608                                         for(int i=0;i<n;i++)
02609                                                 model.rho[i] = atof(st.nextToken());
02610                                 }
02611                                 else if(cmd.startsWith("label"))
02612                                 {
02613                                         int n = model.nr_class;
02614                                         model.label = new int[n];
02615                                         StringTokenizer st = new StringTokenizer(arg);
02616                                         for(int i=0;i<n;i++)
02617                                                 model.label[i] = atoi(st.nextToken());                                  
02618                                 }
02619                                 else if(cmd.startsWith("probA"))
02620                                 {
02621                                         int n = model.nr_class*(model.nr_class-1)/2;
02622                                         model.probA = new double[n];
02623                                         StringTokenizer st = new StringTokenizer(arg);
02624                                         for(int i=0;i<n;i++)
02625                                                 model.probA[i] = atof(st.nextToken());                                  
02626                                 }
02627                                 else if(cmd.startsWith("probB"))
02628                                 {
02629                                         int n = model.nr_class*(model.nr_class-1)/2;
02630                                         model.probB = new double[n];
02631                                         StringTokenizer st = new StringTokenizer(arg);
02632                                         for(int i=0;i<n;i++)
02633                                                 model.probB[i] = atof(st.nextToken());                                  
02634                                 }
02635                                 else if(cmd.startsWith("nr_sv"))
02636                                 {
02637                                         int n = model.nr_class;
02638                                         model.nSV = new int[n];
02639                                         StringTokenizer st = new StringTokenizer(arg);
02640                                         for(int i=0;i<n;i++)
02641                                                 model.nSV[i] = atoi(st.nextToken());
02642                                 }
02643                                 else if(cmd.startsWith("SV"))
02644                                 {
02645                                         break;
02646                                 }
02647                                 else
02648                                 {
02649                                         System.err.print("unknown text in model file: ["+cmd+"]\n");
02650                                         return false;
02651                                 }
02652                         }
02653                 }
02654                 catch(Exception e)
02655                 {
02656                         return false;
02657                 }
02658                 return true;
02659         }
02660 
02661         public static svm_model svm_load_model(String model_file_name) throws IOException
02662         {
02663                 return svm_load_model(new BufferedReader(new FileReader(model_file_name)));
02664         }
02665 
02666         public static svm_model svm_load_model(BufferedReader fp) throws IOException
02667         {
02668                 // read parameters
02669 
02670                 svm_model model = new svm_model();
02671                 model.rho = null;
02672                 model.probA = null;
02673                 model.probB = null;
02674                 model.label = null;
02675                 model.nSV = null;
02676 
02677                 if (read_model_header(fp, model) == false)
02678                 {
02679                         System.err.print("ERROR: failed to read model\n");
02680                         return null;
02681                 }
02682 
02683                 // read sv_coef and SV
02684 
02685                 int m = model.nr_class - 1;
02686                 int l = model.l;
02687                 model.sv_coef = new double[m][l];
02688                 model.SV = new svm_node[l][];
02689 
02690                 for(int i=0;i<l;i++)
02691                 {
02692                         String line = fp.readLine();
02693                         StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
02694 
02695                         for(int k=0;k<m;k++)
02696                                 model.sv_coef[k][i] = atof(st.nextToken());
02697                         int n = st.countTokens()/2;
02698                         model.SV[i] = new svm_node[n];
02699                         for(int j=0;j<n;j++)
02700                         {
02701                                 model.SV[i][j] = new svm_node();
02702                                 model.SV[i][j].index = atoi(st.nextToken());
02703                                 model.SV[i][j].value = atof(st.nextToken());
02704                         }
02705                 }
02706 
02707                 fp.close();
02708                 return model;
02709         }
02710 
02711         public static String svm_check_parameter(svm_problem prob, svm_parameter param)
02712         {
02713                 // svm_type
02714 
02715                 int svm_type = param.svm_type;
02716                 if(svm_type != svm_parameter.C_SVC &&
02717                    svm_type != svm_parameter.NU_SVC &&
02718                    svm_type != svm_parameter.ONE_CLASS &&
02719                    svm_type != svm_parameter.EPSILON_SVR &&
02720                    svm_type != svm_parameter.NU_SVR)
02721                 return "unknown svm type";
02722 
02723                 // kernel_type, degree
02724         
02725                 int kernel_type = param.kernel_type;
02726                 if(kernel_type != svm_parameter.LINEAR &&
02727                    kernel_type != svm_parameter.POLY &&
02728                    kernel_type != svm_parameter.RBF &&
02729                    kernel_type != svm_parameter.SIGMOID &&
02730                    kernel_type != svm_parameter.PRECOMPUTED)
02731                         return "unknown kernel type";
02732 
02733                 if(param.gamma < 0)
02734                         return "gamma < 0";
02735 
02736                 if(param.degree < 0)
02737                         return "degree of polynomial kernel < 0";
02738 
02739                 // cache_size,eps,C,nu,p,shrinking
02740 
02741                 if(param.cache_size <= 0)
02742                         return "cache_size <= 0";
02743 
02744                 if(param.eps <= 0)
02745                         return "eps <= 0";
02746 
02747                 if(svm_type == svm_parameter.C_SVC ||
02748                    svm_type == svm_parameter.EPSILON_SVR ||
02749                    svm_type == svm_parameter.NU_SVR)
02750                         if(param.C <= 0)
02751                                 return "C <= 0";
02752 
02753                 if(svm_type == svm_parameter.NU_SVC ||
02754                    svm_type == svm_parameter.ONE_CLASS ||
02755                    svm_type == svm_parameter.NU_SVR)
02756                         if(param.nu <= 0 || param.nu > 1)
02757                                 return "nu <= 0 or nu > 1";
02758 
02759                 if(svm_type == svm_parameter.EPSILON_SVR)
02760                         if(param.p < 0)
02761                                 return "p < 0";
02762 
02763                 if(param.shrinking != 0 &&
02764                    param.shrinking != 1)
02765                         return "shrinking != 0 and shrinking != 1";
02766 
02767                 if(param.probability != 0 &&
02768                    param.probability != 1)
02769                         return "probability != 0 and probability != 1";
02770 
02771                 if(param.probability == 1 &&
02772                    svm_type == svm_parameter.ONE_CLASS)
02773                         return "one-class SVM probability output not supported yet";
02774                 
02775                 // check whether nu-svc is feasible
02776         
02777                 if(svm_type == svm_parameter.NU_SVC)
02778                 {
02779                         int l = prob.l;
02780                         int max_nr_class = 16;
02781                         int nr_class = 0;
02782                         int[] label = new int[max_nr_class];
02783                         int[] count = new int[max_nr_class];
02784 
02785                         int i;
02786                         for(i=0;i<l;i++)
02787                         {
02788                                 int this_label = (int)prob.y[i];
02789                                 int j;
02790                                 for(j=0;j<nr_class;j++)
02791                                         if(this_label == label[j])
02792                                         {
02793                                                 ++count[j];
02794                                                 break;
02795                                         }
02796 
02797                                 if(j == nr_class)
02798                                 {
02799                                         if(nr_class == max_nr_class)
02800                                         {
02801                                                 max_nr_class *= 2;
02802                                                 int[] new_data = new int[max_nr_class];
02803                                                 System.arraycopy(label,0,new_data,0,label.length);
02804                                                 label = new_data;
02805                                                 
02806                                                 new_data = new int[max_nr_class];
02807                                                 System.arraycopy(count,0,new_data,0,count.length);
02808                                                 count = new_data;
02809                                         }
02810                                         label[nr_class] = this_label;
02811                                         count[nr_class] = 1;
02812                                         ++nr_class;
02813                                 }
02814                         }
02815 
02816                         for(i=0;i<nr_class;i++)
02817                         {
02818                                 int n1 = count[i];
02819                                 for(int j=i+1;j<nr_class;j++)
02820                                 {
02821                                         int n2 = count[j];
02822                                         if(param.nu*(n1+n2)/2 > Math.min(n1,n2))
02823                                                 return "specified nu is infeasible";
02824                                 }
02825                         }
02826                 }
02827 
02828                 return null;
02829         }
02830 
02831         public static int svm_check_probability_model(svm_model model)
02832         {
02833                 if (((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) &&
02834                 model.probA!=null && model.probB!=null) ||
02835                 ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) &&
02836                  model.probA!=null))
02837                         return 1;
02838                 else
02839                         return 0;
02840         }
02841 
02842         public static void svm_set_print_string_function(svm_print_interface print_func)
02843         {
02844                 if (print_func == null)
02845                         svm_print_string = svm_print_stdout;
02846                 else 
02847                         svm_print_string = print_func;
02848         }
02849 }


target_obejct_detector
Author(s): CIR-KIT
autogenerated on Thu Jun 6 2019 20:19:57