svm.java
Go to the documentation of this file.
00001 
00002 
00003 
00004 
00005 package libsvm;
00006 import java.io.*;
00007 import java.util.*;
00008 
00009 //
00010 // Kernel Cache
00011 //
00012 // l is the number of total data items
00013 // size is the cache size limit in bytes
00014 //
00015 class Cache {
00016         private final int l;
00017         private long size;
00018         private final class head_t
00019         {
00020                 head_t prev, next;      // a cicular list
00021                 float[] data;
00022                 int len;                // data[0,len) is cached in this entry
00023         }
00024         private final head_t[] head;
00025         private head_t lru_head;
00026 
00027         Cache(int l_, long size_)
00028         {
00029                 l = l_;
00030                 size = size_;
00031                 head = new head_t[l];
00032                 for(int i=0;i<l;i++) head[i] = new head_t();
00033                 size /= 4;
00034                 size -= l * (16/4);     // sizeof(head_t) == 16
00035                 size = Math.max(size, 2* (long) l);  // cache must be large enough for two columns
00036                 lru_head = new head_t();
00037                 lru_head.next = lru_head.prev = lru_head;
00038         }
00039 
00040         private void lru_delete(head_t h)
00041         {
00042                 // delete from current location
00043                 h.prev.next = h.next;
00044                 h.next.prev = h.prev;
00045         }
00046 
00047         private void lru_insert(head_t h)
00048         {
00049                 // insert to last position
00050                 h.next = lru_head;
00051                 h.prev = lru_head.prev;
00052                 h.prev.next = h;
00053                 h.next.prev = h;
00054         }
00055 
00056         // request data [0,len)
00057         // return some position p where [p,len) need to be filled
00058         // (p >= len if nothing needs to be filled)
00059         // java: simulate pointer using single-element array
00060         int get_data(int index, float[][] data, int len)
00061         {
00062                 head_t h = head[index];
00063                 if(h.len > 0) lru_delete(h);
00064                 int more = len - h.len;
00065 
00066                 if(more > 0)
00067                 {
00068                         // free old space
00069                         while(size < more)
00070                         {
00071                                 head_t old = lru_head.next;
00072                                 lru_delete(old);
00073                                 size += old.len;
00074                                 old.data = null;
00075                                 old.len = 0;
00076                         }
00077 
00078                         // allocate new space
00079                         float[] new_data = new float[len];
00080                         if(h.data != null) System.arraycopy(h.data,0,new_data,0,h.len);
00081                         h.data = new_data;
00082                         size -= more;
00083                         do {int _=h.len; h.len=len; len=_;} while(false);
00084                 }
00085 
00086                 lru_insert(h);
00087                 data[0] = h.data;
00088                 return len;
00089         }
00090 
00091         void swap_index(int i, int j)
00092         {
00093                 if(i==j) return;
00094                 
00095                 if(head[i].len > 0) lru_delete(head[i]);
00096                 if(head[j].len > 0) lru_delete(head[j]);
00097                 do {float[] _=head[i].data; head[i].data=head[j].data; head[j].data=_;} while(false);
00098                 do {int _=head[i].len; head[i].len=head[j].len; head[j].len=_;} while(false);
00099                 if(head[i].len > 0) lru_insert(head[i]);
00100                 if(head[j].len > 0) lru_insert(head[j]);
00101 
00102                 if(i>j) do {int _=i; i=j; j=_;} while(false);
00103                 for(head_t h = lru_head.next; h!=lru_head; h=h.next)
00104                 {
00105                         if(h.len > i)
00106                         {
00107                                 if(h.len > j)
00108                                         do {float _=h.data[i]; h.data[i]=h.data[j]; h.data[j]=_;} while(false);
00109                                 else
00110                                 {
00111                                         // give up
00112                                         lru_delete(h);
00113                                         size += h.len;
00114                                         h.data = null;
00115                                         h.len = 0;
00116                                 }
00117                         }
00118                 }
00119         }
00120 }
00121 
00122 //
00123 // Kernel evaluation
00124 //
00125 // the static method k_function is for doing single kernel evaluation
00126 // the constructor of Kernel prepares to calculate the l*l kernel matrix
00127 // the member function get_Q is for getting one column from the Q Matrix
00128 //
00129 abstract class QMatrix {
00130         abstract float[] get_Q(int column, int len);
00131         abstract double[] get_QD();
00132         abstract void swap_index(int i, int j);
00133 };
00134 
00135 abstract class Kernel extends QMatrix {
00136         private svm_node[][] x;
00137         private final double[] x_square;
00138 
00139         // svm_parameter
00140         private final int kernel_type;
00141         private final int degree;
00142         private final double gamma;
00143         private final double coef0;
00144 
00145         abstract float[] get_Q(int column, int len);
00146         abstract double[] get_QD();
00147 
00148         void swap_index(int i, int j)
00149         {
00150                 do {svm_node[] _=x[i]; x[i]=x[j]; x[j]=_;} while(false);
00151                 if(x_square != null) do {double _=x_square[i]; x_square[i]=x_square[j]; x_square[j]=_;} while(false);
00152         }
00153 
00154         private static double powi(double base, int times)
00155         {
00156                 double tmp = base, ret = 1.0;
00157 
00158                 for(int t=times; t>0; t/=2)
00159                 {
00160                         if(t%2==1) ret*=tmp;
00161                         tmp = tmp * tmp;
00162                 }
00163                 return ret;
00164         }
00165 
00166         double kernel_function(int i, int j)
00167         {
00168                 switch(kernel_type)
00169                 {
00170                         case svm_parameter.LINEAR:
00171                                 return dot(x[i],x[j]);
00172                         case svm_parameter.POLY:
00173                                 return powi(gamma*dot(x[i],x[j])+coef0,degree);
00174                         case svm_parameter.RBF:
00175                                 return Math.exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
00176                         case svm_parameter.SIGMOID:
00177                                 return Math.tanh(gamma*dot(x[i],x[j])+coef0);
00178                         case svm_parameter.PRECOMPUTED:
00179                                 return x[i][(int)(x[j][0].value)].value;
00180                         default:
00181                                 return 0;       // java
00182                 }
00183         }
00184 
00185         Kernel(int l, svm_node[][] x_, svm_parameter param)
00186         {
00187                 this.kernel_type = param.kernel_type;
00188                 this.degree = param.degree;
00189                 this.gamma = param.gamma;
00190                 this.coef0 = param.coef0;
00191 
00192                 x = (svm_node[][])x_.clone();
00193 
00194                 if(kernel_type == svm_parameter.RBF)
00195                 {
00196                         x_square = new double[l];
00197                         for(int i=0;i<l;i++)
00198                                 x_square[i] = dot(x[i],x[i]);
00199                 }
00200                 else x_square = null;
00201         }
00202 
00203         static double dot(svm_node[] x, svm_node[] y)
00204         {
00205                 double sum = 0;
00206                 int xlen = x.length;
00207                 int ylen = y.length;
00208                 int i = 0;
00209                 int j = 0;
00210                 while(i < xlen && j < ylen)
00211                 {
00212                         if(x[i].index == y[j].index)
00213                                 sum += x[i++].value * y[j++].value;
00214                         else
00215                         {
00216                                 if(x[i].index > y[j].index)
00217                                         ++j;
00218                                 else
00219                                         ++i;
00220                         }
00221                 }
00222                 return sum;
00223         }
00224 
00225         static double k_function(svm_node[] x, svm_node[] y,
00226                                         svm_parameter param)
00227         {
00228                 switch(param.kernel_type)
00229                 {
00230                         case svm_parameter.LINEAR:
00231                                 return dot(x,y);
00232                         case svm_parameter.POLY:
00233                                 return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
00234                         case svm_parameter.RBF:
00235                         {
00236                                 double sum = 0;
00237                                 int xlen = x.length;
00238                                 int ylen = y.length;
00239                                 int i = 0;
00240                                 int j = 0;
00241                                 while(i < xlen && j < ylen)
00242                                 {
00243                                         if(x[i].index == y[j].index)
00244                                         {
00245                                                 double d = x[i++].value - y[j++].value;
00246                                                 sum += d*d;
00247                                         }
00248                                         else if(x[i].index > y[j].index)
00249                                         {
00250                                                 sum += y[j].value * y[j].value;
00251                                                 ++j;
00252                                         }
00253                                         else
00254                                         {
00255                                                 sum += x[i].value * x[i].value;
00256                                                 ++i;
00257                                         }
00258                                 }
00259 
00260                                 while(i < xlen)
00261                                 {
00262                                         sum += x[i].value * x[i].value;
00263                                         ++i;
00264                                 }
00265 
00266                                 while(j < ylen)
00267                                 {
00268                                         sum += y[j].value * y[j].value;
00269                                         ++j;
00270                                 }
00271 
00272                                 return Math.exp(-param.gamma*sum);
00273                         }
00274                         case svm_parameter.SIGMOID:
00275                                 return Math.tanh(param.gamma*dot(x,y)+param.coef0);
00276                         case svm_parameter.PRECOMPUTED:
00277                                 return  x[(int)(y[0].value)].value;
00278                         default:
00279                                 return 0;       // java
00280                 }
00281         }
00282 }
00283 
00284 // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
00285 // Solves:
00286 //
00287 //      min 0.5(\alpha^T Q \alpha) + p^T \alpha
00288 //
00289 //              y^T \alpha = \delta
00290 //              y_i = +1 or -1
00291 //              0 <= alpha_i <= Cp for y_i = 1
00292 //              0 <= alpha_i <= Cn for y_i = -1
00293 //
00294 // Given:
00295 //
00296 //      Q, p, y, Cp, Cn, and an initial feasible point \alpha
00297 //      l is the size of vectors and matrices
00298 //      eps is the stopping tolerance
00299 //
00300 // solution will be put in \alpha, objective value will be put in obj
00301 //
00302 class Solver {
00303         int active_size;
00304         byte[] y;
00305         double[] G;             // gradient of objective function
00306         static final byte LOWER_BOUND = 0;
00307         static final byte UPPER_BOUND = 1;
00308         static final byte FREE = 2;
00309         byte[] alpha_status;    // LOWER_BOUND, UPPER_BOUND, FREE
00310         double[] alpha;
00311         QMatrix Q;
00312         double[] QD;
00313         double eps;
00314         double Cp,Cn;
00315         double[] p;
00316         int[] active_set;
00317         double[] G_bar;         // gradient, if we treat free variables as 0
00318         int l;
00319         boolean unshrink;       // XXX
00320         
00321         static final double INF = java.lang.Double.POSITIVE_INFINITY;
00322 
00323         double get_C(int i)
00324         {
00325                 return (y[i] > 0)? Cp : Cn;
00326         }
00327         void update_alpha_status(int i)
00328         {
00329                 if(alpha[i] >= get_C(i))
00330                         alpha_status[i] = UPPER_BOUND;
00331                 else if(alpha[i] <= 0)
00332                         alpha_status[i] = LOWER_BOUND;
00333                 else alpha_status[i] = FREE;
00334         }
00335         boolean is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
00336         boolean is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
00337         boolean is_free(int i) {  return alpha_status[i] == FREE; }
00338 
00339         // java: information about solution except alpha,
00340         // because we cannot return multiple values otherwise...
00341         static class SolutionInfo {
00342                 double obj;
00343                 double rho;
00344                 double upper_bound_p;
00345                 double upper_bound_n;
00346                 double r;       // for Solver_NU
00347         }
00348 
00349         void swap_index(int i, int j)
00350         {
00351                 Q.swap_index(i,j);
00352                 do {byte _=y[i]; y[i]=y[j]; y[j]=_;} while(false);
00353                 do {double _=G[i]; G[i]=G[j]; G[j]=_;} while(false);
00354                 do {byte _=alpha_status[i]; alpha_status[i]=alpha_status[j]; alpha_status[j]=_;} while(false);
00355                 do {double _=alpha[i]; alpha[i]=alpha[j]; alpha[j]=_;} while(false);
00356                 do {double _=p[i]; p[i]=p[j]; p[j]=_;} while(false);
00357                 do {int _=active_set[i]; active_set[i]=active_set[j]; active_set[j]=_;} while(false);
00358                 do {double _=G_bar[i]; G_bar[i]=G_bar[j]; G_bar[j]=_;} while(false);
00359         }
00360 
00361         void reconstruct_gradient()
00362         {
00363                 // reconstruct inactive elements of G from G_bar and free variables
00364 
00365                 if(active_size == l) return;
00366 
00367                 int i,j;
00368                 int nr_free = 0;
00369 
00370                 for(j=active_size;j<l;j++)
00371                         G[j] = G_bar[j] + p[j];
00372 
00373                 for(j=0;j<active_size;j++)
00374                         if(is_free(j))
00375                                 nr_free++;
00376 
00377                 if(2*nr_free < active_size)
00378                         svm.info("\nWARNING: using -h 0 may be faster\n");
00379 
00380                 if (nr_free*l > 2*active_size*(l-active_size))
00381                 {
00382                         for(i=active_size;i<l;i++)
00383                         {
00384                                 float[] Q_i = Q.get_Q(i,active_size);
00385                                 for(j=0;j<active_size;j++)
00386                                         if(is_free(j))
00387                                                 G[i] += alpha[j] * Q_i[j];
00388                         }       
00389                 }
00390                 else
00391                 {
00392                         for(i=0;i<active_size;i++)
00393                                 if(is_free(i))
00394                                 {
00395                                         float[] Q_i = Q.get_Q(i,l);
00396                                         double alpha_i = alpha[i];
00397                                         for(j=active_size;j<l;j++)
00398                                                 G[j] += alpha_i * Q_i[j];
00399                                 }
00400                 }
00401         }
00402 
00403         void Solve(int l, QMatrix Q, double[] p_, byte[] y_,
00404                    double[] alpha_, double Cp, double Cn, double eps, SolutionInfo si, int shrinking)
00405         {
00406                 this.l = l;
00407                 this.Q = Q;
00408                 QD = Q.get_QD();
00409                 p = (double[])p_.clone();
00410                 y = (byte[])y_.clone();
00411                 alpha = (double[])alpha_.clone();
00412                 this.Cp = Cp;
00413                 this.Cn = Cn;
00414                 this.eps = eps;
00415                 this.unshrink = false;
00416 
00417                 // initialize alpha_status
00418                 {
00419                         alpha_status = new byte[l];
00420                         for(int i=0;i<l;i++)
00421                                 update_alpha_status(i);
00422                 }
00423 
00424                 // initialize active set (for shrinking)
00425                 {
00426                         active_set = new int[l];
00427                         for(int i=0;i<l;i++)
00428                                 active_set[i] = i;
00429                         active_size = l;
00430                 }
00431 
00432                 // initialize gradient
00433                 {
00434                         G = new double[l];
00435                         G_bar = new double[l];
00436                         int i;
00437                         for(i=0;i<l;i++)
00438                         {
00439                                 G[i] = p[i];
00440                                 G_bar[i] = 0;
00441                         }
00442                         for(i=0;i<l;i++)
00443                                 if(!is_lower_bound(i))
00444                                 {
00445                                         float[] Q_i = Q.get_Q(i,l);
00446                                         double alpha_i = alpha[i];
00447                                         int j;
00448                                         for(j=0;j<l;j++)
00449                                                 G[j] += alpha_i*Q_i[j];
00450                                         if(is_upper_bound(i))
00451                                                 for(j=0;j<l;j++)
00452                                                         G_bar[j] += get_C(i) * Q_i[j];
00453                                 }
00454                 }
00455 
00456                 // optimization step
00457 
00458                 int iter = 0;
00459                 int max_iter = Math.max(10000000, l>Integer.MAX_VALUE/100 ? Integer.MAX_VALUE : 100*l);
00460                 int counter = Math.min(l,1000)+1;
00461                 int[] working_set = new int[2];
00462 
00463                 while(iter < max_iter)
00464                 {
00465                         // show progress and do shrinking
00466 
00467                         if(--counter == 0)
00468                         {
00469                                 counter = Math.min(l,1000);
00470                                 if(shrinking!=0) do_shrinking();
00471                                 svm.info(".");
00472                         }
00473 
00474                         if(select_working_set(working_set)!=0)
00475                         {
00476                                 // reconstruct the whole gradient
00477                                 reconstruct_gradient();
00478                                 // reset active set size and check
00479                                 active_size = l;
00480                                 svm.info("*");
00481                                 if(select_working_set(working_set)!=0)
00482                                         break;
00483                                 else
00484                                         counter = 1;    // do shrinking next iteration
00485                         }
00486                         
00487                         int i = working_set[0];
00488                         int j = working_set[1];
00489 
00490                         ++iter;
00491 
00492                         // update alpha[i] and alpha[j], handle bounds carefully
00493 
00494                         float[] Q_i = Q.get_Q(i,active_size);
00495                         float[] Q_j = Q.get_Q(j,active_size);
00496 
00497                         double C_i = get_C(i);
00498                         double C_j = get_C(j);
00499 
00500                         double old_alpha_i = alpha[i];
00501                         double old_alpha_j = alpha[j];
00502 
00503                         if(y[i]!=y[j])
00504                         {
00505                                 double quad_coef = QD[i]+QD[j]+2*Q_i[j];
00506                                 if (quad_coef <= 0)
00507                                         quad_coef = 1e-12;
00508                                 double delta = (-G[i]-G[j])/quad_coef;
00509                                 double diff = alpha[i] - alpha[j];
00510                                 alpha[i] += delta;
00511                                 alpha[j] += delta;
00512                         
00513                                 if(diff > 0)
00514                                 {
00515                                         if(alpha[j] < 0)
00516                                         {
00517                                                 alpha[j] = 0;
00518                                                 alpha[i] = diff;
00519                                         }
00520                                 }
00521                                 else
00522                                 {
00523                                         if(alpha[i] < 0)
00524                                         {
00525                                                 alpha[i] = 0;
00526                                                 alpha[j] = -diff;
00527                                         }
00528                                 }
00529                                 if(diff > C_i - C_j)
00530                                 {
00531                                         if(alpha[i] > C_i)
00532                                         {
00533                                                 alpha[i] = C_i;
00534                                                 alpha[j] = C_i - diff;
00535                                         }
00536                                 }
00537                                 else
00538                                 {
00539                                         if(alpha[j] > C_j)
00540                                         {
00541                                                 alpha[j] = C_j;
00542                                                 alpha[i] = C_j + diff;
00543                                         }
00544                                 }
00545                         }
00546                         else
00547                         {
00548                                 double quad_coef = QD[i]+QD[j]-2*Q_i[j];
00549                                 if (quad_coef <= 0)
00550                                         quad_coef = 1e-12;
00551                                 double delta = (G[i]-G[j])/quad_coef;
00552                                 double sum = alpha[i] + alpha[j];
00553                                 alpha[i] -= delta;
00554                                 alpha[j] += delta;
00555 
00556                                 if(sum > C_i)
00557                                 {
00558                                         if(alpha[i] > C_i)
00559                                         {
00560                                                 alpha[i] = C_i;
00561                                                 alpha[j] = sum - C_i;
00562                                         }
00563                                 }
00564                                 else
00565                                 {
00566                                         if(alpha[j] < 0)
00567                                         {
00568                                                 alpha[j] = 0;
00569                                                 alpha[i] = sum;
00570                                         }
00571                                 }
00572                                 if(sum > C_j)
00573                                 {
00574                                         if(alpha[j] > C_j)
00575                                         {
00576                                                 alpha[j] = C_j;
00577                                                 alpha[i] = sum - C_j;
00578                                         }
00579                                 }
00580                                 else
00581                                 {
00582                                         if(alpha[i] < 0)
00583                                         {
00584                                                 alpha[i] = 0;
00585                                                 alpha[j] = sum;
00586                                         }
00587                                 }
00588                         }
00589 
00590                         // update G
00591 
00592                         double delta_alpha_i = alpha[i] - old_alpha_i;
00593                         double delta_alpha_j = alpha[j] - old_alpha_j;
00594 
00595                         for(int k=0;k<active_size;k++)
00596                         {
00597                                 G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
00598                         }
00599 
00600                         // update alpha_status and G_bar
00601 
00602                         {
00603                                 boolean ui = is_upper_bound(i);
00604                                 boolean uj = is_upper_bound(j);
00605                                 update_alpha_status(i);
00606                                 update_alpha_status(j);
00607                                 int k;
00608                                 if(ui != is_upper_bound(i))
00609                                 {
00610                                         Q_i = Q.get_Q(i,l);
00611                                         if(ui)
00612                                                 for(k=0;k<l;k++)
00613                                                         G_bar[k] -= C_i * Q_i[k];
00614                                         else
00615                                                 for(k=0;k<l;k++)
00616                                                         G_bar[k] += C_i * Q_i[k];
00617                                 }
00618 
00619                                 if(uj != is_upper_bound(j))
00620                                 {
00621                                         Q_j = Q.get_Q(j,l);
00622                                         if(uj)
00623                                                 for(k=0;k<l;k++)
00624                                                         G_bar[k] -= C_j * Q_j[k];
00625                                         else
00626                                                 for(k=0;k<l;k++)
00627                                                         G_bar[k] += C_j * Q_j[k];
00628                                 }
00629                         }
00630 
00631                 }
00632                 
00633                 if(iter >= max_iter)
00634                 {
00635                         if(active_size < l)
00636                         {
00637                                 // reconstruct the whole gradient to calculate objective value
00638                                 reconstruct_gradient();
00639                                 active_size = l;
00640                                 svm.info("*");
00641                         }
00642                         System.err.print("\nWARNING: reaching max number of iterations\n");
00643                 }
00644 
00645                 // calculate rho
00646 
00647                 si.rho = calculate_rho();
00648 
00649                 // calculate objective value
00650                 {
00651                         double v = 0;
00652                         int i;
00653                         for(i=0;i<l;i++)
00654                                 v += alpha[i] * (G[i] + p[i]);
00655 
00656                         si.obj = v/2;
00657                 }
00658 
00659                 // put back the solution
00660                 {
00661                         for(int i=0;i<l;i++)
00662                                 alpha_[active_set[i]] = alpha[i];
00663                 }
00664 
00665                 si.upper_bound_p = Cp;
00666                 si.upper_bound_n = Cn;
00667 
00668                 svm.info("\noptimization finished, #iter = "+iter+"\n");
00669         }
00670 
00671         // return 1 if already optimal, return 0 otherwise
00672         int select_working_set(int[] working_set)
00673         {
00674                 // return i,j such that
00675                 // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
00676                 // j: mimimizes the decrease of obj value
00677                 //    (if quadratic coefficeint <= 0, replace it with tau)
00678                 //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
00679                 
00680                 double Gmax = -INF;
00681                 double Gmax2 = -INF;
00682                 int Gmax_idx = -1;
00683                 int Gmin_idx = -1;
00684                 double obj_diff_min = INF;
00685         
00686                 for(int t=0;t<active_size;t++)
00687                         if(y[t]==+1)    
00688                         {
00689                                 if(!is_upper_bound(t))
00690                                         if(-G[t] >= Gmax)
00691                                         {
00692                                                 Gmax = -G[t];
00693                                                 Gmax_idx = t;
00694                                         }
00695                         }
00696                         else
00697                         {
00698                                 if(!is_lower_bound(t))
00699                                         if(G[t] >= Gmax)
00700                                         {
00701                                                 Gmax = G[t];
00702                                                 Gmax_idx = t;
00703                                         }
00704                         }
00705         
00706                 int i = Gmax_idx;
00707                 float[] Q_i = null;
00708                 if(i != -1) // null Q_i not accessed: Gmax=-INF if i=-1
00709                         Q_i = Q.get_Q(i,active_size);
00710         
00711                 for(int j=0;j<active_size;j++)
00712                 {
00713                         if(y[j]==+1)
00714                         {
00715                                 if (!is_lower_bound(j))
00716                                 {
00717                                         double grad_diff=Gmax+G[j];
00718                                         if (G[j] >= Gmax2)
00719                                                 Gmax2 = G[j];
00720                                         if (grad_diff > 0)
00721                                         {
00722                                                 double obj_diff; 
00723                                                 double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];
00724                                                 if (quad_coef > 0)
00725                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00726                                                 else
00727                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00728         
00729                                                 if (obj_diff <= obj_diff_min)
00730                                                 {
00731                                                         Gmin_idx=j;
00732                                                         obj_diff_min = obj_diff;
00733                                                 }
00734                                         }
00735                                 }
00736                         }
00737                         else
00738                         {
00739                                 if (!is_upper_bound(j))
00740                                 {
00741                                         double grad_diff= Gmax-G[j];
00742                                         if (-G[j] >= Gmax2)
00743                                                 Gmax2 = -G[j];
00744                                         if (grad_diff > 0)
00745                                         {
00746                                                 double obj_diff; 
00747                                                 double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];
00748                                                 if (quad_coef > 0)
00749                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00750                                                 else
00751                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00752         
00753                                                 if (obj_diff <= obj_diff_min)
00754                                                 {
00755                                                         Gmin_idx=j;
00756                                                         obj_diff_min = obj_diff;
00757                                                 }
00758                                         }
00759                                 }
00760                         }
00761                 }
00762 
00763                 if(Gmax+Gmax2 < eps)
00764                         return 1;
00765 
00766                 working_set[0] = Gmax_idx;
00767                 working_set[1] = Gmin_idx;
00768                 return 0;
00769         }
00770 
00771         private boolean be_shrunk(int i, double Gmax1, double Gmax2)
00772         {       
00773                 if(is_upper_bound(i))
00774                 {
00775                         if(y[i]==+1)
00776                                 return(-G[i] > Gmax1);
00777                         else
00778                                 return(-G[i] > Gmax2);
00779                 }
00780                 else if(is_lower_bound(i))
00781                 {
00782                         if(y[i]==+1)
00783                                 return(G[i] > Gmax2);
00784                         else    
00785                                 return(G[i] > Gmax1);
00786                 }
00787                 else
00788                         return(false);
00789         }
00790 
00791         void do_shrinking()
00792         {
00793                 int i;
00794                 double Gmax1 = -INF;            // max { -y_i * grad(f)_i | i in I_up(\alpha) }
00795                 double Gmax2 = -INF;            // max { y_i * grad(f)_i | i in I_low(\alpha) }
00796 
00797                 // find maximal violating pair first
00798                 for(i=0;i<active_size;i++)
00799                 {
00800                         if(y[i]==+1)
00801                         {
00802                                 if(!is_upper_bound(i))  
00803                                 {
00804                                         if(-G[i] >= Gmax1)
00805                                                 Gmax1 = -G[i];
00806                                 }
00807                                 if(!is_lower_bound(i))
00808                                 {
00809                                         if(G[i] >= Gmax2)
00810                                                 Gmax2 = G[i];
00811                                 }
00812                         }
00813                         else            
00814                         {
00815                                 if(!is_upper_bound(i))  
00816                                 {
00817                                         if(-G[i] >= Gmax2)
00818                                                 Gmax2 = -G[i];
00819                                 }
00820                                 if(!is_lower_bound(i))  
00821                                 {
00822                                         if(G[i] >= Gmax1)
00823                                                 Gmax1 = G[i];
00824                                 }
00825                         }
00826                 }
00827 
00828                 if(unshrink == false && Gmax1 + Gmax2 <= eps*10) 
00829                 {
00830                         unshrink = true;
00831                         reconstruct_gradient();
00832                         active_size = l;
00833                 }
00834 
00835                 for(i=0;i<active_size;i++)
00836                         if (be_shrunk(i, Gmax1, Gmax2))
00837                         {
00838                                 active_size--;
00839                                 while (active_size > i)
00840                                 {
00841                                         if (!be_shrunk(active_size, Gmax1, Gmax2))
00842                                         {
00843                                                 swap_index(i,active_size);
00844                                                 break;
00845                                         }
00846                                         active_size--;
00847                                 }
00848                         }
00849         }
00850 
00851         double calculate_rho()
00852         {
00853                 double r;
00854                 int nr_free = 0;
00855                 double ub = INF, lb = -INF, sum_free = 0;
00856                 for(int i=0;i<active_size;i++)
00857                 {
00858                         double yG = y[i]*G[i];
00859 
00860                         if(is_lower_bound(i))
00861                         {
00862                                 if(y[i] > 0)
00863                                         ub = Math.min(ub,yG);
00864                                 else
00865                                         lb = Math.max(lb,yG);
00866                         }
00867                         else if(is_upper_bound(i))
00868                         {
00869                                 if(y[i] < 0)
00870                                         ub = Math.min(ub,yG);
00871                                 else
00872                                         lb = Math.max(lb,yG);
00873                         }
00874                         else
00875                         {
00876                                 ++nr_free;
00877                                 sum_free += yG;
00878                         }
00879                 }
00880 
00881                 if(nr_free>0)
00882                         r = sum_free/nr_free;
00883                 else
00884                         r = (ub+lb)/2;
00885 
00886                 return r;
00887         }
00888 
00889 }
00890 
00891 //
00892 // Solver for nu-svm classification and regression
00893 //
00894 // additional constraint: e^T \alpha = constant
00895 //
00896 final class Solver_NU extends Solver
00897 {
00898         private SolutionInfo si;
00899 
00900         void Solve(int l, QMatrix Q, double[] p, byte[] y,
00901                    double[] alpha, double Cp, double Cn, double eps,
00902                    SolutionInfo si, int shrinking)
00903         {
00904                 this.si = si;
00905                 super.Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking);
00906         }
00907 
00908         // return 1 if already optimal, return 0 otherwise
00909         int select_working_set(int[] working_set)
00910         {
00911                 // return i,j such that y_i = y_j and
00912                 // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
00913                 // j: minimizes the decrease of obj value
00914                 //    (if quadratic coefficeint <= 0, replace it with tau)
00915                 //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
00916         
00917                 double Gmaxp = -INF;
00918                 double Gmaxp2 = -INF;
00919                 int Gmaxp_idx = -1;
00920         
00921                 double Gmaxn = -INF;
00922                 double Gmaxn2 = -INF;
00923                 int Gmaxn_idx = -1;
00924         
00925                 int Gmin_idx = -1;
00926                 double obj_diff_min = INF;
00927         
00928                 for(int t=0;t<active_size;t++)
00929                         if(y[t]==+1)
00930                         {
00931                                 if(!is_upper_bound(t))
00932                                         if(-G[t] >= Gmaxp)
00933                                         {
00934                                                 Gmaxp = -G[t];
00935                                                 Gmaxp_idx = t;
00936                                         }
00937                         }
00938                         else
00939                         {
00940                                 if(!is_lower_bound(t))
00941                                         if(G[t] >= Gmaxn)
00942                                         {
00943                                                 Gmaxn = G[t];
00944                                                 Gmaxn_idx = t;
00945                                         }
00946                         }
00947         
00948                 int ip = Gmaxp_idx;
00949                 int in = Gmaxn_idx;
00950                 float[] Q_ip = null;
00951                 float[] Q_in = null;
00952                 if(ip != -1) // null Q_ip not accessed: Gmaxp=-INF if ip=-1
00953                         Q_ip = Q.get_Q(ip,active_size);
00954                 if(in != -1)
00955                         Q_in = Q.get_Q(in,active_size);
00956         
00957                 for(int j=0;j<active_size;j++)
00958                 {
00959                         if(y[j]==+1)
00960                         {
00961                                 if (!is_lower_bound(j)) 
00962                                 {
00963                                         double grad_diff=Gmaxp+G[j];
00964                                         if (G[j] >= Gmaxp2)
00965                                                 Gmaxp2 = G[j];
00966                                         if (grad_diff > 0)
00967                                         {
00968                                                 double obj_diff; 
00969                                                 double quad_coef = QD[ip]+QD[j]-2*Q_ip[j];
00970                                                 if (quad_coef > 0)
00971                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00972                                                 else
00973                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00974         
00975                                                 if (obj_diff <= obj_diff_min)
00976                                                 {
00977                                                         Gmin_idx=j;
00978                                                         obj_diff_min = obj_diff;
00979                                                 }
00980                                         }
00981                                 }
00982                         }
00983                         else
00984                         {
00985                                 if (!is_upper_bound(j))
00986                                 {
00987                                         double grad_diff=Gmaxn-G[j];
00988                                         if (-G[j] >= Gmaxn2)
00989                                                 Gmaxn2 = -G[j];
00990                                         if (grad_diff > 0)
00991                                         {
00992                                                 double obj_diff; 
00993                                                 double quad_coef = QD[in]+QD[j]-2*Q_in[j];
00994                                                 if (quad_coef > 0)
00995                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00996                                                 else
00997                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00998         
00999                                                 if (obj_diff <= obj_diff_min)
01000                                                 {
01001                                                         Gmin_idx=j;
01002                                                         obj_diff_min = obj_diff;
01003                                                 }
01004                                         }
01005                                 }
01006                         }
01007                 }
01008 
01009                 if(Math.max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps)
01010                         return 1;
01011         
01012                 if(y[Gmin_idx] == +1)
01013                         working_set[0] = Gmaxp_idx;
01014                 else
01015                         working_set[0] = Gmaxn_idx;
01016                 working_set[1] = Gmin_idx;
01017         
01018                 return 0;
01019         }
01020 
01021         private boolean be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4)
01022         {
01023                 if(is_upper_bound(i))
01024                 {
01025                         if(y[i]==+1)
01026                                 return(-G[i] > Gmax1);
01027                         else    
01028                                 return(-G[i] > Gmax4);
01029                 }
01030                 else if(is_lower_bound(i))
01031                 {
01032                         if(y[i]==+1)
01033                                 return(G[i] > Gmax2);
01034                         else    
01035                                 return(G[i] > Gmax3);
01036                 }
01037                 else
01038                         return(false);
01039         }
01040 
01041         void do_shrinking()
01042         {
01043                 double Gmax1 = -INF;    // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) }
01044                 double Gmax2 = -INF;    // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) }
01045                 double Gmax3 = -INF;    // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) }
01046                 double Gmax4 = -INF;    // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) }
01047  
01048                 // find maximal violating pair first
01049                 int i;
01050                 for(i=0;i<active_size;i++)
01051                 {
01052                         if(!is_upper_bound(i))
01053                         {
01054                                 if(y[i]==+1)
01055                                 {
01056                                         if(-G[i] > Gmax1) Gmax1 = -G[i];
01057                                 }
01058                                 else    if(-G[i] > Gmax4) Gmax4 = -G[i];
01059                         }
01060                         if(!is_lower_bound(i))
01061                         {
01062                                 if(y[i]==+1)
01063                                 {       
01064                                         if(G[i] > Gmax2) Gmax2 = G[i];
01065                                 }
01066                                 else    if(G[i] > Gmax3) Gmax3 = G[i];
01067                         }
01068                 }
01069 
01070                 if(unshrink == false && Math.max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) 
01071                 {
01072                         unshrink = true;
01073                         reconstruct_gradient();
01074                         active_size = l;
01075                 }
01076 
01077                 for(i=0;i<active_size;i++)
01078                         if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4))
01079                         {
01080                                 active_size--;
01081                                 while (active_size > i)
01082                                 {
01083                                         if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
01084                                         {
01085                                                 swap_index(i,active_size);
01086                                                 break;
01087                                         }
01088                                         active_size--;
01089                                 }
01090                         }
01091         }
01092         
01093         double calculate_rho()
01094         {
01095                 int nr_free1 = 0,nr_free2 = 0;
01096                 double ub1 = INF, ub2 = INF;
01097                 double lb1 = -INF, lb2 = -INF;
01098                 double sum_free1 = 0, sum_free2 = 0;
01099 
01100                 for(int i=0;i<active_size;i++)
01101                 {
01102                         if(y[i]==+1)
01103                         {
01104                                 if(is_lower_bound(i))
01105                                         ub1 = Math.min(ub1,G[i]);
01106                                 else if(is_upper_bound(i))
01107                                         lb1 = Math.max(lb1,G[i]);
01108                                 else
01109                                 {
01110                                         ++nr_free1;
01111                                         sum_free1 += G[i];
01112                                 }
01113                         }
01114                         else
01115                         {
01116                                 if(is_lower_bound(i))
01117                                         ub2 = Math.min(ub2,G[i]);
01118                                 else if(is_upper_bound(i))
01119                                         lb2 = Math.max(lb2,G[i]);
01120                                 else
01121                                 {
01122                                         ++nr_free2;
01123                                         sum_free2 += G[i];
01124                                 }
01125                         }
01126                 }
01127 
01128                 double r1,r2;
01129                 if(nr_free1 > 0)
01130                         r1 = sum_free1/nr_free1;
01131                 else
01132                         r1 = (ub1+lb1)/2;
01133 
01134                 if(nr_free2 > 0)
01135                         r2 = sum_free2/nr_free2;
01136                 else
01137                         r2 = (ub2+lb2)/2;
01138 
01139                 si.r = (r1+r2)/2;
01140                 return (r1-r2)/2;
01141         }
01142 }
01143 
01144 //
01145 // Q matrices for various formulations
01146 //
01147 class SVC_Q extends Kernel
01148 {
01149         private final byte[] y;
01150         private final Cache cache;
01151         private final double[] QD;
01152 
01153         SVC_Q(svm_problem prob, svm_parameter param, byte[] y_)
01154         {
01155                 super(prob.l, prob.x, param);
01156                 y = (byte[])y_.clone();
01157                 cache = new Cache(prob.l,(long)(param.cache_size*(1<<20)));
01158                 QD = new double[prob.l];
01159                 for(int i=0;i<prob.l;i++)
01160                         QD[i] = kernel_function(i,i);
01161         }
01162 
01163         float[] get_Q(int i, int len)
01164         {
01165                 float[][] data = new float[1][];
01166                 int start, j;
01167                 if((start = cache.get_data(i,data,len)) < len)
01168                 {
01169                         for(j=start;j<len;j++)
01170                                 data[0][j] = (float)(y[i]*y[j]*kernel_function(i,j));
01171                 }
01172                 return data[0];
01173         }
01174 
01175         double[] get_QD()
01176         {
01177                 return QD;
01178         }
01179 
01180         void swap_index(int i, int j)
01181         {
01182                 cache.swap_index(i,j);
01183                 super.swap_index(i,j);
01184                 do {byte _=y[i]; y[i]=y[j]; y[j]=_;} while(false);
01185                 do {double _=QD[i]; QD[i]=QD[j]; QD[j]=_;} while(false);
01186         }
01187 }
01188 
01189 class ONE_CLASS_Q extends Kernel
01190 {
01191         private final Cache cache;
01192         private final double[] QD;
01193 
01194         ONE_CLASS_Q(svm_problem prob, svm_parameter param)
01195         {
01196                 super(prob.l, prob.x, param);
01197                 cache = new Cache(prob.l,(long)(param.cache_size*(1<<20)));
01198                 QD = new double[prob.l];
01199                 for(int i=0;i<prob.l;i++)
01200                         QD[i] = kernel_function(i,i);
01201         }
01202 
01203         float[] get_Q(int i, int len)
01204         {
01205                 float[][] data = new float[1][];
01206                 int start, j;
01207                 if((start = cache.get_data(i,data,len)) < len)
01208                 {
01209                         for(j=start;j<len;j++)
01210                                 data[0][j] = (float)kernel_function(i,j);
01211                 }
01212                 return data[0];
01213         }
01214 
01215         double[] get_QD()
01216         {
01217                 return QD;
01218         }
01219 
01220         void swap_index(int i, int j)
01221         {
01222                 cache.swap_index(i,j);
01223                 super.swap_index(i,j);
01224                 do {double _=QD[i]; QD[i]=QD[j]; QD[j]=_;} while(false);
01225         }
01226 }
01227 
01228 class SVR_Q extends Kernel
01229 {
01230         private final int l;
01231         private final Cache cache;
01232         private final byte[] sign;
01233         private final int[] index;
01234         private int next_buffer;
01235         private float[][] buffer;
01236         private final double[] QD;
01237 
01238         SVR_Q(svm_problem prob, svm_parameter param)
01239         {
01240                 super(prob.l, prob.x, param);
01241                 l = prob.l;
01242                 cache = new Cache(l,(long)(param.cache_size*(1<<20)));
01243                 QD = new double[2*l];
01244                 sign = new byte[2*l];
01245                 index = new int[2*l];
01246                 for(int k=0;k<l;k++)
01247                 {
01248                         sign[k] = 1;
01249                         sign[k+l] = -1;
01250                         index[k] = k;
01251                         index[k+l] = k;
01252                         QD[k] = kernel_function(k,k);
01253                         QD[k+l] = QD[k];
01254                 }
01255                 buffer = new float[2][2*l];
01256                 next_buffer = 0;
01257         }
01258 
01259         void swap_index(int i, int j)
01260         {
01261                 do {byte _=sign[i]; sign[i]=sign[j]; sign[j]=_;} while(false);
01262                 do {int _=index[i]; index[i]=index[j]; index[j]=_;} while(false);
01263                 do {double _=QD[i]; QD[i]=QD[j]; QD[j]=_;} while(false);
01264         }
01265 
01266         float[] get_Q(int i, int len)
01267         {
01268                 float[][] data = new float[1][];
01269                 int j, real_i = index[i];
01270                 if(cache.get_data(real_i,data,l) < l)
01271                 {
01272                         for(j=0;j<l;j++)
01273                                 data[0][j] = (float)kernel_function(real_i,j);
01274                 }
01275 
01276                 // reorder and copy
01277                 float buf[] = buffer[next_buffer];
01278                 next_buffer = 1 - next_buffer;
01279                 byte si = sign[i];
01280                 for(j=0;j<len;j++)
01281                         buf[j] = (float) si * sign[j] * data[0][index[j]];
01282                 return buf;
01283         }
01284 
01285         double[] get_QD()
01286         {
01287                 return QD;
01288         }
01289 }
01290 
01291 public class svm {
01292         //
01293         // construct and solve various formulations
01294         //
01295         public static final int LIBSVM_VERSION=314; 
01296         public static final Random rand = new Random();
01297 
01298         private static svm_print_interface svm_print_stdout = new svm_print_interface()
01299         {
01300                 public void print(String s)
01301                 {
01302                         System.out.print(s);
01303                         System.out.flush();
01304                 }
01305         };
01306 
01307         private static svm_print_interface svm_print_string = svm_print_stdout;
01308 
01309         static void info(String s) 
01310         {
01311                 svm_print_string.print(s);
01312         }
01313 
01314         private static void solve_c_svc(svm_problem prob, svm_parameter param,
01315                                         double[] alpha, Solver.SolutionInfo si,
01316                                         double Cp, double Cn)
01317         {
01318                 int l = prob.l;
01319                 double[] minus_ones = new double[l];
01320                 byte[] y = new byte[l];
01321 
01322                 int i;
01323 
01324                 for(i=0;i<l;i++)
01325                 {
01326                         alpha[i] = 0;
01327                         minus_ones[i] = -1;
01328                         if(prob.y[i] > 0) y[i] = +1; else y[i] = -1;
01329                 }
01330 
01331                 Solver s = new Solver();
01332                 s.Solve(l, new SVC_Q(prob,param,y), minus_ones, y,
01333                         alpha, Cp, Cn, param.eps, si, param.shrinking);
01334 
01335                 double sum_alpha=0;
01336                 for(i=0;i<l;i++)
01337                         sum_alpha += alpha[i];
01338 
01339                 if (Cp==Cn)
01340                         svm.info("nu = "+sum_alpha/(Cp*prob.l)+"\n");
01341 
01342                 for(i=0;i<l;i++)
01343                         alpha[i] *= y[i];
01344         }
01345 
01346         private static void solve_nu_svc(svm_problem prob, svm_parameter param,
01347                                         double[] alpha, Solver.SolutionInfo si)
01348         {
01349                 int i;
01350                 int l = prob.l;
01351                 double nu = param.nu;
01352 
01353                 byte[] y = new byte[l];
01354 
01355                 for(i=0;i<l;i++)
01356                         if(prob.y[i]>0)
01357                                 y[i] = +1;
01358                         else
01359                                 y[i] = -1;
01360 
01361                 double sum_pos = nu*l/2;
01362                 double sum_neg = nu*l/2;
01363 
01364                 for(i=0;i<l;i++)
01365                         if(y[i] == +1)
01366                         {
01367                                 alpha[i] = Math.min(1.0,sum_pos);
01368                                 sum_pos -= alpha[i];
01369                         }
01370                         else
01371                         {
01372                                 alpha[i] = Math.min(1.0,sum_neg);
01373                                 sum_neg -= alpha[i];
01374                         }
01375 
01376                 double[] zeros = new double[l];
01377 
01378                 for(i=0;i<l;i++)
01379                         zeros[i] = 0;
01380 
01381                 Solver_NU s = new Solver_NU();
01382                 s.Solve(l, new SVC_Q(prob,param,y), zeros, y,
01383                         alpha, 1.0, 1.0, param.eps, si, param.shrinking);
01384                 double r = si.r;
01385 
01386                 svm.info("C = "+1/r+"\n");
01387 
01388                 for(i=0;i<l;i++)
01389                         alpha[i] *= y[i]/r;
01390 
01391                 si.rho /= r;
01392                 si.obj /= (r*r);
01393                 si.upper_bound_p = 1/r;
01394                 si.upper_bound_n = 1/r;
01395         }
01396 
01397         private static void solve_one_class(svm_problem prob, svm_parameter param,
01398                                         double[] alpha, Solver.SolutionInfo si)
01399         {
01400                 int l = prob.l;
01401                 double[] zeros = new double[l];
01402                 byte[] ones = new byte[l];
01403                 int i;
01404 
01405                 int n = (int)(param.nu*prob.l); // # of alpha's at upper bound
01406 
01407                 for(i=0;i<n;i++)
01408                         alpha[i] = 1;
01409                 if(n<prob.l)
01410                         alpha[n] = param.nu * prob.l - n;
01411                 for(i=n+1;i<l;i++)
01412                         alpha[i] = 0;
01413 
01414                 for(i=0;i<l;i++)
01415                 {
01416                         zeros[i] = 0;
01417                         ones[i] = 1;
01418                 }
01419 
01420                 Solver s = new Solver();
01421                 s.Solve(l, new ONE_CLASS_Q(prob,param), zeros, ones,
01422                         alpha, 1.0, 1.0, param.eps, si, param.shrinking);
01423         }
01424 
01425         private static void solve_epsilon_svr(svm_problem prob, svm_parameter param,
01426                                         double[] alpha, Solver.SolutionInfo si)
01427         {
01428                 int l = prob.l;
01429                 double[] alpha2 = new double[2*l];
01430                 double[] linear_term = new double[2*l];
01431                 byte[] y = new byte[2*l];
01432                 int i;
01433 
01434                 for(i=0;i<l;i++)
01435                 {
01436                         alpha2[i] = 0;
01437                         linear_term[i] = param.p - prob.y[i];
01438                         y[i] = 1;
01439 
01440                         alpha2[i+l] = 0;
01441                         linear_term[i+l] = param.p + prob.y[i];
01442                         y[i+l] = -1;
01443                 }
01444 
01445                 Solver s = new Solver();
01446                 s.Solve(2*l, new SVR_Q(prob,param), linear_term, y,
01447                         alpha2, param.C, param.C, param.eps, si, param.shrinking);
01448 
01449                 double sum_alpha = 0;
01450                 for(i=0;i<l;i++)
01451                 {
01452                         alpha[i] = alpha2[i] - alpha2[i+l];
01453                         sum_alpha += Math.abs(alpha[i]);
01454                 }
01455                 svm.info("nu = "+sum_alpha/(param.C*l)+"\n");
01456         }
01457 
01458         private static void solve_nu_svr(svm_problem prob, svm_parameter param,
01459                                         double[] alpha, Solver.SolutionInfo si)
01460         {
01461                 int l = prob.l;
01462                 double C = param.C;
01463                 double[] alpha2 = new double[2*l];
01464                 double[] linear_term = new double[2*l];
01465                 byte[] y = new byte[2*l];
01466                 int i;
01467 
01468                 double sum = C * param.nu * l / 2;
01469                 for(i=0;i<l;i++)
01470                 {
01471                         alpha2[i] = alpha2[i+l] = Math.min(sum,C);
01472                         sum -= alpha2[i];
01473                         
01474                         linear_term[i] = - prob.y[i];
01475                         y[i] = 1;
01476 
01477                         linear_term[i+l] = prob.y[i];
01478                         y[i+l] = -1;
01479                 }
01480 
01481                 Solver_NU s = new Solver_NU();
01482                 s.Solve(2*l, new SVR_Q(prob,param), linear_term, y,
01483                         alpha2, C, C, param.eps, si, param.shrinking);
01484 
01485                 svm.info("epsilon = "+(-si.r)+"\n");
01486                 
01487                 for(i=0;i<l;i++)
01488                         alpha[i] = alpha2[i] - alpha2[i+l];
01489         }
01490 
01491         //
01492         // decision_function
01493         //
01494         static class decision_function
01495         {
01496                 double[] alpha;
01497                 double rho;     
01498         };
01499 
01500         static decision_function svm_train_one(
01501                 svm_problem prob, svm_parameter param,
01502                 double Cp, double Cn)
01503         {
01504                 double[] alpha = new double[prob.l];
01505                 Solver.SolutionInfo si = new Solver.SolutionInfo();
01506                 switch(param.svm_type)
01507                 {
01508                         case svm_parameter.C_SVC:
01509                                 solve_c_svc(prob,param,alpha,si,Cp,Cn);
01510                                 break;
01511                         case svm_parameter.NU_SVC:
01512                                 solve_nu_svc(prob,param,alpha,si);
01513                                 break;
01514                         case svm_parameter.ONE_CLASS:
01515                                 solve_one_class(prob,param,alpha,si);
01516                                 break;
01517                         case svm_parameter.EPSILON_SVR:
01518                                 solve_epsilon_svr(prob,param,alpha,si);
01519                                 break;
01520                         case svm_parameter.NU_SVR:
01521                                 solve_nu_svr(prob,param,alpha,si);
01522                                 break;
01523                 }
01524 
01525                 svm.info("obj = "+si.obj+", rho = "+si.rho+"\n");
01526 
01527                 // output SVs
01528 
01529                 int nSV = 0;
01530                 int nBSV = 0;
01531                 for(int i=0;i<prob.l;i++)
01532                 {
01533                         if(Math.abs(alpha[i]) > 0)
01534                         {
01535                                 ++nSV;
01536                                 if(prob.y[i] > 0)
01537                                 {
01538                                         if(Math.abs(alpha[i]) >= si.upper_bound_p)
01539                                         ++nBSV;
01540                                 }
01541                                 else
01542                                 {
01543                                         if(Math.abs(alpha[i]) >= si.upper_bound_n)
01544                                                 ++nBSV;
01545                                 }
01546                         }
01547                 }
01548 
01549                 svm.info("nSV = "+nSV+", nBSV = "+nBSV+"\n");
01550 
01551                 decision_function f = new decision_function();
01552                 f.alpha = alpha;
01553                 f.rho = si.rho;
01554                 return f;
01555         }
01556 
01557         // Platt's binary SVM Probablistic Output: an improvement from Lin et al.
01558         private static void sigmoid_train(int l, double[] dec_values, double[] labels, 
01559                                   double[] probAB)
01560         {
01561                 double A, B;
01562                 double prior1=0, prior0 = 0;
01563                 int i;
01564 
01565                 for (i=0;i<l;i++)
01566                         if (labels[i] > 0) prior1+=1;
01567                         else prior0+=1;
01568         
01569                 int max_iter=100;       // Maximal number of iterations
01570                 double min_step=1e-10;  // Minimal step taken in line search
01571                 double sigma=1e-12;     // For numerically strict PD of Hessian
01572                 double eps=1e-5;
01573                 double hiTarget=(prior1+1.0)/(prior1+2.0);
01574                 double loTarget=1/(prior0+2.0);
01575                 double[] t= new double[l];
01576                 double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
01577                 double newA,newB,newf,d1,d2;
01578                 int iter; 
01579         
01580                 // Initial Point and Initial Fun Value
01581                 A=0.0; B=Math.log((prior0+1.0)/(prior1+1.0));
01582                 double fval = 0.0;
01583 
01584                 for (i=0;i<l;i++)
01585                 {
01586                         if (labels[i]>0) t[i]=hiTarget;
01587                         else t[i]=loTarget;
01588                         fApB = dec_values[i]*A+B;
01589                         if (fApB>=0)
01590                                 fval += t[i]*fApB + Math.log(1+Math.exp(-fApB));
01591                         else
01592                                 fval += (t[i] - 1)*fApB +Math.log(1+Math.exp(fApB));
01593                 }
01594                 for (iter=0;iter<max_iter;iter++)
01595                 {
01596                         // Update Gradient and Hessian (use H' = H + sigma I)
01597                         h11=sigma; // numerically ensures strict PD
01598                         h22=sigma;
01599                         h21=0.0;g1=0.0;g2=0.0;
01600                         for (i=0;i<l;i++)
01601                         {
01602                                 fApB = dec_values[i]*A+B;
01603                                 if (fApB >= 0)
01604                                 {
01605                                         p=Math.exp(-fApB)/(1.0+Math.exp(-fApB));
01606                                         q=1.0/(1.0+Math.exp(-fApB));
01607                                 }
01608                                 else
01609                                 {
01610                                         p=1.0/(1.0+Math.exp(fApB));
01611                                         q=Math.exp(fApB)/(1.0+Math.exp(fApB));
01612                                 }
01613                                 d2=p*q;
01614                                 h11+=dec_values[i]*dec_values[i]*d2;
01615                                 h22+=d2;
01616                                 h21+=dec_values[i]*d2;
01617                                 d1=t[i]-p;
01618                                 g1+=dec_values[i]*d1;
01619                                 g2+=d1;
01620                         }
01621 
01622                         // Stopping Criteria
01623                         if (Math.abs(g1)<eps && Math.abs(g2)<eps)
01624                                 break;
01625                         
01626                         // Finding Newton direction: -inv(H') * g
01627                         det=h11*h22-h21*h21;
01628                         dA=-(h22*g1 - h21 * g2) / det;
01629                         dB=-(-h21*g1+ h11 * g2) / det;
01630                         gd=g1*dA+g2*dB;
01631 
01632 
01633                         stepsize = 1;           // Line Search
01634                         while (stepsize >= min_step)
01635                         {
01636                                 newA = A + stepsize * dA;
01637                                 newB = B + stepsize * dB;
01638 
01639                                 // New function value
01640                                 newf = 0.0;
01641                                 for (i=0;i<l;i++)
01642                                 {
01643                                         fApB = dec_values[i]*newA+newB;
01644                                         if (fApB >= 0)
01645                                                 newf += t[i]*fApB + Math.log(1+Math.exp(-fApB));
01646                                         else
01647                                                 newf += (t[i] - 1)*fApB +Math.log(1+Math.exp(fApB));
01648                                 }
01649                                 // Check sufficient decrease
01650                                 if (newf<fval+0.0001*stepsize*gd)
01651                                 {
01652                                         A=newA;B=newB;fval=newf;
01653                                         break;
01654                                 }
01655                                 else
01656                                         stepsize = stepsize / 2.0;
01657                         }
01658                         
01659                         if (stepsize < min_step)
01660                         {
01661                                 svm.info("Line search fails in two-class probability estimates\n");
01662                                 break;
01663                         }
01664                 }
01665                 
01666                 if (iter>=max_iter)
01667                         svm.info("Reaching maximal iterations in two-class probability estimates\n");
01668                 probAB[0]=A;probAB[1]=B;
01669         }
01670 
01671         private static double sigmoid_predict(double decision_value, double A, double B)
01672         {
01673                 double fApB = decision_value*A+B;
01674                 if (fApB >= 0)
01675                         return Math.exp(-fApB)/(1.0+Math.exp(-fApB));
01676                 else
01677                         return 1.0/(1+Math.exp(fApB)) ;
01678         }
01679 
01680         // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
01681         private static void multiclass_probability(int k, double[][] r, double[] p)
01682         {
01683                 int t,j;
01684                 int iter = 0, max_iter=Math.max(100,k);
01685                 double[][] Q=new double[k][k];
01686                 double[] Qp=new double[k];
01687                 double pQp, eps=0.005/k;
01688         
01689                 for (t=0;t<k;t++)
01690                 {
01691                         p[t]=1.0/k;  // Valid if k = 1
01692                         Q[t][t]=0;
01693                         for (j=0;j<t;j++)
01694                         {
01695                                 Q[t][t]+=r[j][t]*r[j][t];
01696                                 Q[t][j]=Q[j][t];
01697                         }
01698                         for (j=t+1;j<k;j++)
01699                         {
01700                                 Q[t][t]+=r[j][t]*r[j][t];
01701                                 Q[t][j]=-r[j][t]*r[t][j];
01702                         }
01703                 }
01704                 for (iter=0;iter<max_iter;iter++)
01705                 {
01706                         // stopping condition, recalculate QP,pQP for numerical accuracy
01707                         pQp=0;
01708                         for (t=0;t<k;t++)
01709                         {
01710                                 Qp[t]=0;
01711                                 for (j=0;j<k;j++)
01712                                         Qp[t]+=Q[t][j]*p[j];
01713                                 pQp+=p[t]*Qp[t];
01714                         }
01715                         double max_error=0;
01716                         for (t=0;t<k;t++)
01717                         {
01718                                 double error=Math.abs(Qp[t]-pQp);
01719                                 if (error>max_error)
01720                                         max_error=error;
01721                         }
01722                         if (max_error<eps) break;
01723                 
01724                         for (t=0;t<k;t++)
01725                         {
01726                                 double diff=(-Qp[t]+pQp)/Q[t][t];
01727                                 p[t]+=diff;
01728                                 pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff);
01729                                 for (j=0;j<k;j++)
01730                                 {
01731                                         Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff);
01732                                         p[j]/=(1+diff);
01733                                 }
01734                         }
01735                 }
01736                 if (iter>=max_iter)
01737                         svm.info("Exceeds max_iter in multiclass_prob\n");
01738         }
01739 
01740         // Cross-validation decision values for probability estimates
01741         private static void svm_binary_svc_probability(svm_problem prob, svm_parameter param, double Cp, double Cn, double[] probAB)
01742         {
01743                 int i;
01744                 int nr_fold = 5;
01745                 int[] perm = new int[prob.l];
01746                 double[] dec_values = new double[prob.l];
01747 
01748                 // random shuffle
01749                 for(i=0;i<prob.l;i++) perm[i]=i;
01750                 for(i=0;i<prob.l;i++)
01751                 {
01752                         int j = i+rand.nextInt(prob.l-i);
01753                         do {int _=perm[i]; perm[i]=perm[j]; perm[j]=_;} while(false);
01754                 }
01755                 for(i=0;i<nr_fold;i++)
01756                 {
01757                         int begin = i*prob.l/nr_fold;
01758                         int end = (i+1)*prob.l/nr_fold;
01759                         int j,k;
01760                         svm_problem subprob = new svm_problem();
01761 
01762                         subprob.l = prob.l-(end-begin);
01763                         subprob.x = new svm_node[subprob.l][];
01764                         subprob.y = new double[subprob.l];
01765                         
01766                         k=0;
01767                         for(j=0;j<begin;j++)
01768                         {
01769                                 subprob.x[k] = prob.x[perm[j]];
01770                                 subprob.y[k] = prob.y[perm[j]];
01771                                 ++k;
01772                         }
01773                         for(j=end;j<prob.l;j++)
01774                         {
01775                                 subprob.x[k] = prob.x[perm[j]];
01776                                 subprob.y[k] = prob.y[perm[j]];
01777                                 ++k;
01778                         }
01779                         int p_count=0,n_count=0;
01780                         for(j=0;j<k;j++)
01781                                 if(subprob.y[j]>0)
01782                                         p_count++;
01783                                 else
01784                                         n_count++;
01785                         
01786                         if(p_count==0 && n_count==0)
01787                                 for(j=begin;j<end;j++)
01788                                         dec_values[perm[j]] = 0;
01789                         else if(p_count > 0 && n_count == 0)
01790                                 for(j=begin;j<end;j++)
01791                                         dec_values[perm[j]] = 1;
01792                         else if(p_count == 0 && n_count > 0)
01793                                 for(j=begin;j<end;j++)
01794                                         dec_values[perm[j]] = -1;
01795                         else
01796                         {
01797                                 svm_parameter subparam = (svm_parameter)param.clone();
01798                                 subparam.probability=0;
01799                                 subparam.C=1.0;
01800                                 subparam.nr_weight=2;
01801                                 subparam.weight_label = new int[2];
01802                                 subparam.weight = new double[2];
01803                                 subparam.weight_label[0]=+1;
01804                                 subparam.weight_label[1]=-1;
01805                                 subparam.weight[0]=Cp;
01806                                 subparam.weight[1]=Cn;
01807                                 svm_model submodel = svm_train(subprob,subparam);
01808                                 for(j=begin;j<end;j++)
01809                                 {
01810                                         double[] dec_value=new double[1];
01811                                         svm_predict_values(submodel,prob.x[perm[j]],dec_value);
01812                                         dec_values[perm[j]]=dec_value[0];
01813                                         // ensure +1 -1 order; reason not using CV subroutine
01814                                         dec_values[perm[j]] *= submodel.label[0];
01815                                 }               
01816                         }
01817                 }               
01818                 sigmoid_train(prob.l,dec_values,prob.y,probAB);
01819         }
01820 
01821         // Return parameter of a Laplace distribution 
01822         private static double svm_svr_probability(svm_problem prob, svm_parameter param)
01823         {
01824                 int i;
01825                 int nr_fold = 5;
01826                 double[] ymv = new double[prob.l];
01827                 double mae = 0;
01828 
01829                 svm_parameter newparam = (svm_parameter)param.clone();
01830                 newparam.probability = 0;
01831                 svm_cross_validation(prob,newparam,nr_fold,ymv);
01832                 for(i=0;i<prob.l;i++)
01833                 {
01834                         ymv[i]=prob.y[i]-ymv[i];
01835                         mae += Math.abs(ymv[i]);
01836                 }               
01837                 mae /= prob.l;
01838                 double std=Math.sqrt(2*mae*mae);
01839                 int count=0;
01840                 mae=0;
01841                 for(i=0;i<prob.l;i++)
01842                         if (Math.abs(ymv[i]) > 5*std) 
01843                                 count=count+1;
01844                         else 
01845                                 mae+=Math.abs(ymv[i]);
01846                 mae /= (prob.l-count);
01847                 svm.info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma="+mae+"\n");
01848                 return mae;
01849         }
01850 
01851         // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
01852         // perm, length l, must be allocated before calling this subroutine
01853         private static void svm_group_classes(svm_problem prob, int[] nr_class_ret, int[][] label_ret, int[][] start_ret, int[][] count_ret, int[] perm)
01854         {
01855                 int l = prob.l;
01856                 int max_nr_class = 16;
01857                 int nr_class = 0;
01858                 int[] label = new int[max_nr_class];
01859                 int[] count = new int[max_nr_class];
01860                 int[] data_label = new int[l];
01861                 int i;
01862 
01863                 for(i=0;i<l;i++)
01864                 {
01865                         int this_label = (int)(prob.y[i]);
01866                         int j;
01867                         for(j=0;j<nr_class;j++)
01868                         {
01869                                 if(this_label == label[j])
01870                                 {
01871                                         ++count[j];
01872                                         break;
01873                                 }
01874                         }
01875                         data_label[i] = j;
01876                         if(j == nr_class)
01877                         {
01878                                 if(nr_class == max_nr_class)
01879                                 {
01880                                         max_nr_class *= 2;
01881                                         int[] new_data = new int[max_nr_class];
01882                                         System.arraycopy(label,0,new_data,0,label.length);
01883                                         label = new_data;
01884                                         new_data = new int[max_nr_class];
01885                                         System.arraycopy(count,0,new_data,0,count.length);
01886                                         count = new_data;                                       
01887                                 }
01888                                 label[nr_class] = this_label;
01889                                 count[nr_class] = 1;
01890                                 ++nr_class;
01891                         }
01892                 }
01893 
01894                 int[] start = new int[nr_class];
01895                 start[0] = 0;
01896                 for(i=1;i<nr_class;i++)
01897                         start[i] = start[i-1]+count[i-1];
01898                 for(i=0;i<l;i++)
01899                 {
01900                         perm[start[data_label[i]]] = i;
01901                         ++start[data_label[i]];
01902                 }
01903                 start[0] = 0;
01904                 for(i=1;i<nr_class;i++)
01905                         start[i] = start[i-1]+count[i-1];
01906 
01907                 nr_class_ret[0] = nr_class;
01908                 label_ret[0] = label;
01909                 start_ret[0] = start;
01910                 count_ret[0] = count;
01911         }
01912 
01913         //
01914         // Interface functions
01915         //
01916         public static svm_model svm_train(svm_problem prob, svm_parameter param)
01917         {
01918                 svm_model model = new svm_model();
01919                 model.param = param;
01920 
01921                 if(param.svm_type == svm_parameter.ONE_CLASS ||
01922                    param.svm_type == svm_parameter.EPSILON_SVR ||
01923                    param.svm_type == svm_parameter.NU_SVR)
01924                 {
01925                         // regression or one-class-svm
01926                         model.nr_class = 2;
01927                         model.label = null;
01928                         model.nSV = null;
01929                         model.probA = null; model.probB = null;
01930                         model.sv_coef = new double[1][];
01931 
01932                         if(param.probability == 1 &&
01933                            (param.svm_type == svm_parameter.EPSILON_SVR ||
01934                             param.svm_type == svm_parameter.NU_SVR))
01935                         {
01936                                 model.probA = new double[1];
01937                                 model.probA[0] = svm_svr_probability(prob,param);
01938                         }
01939 
01940                         decision_function f = svm_train_one(prob,param,0,0);
01941                         model.rho = new double[1];
01942                         model.rho[0] = f.rho;
01943 
01944                         int nSV = 0;
01945                         int i;
01946                         for(i=0;i<prob.l;i++)
01947                                 if(Math.abs(f.alpha[i]) > 0) ++nSV;
01948                         model.l = nSV;
01949                         model.SV = new svm_node[nSV][];
01950                         model.sv_coef[0] = new double[nSV];
01951                         model.sv_indices = new int[nSV];
01952                         int j = 0;
01953                         for(i=0;i<prob.l;i++)
01954                                 if(Math.abs(f.alpha[i]) > 0)
01955                                 {
01956                                         model.SV[j] = prob.x[i];
01957                                         model.sv_coef[0][j] = f.alpha[i];
01958                                         model.sv_indices[j] = i+1;
01959                                         ++j;
01960                                 }
01961                 }
01962                 else
01963                 {
01964                         // classification
01965                         int l = prob.l;
01966                         int[] tmp_nr_class = new int[1];
01967                         int[][] tmp_label = new int[1][];
01968                         int[][] tmp_start = new int[1][];
01969                         int[][] tmp_count = new int[1][];                       
01970                         int[] perm = new int[l];
01971 
01972                         // group training data of the same class
01973                         svm_group_classes(prob,tmp_nr_class,tmp_label,tmp_start,tmp_count,perm);
01974                         int nr_class = tmp_nr_class[0];                 
01975                         int[] label = tmp_label[0];
01976                         int[] start = tmp_start[0];
01977                         int[] count = tmp_count[0];
01978                         
01979                         if(nr_class == 1) 
01980                                 svm.info("WARNING: training data in only one class. See README for details.\n");
01981                         
01982                         svm_node[][] x = new svm_node[l][];
01983                         int i;
01984                         for(i=0;i<l;i++)
01985                                 x[i] = prob.x[perm[i]];
01986 
01987                         // calculate weighted C
01988 
01989                         double[] weighted_C = new double[nr_class];
01990                         for(i=0;i<nr_class;i++)
01991                                 weighted_C[i] = param.C;
01992                         for(i=0;i<param.nr_weight;i++)
01993                         {
01994                                 int j;
01995                                 for(j=0;j<nr_class;j++)
01996                                         if(param.weight_label[i] == label[j])
01997                                                 break;
01998                                 if(j == nr_class)
01999                                         System.err.print("WARNING: class label "+param.weight_label[i]+" specified in weight is not found\n");
02000                                 else
02001                                         weighted_C[j] *= param.weight[i];
02002                         }
02003 
02004                         // train k*(k-1)/2 models
02005 
02006                         boolean[] nonzero = new boolean[l];
02007                         for(i=0;i<l;i++)
02008                                 nonzero[i] = false;
02009                         decision_function[] f = new decision_function[nr_class*(nr_class-1)/2];
02010 
02011                         double[] probA=null,probB=null;
02012                         if (param.probability == 1)
02013                         {
02014                                 probA=new double[nr_class*(nr_class-1)/2];
02015                                 probB=new double[nr_class*(nr_class-1)/2];
02016                         }
02017 
02018                         int p = 0;
02019                         for(i=0;i<nr_class;i++)
02020                                 for(int j=i+1;j<nr_class;j++)
02021                                 {
02022                                         svm_problem sub_prob = new svm_problem();
02023                                         int si = start[i], sj = start[j];
02024                                         int ci = count[i], cj = count[j];
02025                                         sub_prob.l = ci+cj;
02026                                         sub_prob.x = new svm_node[sub_prob.l][];
02027                                         sub_prob.y = new double[sub_prob.l];
02028                                         int k;
02029                                         for(k=0;k<ci;k++)
02030                                         {
02031                                                 sub_prob.x[k] = x[si+k];
02032                                                 sub_prob.y[k] = +1;
02033                                         }
02034                                         for(k=0;k<cj;k++)
02035                                         {
02036                                                 sub_prob.x[ci+k] = x[sj+k];
02037                                                 sub_prob.y[ci+k] = -1;
02038                                         }
02039 
02040                                         if(param.probability == 1)
02041                                         {
02042                                                 double[] probAB=new double[2];
02043                                                 svm_binary_svc_probability(sub_prob,param,weighted_C[i],weighted_C[j],probAB);
02044                                                 probA[p]=probAB[0];
02045                                                 probB[p]=probAB[1];
02046                                         }
02047 
02048                                         f[p] = svm_train_one(sub_prob,param,weighted_C[i],weighted_C[j]);
02049                                         for(k=0;k<ci;k++)
02050                                                 if(!nonzero[si+k] && Math.abs(f[p].alpha[k]) > 0)
02051                                                         nonzero[si+k] = true;
02052                                         for(k=0;k<cj;k++)
02053                                                 if(!nonzero[sj+k] && Math.abs(f[p].alpha[ci+k]) > 0)
02054                                                         nonzero[sj+k] = true;
02055                                         ++p;
02056                                 }
02057 
02058                         // build output
02059 
02060                         model.nr_class = nr_class;
02061 
02062                         model.label = new int[nr_class];
02063                         for(i=0;i<nr_class;i++)
02064                                 model.label[i] = label[i];
02065 
02066                         model.rho = new double[nr_class*(nr_class-1)/2];
02067                         for(i=0;i<nr_class*(nr_class-1)/2;i++)
02068                                 model.rho[i] = f[i].rho;
02069 
02070                         if(param.probability == 1)
02071                         {
02072                                 model.probA = new double[nr_class*(nr_class-1)/2];
02073                                 model.probB = new double[nr_class*(nr_class-1)/2];
02074                                 for(i=0;i<nr_class*(nr_class-1)/2;i++)
02075                                 {
02076                                         model.probA[i] = probA[i];
02077                                         model.probB[i] = probB[i];
02078                                 }
02079                         }
02080                         else
02081                         {
02082                                 model.probA=null;
02083                                 model.probB=null;
02084                         }
02085 
02086                         int nnz = 0;
02087                         int[] nz_count = new int[nr_class];
02088                         model.nSV = new int[nr_class];
02089                         for(i=0;i<nr_class;i++)
02090                         {
02091                                 int nSV = 0;
02092                                 for(int j=0;j<count[i];j++)
02093                                         if(nonzero[start[i]+j])
02094                                         {
02095                                                 ++nSV;
02096                                                 ++nnz;
02097                                         }
02098                                 model.nSV[i] = nSV;
02099                                 nz_count[i] = nSV;
02100                         }
02101 
02102                         svm.info("Total nSV = "+nnz+"\n");
02103 
02104                         model.l = nnz;
02105                         model.SV = new svm_node[nnz][];
02106                         model.sv_indices = new int[nnz];
02107                         p = 0;
02108                         for(i=0;i<l;i++)
02109                                 if(nonzero[i])
02110                                 {
02111                                         model.SV[p] = x[i];
02112                                         model.sv_indices[p++] = perm[i] + 1;
02113                                 }
02114 
02115                         int[] nz_start = new int[nr_class];
02116                         nz_start[0] = 0;
02117                         for(i=1;i<nr_class;i++)
02118                                 nz_start[i] = nz_start[i-1]+nz_count[i-1];
02119 
02120                         model.sv_coef = new double[nr_class-1][];
02121                         for(i=0;i<nr_class-1;i++)
02122                                 model.sv_coef[i] = new double[nnz];
02123 
02124                         p = 0;
02125                         for(i=0;i<nr_class;i++)
02126                                 for(int j=i+1;j<nr_class;j++)
02127                                 {
02128                                         // classifier (i,j): coefficients with
02129                                         // i are in sv_coef[j-1][nz_start[i]...],
02130                                         // j are in sv_coef[i][nz_start[j]...]
02131 
02132                                         int si = start[i];
02133                                         int sj = start[j];
02134                                         int ci = count[i];
02135                                         int cj = count[j];
02136 
02137                                         int q = nz_start[i];
02138                                         int k;
02139                                         for(k=0;k<ci;k++)
02140                                                 if(nonzero[si+k])
02141                                                         model.sv_coef[j-1][q++] = f[p].alpha[k];
02142                                         q = nz_start[j];
02143                                         for(k=0;k<cj;k++)
02144                                                 if(nonzero[sj+k])
02145                                                         model.sv_coef[i][q++] = f[p].alpha[ci+k];
02146                                         ++p;
02147                                 }
02148                 }
02149                 return model;
02150         }
02151         
02152         // Stratified cross validation
02153         public static void svm_cross_validation(svm_problem prob, svm_parameter param, int nr_fold, double[] target)
02154         {
02155                 int i;
02156                 int[] fold_start = new int[nr_fold+1];
02157                 int l = prob.l;
02158                 int[] perm = new int[l];
02159                 
02160                 // stratified cv may not give leave-one-out rate
02161                 // Each class to l folds -> some folds may have zero elements
02162                 if((param.svm_type == svm_parameter.C_SVC ||
02163                     param.svm_type == svm_parameter.NU_SVC) && nr_fold < l)
02164                 {
02165                         int[] tmp_nr_class = new int[1];
02166                         int[][] tmp_label = new int[1][];
02167                         int[][] tmp_start = new int[1][];
02168                         int[][] tmp_count = new int[1][];
02169 
02170                         svm_group_classes(prob,tmp_nr_class,tmp_label,tmp_start,tmp_count,perm);
02171 
02172                         int nr_class = tmp_nr_class[0];
02173                         int[] start = tmp_start[0];
02174                         int[] count = tmp_count[0];             
02175 
02176                         // random shuffle and then data grouped by fold using the array perm
02177                         int[] fold_count = new int[nr_fold];
02178                         int c;
02179                         int[] index = new int[l];
02180                         for(i=0;i<l;i++)
02181                                 index[i]=perm[i];
02182                         for (c=0; c<nr_class; c++)
02183                                 for(i=0;i<count[c];i++)
02184                                 {
02185                                         int j = i+rand.nextInt(count[c]-i);
02186                                         do {int _=index[start[c]+j]; index[start[c]+j]=index[start[c]+i]; index[start[c]+i]=_;} while(false);
02187                                 }
02188                         for(i=0;i<nr_fold;i++)
02189                         {
02190                                 fold_count[i] = 0;
02191                                 for (c=0; c<nr_class;c++)
02192                                         fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold;
02193                         }
02194                         fold_start[0]=0;
02195                         for (i=1;i<=nr_fold;i++)
02196                                 fold_start[i] = fold_start[i-1]+fold_count[i-1];
02197                         for (c=0; c<nr_class;c++)
02198                                 for(i=0;i<nr_fold;i++)
02199                                 {
02200                                         int begin = start[c]+i*count[c]/nr_fold;
02201                                         int end = start[c]+(i+1)*count[c]/nr_fold;
02202                                         for(int j=begin;j<end;j++)
02203                                         {
02204                                                 perm[fold_start[i]] = index[j];
02205                                                 fold_start[i]++;
02206                                         }
02207                                 }
02208                         fold_start[0]=0;
02209                         for (i=1;i<=nr_fold;i++)
02210                                 fold_start[i] = fold_start[i-1]+fold_count[i-1];
02211                 }
02212                 else
02213                 {
02214                         for(i=0;i<l;i++) perm[i]=i;
02215                         for(i=0;i<l;i++)
02216                         {
02217                                 int j = i+rand.nextInt(l-i);
02218                                 do {int _=perm[i]; perm[i]=perm[j]; perm[j]=_;} while(false);
02219                         }
02220                         for(i=0;i<=nr_fold;i++)
02221                                 fold_start[i]=i*l/nr_fold;
02222                 }
02223 
02224                 for(i=0;i<nr_fold;i++)
02225                 {
02226                         int begin = fold_start[i];
02227                         int end = fold_start[i+1];
02228                         int j,k;
02229                         svm_problem subprob = new svm_problem();
02230 
02231                         subprob.l = l-(end-begin);
02232                         subprob.x = new svm_node[subprob.l][];
02233                         subprob.y = new double[subprob.l];
02234 
02235                         k=0;
02236                         for(j=0;j<begin;j++)
02237                         {
02238                                 subprob.x[k] = prob.x[perm[j]];
02239                                 subprob.y[k] = prob.y[perm[j]];
02240                                 ++k;
02241                         }
02242                         for(j=end;j<l;j++)
02243                         {
02244                                 subprob.x[k] = prob.x[perm[j]];
02245                                 subprob.y[k] = prob.y[perm[j]];
02246                                 ++k;
02247                         }
02248                         svm_model submodel = svm_train(subprob,param);
02249                         if(param.probability==1 &&
02250                            (param.svm_type == svm_parameter.C_SVC ||
02251                             param.svm_type == svm_parameter.NU_SVC))
02252                         {
02253                                 double[] prob_estimates= new double[svm_get_nr_class(submodel)];
02254                                 for(j=begin;j<end;j++)
02255                                         target[perm[j]] = svm_predict_probability(submodel,prob.x[perm[j]],prob_estimates);
02256                         }
02257                         else
02258                                 for(j=begin;j<end;j++)
02259                                         target[perm[j]] = svm_predict(submodel,prob.x[perm[j]]);
02260                 }
02261         }
02262 
02263         public static int svm_get_svm_type(svm_model model)
02264         {
02265                 return model.param.svm_type;
02266         }
02267 
02268         public static int svm_get_nr_class(svm_model model)
02269         {
02270                 return model.nr_class;
02271         }
02272 
02273         public static void svm_get_labels(svm_model model, int[] label)
02274         {
02275                 if (model.label != null)
02276                         for(int i=0;i<model.nr_class;i++)
02277                                 label[i] = model.label[i];
02278         }
02279 
02280         public static void svm_get_sv_indices(svm_model model, int[] indices)
02281         {
02282                 if (model.sv_indices != null)
02283                         for(int i=0;i<model.l;i++)
02284                                 indices[i] = model.sv_indices[i];
02285         }
02286 
02287         public static int svm_get_nr_sv(svm_model model)
02288         {
02289                 return model.l;
02290         }
02291 
02292         public static double svm_get_svr_probability(svm_model model)
02293         {
02294                 if ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) &&
02295                     model.probA!=null)
02296                 return model.probA[0];
02297                 else
02298                 {
02299                         System.err.print("Model doesn't contain information for SVR probability inference\n");
02300                         return 0;
02301                 }
02302         }
02303 
02304         public static double svm_predict_values(svm_model model, svm_node[] x, double[] dec_values)
02305         {
02306                 int i;
02307                 if(model.param.svm_type == svm_parameter.ONE_CLASS ||
02308                    model.param.svm_type == svm_parameter.EPSILON_SVR ||
02309                    model.param.svm_type == svm_parameter.NU_SVR)
02310                 {
02311                         double[] sv_coef = model.sv_coef[0];
02312                         double sum = 0;
02313                         for(i=0;i<model.l;i++)
02314                                 sum += sv_coef[i] * Kernel.k_function(x,model.SV[i],model.param);
02315                         sum -= model.rho[0];
02316                         dec_values[0] = sum;
02317 
02318                         if(model.param.svm_type == svm_parameter.ONE_CLASS)
02319                                 return (sum>0)?1:-1;
02320                         else
02321                                 return sum;
02322                 }
02323                 else
02324                 {
02325                         int nr_class = model.nr_class;
02326                         int l = model.l;
02327                 
02328                         double[] kvalue = new double[l];
02329                         for(i=0;i<l;i++)
02330                                 kvalue[i] = Kernel.k_function(x,model.SV[i],model.param);
02331 
02332                         int[] start = new int[nr_class];
02333                         start[0] = 0;
02334                         for(i=1;i<nr_class;i++)
02335                                 start[i] = start[i-1]+model.nSV[i-1];
02336 
02337                         int[] vote = new int[nr_class];
02338                         for(i=0;i<nr_class;i++)
02339                                 vote[i] = 0;
02340 
02341                         int p=0;
02342                         for(i=0;i<nr_class;i++)
02343                                 for(int j=i+1;j<nr_class;j++)
02344                                 {
02345                                         double sum = 0;
02346                                         int si = start[i];
02347                                         int sj = start[j];
02348                                         int ci = model.nSV[i];
02349                                         int cj = model.nSV[j];
02350                                 
02351                                         int k;
02352                                         double[] coef1 = model.sv_coef[j-1];
02353                                         double[] coef2 = model.sv_coef[i];
02354                                         for(k=0;k<ci;k++)
02355                                                 sum += coef1[si+k] * kvalue[si+k];
02356                                         for(k=0;k<cj;k++)
02357                                                 sum += coef2[sj+k] * kvalue[sj+k];
02358                                         sum -= model.rho[p];
02359                                         dec_values[p] = sum;                                    
02360 
02361                                         if(dec_values[p] > 0)
02362                                                 ++vote[i];
02363                                         else
02364                                                 ++vote[j];
02365                                         p++;
02366                                 }
02367 
02368                         int vote_max_idx = 0;
02369                         for(i=1;i<nr_class;i++)
02370                                 if(vote[i] > vote[vote_max_idx])
02371                                         vote_max_idx = i;
02372 
02373                         return model.label[vote_max_idx];
02374                 }
02375         }
02376 
02377         public static double svm_predict(svm_model model, svm_node[] x)
02378         {
02379                 int nr_class = model.nr_class;
02380                 double[] dec_values;
02381                 if(model.param.svm_type == svm_parameter.ONE_CLASS ||
02382                                 model.param.svm_type == svm_parameter.EPSILON_SVR ||
02383                                 model.param.svm_type == svm_parameter.NU_SVR)
02384                         dec_values = new double[1];
02385                 else
02386                         dec_values = new double[nr_class*(nr_class-1)/2];
02387                 double pred_result = svm_predict_values(model, x, dec_values);
02388                 return pred_result;
02389         }
02390 
02391         public static double svm_predict_probability(svm_model model, svm_node[] x, double[] prob_estimates)
02392         {
02393                 if ((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) &&
02394                     model.probA!=null && model.probB!=null)
02395                 {
02396                         int i;
02397                         int nr_class = model.nr_class;
02398                         double[] dec_values = new double[nr_class*(nr_class-1)/2];
02399                         svm_predict_values(model, x, dec_values);
02400 
02401                         double min_prob=1e-7;
02402                         double[][] pairwise_prob=new double[nr_class][nr_class];
02403                         
02404                         int k=0;
02405                         for(i=0;i<nr_class;i++)
02406                                 for(int j=i+1;j<nr_class;j++)
02407                                 {
02408                                         pairwise_prob[i][j]=Math.min(Math.max(sigmoid_predict(dec_values[k],model.probA[k],model.probB[k]),min_prob),1-min_prob);
02409                                         pairwise_prob[j][i]=1-pairwise_prob[i][j];
02410                                         k++;
02411                                 }
02412                         multiclass_probability(nr_class,pairwise_prob,prob_estimates);
02413 
02414                         int prob_max_idx = 0;
02415                         for(i=1;i<nr_class;i++)
02416                                 if(prob_estimates[i] > prob_estimates[prob_max_idx])
02417                                         prob_max_idx = i;
02418                         return model.label[prob_max_idx];
02419                 }
02420                 else 
02421                         return svm_predict(model, x);
02422         }
02423 
02424         static final String svm_type_table[] =
02425         {
02426                 "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",
02427         };
02428 
02429         static final String kernel_type_table[]=
02430         {
02431                 "linear","polynomial","rbf","sigmoid","precomputed"
02432         };
02433 
02434         public static void svm_save_model(String model_file_name, svm_model model) throws IOException
02435         {
02436                 DataOutputStream fp = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(model_file_name)));
02437 
02438                 svm_parameter param = model.param;
02439 
02440                 fp.writeBytes("svm_type "+svm_type_table[param.svm_type]+"\n");
02441                 fp.writeBytes("kernel_type "+kernel_type_table[param.kernel_type]+"\n");
02442 
02443                 if(param.kernel_type == svm_parameter.POLY)
02444                         fp.writeBytes("degree "+param.degree+"\n");
02445 
02446                 if(param.kernel_type == svm_parameter.POLY ||
02447                    param.kernel_type == svm_parameter.RBF ||
02448                    param.kernel_type == svm_parameter.SIGMOID)
02449                         fp.writeBytes("gamma "+param.gamma+"\n");
02450 
02451                 if(param.kernel_type == svm_parameter.POLY ||
02452                    param.kernel_type == svm_parameter.SIGMOID)
02453                         fp.writeBytes("coef0 "+param.coef0+"\n");
02454 
02455                 int nr_class = model.nr_class;
02456                 int l = model.l;
02457                 fp.writeBytes("nr_class "+nr_class+"\n");
02458                 fp.writeBytes("total_sv "+l+"\n");
02459         
02460                 {
02461                         fp.writeBytes("rho");
02462                         for(int i=0;i<nr_class*(nr_class-1)/2;i++)
02463                                 fp.writeBytes(" "+model.rho[i]);
02464                         fp.writeBytes("\n");
02465                 }
02466         
02467                 if(model.label != null)
02468                 {
02469                         fp.writeBytes("label");
02470                         for(int i=0;i<nr_class;i++)
02471                                 fp.writeBytes(" "+model.label[i]);
02472                         fp.writeBytes("\n");
02473                 }
02474 
02475                 if(model.probA != null) // regression has probA only
02476                 {
02477                         fp.writeBytes("probA");
02478                         for(int i=0;i<nr_class*(nr_class-1)/2;i++)
02479                                 fp.writeBytes(" "+model.probA[i]);
02480                         fp.writeBytes("\n");
02481                 }
02482                 if(model.probB != null) 
02483                 {
02484                         fp.writeBytes("probB");
02485                         for(int i=0;i<nr_class*(nr_class-1)/2;i++)
02486                                 fp.writeBytes(" "+model.probB[i]);
02487                         fp.writeBytes("\n");
02488                 }
02489 
02490                 if(model.nSV != null)
02491                 {
02492                         fp.writeBytes("nr_sv");
02493                         for(int i=0;i<nr_class;i++)
02494                                 fp.writeBytes(" "+model.nSV[i]);
02495                         fp.writeBytes("\n");
02496                 }
02497 
02498                 fp.writeBytes("SV\n");
02499                 double[][] sv_coef = model.sv_coef;
02500                 svm_node[][] SV = model.SV;
02501 
02502                 for(int i=0;i<l;i++)
02503                 {
02504                         for(int j=0;j<nr_class-1;j++)
02505                                 fp.writeBytes(sv_coef[j][i]+" ");
02506 
02507                         svm_node[] p = SV[i];
02508                         if(param.kernel_type == svm_parameter.PRECOMPUTED)
02509                                 fp.writeBytes("0:"+(int)(p[0].value));
02510                         else    
02511                                 for(int j=0;j<p.length;j++)
02512                                         fp.writeBytes(p[j].index+":"+p[j].value+" ");
02513                         fp.writeBytes("\n");
02514                 }
02515 
02516                 fp.close();
02517         }
02518 
02519         private static double atof(String s)
02520         {
02521                 return Double.valueOf(s).doubleValue();
02522         }
02523 
02524         private static int atoi(String s)
02525         {
02526                 return Integer.parseInt(s);
02527         }
02528 
02529         public static svm_model svm_load_model(String model_file_name) throws IOException
02530         {
02531                 return svm_load_model(new BufferedReader(new FileReader(model_file_name)));
02532         }
02533 
02534         public static svm_model svm_load_model(BufferedReader fp) throws IOException
02535         {
02536                 // read parameters
02537 
02538                 svm_model model = new svm_model();
02539                 svm_parameter param = new svm_parameter();
02540                 model.param = param;
02541                 model.rho = null;
02542                 model.probA = null;
02543                 model.probB = null;
02544                 model.label = null;
02545                 model.nSV = null;
02546 
02547                 while(true)
02548                 {
02549                         String cmd = fp.readLine();
02550                         String arg = cmd.substring(cmd.indexOf(' ')+1);
02551 
02552                         if(cmd.startsWith("svm_type"))
02553                         {
02554                                 int i;
02555                                 for(i=0;i<svm_type_table.length;i++)
02556                                 {
02557                                         if(arg.indexOf(svm_type_table[i])!=-1)
02558                                         {
02559                                                 param.svm_type=i;
02560                                                 break;
02561                                         }
02562                                 }
02563                                 if(i == svm_type_table.length)
02564                                 {
02565                                         System.err.print("unknown svm type.\n");
02566                                         return null;
02567                                 }
02568                         }
02569                         else if(cmd.startsWith("kernel_type"))
02570                         {
02571                                 int i;
02572                                 for(i=0;i<kernel_type_table.length;i++)
02573                                 {
02574                                         if(arg.indexOf(kernel_type_table[i])!=-1)
02575                                         {
02576                                                 param.kernel_type=i;
02577                                                 break;
02578                                         }
02579                                 }
02580                                 if(i == kernel_type_table.length)
02581                                 {
02582                                         System.err.print("unknown kernel function.\n");
02583                                         return null;
02584                                 }
02585                         }
02586                         else if(cmd.startsWith("degree"))
02587                                 param.degree = atoi(arg);
02588                         else if(cmd.startsWith("gamma"))
02589                                 param.gamma = atof(arg);
02590                         else if(cmd.startsWith("coef0"))
02591                                 param.coef0 = atof(arg);
02592                         else if(cmd.startsWith("nr_class"))
02593                                 model.nr_class = atoi(arg);
02594                         else if(cmd.startsWith("total_sv"))
02595                                 model.l = atoi(arg);
02596                         else if(cmd.startsWith("rho"))
02597                         {
02598                                 int n = model.nr_class * (model.nr_class-1)/2;
02599                                 model.rho = new double[n];
02600                                 StringTokenizer st = new StringTokenizer(arg);
02601                                 for(int i=0;i<n;i++)
02602                                         model.rho[i] = atof(st.nextToken());
02603                         }
02604                         else if(cmd.startsWith("label"))
02605                         {
02606                                 int n = model.nr_class;
02607                                 model.label = new int[n];
02608                                 StringTokenizer st = new StringTokenizer(arg);
02609                                 for(int i=0;i<n;i++)
02610                                         model.label[i] = atoi(st.nextToken());                                  
02611                         }
02612                         else if(cmd.startsWith("probA"))
02613                         {
02614                                 int n = model.nr_class*(model.nr_class-1)/2;
02615                                 model.probA = new double[n];
02616                                 StringTokenizer st = new StringTokenizer(arg);
02617                                 for(int i=0;i<n;i++)
02618                                         model.probA[i] = atof(st.nextToken());                                  
02619                         }
02620                         else if(cmd.startsWith("probB"))
02621                         {
02622                                 int n = model.nr_class*(model.nr_class-1)/2;
02623                                 model.probB = new double[n];
02624                                 StringTokenizer st = new StringTokenizer(arg);
02625                                 for(int i=0;i<n;i++)
02626                                         model.probB[i] = atof(st.nextToken());                                  
02627                         }
02628                         else if(cmd.startsWith("nr_sv"))
02629                         {
02630                                 int n = model.nr_class;
02631                                 model.nSV = new int[n];
02632                                 StringTokenizer st = new StringTokenizer(arg);
02633                                 for(int i=0;i<n;i++)
02634                                         model.nSV[i] = atoi(st.nextToken());
02635                         }
02636                         else if(cmd.startsWith("SV"))
02637                         {
02638                                 break;
02639                         }
02640                         else
02641                         {
02642                                 System.err.print("unknown text in model file: ["+cmd+"]\n");
02643                                 return null;
02644                         }
02645                 }
02646 
02647                 // read sv_coef and SV
02648 
02649                 int m = model.nr_class - 1;
02650                 int l = model.l;
02651                 model.sv_coef = new double[m][l];
02652                 model.SV = new svm_node[l][];
02653 
02654                 for(int i=0;i<l;i++)
02655                 {
02656                         String line = fp.readLine();
02657                         StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
02658 
02659                         for(int k=0;k<m;k++)
02660                                 model.sv_coef[k][i] = atof(st.nextToken());
02661                         int n = st.countTokens()/2;
02662                         model.SV[i] = new svm_node[n];
02663                         for(int j=0;j<n;j++)
02664                         {
02665                                 model.SV[i][j] = new svm_node();
02666                                 model.SV[i][j].index = atoi(st.nextToken());
02667                                 model.SV[i][j].value = atof(st.nextToken());
02668                         }
02669                 }
02670 
02671                 fp.close();
02672                 return model;
02673         }
02674 
02675         public static String svm_check_parameter(svm_problem prob, svm_parameter param)
02676         {
02677                 // svm_type
02678 
02679                 int svm_type = param.svm_type;
02680                 if(svm_type != svm_parameter.C_SVC &&
02681                    svm_type != svm_parameter.NU_SVC &&
02682                    svm_type != svm_parameter.ONE_CLASS &&
02683                    svm_type != svm_parameter.EPSILON_SVR &&
02684                    svm_type != svm_parameter.NU_SVR)
02685                 return "unknown svm type";
02686 
02687                 // kernel_type, degree
02688         
02689                 int kernel_type = param.kernel_type;
02690                 if(kernel_type != svm_parameter.LINEAR &&
02691                    kernel_type != svm_parameter.POLY &&
02692                    kernel_type != svm_parameter.RBF &&
02693                    kernel_type != svm_parameter.SIGMOID &&
02694                    kernel_type != svm_parameter.PRECOMPUTED)
02695                         return "unknown kernel type";
02696 
02697                 if(param.gamma < 0)
02698                         return "gamma < 0";
02699 
02700                 if(param.degree < 0)
02701                         return "degree of polynomial kernel < 0";
02702 
02703                 // cache_size,eps,C,nu,p,shrinking
02704 
02705                 if(param.cache_size <= 0)
02706                         return "cache_size <= 0";
02707 
02708                 if(param.eps <= 0)
02709                         return "eps <= 0";
02710 
02711                 if(svm_type == svm_parameter.C_SVC ||
02712                    svm_type == svm_parameter.EPSILON_SVR ||
02713                    svm_type == svm_parameter.NU_SVR)
02714                         if(param.C <= 0)
02715                                 return "C <= 0";
02716 
02717                 if(svm_type == svm_parameter.NU_SVC ||
02718                    svm_type == svm_parameter.ONE_CLASS ||
02719                    svm_type == svm_parameter.NU_SVR)
02720                         if(param.nu <= 0 || param.nu > 1)
02721                                 return "nu <= 0 or nu > 1";
02722 
02723                 if(svm_type == svm_parameter.EPSILON_SVR)
02724                         if(param.p < 0)
02725                                 return "p < 0";
02726 
02727                 if(param.shrinking != 0 &&
02728                    param.shrinking != 1)
02729                         return "shrinking != 0 and shrinking != 1";
02730 
02731                 if(param.probability != 0 &&
02732                    param.probability != 1)
02733                         return "probability != 0 and probability != 1";
02734 
02735                 if(param.probability == 1 &&
02736                    svm_type == svm_parameter.ONE_CLASS)
02737                         return "one-class SVM probability output not supported yet";
02738                 
02739                 // check whether nu-svc is feasible
02740         
02741                 if(svm_type == svm_parameter.NU_SVC)
02742                 {
02743                         int l = prob.l;
02744                         int max_nr_class = 16;
02745                         int nr_class = 0;
02746                         int[] label = new int[max_nr_class];
02747                         int[] count = new int[max_nr_class];
02748 
02749                         int i;
02750                         for(i=0;i<l;i++)
02751                         {
02752                                 int this_label = (int)prob.y[i];
02753                                 int j;
02754                                 for(j=0;j<nr_class;j++)
02755                                         if(this_label == label[j])
02756                                         {
02757                                                 ++count[j];
02758                                                 break;
02759                                         }
02760 
02761                                 if(j == nr_class)
02762                                 {
02763                                         if(nr_class == max_nr_class)
02764                                         {
02765                                                 max_nr_class *= 2;
02766                                                 int[] new_data = new int[max_nr_class];
02767                                                 System.arraycopy(label,0,new_data,0,label.length);
02768                                                 label = new_data;
02769                                                 
02770                                                 new_data = new int[max_nr_class];
02771                                                 System.arraycopy(count,0,new_data,0,count.length);
02772                                                 count = new_data;
02773                                         }
02774                                         label[nr_class] = this_label;
02775                                         count[nr_class] = 1;
02776                                         ++nr_class;
02777                                 }
02778                         }
02779 
02780                         for(i=0;i<nr_class;i++)
02781                         {
02782                                 int n1 = count[i];
02783                                 for(int j=i+1;j<nr_class;j++)
02784                                 {
02785                                         int n2 = count[j];
02786                                         if(param.nu*(n1+n2)/2 > Math.min(n1,n2))
02787                                                 return "specified nu is infeasible";
02788                                 }
02789                         }
02790                 }
02791 
02792                 return null;
02793         }
02794 
02795         public static int svm_check_probability_model(svm_model model)
02796         {
02797                 if (((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) &&
02798                 model.probA!=null && model.probB!=null) ||
02799                 ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) &&
02800                  model.probA!=null))
02801                         return 1;
02802                 else
02803                         return 0;
02804         }
02805 
02806         public static void svm_set_print_string_function(svm_print_interface print_func)
02807         {
02808                 if (print_func == null)
02809                         svm_print_string = svm_print_stdout;
02810                 else 
02811                         svm_print_string = print_func;
02812         }
02813 }


ml_classifiers
Author(s): Scott Niekum
autogenerated on Thu Aug 27 2015 13:59:04