svm.java
Go to the documentation of this file.
00001 
00002 
00003 
00004 
00005 package libsvm;
00006 import java.io.*;
00007 import java.util.*;
00008 
00009 //
00010 // Kernel Cache
00011 //
00012 // l is the number of total data items
00013 // size is the cache size limit in bytes
00014 //
00015 class Cache {
00016         private final int l;
00017         private long size;
00018         private final class head_t
00019         {
00020                 head_t prev, next;      // a cicular list
00021                 float[] data;
00022                 int len;                // data[0,len) is cached in this entry
00023         }
00024         private final head_t[] head;
00025         private head_t lru_head;
00026 
00027         Cache(int l_, long size_)
00028         {
00029                 l = l_;
00030                 size = size_;
00031                 head = new head_t[l];
00032                 for(int i=0;i<l;i++) head[i] = new head_t();
00033                 size /= 4;
00034                 size -= l * (16/4);     // sizeof(head_t) == 16
00035                 size = Math.max(size, 2* (long) l);  // cache must be large enough for two columns
00036                 lru_head = new head_t();
00037                 lru_head.next = lru_head.prev = lru_head;
00038         }
00039 
00040         private void lru_delete(head_t h)
00041         {
00042                 // delete from current location
00043                 h.prev.next = h.next;
00044                 h.next.prev = h.prev;
00045         }
00046 
00047         private void lru_insert(head_t h)
00048         {
00049                 // insert to last position
00050                 h.next = lru_head;
00051                 h.prev = lru_head.prev;
00052                 h.prev.next = h;
00053                 h.next.prev = h;
00054         }
00055 
00056         // request data [0,len)
00057         // return some position p where [p,len) need to be filled
00058         // (p >= len if nothing needs to be filled)
00059         // java: simulate pointer using single-element array
00060         int get_data(int index, float[][] data, int len)
00061         {
00062                 head_t h = head[index];
00063                 if(h.len > 0) lru_delete(h);
00064                 int more = len - h.len;
00065 
00066                 if(more > 0)
00067                 {
00068                         // free old space
00069                         while(size < more)
00070                         {
00071                                 head_t old = lru_head.next;
00072                                 lru_delete(old);
00073                                 size += old.len;
00074                                 old.data = null;
00075                                 old.len = 0;
00076                         }
00077 
00078                         // allocate new space
00079                         float[] new_data = new float[len];
00080                         if(h.data != null) System.arraycopy(h.data,0,new_data,0,h.len);
00081                         h.data = new_data;
00082                         size -= more;
00083                         do {int _=h.len; h.len=len; len=_;} while(false);
00084                 }
00085 
00086                 lru_insert(h);
00087                 data[0] = h.data;
00088                 return len;
00089         }
00090 
00091         void swap_index(int i, int j)
00092         {
00093                 if(i==j) return;
00094                 
00095                 if(head[i].len > 0) lru_delete(head[i]);
00096                 if(head[j].len > 0) lru_delete(head[j]);
00097                 do {float[] _=head[i].data; head[i].data=head[j].data; head[j].data=_;} while(false);
00098                 do {int _=head[i].len; head[i].len=head[j].len; head[j].len=_;} while(false);
00099                 if(head[i].len > 0) lru_insert(head[i]);
00100                 if(head[j].len > 0) lru_insert(head[j]);
00101 
00102                 if(i>j) do {int _=i; i=j; j=_;} while(false);
00103                 for(head_t h = lru_head.next; h!=lru_head; h=h.next)
00104                 {
00105                         if(h.len > i)
00106                         {
00107                                 if(h.len > j)
00108                                         do {float _=h.data[i]; h.data[i]=h.data[j]; h.data[j]=_;} while(false);
00109                                 else
00110                                 {
00111                                         // give up
00112                                         lru_delete(h);
00113                                         size += h.len;
00114                                         h.data = null;
00115                                         h.len = 0;
00116                                 }
00117                         }
00118                 }
00119         }
00120 }
00121 
00122 //
00123 // Kernel evaluation
00124 //
00125 // the static method k_function is for doing single kernel evaluation
00126 // the constructor of Kernel prepares to calculate the l*l kernel matrix
00127 // the member function get_Q is for getting one column from the Q Matrix
00128 //
00129 abstract class QMatrix {
00130         abstract float[] get_Q(int column, int len);
00131         abstract double[] get_QD();
00132         abstract void swap_index(int i, int j);
00133 };
00134 
00135 abstract class Kernel extends QMatrix {
00136         private svm_node[][] x;
00137         private final double[] x_square;
00138 
00139         // svm_parameter
00140         private final int kernel_type;
00141         private final int degree;
00142         private final double gamma;
00143         private final double coef0;
00144 
00145         abstract float[] get_Q(int column, int len);
00146         abstract double[] get_QD();
00147 
00148         void swap_index(int i, int j)
00149         {
00150                 do {svm_node[] _=x[i]; x[i]=x[j]; x[j]=_;} while(false);
00151                 if(x_square != null) do {double _=x_square[i]; x_square[i]=x_square[j]; x_square[j]=_;} while(false);
00152         }
00153 
00154         private static double powi(double base, int times)
00155         {
00156                 double tmp = base, ret = 1.0;
00157 
00158                 for(int t=times; t>0; t/=2)
00159                 {
00160                         if(t%2==1) ret*=tmp;
00161                         tmp = tmp * tmp;
00162                 }
00163                 return ret;
00164         }
00165 
00166         double kernel_function(int i, int j)
00167         {
00168                 switch(kernel_type)
00169                 {
00170                         case svm_parameter.LINEAR:
00171                                 return dot(x[i],x[j]);
00172                         case svm_parameter.POLY:
00173                                 return powi(gamma*dot(x[i],x[j])+coef0,degree);
00174                         case svm_parameter.RBF:
00175                                 return Math.exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
00176                         case svm_parameter.SIGMOID:
00177                                 return Math.tanh(gamma*dot(x[i],x[j])+coef0);
00178                         case svm_parameter.PRECOMPUTED:
00179                                 return x[i][(int)(x[j][0].value)].value;
00180                         default:
00181                                 return 0;       // java
00182                 }
00183         }
00184 
00185         Kernel(int l, svm_node[][] x_, svm_parameter param)
00186         {
00187                 this.kernel_type = param.kernel_type;
00188                 this.degree = param.degree;
00189                 this.gamma = param.gamma;
00190                 this.coef0 = param.coef0;
00191 
00192                 x = (svm_node[][])x_.clone();
00193 
00194                 if(kernel_type == svm_parameter.RBF)
00195                 {
00196                         x_square = new double[l];
00197                         for(int i=0;i<l;i++)
00198                                 x_square[i] = dot(x[i],x[i]);
00199                 }
00200                 else x_square = null;
00201         }
00202 
00203         static double dot(svm_node[] x, svm_node[] y)
00204         {
00205                 double sum = 0;
00206                 int xlen = x.length;
00207                 int ylen = y.length;
00208                 int i = 0;
00209                 int j = 0;
00210                 while(i < xlen && j < ylen)
00211                 {
00212                         if(x[i].index == y[j].index)
00213                                 sum += x[i++].value * y[j++].value;
00214                         else
00215                         {
00216                                 if(x[i].index > y[j].index)
00217                                         ++j;
00218                                 else
00219                                         ++i;
00220                         }
00221                 }
00222                 return sum;
00223         }
00224 
00225         static double k_function(svm_node[] x, svm_node[] y,
00226                                         svm_parameter param)
00227         {
00228                 switch(param.kernel_type)
00229                 {
00230                         case svm_parameter.LINEAR:
00231                                 return dot(x,y);
00232                         case svm_parameter.POLY:
00233                                 return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
00234                         case svm_parameter.RBF:
00235                         {
00236                                 double sum = 0;
00237                                 int xlen = x.length;
00238                                 int ylen = y.length;
00239                                 int i = 0;
00240                                 int j = 0;
00241                                 while(i < xlen && j < ylen)
00242                                 {
00243                                         if(x[i].index == y[j].index)
00244                                         {
00245                                                 double d = x[i++].value - y[j++].value;
00246                                                 sum += d*d;
00247                                         }
00248                                         else if(x[i].index > y[j].index)
00249                                         {
00250                                                 sum += y[j].value * y[j].value;
00251                                                 ++j;
00252                                         }
00253                                         else
00254                                         {
00255                                                 sum += x[i].value * x[i].value;
00256                                                 ++i;
00257                                         }
00258                                 }
00259 
00260                                 while(i < xlen)
00261                                 {
00262                                         sum += x[i].value * x[i].value;
00263                                         ++i;
00264                                 }
00265 
00266                                 while(j < ylen)
00267                                 {
00268                                         sum += y[j].value * y[j].value;
00269                                         ++j;
00270                                 }
00271 
00272                                 return Math.exp(-param.gamma*sum);
00273                         }
00274                         case svm_parameter.SIGMOID:
00275                                 return Math.tanh(param.gamma*dot(x,y)+param.coef0);
00276                         case svm_parameter.PRECOMPUTED:
00277                                 return  x[(int)(y[0].value)].value;
00278                         default:
00279                                 return 0;       // java
00280                 }
00281         }
00282 }
00283 
00284 // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
00285 // Solves:
00286 //
00287 //      min 0.5(\alpha^T Q \alpha) + p^T \alpha
00288 //
00289 //              y^T \alpha = \delta
00290 //              y_i = +1 or -1
00291 //              0 <= alpha_i <= Cp for y_i = 1
00292 //              0 <= alpha_i <= Cn for y_i = -1
00293 //
00294 // Given:
00295 //
00296 //      Q, p, y, Cp, Cn, and an initial feasible point \alpha
00297 //      l is the size of vectors and matrices
00298 //      eps is the stopping tolerance
00299 //
00300 // solution will be put in \alpha, objective value will be put in obj
00301 //
00302 class Solver {
00303         int active_size;
00304         byte[] y;
00305         double[] G;             // gradient of objective function
00306         static final byte LOWER_BOUND = 0;
00307         static final byte UPPER_BOUND = 1;
00308         static final byte FREE = 2;
00309         byte[] alpha_status;    // LOWER_BOUND, UPPER_BOUND, FREE
00310         double[] alpha;
00311         QMatrix Q;
00312         double[] QD;
00313         double eps;
00314         double Cp,Cn;
00315         double[] p;
00316         int[] active_set;
00317         double[] G_bar;         // gradient, if we treat free variables as 0
00318         int l;
00319         boolean unshrink;       // XXX
00320         
00321         static final double INF = java.lang.Double.POSITIVE_INFINITY;
00322 
00323         double get_C(int i)
00324         {
00325                 return (y[i] > 0)? Cp : Cn;
00326         }
00327         void update_alpha_status(int i)
00328         {
00329                 if(alpha[i] >= get_C(i))
00330                         alpha_status[i] = UPPER_BOUND;
00331                 else if(alpha[i] <= 0)
00332                         alpha_status[i] = LOWER_BOUND;
00333                 else alpha_status[i] = FREE;
00334         }
00335         boolean is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
00336         boolean is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
00337         boolean is_free(int i) {  return alpha_status[i] == FREE; }
00338 
00339         // java: information about solution except alpha,
00340         // because we cannot return multiple values otherwise...
00341         static class SolutionInfo {
00342                 double obj;
00343                 double rho;
00344                 double upper_bound_p;
00345                 double upper_bound_n;
00346                 double r;       // for Solver_NU
00347         }
00348 
00349         void swap_index(int i, int j)
00350         {
00351                 Q.swap_index(i,j);
00352                 do {byte _=y[i]; y[i]=y[j]; y[j]=_;} while(false);
00353                 do {double _=G[i]; G[i]=G[j]; G[j]=_;} while(false);
00354                 do {byte _=alpha_status[i]; alpha_status[i]=alpha_status[j]; alpha_status[j]=_;} while(false);
00355                 do {double _=alpha[i]; alpha[i]=alpha[j]; alpha[j]=_;} while(false);
00356                 do {double _=p[i]; p[i]=p[j]; p[j]=_;} while(false);
00357                 do {int _=active_set[i]; active_set[i]=active_set[j]; active_set[j]=_;} while(false);
00358                 do {double _=G_bar[i]; G_bar[i]=G_bar[j]; G_bar[j]=_;} while(false);
00359         }
00360 
00361         void reconstruct_gradient()
00362         {
00363                 // reconstruct inactive elements of G from G_bar and free variables
00364 
00365                 if(active_size == l) return;
00366 
00367                 int i,j;
00368                 int nr_free = 0;
00369 
00370                 for(j=active_size;j<l;j++)
00371                         G[j] = G_bar[j] + p[j];
00372 
00373                 for(j=0;j<active_size;j++)
00374                         if(is_free(j))
00375                                 nr_free++;
00376 
00377                 if(2*nr_free < active_size)
00378                         svm.info("\nWARNING: using -h 0 may be faster\n");
00379 
00380                 if (nr_free*l > 2*active_size*(l-active_size))
00381                 {
00382                         for(i=active_size;i<l;i++)
00383                         {
00384                                 float[] Q_i = Q.get_Q(i,active_size);
00385                                 for(j=0;j<active_size;j++)
00386                                         if(is_free(j))
00387                                                 G[i] += alpha[j] * Q_i[j];
00388                         }       
00389                 }
00390                 else
00391                 {
00392                         for(i=0;i<active_size;i++)
00393                                 if(is_free(i))
00394                                 {
00395                                         float[] Q_i = Q.get_Q(i,l);
00396                                         double alpha_i = alpha[i];
00397                                         for(j=active_size;j<l;j++)
00398                                                 G[j] += alpha_i * Q_i[j];
00399                                 }
00400                 }
00401         }
00402 
00403         void Solve(int l, QMatrix Q, double[] p_, byte[] y_,
00404                    double[] alpha_, double Cp, double Cn, double eps, SolutionInfo si, int shrinking)
00405         {
00406                 this.l = l;
00407                 this.Q = Q;
00408                 QD = Q.get_QD();
00409                 p = (double[])p_.clone();
00410                 y = (byte[])y_.clone();
00411                 alpha = (double[])alpha_.clone();
00412                 this.Cp = Cp;
00413                 this.Cn = Cn;
00414                 this.eps = eps;
00415                 this.unshrink = false;
00416 
00417                 // initialize alpha_status
00418                 {
00419                         alpha_status = new byte[l];
00420                         for(int i=0;i<l;i++)
00421                                 update_alpha_status(i);
00422                 }
00423 
00424                 // initialize active set (for shrinking)
00425                 {
00426                         active_set = new int[l];
00427                         for(int i=0;i<l;i++)
00428                                 active_set[i] = i;
00429                         active_size = l;
00430                 }
00431 
00432                 // initialize gradient
00433                 {
00434                         G = new double[l];
00435                         G_bar = new double[l];
00436                         int i;
00437                         for(i=0;i<l;i++)
00438                         {
00439                                 G[i] = p[i];
00440                                 G_bar[i] = 0;
00441                         }
00442                         for(i=0;i<l;i++)
00443                                 if(!is_lower_bound(i))
00444                                 {
00445                                         float[] Q_i = Q.get_Q(i,l);
00446                                         double alpha_i = alpha[i];
00447                                         int j;
00448                                         for(j=0;j<l;j++)
00449                                                 G[j] += alpha_i*Q_i[j];
00450                                         if(is_upper_bound(i))
00451                                                 for(j=0;j<l;j++)
00452                                                         G_bar[j] += get_C(i) * Q_i[j];
00453                                 }
00454                 }
00455 
00456                 // optimization step
00457 
00458                 int iter = 0;
00459                 int max_iter = Math.max(10000000, l>Integer.MAX_VALUE/100 ? Integer.MAX_VALUE : 100*l);
00460                 int counter = Math.min(l,1000)+1;
00461                 int[] working_set = new int[2];
00462 
00463                 while(iter < max_iter)
00464                 {
00465                         // show progress and do shrinking
00466 
00467                         if(--counter == 0)
00468                         {
00469                                 counter = Math.min(l,1000);
00470                                 if(shrinking!=0) do_shrinking();
00471                                 svm.info(".");
00472                         }
00473 
00474                         if(select_working_set(working_set)!=0)
00475                         {
00476                                 // reconstruct the whole gradient
00477                                 reconstruct_gradient();
00478                                 // reset active set size and check
00479                                 active_size = l;
00480                                 svm.info("*");
00481                                 if(select_working_set(working_set)!=0)
00482                                         break;
00483                                 else
00484                                         counter = 1;    // do shrinking next iteration
00485                         }
00486                         
00487                         int i = working_set[0];
00488                         int j = working_set[1];
00489 
00490                         ++iter;
00491 
00492                         // update alpha[i] and alpha[j], handle bounds carefully
00493 
00494                         float[] Q_i = Q.get_Q(i,active_size);
00495                         float[] Q_j = Q.get_Q(j,active_size);
00496 
00497                         double C_i = get_C(i);
00498                         double C_j = get_C(j);
00499 
00500                         double old_alpha_i = alpha[i];
00501                         double old_alpha_j = alpha[j];
00502 
00503                         if(y[i]!=y[j])
00504                         {
00505                                 double quad_coef = QD[i]+QD[j]+2*Q_i[j];
00506                                 if (quad_coef <= 0)
00507                                         quad_coef = 1e-12;
00508                                 double delta = (-G[i]-G[j])/quad_coef;
00509                                 double diff = alpha[i] - alpha[j];
00510                                 alpha[i] += delta;
00511                                 alpha[j] += delta;
00512                         
00513                                 if(diff > 0)
00514                                 {
00515                                         if(alpha[j] < 0)
00516                                         {
00517                                                 alpha[j] = 0;
00518                                                 alpha[i] = diff;
00519                                         }
00520                                 }
00521                                 else
00522                                 {
00523                                         if(alpha[i] < 0)
00524                                         {
00525                                                 alpha[i] = 0;
00526                                                 alpha[j] = -diff;
00527                                         }
00528                                 }
00529                                 if(diff > C_i - C_j)
00530                                 {
00531                                         if(alpha[i] > C_i)
00532                                         {
00533                                                 alpha[i] = C_i;
00534                                                 alpha[j] = C_i - diff;
00535                                         }
00536                                 }
00537                                 else
00538                                 {
00539                                         if(alpha[j] > C_j)
00540                                         {
00541                                                 alpha[j] = C_j;
00542                                                 alpha[i] = C_j + diff;
00543                                         }
00544                                 }
00545                         }
00546                         else
00547                         {
00548                                 double quad_coef = QD[i]+QD[j]-2*Q_i[j];
00549                                 if (quad_coef <= 0)
00550                                         quad_coef = 1e-12;
00551                                 double delta = (G[i]-G[j])/quad_coef;
00552                                 double sum = alpha[i] + alpha[j];
00553                                 alpha[i] -= delta;
00554                                 alpha[j] += delta;
00555 
00556                                 if(sum > C_i)
00557                                 {
00558                                         if(alpha[i] > C_i)
00559                                         {
00560                                                 alpha[i] = C_i;
00561                                                 alpha[j] = sum - C_i;
00562                                         }
00563                                 }
00564                                 else
00565                                 {
00566                                         if(alpha[j] < 0)
00567                                         {
00568                                                 alpha[j] = 0;
00569                                                 alpha[i] = sum;
00570                                         }
00571                                 }
00572                                 if(sum > C_j)
00573                                 {
00574                                         if(alpha[j] > C_j)
00575                                         {
00576                                                 alpha[j] = C_j;
00577                                                 alpha[i] = sum - C_j;
00578                                         }
00579                                 }
00580                                 else
00581                                 {
00582                                         if(alpha[i] < 0)
00583                                         {
00584                                                 alpha[i] = 0;
00585                                                 alpha[j] = sum;
00586                                         }
00587                                 }
00588                         }
00589 
00590                         // update G
00591 
00592                         double delta_alpha_i = alpha[i] - old_alpha_i;
00593                         double delta_alpha_j = alpha[j] - old_alpha_j;
00594 
00595                         for(int k=0;k<active_size;k++)
00596                         {
00597                                 G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
00598                         }
00599 
00600                         // update alpha_status and G_bar
00601 
00602                         {
00603                                 boolean ui = is_upper_bound(i);
00604                                 boolean uj = is_upper_bound(j);
00605                                 update_alpha_status(i);
00606                                 update_alpha_status(j);
00607                                 int k;
00608                                 if(ui != is_upper_bound(i))
00609                                 {
00610                                         Q_i = Q.get_Q(i,l);
00611                                         if(ui)
00612                                                 for(k=0;k<l;k++)
00613                                                         G_bar[k] -= C_i * Q_i[k];
00614                                         else
00615                                                 for(k=0;k<l;k++)
00616                                                         G_bar[k] += C_i * Q_i[k];
00617                                 }
00618 
00619                                 if(uj != is_upper_bound(j))
00620                                 {
00621                                         Q_j = Q.get_Q(j,l);
00622                                         if(uj)
00623                                                 for(k=0;k<l;k++)
00624                                                         G_bar[k] -= C_j * Q_j[k];
00625                                         else
00626                                                 for(k=0;k<l;k++)
00627                                                         G_bar[k] += C_j * Q_j[k];
00628                                 }
00629                         }
00630 
00631                 }
00632                 
00633                 if(iter >= max_iter)
00634                 {
00635                         if(active_size < l)
00636                         {
00637                                 // reconstruct the whole gradient to calculate objective value
00638                                 reconstruct_gradient();
00639                                 active_size = l;
00640                                 svm.info("*");
00641                         }
00642                         svm.info("\nWARNING: reaching max number of iterations");
00643                 }
00644 
00645                 // calculate rho
00646 
00647                 si.rho = calculate_rho();
00648 
00649                 // calculate objective value
00650                 {
00651                         double v = 0;
00652                         int i;
00653                         for(i=0;i<l;i++)
00654                                 v += alpha[i] * (G[i] + p[i]);
00655 
00656                         si.obj = v/2;
00657                 }
00658 
00659                 // put back the solution
00660                 {
00661                         for(int i=0;i<l;i++)
00662                                 alpha_[active_set[i]] = alpha[i];
00663                 }
00664 
00665                 si.upper_bound_p = Cp;
00666                 si.upper_bound_n = Cn;
00667 
00668                 svm.info("\noptimization finished, #iter = "+iter+"\n");
00669         }
00670 
00671         // return 1 if already optimal, return 0 otherwise
00672         int select_working_set(int[] working_set)
00673         {
00674                 // return i,j such that
00675                 // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
00676                 // j: mimimizes the decrease of obj value
00677                 //    (if quadratic coefficeint <= 0, replace it with tau)
00678                 //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
00679                 
00680                 double Gmax = -INF;
00681                 double Gmax2 = -INF;
00682                 int Gmax_idx = -1;
00683                 int Gmin_idx = -1;
00684                 double obj_diff_min = INF;
00685         
00686                 for(int t=0;t<active_size;t++)
00687                         if(y[t]==+1)    
00688                         {
00689                                 if(!is_upper_bound(t))
00690                                         if(-G[t] >= Gmax)
00691                                         {
00692                                                 Gmax = -G[t];
00693                                                 Gmax_idx = t;
00694                                         }
00695                         }
00696                         else
00697                         {
00698                                 if(!is_lower_bound(t))
00699                                         if(G[t] >= Gmax)
00700                                         {
00701                                                 Gmax = G[t];
00702                                                 Gmax_idx = t;
00703                                         }
00704                         }
00705         
00706                 int i = Gmax_idx;
00707                 float[] Q_i = null;
00708                 if(i != -1) // null Q_i not accessed: Gmax=-INF if i=-1
00709                         Q_i = Q.get_Q(i,active_size);
00710         
00711                 for(int j=0;j<active_size;j++)
00712                 {
00713                         if(y[j]==+1)
00714                         {
00715                                 if (!is_lower_bound(j))
00716                                 {
00717                                         double grad_diff=Gmax+G[j];
00718                                         if (G[j] >= Gmax2)
00719                                                 Gmax2 = G[j];
00720                                         if (grad_diff > 0)
00721                                         {
00722                                                 double obj_diff; 
00723                                                 double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];
00724                                                 if (quad_coef > 0)
00725                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00726                                                 else
00727                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00728         
00729                                                 if (obj_diff <= obj_diff_min)
00730                                                 {
00731                                                         Gmin_idx=j;
00732                                                         obj_diff_min = obj_diff;
00733                                                 }
00734                                         }
00735                                 }
00736                         }
00737                         else
00738                         {
00739                                 if (!is_upper_bound(j))
00740                                 {
00741                                         double grad_diff= Gmax-G[j];
00742                                         if (-G[j] >= Gmax2)
00743                                                 Gmax2 = -G[j];
00744                                         if (grad_diff > 0)
00745                                         {
00746                                                 double obj_diff; 
00747                                                 double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];
00748                                                 if (quad_coef > 0)
00749                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00750                                                 else
00751                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00752         
00753                                                 if (obj_diff <= obj_diff_min)
00754                                                 {
00755                                                         Gmin_idx=j;
00756                                                         obj_diff_min = obj_diff;
00757                                                 }
00758                                         }
00759                                 }
00760                         }
00761                 }
00762 
00763                 if(Gmax+Gmax2 < eps)
00764                         return 1;
00765 
00766                 working_set[0] = Gmax_idx;
00767                 working_set[1] = Gmin_idx;
00768                 return 0;
00769         }
00770 
00771         private boolean be_shrunk(int i, double Gmax1, double Gmax2)
00772         {       
00773                 if(is_upper_bound(i))
00774                 {
00775                         if(y[i]==+1)
00776                                 return(-G[i] > Gmax1);
00777                         else
00778                                 return(-G[i] > Gmax2);
00779                 }
00780                 else if(is_lower_bound(i))
00781                 {
00782                         if(y[i]==+1)
00783                                 return(G[i] > Gmax2);
00784                         else    
00785                                 return(G[i] > Gmax1);
00786                 }
00787                 else
00788                         return(false);
00789         }
00790 
00791         void do_shrinking()
00792         {
00793                 int i;
00794                 double Gmax1 = -INF;            // max { -y_i * grad(f)_i | i in I_up(\alpha) }
00795                 double Gmax2 = -INF;            // max { y_i * grad(f)_i | i in I_low(\alpha) }
00796 
00797                 // find maximal violating pair first
00798                 for(i=0;i<active_size;i++)
00799                 {
00800                         if(y[i]==+1)
00801                         {
00802                                 if(!is_upper_bound(i))  
00803                                 {
00804                                         if(-G[i] >= Gmax1)
00805                                                 Gmax1 = -G[i];
00806                                 }
00807                                 if(!is_lower_bound(i))
00808                                 {
00809                                         if(G[i] >= Gmax2)
00810                                                 Gmax2 = G[i];
00811                                 }
00812                         }
00813                         else            
00814                         {
00815                                 if(!is_upper_bound(i))  
00816                                 {
00817                                         if(-G[i] >= Gmax2)
00818                                                 Gmax2 = -G[i];
00819                                 }
00820                                 if(!is_lower_bound(i))  
00821                                 {
00822                                         if(G[i] >= Gmax1)
00823                                                 Gmax1 = G[i];
00824                                 }
00825                         }
00826                 }
00827 
00828                 if(unshrink == false && Gmax1 + Gmax2 <= eps*10) 
00829                 {
00830                         unshrink = true;
00831                         reconstruct_gradient();
00832                         active_size = l;
00833                 }
00834 
00835                 for(i=0;i<active_size;i++)
00836                         if (be_shrunk(i, Gmax1, Gmax2))
00837                         {
00838                                 active_size--;
00839                                 while (active_size > i)
00840                                 {
00841                                         if (!be_shrunk(active_size, Gmax1, Gmax2))
00842                                         {
00843                                                 swap_index(i,active_size);
00844                                                 break;
00845                                         }
00846                                         active_size--;
00847                                 }
00848                         }
00849         }
00850 
00851         double calculate_rho()
00852         {
00853                 double r;
00854                 int nr_free = 0;
00855                 double ub = INF, lb = -INF, sum_free = 0;
00856                 for(int i=0;i<active_size;i++)
00857                 {
00858                         double yG = y[i]*G[i];
00859 
00860                         if(is_lower_bound(i))
00861                         {
00862                                 if(y[i] > 0)
00863                                         ub = Math.min(ub,yG);
00864                                 else
00865                                         lb = Math.max(lb,yG);
00866                         }
00867                         else if(is_upper_bound(i))
00868                         {
00869                                 if(y[i] < 0)
00870                                         ub = Math.min(ub,yG);
00871                                 else
00872                                         lb = Math.max(lb,yG);
00873                         }
00874                         else
00875                         {
00876                                 ++nr_free;
00877                                 sum_free += yG;
00878                         }
00879                 }
00880 
00881                 if(nr_free>0)
00882                         r = sum_free/nr_free;
00883                 else
00884                         r = (ub+lb)/2;
00885 
00886                 return r;
00887         }
00888 
00889 }
00890 
00891 //
00892 // Solver for nu-svm classification and regression
00893 //
00894 // additional constraint: e^T \alpha = constant
00895 //
00896 final class Solver_NU extends Solver
00897 {
00898         private SolutionInfo si;
00899 
00900         void Solve(int l, QMatrix Q, double[] p, byte[] y,
00901                    double[] alpha, double Cp, double Cn, double eps,
00902                    SolutionInfo si, int shrinking)
00903         {
00904                 this.si = si;
00905                 super.Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking);
00906         }
00907 
00908         // return 1 if already optimal, return 0 otherwise
00909         int select_working_set(int[] working_set)
00910         {
00911                 // return i,j such that y_i = y_j and
00912                 // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
00913                 // j: minimizes the decrease of obj value
00914                 //    (if quadratic coefficeint <= 0, replace it with tau)
00915                 //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
00916         
00917                 double Gmaxp = -INF;
00918                 double Gmaxp2 = -INF;
00919                 int Gmaxp_idx = -1;
00920         
00921                 double Gmaxn = -INF;
00922                 double Gmaxn2 = -INF;
00923                 int Gmaxn_idx = -1;
00924         
00925                 int Gmin_idx = -1;
00926                 double obj_diff_min = INF;
00927         
00928                 for(int t=0;t<active_size;t++)
00929                         if(y[t]==+1)
00930                         {
00931                                 if(!is_upper_bound(t))
00932                                         if(-G[t] >= Gmaxp)
00933                                         {
00934                                                 Gmaxp = -G[t];
00935                                                 Gmaxp_idx = t;
00936                                         }
00937                         }
00938                         else
00939                         {
00940                                 if(!is_lower_bound(t))
00941                                         if(G[t] >= Gmaxn)
00942                                         {
00943                                                 Gmaxn = G[t];
00944                                                 Gmaxn_idx = t;
00945                                         }
00946                         }
00947         
00948                 int ip = Gmaxp_idx;
00949                 int in = Gmaxn_idx;
00950                 float[] Q_ip = null;
00951                 float[] Q_in = null;
00952                 if(ip != -1) // null Q_ip not accessed: Gmaxp=-INF if ip=-1
00953                         Q_ip = Q.get_Q(ip,active_size);
00954                 if(in != -1)
00955                         Q_in = Q.get_Q(in,active_size);
00956         
00957                 for(int j=0;j<active_size;j++)
00958                 {
00959                         if(y[j]==+1)
00960                         {
00961                                 if (!is_lower_bound(j)) 
00962                                 {
00963                                         double grad_diff=Gmaxp+G[j];
00964                                         if (G[j] >= Gmaxp2)
00965                                                 Gmaxp2 = G[j];
00966                                         if (grad_diff > 0)
00967                                         {
00968                                                 double obj_diff; 
00969                                                 double quad_coef = QD[ip]+QD[j]-2*Q_ip[j];
00970                                                 if (quad_coef > 0)
00971                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00972                                                 else
00973                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00974         
00975                                                 if (obj_diff <= obj_diff_min)
00976                                                 {
00977                                                         Gmin_idx=j;
00978                                                         obj_diff_min = obj_diff;
00979                                                 }
00980                                         }
00981                                 }
00982                         }
00983                         else
00984                         {
00985                                 if (!is_upper_bound(j))
00986                                 {
00987                                         double grad_diff=Gmaxn-G[j];
00988                                         if (-G[j] >= Gmaxn2)
00989                                                 Gmaxn2 = -G[j];
00990                                         if (grad_diff > 0)
00991                                         {
00992                                                 double obj_diff; 
00993                                                 double quad_coef = QD[in]+QD[j]-2*Q_in[j];
00994                                                 if (quad_coef > 0)
00995                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00996                                                 else
00997                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00998         
00999                                                 if (obj_diff <= obj_diff_min)
01000                                                 {
01001                                                         Gmin_idx=j;
01002                                                         obj_diff_min = obj_diff;
01003                                                 }
01004                                         }
01005                                 }
01006                         }
01007                 }
01008 
01009                 if(Math.max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps)
01010                         return 1;
01011         
01012                 if(y[Gmin_idx] == +1)
01013                         working_set[0] = Gmaxp_idx;
01014                 else
01015                         working_set[0] = Gmaxn_idx;
01016                 working_set[1] = Gmin_idx;
01017         
01018                 return 0;
01019         }
01020 
01021         private boolean be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4)
01022         {
01023                 if(is_upper_bound(i))
01024                 {
01025                         if(y[i]==+1)
01026                                 return(-G[i] > Gmax1);
01027                         else    
01028                                 return(-G[i] > Gmax4);
01029                 }
01030                 else if(is_lower_bound(i))
01031                 {
01032                         if(y[i]==+1)
01033                                 return(G[i] > Gmax2);
01034                         else    
01035                                 return(G[i] > Gmax3);
01036                 }
01037                 else
01038                         return(false);
01039         }
01040 
01041         void do_shrinking()
01042         {
01043                 double Gmax1 = -INF;    // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) }
01044                 double Gmax2 = -INF;    // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) }
01045                 double Gmax3 = -INF;    // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) }
01046                 double Gmax4 = -INF;    // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) }
01047  
01048                 // find maximal violating pair first
01049                 int i;
01050                 for(i=0;i<active_size;i++)
01051                 {
01052                         if(!is_upper_bound(i))
01053                         {
01054                                 if(y[i]==+1)
01055                                 {
01056                                         if(-G[i] > Gmax1) Gmax1 = -G[i];
01057                                 }
01058                                 else    if(-G[i] > Gmax4) Gmax4 = -G[i];
01059                         }
01060                         if(!is_lower_bound(i))
01061                         {
01062                                 if(y[i]==+1)
01063                                 {       
01064                                         if(G[i] > Gmax2) Gmax2 = G[i];
01065                                 }
01066                                 else    if(G[i] > Gmax3) Gmax3 = G[i];
01067                         }
01068                 }
01069 
01070                 if(unshrink == false && Math.max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) 
01071                 {
01072                         unshrink = true;
01073                         reconstruct_gradient();
01074                         active_size = l;
01075                 }
01076 
01077                 for(i=0;i<active_size;i++)
01078                         if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4))
01079                         {
01080                                 active_size--;
01081                                 while (active_size > i)
01082                                 {
01083                                         if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
01084                                         {
01085                                                 swap_index(i,active_size);
01086                                                 break;
01087                                         }
01088                                         active_size--;
01089                                 }
01090                         }
01091         }
01092         
01093         double calculate_rho()
01094         {
01095                 int nr_free1 = 0,nr_free2 = 0;
01096                 double ub1 = INF, ub2 = INF;
01097                 double lb1 = -INF, lb2 = -INF;
01098                 double sum_free1 = 0, sum_free2 = 0;
01099 
01100                 for(int i=0;i<active_size;i++)
01101                 {
01102                         if(y[i]==+1)
01103                         {
01104                                 if(is_lower_bound(i))
01105                                         ub1 = Math.min(ub1,G[i]);
01106                                 else if(is_upper_bound(i))
01107                                         lb1 = Math.max(lb1,G[i]);
01108                                 else
01109                                 {
01110                                         ++nr_free1;
01111                                         sum_free1 += G[i];
01112                                 }
01113                         }
01114                         else
01115                         {
01116                                 if(is_lower_bound(i))
01117                                         ub2 = Math.min(ub2,G[i]);
01118                                 else if(is_upper_bound(i))
01119                                         lb2 = Math.max(lb2,G[i]);
01120                                 else
01121                                 {
01122                                         ++nr_free2;
01123                                         sum_free2 += G[i];
01124                                 }
01125                         }
01126                 }
01127 
01128                 double r1,r2;
01129                 if(nr_free1 > 0)
01130                         r1 = sum_free1/nr_free1;
01131                 else
01132                         r1 = (ub1+lb1)/2;
01133 
01134                 if(nr_free2 > 0)
01135                         r2 = sum_free2/nr_free2;
01136                 else
01137                         r2 = (ub2+lb2)/2;
01138 
01139                 si.r = (r1+r2)/2;
01140                 return (r1-r2)/2;
01141         }
01142 }
01143 
01144 //
01145 // Q matrices for various formulations
01146 //
01147 class SVC_Q extends Kernel
01148 {
01149         private final byte[] y;
01150         private final Cache cache;
01151         private final double[] QD;
01152 
01153         SVC_Q(svm_problem prob, svm_parameter param, byte[] y_)
01154         {
01155                 super(prob.l, prob.x, param);
01156                 y = (byte[])y_.clone();
01157                 cache = new Cache(prob.l,(long)(param.cache_size*(1<<20)));
01158                 QD = new double[prob.l];
01159                 for(int i=0;i<prob.l;i++)
01160                         QD[i] = kernel_function(i,i);
01161         }
01162 
01163         float[] get_Q(int i, int len)
01164         {
01165                 float[][] data = new float[1][];
01166                 int start, j;
01167                 if((start = cache.get_data(i,data,len)) < len)
01168                 {
01169                         for(j=start;j<len;j++)
01170                                 data[0][j] = (float)(y[i]*y[j]*kernel_function(i,j));
01171                 }
01172                 return data[0];
01173         }
01174 
01175         double[] get_QD()
01176         {
01177                 return QD;
01178         }
01179 
01180         void swap_index(int i, int j)
01181         {
01182                 cache.swap_index(i,j);
01183                 super.swap_index(i,j);
01184                 do {byte _=y[i]; y[i]=y[j]; y[j]=_;} while(false);
01185                 do {double _=QD[i]; QD[i]=QD[j]; QD[j]=_;} while(false);
01186         }
01187 }
01188 
01189 class ONE_CLASS_Q extends Kernel
01190 {
01191         private final Cache cache;
01192         private final double[] QD;
01193 
01194         ONE_CLASS_Q(svm_problem prob, svm_parameter param)
01195         {
01196                 super(prob.l, prob.x, param);
01197                 cache = new Cache(prob.l,(long)(param.cache_size*(1<<20)));
01198                 QD = new double[prob.l];
01199                 for(int i=0;i<prob.l;i++)
01200                         QD[i] = kernel_function(i,i);
01201         }
01202 
01203         float[] get_Q(int i, int len)
01204         {
01205                 float[][] data = new float[1][];
01206                 int start, j;
01207                 if((start = cache.get_data(i,data,len)) < len)
01208                 {
01209                         for(j=start;j<len;j++)
01210                                 data[0][j] = (float)kernel_function(i,j);
01211                 }
01212                 return data[0];
01213         }
01214 
01215         double[] get_QD()
01216         {
01217                 return QD;
01218         }
01219 
01220         void swap_index(int i, int j)
01221         {
01222                 cache.swap_index(i,j);
01223                 super.swap_index(i,j);
01224                 do {double _=QD[i]; QD[i]=QD[j]; QD[j]=_;} while(false);
01225         }
01226 }
01227 
01228 class SVR_Q extends Kernel
01229 {
01230         private final int l;
01231         private final Cache cache;
01232         private final byte[] sign;
01233         private final int[] index;
01234         private int next_buffer;
01235         private float[][] buffer;
01236         private final double[] QD;
01237 
01238         SVR_Q(svm_problem prob, svm_parameter param)
01239         {
01240                 super(prob.l, prob.x, param);
01241                 l = prob.l;
01242                 cache = new Cache(l,(long)(param.cache_size*(1<<20)));
01243                 QD = new double[2*l];
01244                 sign = new byte[2*l];
01245                 index = new int[2*l];
01246                 for(int k=0;k<l;k++)
01247                 {
01248                         sign[k] = 1;
01249                         sign[k+l] = -1;
01250                         index[k] = k;
01251                         index[k+l] = k;
01252                         QD[k] = kernel_function(k,k);
01253                         QD[k+l] = QD[k];
01254                 }
01255                 buffer = new float[2][2*l];
01256                 next_buffer = 0;
01257         }
01258 
01259         void swap_index(int i, int j)
01260         {
01261                 do {byte _=sign[i]; sign[i]=sign[j]; sign[j]=_;} while(false);
01262                 do {int _=index[i]; index[i]=index[j]; index[j]=_;} while(false);
01263                 do {double _=QD[i]; QD[i]=QD[j]; QD[j]=_;} while(false);
01264         }
01265 
01266         float[] get_Q(int i, int len)
01267         {
01268                 float[][] data = new float[1][];
01269                 int j, real_i = index[i];
01270                 if(cache.get_data(real_i,data,l) < l)
01271                 {
01272                         for(j=0;j<l;j++)
01273                                 data[0][j] = (float)kernel_function(real_i,j);
01274                 }
01275 
01276                 // reorder and copy
01277                 float buf[] = buffer[next_buffer];
01278                 next_buffer = 1 - next_buffer;
01279                 byte si = sign[i];
01280                 for(j=0;j<len;j++)
01281                         buf[j] = (float) si * sign[j] * data[0][index[j]];
01282                 return buf;
01283         }
01284 
01285         double[] get_QD()
01286         {
01287                 return QD;
01288         }
01289 }
01290 
01291 public class svm {
01292         //
01293         // construct and solve various formulations
01294         //
01295         public static final int LIBSVM_VERSION=312; 
01296         public static final Random rand = new Random();
01297 
01298         private static svm_print_interface svm_print_stdout = new svm_print_interface()
01299         {
01300                 public void print(String s)
01301                 {
01302                         System.out.print(s);
01303                         System.out.flush();
01304                 }
01305         };
01306 
01307         private static svm_print_interface svm_print_string = svm_print_stdout;
01308 
01309         static void info(String s) 
01310         {
01311                 svm_print_string.print(s);
01312         }
01313 
01314         private static void solve_c_svc(svm_problem prob, svm_parameter param,
01315                                         double[] alpha, Solver.SolutionInfo si,
01316                                         double Cp, double Cn)
01317         {
01318                 int l = prob.l;
01319                 double[] minus_ones = new double[l];
01320                 byte[] y = new byte[l];
01321 
01322                 int i;
01323 
01324                 for(i=0;i<l;i++)
01325                 {
01326                         alpha[i] = 0;
01327                         minus_ones[i] = -1;
01328                         if(prob.y[i] > 0) y[i] = +1; else y[i] = -1;
01329                 }
01330 
01331                 Solver s = new Solver();
01332                 s.Solve(l, new SVC_Q(prob,param,y), minus_ones, y,
01333                         alpha, Cp, Cn, param.eps, si, param.shrinking);
01334 
01335                 double sum_alpha=0;
01336                 for(i=0;i<l;i++)
01337                         sum_alpha += alpha[i];
01338 
01339                 if (Cp==Cn)
01340                         svm.info("nu = "+sum_alpha/(Cp*prob.l)+"\n");
01341 
01342                 for(i=0;i<l;i++)
01343                         alpha[i] *= y[i];
01344         }
01345 
01346         private static void solve_nu_svc(svm_problem prob, svm_parameter param,
01347                                         double[] alpha, Solver.SolutionInfo si)
01348         {
01349                 int i;
01350                 int l = prob.l;
01351                 double nu = param.nu;
01352 
01353                 byte[] y = new byte[l];
01354 
01355                 for(i=0;i<l;i++)
01356                         if(prob.y[i]>0)
01357                                 y[i] = +1;
01358                         else
01359                                 y[i] = -1;
01360 
01361                 double sum_pos = nu*l/2;
01362                 double sum_neg = nu*l/2;
01363 
01364                 for(i=0;i<l;i++)
01365                         if(y[i] == +1)
01366                         {
01367                                 alpha[i] = Math.min(1.0,sum_pos);
01368                                 sum_pos -= alpha[i];
01369                         }
01370                         else
01371                         {
01372                                 alpha[i] = Math.min(1.0,sum_neg);
01373                                 sum_neg -= alpha[i];
01374                         }
01375 
01376                 double[] zeros = new double[l];
01377 
01378                 for(i=0;i<l;i++)
01379                         zeros[i] = 0;
01380 
01381                 Solver_NU s = new Solver_NU();
01382                 s.Solve(l, new SVC_Q(prob,param,y), zeros, y,
01383                         alpha, 1.0, 1.0, param.eps, si, param.shrinking);
01384                 double r = si.r;
01385 
01386                 svm.info("C = "+1/r+"\n");
01387 
01388                 for(i=0;i<l;i++)
01389                         alpha[i] *= y[i]/r;
01390 
01391                 si.rho /= r;
01392                 si.obj /= (r*r);
01393                 si.upper_bound_p = 1/r;
01394                 si.upper_bound_n = 1/r;
01395         }
01396 
01397         private static void solve_one_class(svm_problem prob, svm_parameter param,
01398                                         double[] alpha, Solver.SolutionInfo si)
01399         {
01400                 int l = prob.l;
01401                 double[] zeros = new double[l];
01402                 byte[] ones = new byte[l];
01403                 int i;
01404 
01405                 int n = (int)(param.nu*prob.l); // # of alpha's at upper bound
01406 
01407                 for(i=0;i<n;i++)
01408                         alpha[i] = 1;
01409                 if(n<prob.l)
01410                         alpha[n] = param.nu * prob.l - n;
01411                 for(i=n+1;i<l;i++)
01412                         alpha[i] = 0;
01413 
01414                 for(i=0;i<l;i++)
01415                 {
01416                         zeros[i] = 0;
01417                         ones[i] = 1;
01418                 }
01419 
01420                 Solver s = new Solver();
01421                 s.Solve(l, new ONE_CLASS_Q(prob,param), zeros, ones,
01422                         alpha, 1.0, 1.0, param.eps, si, param.shrinking);
01423         }
01424 
01425         private static void solve_epsilon_svr(svm_problem prob, svm_parameter param,
01426                                         double[] alpha, Solver.SolutionInfo si)
01427         {
01428                 int l = prob.l;
01429                 double[] alpha2 = new double[2*l];
01430                 double[] linear_term = new double[2*l];
01431                 byte[] y = new byte[2*l];
01432                 int i;
01433 
01434                 for(i=0;i<l;i++)
01435                 {
01436                         alpha2[i] = 0;
01437                         linear_term[i] = param.p - prob.y[i];
01438                         y[i] = 1;
01439 
01440                         alpha2[i+l] = 0;
01441                         linear_term[i+l] = param.p + prob.y[i];
01442                         y[i+l] = -1;
01443                 }
01444 
01445                 Solver s = new Solver();
01446                 s.Solve(2*l, new SVR_Q(prob,param), linear_term, y,
01447                         alpha2, param.C, param.C, param.eps, si, param.shrinking);
01448 
01449                 double sum_alpha = 0;
01450                 for(i=0;i<l;i++)
01451                 {
01452                         alpha[i] = alpha2[i] - alpha2[i+l];
01453                         sum_alpha += Math.abs(alpha[i]);
01454                 }
01455                 svm.info("nu = "+sum_alpha/(param.C*l)+"\n");
01456         }
01457 
01458         private static void solve_nu_svr(svm_problem prob, svm_parameter param,
01459                                         double[] alpha, Solver.SolutionInfo si)
01460         {
01461                 int l = prob.l;
01462                 double C = param.C;
01463                 double[] alpha2 = new double[2*l];
01464                 double[] linear_term = new double[2*l];
01465                 byte[] y = new byte[2*l];
01466                 int i;
01467 
01468                 double sum = C * param.nu * l / 2;
01469                 for(i=0;i<l;i++)
01470                 {
01471                         alpha2[i] = alpha2[i+l] = Math.min(sum,C);
01472                         sum -= alpha2[i];
01473                         
01474                         linear_term[i] = - prob.y[i];
01475                         y[i] = 1;
01476 
01477                         linear_term[i+l] = prob.y[i];
01478                         y[i+l] = -1;
01479                 }
01480 
01481                 Solver_NU s = new Solver_NU();
01482                 s.Solve(2*l, new SVR_Q(prob,param), linear_term, y,
01483                         alpha2, C, C, param.eps, si, param.shrinking);
01484 
01485                 svm.info("epsilon = "+(-si.r)+"\n");
01486                 
01487                 for(i=0;i<l;i++)
01488                         alpha[i] = alpha2[i] - alpha2[i+l];
01489         }
01490 
01491         //
01492         // decision_function
01493         //
01494         static class decision_function
01495         {
01496                 double[] alpha;
01497                 double rho;     
01498         };
01499 
01500         static decision_function svm_train_one(
01501                 svm_problem prob, svm_parameter param,
01502                 double Cp, double Cn)
01503         {
01504                 double[] alpha = new double[prob.l];
01505                 Solver.SolutionInfo si = new Solver.SolutionInfo();
01506                 switch(param.svm_type)
01507                 {
01508                         case svm_parameter.C_SVC:
01509                                 solve_c_svc(prob,param,alpha,si,Cp,Cn);
01510                                 break;
01511                         case svm_parameter.NU_SVC:
01512                                 solve_nu_svc(prob,param,alpha,si);
01513                                 break;
01514                         case svm_parameter.ONE_CLASS:
01515                                 solve_one_class(prob,param,alpha,si);
01516                                 break;
01517                         case svm_parameter.EPSILON_SVR:
01518                                 solve_epsilon_svr(prob,param,alpha,si);
01519                                 break;
01520                         case svm_parameter.NU_SVR:
01521                                 solve_nu_svr(prob,param,alpha,si);
01522                                 break;
01523                 }
01524 
01525                 svm.info("obj = "+si.obj+", rho = "+si.rho+"\n");
01526 
01527                 // output SVs
01528 
01529                 int nSV = 0;
01530                 int nBSV = 0;
01531                 for(int i=0;i<prob.l;i++)
01532                 {
01533                         if(Math.abs(alpha[i]) > 0)
01534                         {
01535                                 ++nSV;
01536                                 if(prob.y[i] > 0)
01537                                 {
01538                                         if(Math.abs(alpha[i]) >= si.upper_bound_p)
01539                                         ++nBSV;
01540                                 }
01541                                 else
01542                                 {
01543                                         if(Math.abs(alpha[i]) >= si.upper_bound_n)
01544                                                 ++nBSV;
01545                                 }
01546                         }
01547                 }
01548 
01549                 svm.info("nSV = "+nSV+", nBSV = "+nBSV+"\n");
01550 
01551                 decision_function f = new decision_function();
01552                 f.alpha = alpha;
01553                 f.rho = si.rho;
01554                 return f;
01555         }
01556 
01557         // Platt's binary SVM Probablistic Output: an improvement from Lin et al.
01558         private static void sigmoid_train(int l, double[] dec_values, double[] labels, 
01559                                   double[] probAB)
01560         {
01561                 double A, B;
01562                 double prior1=0, prior0 = 0;
01563                 int i;
01564 
01565                 for (i=0;i<l;i++)
01566                         if (labels[i] > 0) prior1+=1;
01567                         else prior0+=1;
01568         
01569                 int max_iter=100;       // Maximal number of iterations
01570                 double min_step=1e-10;  // Minimal step taken in line search
01571                 double sigma=1e-12;     // For numerically strict PD of Hessian
01572                 double eps=1e-5;
01573                 double hiTarget=(prior1+1.0)/(prior1+2.0);
01574                 double loTarget=1/(prior0+2.0);
01575                 double[] t= new double[l];
01576                 double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
01577                 double newA,newB,newf,d1,d2;
01578                 int iter; 
01579         
01580                 // Initial Point and Initial Fun Value
01581                 A=0.0; B=Math.log((prior0+1.0)/(prior1+1.0));
01582                 double fval = 0.0;
01583 
01584                 for (i=0;i<l;i++)
01585                 {
01586                         if (labels[i]>0) t[i]=hiTarget;
01587                         else t[i]=loTarget;
01588                         fApB = dec_values[i]*A+B;
01589                         if (fApB>=0)
01590                                 fval += t[i]*fApB + Math.log(1+Math.exp(-fApB));
01591                         else
01592                                 fval += (t[i] - 1)*fApB +Math.log(1+Math.exp(fApB));
01593                 }
01594                 for (iter=0;iter<max_iter;iter++)
01595                 {
01596                         // Update Gradient and Hessian (use H' = H + sigma I)
01597                         h11=sigma; // numerically ensures strict PD
01598                         h22=sigma;
01599                         h21=0.0;g1=0.0;g2=0.0;
01600                         for (i=0;i<l;i++)
01601                         {
01602                                 fApB = dec_values[i]*A+B;
01603                                 if (fApB >= 0)
01604                                 {
01605                                         p=Math.exp(-fApB)/(1.0+Math.exp(-fApB));
01606                                         q=1.0/(1.0+Math.exp(-fApB));
01607                                 }
01608                                 else
01609                                 {
01610                                         p=1.0/(1.0+Math.exp(fApB));
01611                                         q=Math.exp(fApB)/(1.0+Math.exp(fApB));
01612                                 }
01613                                 d2=p*q;
01614                                 h11+=dec_values[i]*dec_values[i]*d2;
01615                                 h22+=d2;
01616                                 h21+=dec_values[i]*d2;
01617                                 d1=t[i]-p;
01618                                 g1+=dec_values[i]*d1;
01619                                 g2+=d1;
01620                         }
01621 
01622                         // Stopping Criteria
01623                         if (Math.abs(g1)<eps && Math.abs(g2)<eps)
01624                                 break;
01625                         
01626                         // Finding Newton direction: -inv(H') * g
01627                         det=h11*h22-h21*h21;
01628                         dA=-(h22*g1 - h21 * g2) / det;
01629                         dB=-(-h21*g1+ h11 * g2) / det;
01630                         gd=g1*dA+g2*dB;
01631 
01632 
01633                         stepsize = 1;           // Line Search
01634                         while (stepsize >= min_step)
01635                         {
01636                                 newA = A + stepsize * dA;
01637                                 newB = B + stepsize * dB;
01638 
01639                                 // New function value
01640                                 newf = 0.0;
01641                                 for (i=0;i<l;i++)
01642                                 {
01643                                         fApB = dec_values[i]*newA+newB;
01644                                         if (fApB >= 0)
01645                                                 newf += t[i]*fApB + Math.log(1+Math.exp(-fApB));
01646                                         else
01647                                                 newf += (t[i] - 1)*fApB +Math.log(1+Math.exp(fApB));
01648                                 }
01649                                 // Check sufficient decrease
01650                                 if (newf<fval+0.0001*stepsize*gd)
01651                                 {
01652                                         A=newA;B=newB;fval=newf;
01653                                         break;
01654                                 }
01655                                 else
01656                                         stepsize = stepsize / 2.0;
01657                         }
01658                         
01659                         if (stepsize < min_step)
01660                         {
01661                                 svm.info("Line search fails in two-class probability estimates\n");
01662                                 break;
01663                         }
01664                 }
01665                 
01666                 if (iter>=max_iter)
01667                         svm.info("Reaching maximal iterations in two-class probability estimates\n");
01668                 probAB[0]=A;probAB[1]=B;
01669         }
01670 
01671         private static double sigmoid_predict(double decision_value, double A, double B)
01672         {
01673                 double fApB = decision_value*A+B;
01674                 if (fApB >= 0)
01675                         return Math.exp(-fApB)/(1.0+Math.exp(-fApB));
01676                 else
01677                         return 1.0/(1+Math.exp(fApB)) ;
01678         }
01679 
01680         // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
01681         private static void multiclass_probability(int k, double[][] r, double[] p)
01682         {
01683                 int t,j;
01684                 int iter = 0, max_iter=Math.max(100,k);
01685                 double[][] Q=new double[k][k];
01686                 double[] Qp=new double[k];
01687                 double pQp, eps=0.005/k;
01688         
01689                 for (t=0;t<k;t++)
01690                 {
01691                         p[t]=1.0/k;  // Valid if k = 1
01692                         Q[t][t]=0;
01693                         for (j=0;j<t;j++)
01694                         {
01695                                 Q[t][t]+=r[j][t]*r[j][t];
01696                                 Q[t][j]=Q[j][t];
01697                         }
01698                         for (j=t+1;j<k;j++)
01699                         {
01700                                 Q[t][t]+=r[j][t]*r[j][t];
01701                                 Q[t][j]=-r[j][t]*r[t][j];
01702                         }
01703                 }
01704                 for (iter=0;iter<max_iter;iter++)
01705                 {
01706                         // stopping condition, recalculate QP,pQP for numerical accuracy
01707                         pQp=0;
01708                         for (t=0;t<k;t++)
01709                         {
01710                                 Qp[t]=0;
01711                                 for (j=0;j<k;j++)
01712                                         Qp[t]+=Q[t][j]*p[j];
01713                                 pQp+=p[t]*Qp[t];
01714                         }
01715                         double max_error=0;
01716                         for (t=0;t<k;t++)
01717                         {
01718                                 double error=Math.abs(Qp[t]-pQp);
01719                                 if (error>max_error)
01720                                         max_error=error;
01721                         }
01722                         if (max_error<eps) break;
01723                 
01724                         for (t=0;t<k;t++)
01725                         {
01726                                 double diff=(-Qp[t]+pQp)/Q[t][t];
01727                                 p[t]+=diff;
01728                                 pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff);
01729                                 for (j=0;j<k;j++)
01730                                 {
01731                                         Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff);
01732                                         p[j]/=(1+diff);
01733                                 }
01734                         }
01735                 }
01736                 if (iter>=max_iter)
01737                         svm.info("Exceeds max_iter in multiclass_prob\n");
01738         }
01739 
01740         // Cross-validation decision values for probability estimates
01741         private static void svm_binary_svc_probability(svm_problem prob, svm_parameter param, double Cp, double Cn, double[] probAB)
01742         {
01743                 int i;
01744                 int nr_fold = 5;
01745                 int[] perm = new int[prob.l];
01746                 double[] dec_values = new double[prob.l];
01747 
01748                 // random shuffle
01749                 for(i=0;i<prob.l;i++) perm[i]=i;
01750                 for(i=0;i<prob.l;i++)
01751                 {
01752                         int j = i+rand.nextInt(prob.l-i);
01753                         do {int _=perm[i]; perm[i]=perm[j]; perm[j]=_;} while(false);
01754                 }
01755                 for(i=0;i<nr_fold;i++)
01756                 {
01757                         int begin = i*prob.l/nr_fold;
01758                         int end = (i+1)*prob.l/nr_fold;
01759                         int j,k;
01760                         svm_problem subprob = new svm_problem();
01761 
01762                         subprob.l = prob.l-(end-begin);
01763                         subprob.x = new svm_node[subprob.l][];
01764                         subprob.y = new double[subprob.l];
01765                         
01766                         k=0;
01767                         for(j=0;j<begin;j++)
01768                         {
01769                                 subprob.x[k] = prob.x[perm[j]];
01770                                 subprob.y[k] = prob.y[perm[j]];
01771                                 ++k;
01772                         }
01773                         for(j=end;j<prob.l;j++)
01774                         {
01775                                 subprob.x[k] = prob.x[perm[j]];
01776                                 subprob.y[k] = prob.y[perm[j]];
01777                                 ++k;
01778                         }
01779                         int p_count=0,n_count=0;
01780                         for(j=0;j<k;j++)
01781                                 if(subprob.y[j]>0)
01782                                         p_count++;
01783                                 else
01784                                         n_count++;
01785                         
01786                         if(p_count==0 && n_count==0)
01787                                 for(j=begin;j<end;j++)
01788                                         dec_values[perm[j]] = 0;
01789                         else if(p_count > 0 && n_count == 0)
01790                                 for(j=begin;j<end;j++)
01791                                         dec_values[perm[j]] = 1;
01792                         else if(p_count == 0 && n_count > 0)
01793                                 for(j=begin;j<end;j++)
01794                                         dec_values[perm[j]] = -1;
01795                         else
01796                         {
01797                                 svm_parameter subparam = (svm_parameter)param.clone();
01798                                 subparam.probability=0;
01799                                 subparam.C=1.0;
01800                                 subparam.nr_weight=2;
01801                                 subparam.weight_label = new int[2];
01802                                 subparam.weight = new double[2];
01803                                 subparam.weight_label[0]=+1;
01804                                 subparam.weight_label[1]=-1;
01805                                 subparam.weight[0]=Cp;
01806                                 subparam.weight[1]=Cn;
01807                                 svm_model submodel = svm_train(subprob,subparam);
01808                                 for(j=begin;j<end;j++)
01809                                 {
01810                                         double[] dec_value=new double[1];
01811                                         svm_predict_values(submodel,prob.x[perm[j]],dec_value);
01812                                         dec_values[perm[j]]=dec_value[0];
01813                                         // ensure +1 -1 order; reason not using CV subroutine
01814                                         dec_values[perm[j]] *= submodel.label[0];
01815                                 }               
01816                         }
01817                 }               
01818                 sigmoid_train(prob.l,dec_values,prob.y,probAB);
01819         }
01820 
01821         // Return parameter of a Laplace distribution 
01822         private static double svm_svr_probability(svm_problem prob, svm_parameter param)
01823         {
01824                 int i;
01825                 int nr_fold = 5;
01826                 double[] ymv = new double[prob.l];
01827                 double mae = 0;
01828 
01829                 svm_parameter newparam = (svm_parameter)param.clone();
01830                 newparam.probability = 0;
01831                 svm_cross_validation(prob,newparam,nr_fold,ymv);
01832                 for(i=0;i<prob.l;i++)
01833                 {
01834                         ymv[i]=prob.y[i]-ymv[i];
01835                         mae += Math.abs(ymv[i]);
01836                 }               
01837                 mae /= prob.l;
01838                 double std=Math.sqrt(2*mae*mae);
01839                 int count=0;
01840                 mae=0;
01841                 for(i=0;i<prob.l;i++)
01842                         if (Math.abs(ymv[i]) > 5*std) 
01843                                 count=count+1;
01844                         else 
01845                                 mae+=Math.abs(ymv[i]);
01846                 mae /= (prob.l-count);
01847                 svm.info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma="+mae+"\n");
01848                 return mae;
01849         }
01850 
01851         // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
01852         // perm, length l, must be allocated before calling this subroutine
01853         private static void svm_group_classes(svm_problem prob, int[] nr_class_ret, int[][] label_ret, int[][] start_ret, int[][] count_ret, int[] perm)
01854         {
01855                 int l = prob.l;
01856                 int max_nr_class = 16;
01857                 int nr_class = 0;
01858                 int[] label = new int[max_nr_class];
01859                 int[] count = new int[max_nr_class];
01860                 int[] data_label = new int[l];
01861                 int i;
01862 
01863                 for(i=0;i<l;i++)
01864                 {
01865                         int this_label = (int)(prob.y[i]);
01866                         int j;
01867                         for(j=0;j<nr_class;j++)
01868                         {
01869                                 if(this_label == label[j])
01870                                 {
01871                                         ++count[j];
01872                                         break;
01873                                 }
01874                         }
01875                         data_label[i] = j;
01876                         if(j == nr_class)
01877                         {
01878                                 if(nr_class == max_nr_class)
01879                                 {
01880                                         max_nr_class *= 2;
01881                                         int[] new_data = new int[max_nr_class];
01882                                         System.arraycopy(label,0,new_data,0,label.length);
01883                                         label = new_data;
01884                                         new_data = new int[max_nr_class];
01885                                         System.arraycopy(count,0,new_data,0,count.length);
01886                                         count = new_data;                                       
01887                                 }
01888                                 label[nr_class] = this_label;
01889                                 count[nr_class] = 1;
01890                                 ++nr_class;
01891                         }
01892                 }
01893 
01894                 int[] start = new int[nr_class];
01895                 start[0] = 0;
01896                 for(i=1;i<nr_class;i++)
01897                         start[i] = start[i-1]+count[i-1];
01898                 for(i=0;i<l;i++)
01899                 {
01900                         perm[start[data_label[i]]] = i;
01901                         ++start[data_label[i]];
01902                 }
01903                 start[0] = 0;
01904                 for(i=1;i<nr_class;i++)
01905                         start[i] = start[i-1]+count[i-1];
01906 
01907                 nr_class_ret[0] = nr_class;
01908                 label_ret[0] = label;
01909                 start_ret[0] = start;
01910                 count_ret[0] = count;
01911         }
01912 
01913         //
01914         // Interface functions
01915         //
01916         public static svm_model svm_train(svm_problem prob, svm_parameter param)
01917         {
01918                 svm_model model = new svm_model();
01919                 model.param = param;
01920 
01921                 if(param.svm_type == svm_parameter.ONE_CLASS ||
01922                    param.svm_type == svm_parameter.EPSILON_SVR ||
01923                    param.svm_type == svm_parameter.NU_SVR)
01924                 {
01925                         // regression or one-class-svm
01926                         model.nr_class = 2;
01927                         model.label = null;
01928                         model.nSV = null;
01929                         model.probA = null; model.probB = null;
01930                         model.sv_coef = new double[1][];
01931 
01932                         if(param.probability == 1 &&
01933                            (param.svm_type == svm_parameter.EPSILON_SVR ||
01934                             param.svm_type == svm_parameter.NU_SVR))
01935                         {
01936                                 model.probA = new double[1];
01937                                 model.probA[0] = svm_svr_probability(prob,param);
01938                         }
01939 
01940                         decision_function f = svm_train_one(prob,param,0,0);
01941                         model.rho = new double[1];
01942                         model.rho[0] = f.rho;
01943 
01944                         int nSV = 0;
01945                         int i;
01946                         for(i=0;i<prob.l;i++)
01947                                 if(Math.abs(f.alpha[i]) > 0) ++nSV;
01948                         model.l = nSV;
01949                         model.SV = new svm_node[nSV][];
01950                         model.sv_coef[0] = new double[nSV];
01951                         int j = 0;
01952                         for(i=0;i<prob.l;i++)
01953                                 if(Math.abs(f.alpha[i]) > 0)
01954                                 {
01955                                         model.SV[j] = prob.x[i];
01956                                         model.sv_coef[0][j] = f.alpha[i];
01957                                         ++j;
01958                                 }
01959                 }
01960                 else
01961                 {
01962                         // classification
01963                         int l = prob.l;
01964                         int[] tmp_nr_class = new int[1];
01965                         int[][] tmp_label = new int[1][];
01966                         int[][] tmp_start = new int[1][];
01967                         int[][] tmp_count = new int[1][];                       
01968                         int[] perm = new int[l];
01969 
01970                         // group training data of the same class
01971                         svm_group_classes(prob,tmp_nr_class,tmp_label,tmp_start,tmp_count,perm);
01972                         int nr_class = tmp_nr_class[0];                 
01973                         int[] label = tmp_label[0];
01974                         int[] start = tmp_start[0];
01975                         int[] count = tmp_count[0];
01976                         
01977                         if(nr_class == 1) 
01978                                 svm.info("WARNING: training data in only one class. See README for details.\n");
01979                         
01980                         svm_node[][] x = new svm_node[l][];
01981                         int i;
01982                         for(i=0;i<l;i++)
01983                                 x[i] = prob.x[perm[i]];
01984 
01985                         // calculate weighted C
01986 
01987                         double[] weighted_C = new double[nr_class];
01988                         for(i=0;i<nr_class;i++)
01989                                 weighted_C[i] = param.C;
01990                         for(i=0;i<param.nr_weight;i++)
01991                         {
01992                                 int j;
01993                                 for(j=0;j<nr_class;j++)
01994                                         if(param.weight_label[i] == label[j])
01995                                                 break;
01996                                 if(j == nr_class)
01997                                         System.err.print("WARNING: class label "+param.weight_label[i]+" specified in weight is not found\n");
01998                                 else
01999                                         weighted_C[j] *= param.weight[i];
02000                         }
02001 
02002                         // train k*(k-1)/2 models
02003 
02004                         boolean[] nonzero = new boolean[l];
02005                         for(i=0;i<l;i++)
02006                                 nonzero[i] = false;
02007                         decision_function[] f = new decision_function[nr_class*(nr_class-1)/2];
02008 
02009                         double[] probA=null,probB=null;
02010                         if (param.probability == 1)
02011                         {
02012                                 probA=new double[nr_class*(nr_class-1)/2];
02013                                 probB=new double[nr_class*(nr_class-1)/2];
02014                         }
02015 
02016                         int p = 0;
02017                         for(i=0;i<nr_class;i++)
02018                                 for(int j=i+1;j<nr_class;j++)
02019                                 {
02020                                         svm_problem sub_prob = new svm_problem();
02021                                         int si = start[i], sj = start[j];
02022                                         int ci = count[i], cj = count[j];
02023                                         sub_prob.l = ci+cj;
02024                                         sub_prob.x = new svm_node[sub_prob.l][];
02025                                         sub_prob.y = new double[sub_prob.l];
02026                                         int k;
02027                                         for(k=0;k<ci;k++)
02028                                         {
02029                                                 sub_prob.x[k] = x[si+k];
02030                                                 sub_prob.y[k] = +1;
02031                                         }
02032                                         for(k=0;k<cj;k++)
02033                                         {
02034                                                 sub_prob.x[ci+k] = x[sj+k];
02035                                                 sub_prob.y[ci+k] = -1;
02036                                         }
02037 
02038                                         if(param.probability == 1)
02039                                         {
02040                                                 double[] probAB=new double[2];
02041                                                 svm_binary_svc_probability(sub_prob,param,weighted_C[i],weighted_C[j],probAB);
02042                                                 probA[p]=probAB[0];
02043                                                 probB[p]=probAB[1];
02044                                         }
02045 
02046                                         f[p] = svm_train_one(sub_prob,param,weighted_C[i],weighted_C[j]);
02047                                         for(k=0;k<ci;k++)
02048                                                 if(!nonzero[si+k] && Math.abs(f[p].alpha[k]) > 0)
02049                                                         nonzero[si+k] = true;
02050                                         for(k=0;k<cj;k++)
02051                                                 if(!nonzero[sj+k] && Math.abs(f[p].alpha[ci+k]) > 0)
02052                                                         nonzero[sj+k] = true;
02053                                         ++p;
02054                                 }
02055 
02056                         // build output
02057 
02058                         model.nr_class = nr_class;
02059 
02060                         model.label = new int[nr_class];
02061                         for(i=0;i<nr_class;i++)
02062                                 model.label[i] = label[i];
02063 
02064                         model.rho = new double[nr_class*(nr_class-1)/2];
02065                         for(i=0;i<nr_class*(nr_class-1)/2;i++)
02066                                 model.rho[i] = f[i].rho;
02067 
02068                         if(param.probability == 1)
02069                         {
02070                                 model.probA = new double[nr_class*(nr_class-1)/2];
02071                                 model.probB = new double[nr_class*(nr_class-1)/2];
02072                                 for(i=0;i<nr_class*(nr_class-1)/2;i++)
02073                                 {
02074                                         model.probA[i] = probA[i];
02075                                         model.probB[i] = probB[i];
02076                                 }
02077                         }
02078                         else
02079                         {
02080                                 model.probA=null;
02081                                 model.probB=null;
02082                         }
02083 
02084                         int nnz = 0;
02085                         int[] nz_count = new int[nr_class];
02086                         model.nSV = new int[nr_class];
02087                         for(i=0;i<nr_class;i++)
02088                         {
02089                                 int nSV = 0;
02090                                 for(int j=0;j<count[i];j++)
02091                                         if(nonzero[start[i]+j])
02092                                         {
02093                                                 ++nSV;
02094                                                 ++nnz;
02095                                         }
02096                                 model.nSV[i] = nSV;
02097                                 nz_count[i] = nSV;
02098                         }
02099 
02100                         svm.info("Total nSV = "+nnz+"\n");
02101 
02102                         model.l = nnz;
02103                         model.SV = new svm_node[nnz][];
02104                         p = 0;
02105                         for(i=0;i<l;i++)
02106                                 if(nonzero[i]) model.SV[p++] = x[i];
02107 
02108                         int[] nz_start = new int[nr_class];
02109                         nz_start[0] = 0;
02110                         for(i=1;i<nr_class;i++)
02111                                 nz_start[i] = nz_start[i-1]+nz_count[i-1];
02112 
02113                         model.sv_coef = new double[nr_class-1][];
02114                         for(i=0;i<nr_class-1;i++)
02115                                 model.sv_coef[i] = new double[nnz];
02116 
02117                         p = 0;
02118                         for(i=0;i<nr_class;i++)
02119                                 for(int j=i+1;j<nr_class;j++)
02120                                 {
02121                                         // classifier (i,j): coefficients with
02122                                         // i are in sv_coef[j-1][nz_start[i]...],
02123                                         // j are in sv_coef[i][nz_start[j]...]
02124 
02125                                         int si = start[i];
02126                                         int sj = start[j];
02127                                         int ci = count[i];
02128                                         int cj = count[j];
02129 
02130                                         int q = nz_start[i];
02131                                         int k;
02132                                         for(k=0;k<ci;k++)
02133                                                 if(nonzero[si+k])
02134                                                         model.sv_coef[j-1][q++] = f[p].alpha[k];
02135                                         q = nz_start[j];
02136                                         for(k=0;k<cj;k++)
02137                                                 if(nonzero[sj+k])
02138                                                         model.sv_coef[i][q++] = f[p].alpha[ci+k];
02139                                         ++p;
02140                                 }
02141                 }
02142                 return model;
02143         }
02144         
02145         // Stratified cross validation
02146         public static void svm_cross_validation(svm_problem prob, svm_parameter param, int nr_fold, double[] target)
02147         {
02148                 int i;
02149                 int[] fold_start = new int[nr_fold+1];
02150                 int l = prob.l;
02151                 int[] perm = new int[l];
02152                 
02153                 // stratified cv may not give leave-one-out rate
02154                 // Each class to l folds -> some folds may have zero elements
02155                 if((param.svm_type == svm_parameter.C_SVC ||
02156                     param.svm_type == svm_parameter.NU_SVC) && nr_fold < l)
02157                 {
02158                         int[] tmp_nr_class = new int[1];
02159                         int[][] tmp_label = new int[1][];
02160                         int[][] tmp_start = new int[1][];
02161                         int[][] tmp_count = new int[1][];
02162 
02163                         svm_group_classes(prob,tmp_nr_class,tmp_label,tmp_start,tmp_count,perm);
02164 
02165                         int nr_class = tmp_nr_class[0];
02166                         int[] start = tmp_start[0];
02167                         int[] count = tmp_count[0];             
02168 
02169                         // random shuffle and then data grouped by fold using the array perm
02170                         int[] fold_count = new int[nr_fold];
02171                         int c;
02172                         int[] index = new int[l];
02173                         for(i=0;i<l;i++)
02174                                 index[i]=perm[i];
02175                         for (c=0; c<nr_class; c++)
02176                                 for(i=0;i<count[c];i++)
02177                                 {
02178                                         int j = i+rand.nextInt(count[c]-i);
02179                                         do {int _=index[start[c]+j]; index[start[c]+j]=index[start[c]+i]; index[start[c]+i]=_;} while(false);
02180                                 }
02181                         for(i=0;i<nr_fold;i++)
02182                         {
02183                                 fold_count[i] = 0;
02184                                 for (c=0; c<nr_class;c++)
02185                                         fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold;
02186                         }
02187                         fold_start[0]=0;
02188                         for (i=1;i<=nr_fold;i++)
02189                                 fold_start[i] = fold_start[i-1]+fold_count[i-1];
02190                         for (c=0; c<nr_class;c++)
02191                                 for(i=0;i<nr_fold;i++)
02192                                 {
02193                                         int begin = start[c]+i*count[c]/nr_fold;
02194                                         int end = start[c]+(i+1)*count[c]/nr_fold;
02195                                         for(int j=begin;j<end;j++)
02196                                         {
02197                                                 perm[fold_start[i]] = index[j];
02198                                                 fold_start[i]++;
02199                                         }
02200                                 }
02201                         fold_start[0]=0;
02202                         for (i=1;i<=nr_fold;i++)
02203                                 fold_start[i] = fold_start[i-1]+fold_count[i-1];
02204                 }
02205                 else
02206                 {
02207                         for(i=0;i<l;i++) perm[i]=i;
02208                         for(i=0;i<l;i++)
02209                         {
02210                                 int j = i+rand.nextInt(l-i);
02211                                 do {int _=perm[i]; perm[i]=perm[j]; perm[j]=_;} while(false);
02212                         }
02213                         for(i=0;i<=nr_fold;i++)
02214                                 fold_start[i]=i*l/nr_fold;
02215                 }
02216 
02217                 for(i=0;i<nr_fold;i++)
02218                 {
02219                         int begin = fold_start[i];
02220                         int end = fold_start[i+1];
02221                         int j,k;
02222                         svm_problem subprob = new svm_problem();
02223 
02224                         subprob.l = l-(end-begin);
02225                         subprob.x = new svm_node[subprob.l][];
02226                         subprob.y = new double[subprob.l];
02227 
02228                         k=0;
02229                         for(j=0;j<begin;j++)
02230                         {
02231                                 subprob.x[k] = prob.x[perm[j]];
02232                                 subprob.y[k] = prob.y[perm[j]];
02233                                 ++k;
02234                         }
02235                         for(j=end;j<l;j++)
02236                         {
02237                                 subprob.x[k] = prob.x[perm[j]];
02238                                 subprob.y[k] = prob.y[perm[j]];
02239                                 ++k;
02240                         }
02241                         svm_model submodel = svm_train(subprob,param);
02242                         if(param.probability==1 &&
02243                            (param.svm_type == svm_parameter.C_SVC ||
02244                             param.svm_type == svm_parameter.NU_SVC))
02245                         {
02246                                 double[] prob_estimates= new double[svm_get_nr_class(submodel)];
02247                                 for(j=begin;j<end;j++)
02248                                         target[perm[j]] = svm_predict_probability(submodel,prob.x[perm[j]],prob_estimates);
02249                         }
02250                         else
02251                                 for(j=begin;j<end;j++)
02252                                         target[perm[j]] = svm_predict(submodel,prob.x[perm[j]]);
02253                 }
02254         }
02255 
02256         public static int svm_get_svm_type(svm_model model)
02257         {
02258                 return model.param.svm_type;
02259         }
02260 
02261         public static int svm_get_nr_class(svm_model model)
02262         {
02263                 return model.nr_class;
02264         }
02265 
02266         public static void svm_get_labels(svm_model model, int[] label)
02267         {
02268                 if (model.label != null)
02269                         for(int i=0;i<model.nr_class;i++)
02270                                 label[i] = model.label[i];
02271         }
02272 
02273         public static double svm_get_svr_probability(svm_model model)
02274         {
02275                 if ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) &&
02276                     model.probA!=null)
02277                 return model.probA[0];
02278                 else
02279                 {
02280                         System.err.print("Model doesn't contain information for SVR probability inference\n");
02281                         return 0;
02282                 }
02283         }
02284 
02285         public static double svm_predict_values(svm_model model, svm_node[] x, double[] dec_values)
02286         {
02287                 int i;
02288                 if(model.param.svm_type == svm_parameter.ONE_CLASS ||
02289                    model.param.svm_type == svm_parameter.EPSILON_SVR ||
02290                    model.param.svm_type == svm_parameter.NU_SVR)
02291                 {
02292                         double[] sv_coef = model.sv_coef[0];
02293                         double sum = 0;
02294                         for(i=0;i<model.l;i++)
02295                                 sum += sv_coef[i] * Kernel.k_function(x,model.SV[i],model.param);
02296                         sum -= model.rho[0];
02297                         dec_values[0] = sum;
02298 
02299                         if(model.param.svm_type == svm_parameter.ONE_CLASS)
02300                                 return (sum>0)?1:-1;
02301                         else
02302                                 return sum;
02303                 }
02304                 else
02305                 {
02306                         int nr_class = model.nr_class;
02307                         int l = model.l;
02308                 
02309                         double[] kvalue = new double[l];
02310                         for(i=0;i<l;i++)
02311                                 kvalue[i] = Kernel.k_function(x,model.SV[i],model.param);
02312 
02313                         int[] start = new int[nr_class];
02314                         start[0] = 0;
02315                         for(i=1;i<nr_class;i++)
02316                                 start[i] = start[i-1]+model.nSV[i-1];
02317 
02318                         int[] vote = new int[nr_class];
02319                         for(i=0;i<nr_class;i++)
02320                                 vote[i] = 0;
02321 
02322                         int p=0;
02323                         for(i=0;i<nr_class;i++)
02324                                 for(int j=i+1;j<nr_class;j++)
02325                                 {
02326                                         double sum = 0;
02327                                         int si = start[i];
02328                                         int sj = start[j];
02329                                         int ci = model.nSV[i];
02330                                         int cj = model.nSV[j];
02331                                 
02332                                         int k;
02333                                         double[] coef1 = model.sv_coef[j-1];
02334                                         double[] coef2 = model.sv_coef[i];
02335                                         for(k=0;k<ci;k++)
02336                                                 sum += coef1[si+k] * kvalue[si+k];
02337                                         for(k=0;k<cj;k++)
02338                                                 sum += coef2[sj+k] * kvalue[sj+k];
02339                                         sum -= model.rho[p];
02340                                         dec_values[p] = sum;                                    
02341 
02342                                         if(dec_values[p] > 0)
02343                                                 ++vote[i];
02344                                         else
02345                                                 ++vote[j];
02346                                         p++;
02347                                 }
02348 
02349                         int vote_max_idx = 0;
02350                         for(i=1;i<nr_class;i++)
02351                                 if(vote[i] > vote[vote_max_idx])
02352                                         vote_max_idx = i;
02353 
02354                         return model.label[vote_max_idx];
02355                 }
02356         }
02357 
02358         public static double svm_predict(svm_model model, svm_node[] x)
02359         {
02360                 int nr_class = model.nr_class;
02361                 double[] dec_values;
02362                 if(model.param.svm_type == svm_parameter.ONE_CLASS ||
02363                                 model.param.svm_type == svm_parameter.EPSILON_SVR ||
02364                                 model.param.svm_type == svm_parameter.NU_SVR)
02365                         dec_values = new double[1];
02366                 else
02367                         dec_values = new double[nr_class*(nr_class-1)/2];
02368                 double pred_result = svm_predict_values(model, x, dec_values);
02369                 return pred_result;
02370         }
02371 
02372         public static double svm_predict_probability(svm_model model, svm_node[] x, double[] prob_estimates)
02373         {
02374                 if ((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) &&
02375                     model.probA!=null && model.probB!=null)
02376                 {
02377                         int i;
02378                         int nr_class = model.nr_class;
02379                         double[] dec_values = new double[nr_class*(nr_class-1)/2];
02380                         svm_predict_values(model, x, dec_values);
02381 
02382                         double min_prob=1e-7;
02383                         double[][] pairwise_prob=new double[nr_class][nr_class];
02384                         
02385                         int k=0;
02386                         for(i=0;i<nr_class;i++)
02387                                 for(int j=i+1;j<nr_class;j++)
02388                                 {
02389                                         pairwise_prob[i][j]=Math.min(Math.max(sigmoid_predict(dec_values[k],model.probA[k],model.probB[k]),min_prob),1-min_prob);
02390                                         pairwise_prob[j][i]=1-pairwise_prob[i][j];
02391                                         k++;
02392                                 }
02393                         multiclass_probability(nr_class,pairwise_prob,prob_estimates);
02394 
02395                         int prob_max_idx = 0;
02396                         for(i=1;i<nr_class;i++)
02397                                 if(prob_estimates[i] > prob_estimates[prob_max_idx])
02398                                         prob_max_idx = i;
02399                         return model.label[prob_max_idx];
02400                 }
02401                 else 
02402                         return svm_predict(model, x);
02403         }
02404 
02405         static final String svm_type_table[] =
02406         {
02407                 "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",
02408         };
02409 
02410         static final String kernel_type_table[]=
02411         {
02412                 "linear","polynomial","rbf","sigmoid","precomputed"
02413         };
02414 
02415         public static void svm_save_model(String model_file_name, svm_model model) throws IOException
02416         {
02417                 DataOutputStream fp = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(model_file_name)));
02418 
02419                 svm_parameter param = model.param;
02420 
02421                 fp.writeBytes("svm_type "+svm_type_table[param.svm_type]+"\n");
02422                 fp.writeBytes("kernel_type "+kernel_type_table[param.kernel_type]+"\n");
02423 
02424                 if(param.kernel_type == svm_parameter.POLY)
02425                         fp.writeBytes("degree "+param.degree+"\n");
02426 
02427                 if(param.kernel_type == svm_parameter.POLY ||
02428                    param.kernel_type == svm_parameter.RBF ||
02429                    param.kernel_type == svm_parameter.SIGMOID)
02430                         fp.writeBytes("gamma "+param.gamma+"\n");
02431 
02432                 if(param.kernel_type == svm_parameter.POLY ||
02433                    param.kernel_type == svm_parameter.SIGMOID)
02434                         fp.writeBytes("coef0 "+param.coef0+"\n");
02435 
02436                 int nr_class = model.nr_class;
02437                 int l = model.l;
02438                 fp.writeBytes("nr_class "+nr_class+"\n");
02439                 fp.writeBytes("total_sv "+l+"\n");
02440         
02441                 {
02442                         fp.writeBytes("rho");
02443                         for(int i=0;i<nr_class*(nr_class-1)/2;i++)
02444                                 fp.writeBytes(" "+model.rho[i]);
02445                         fp.writeBytes("\n");
02446                 }
02447         
02448                 if(model.label != null)
02449                 {
02450                         fp.writeBytes("label");
02451                         for(int i=0;i<nr_class;i++)
02452                                 fp.writeBytes(" "+model.label[i]);
02453                         fp.writeBytes("\n");
02454                 }
02455 
02456                 if(model.probA != null) // regression has probA only
02457                 {
02458                         fp.writeBytes("probA");
02459                         for(int i=0;i<nr_class*(nr_class-1)/2;i++)
02460                                 fp.writeBytes(" "+model.probA[i]);
02461                         fp.writeBytes("\n");
02462                 }
02463                 if(model.probB != null) 
02464                 {
02465                         fp.writeBytes("probB");
02466                         for(int i=0;i<nr_class*(nr_class-1)/2;i++)
02467                                 fp.writeBytes(" "+model.probB[i]);
02468                         fp.writeBytes("\n");
02469                 }
02470 
02471                 if(model.nSV != null)
02472                 {
02473                         fp.writeBytes("nr_sv");
02474                         for(int i=0;i<nr_class;i++)
02475                                 fp.writeBytes(" "+model.nSV[i]);
02476                         fp.writeBytes("\n");
02477                 }
02478 
02479                 fp.writeBytes("SV\n");
02480                 double[][] sv_coef = model.sv_coef;
02481                 svm_node[][] SV = model.SV;
02482 
02483                 for(int i=0;i<l;i++)
02484                 {
02485                         for(int j=0;j<nr_class-1;j++)
02486                                 fp.writeBytes(sv_coef[j][i]+" ");
02487 
02488                         svm_node[] p = SV[i];
02489                         if(param.kernel_type == svm_parameter.PRECOMPUTED)
02490                                 fp.writeBytes("0:"+(int)(p[0].value));
02491                         else    
02492                                 for(int j=0;j<p.length;j++)
02493                                         fp.writeBytes(p[j].index+":"+p[j].value+" ");
02494                         fp.writeBytes("\n");
02495                 }
02496 
02497                 fp.close();
02498         }
02499 
02500         private static double atof(String s)
02501         {
02502                 return Double.valueOf(s).doubleValue();
02503         }
02504 
02505         private static int atoi(String s)
02506         {
02507                 return Integer.parseInt(s);
02508         }
02509 
02510         public static svm_model svm_load_model(String model_file_name) throws IOException
02511         {
02512                 return svm_load_model(new BufferedReader(new FileReader(model_file_name)));
02513         }
02514 
02515         public static svm_model svm_load_model(BufferedReader fp) throws IOException
02516         {
02517                 // read parameters
02518 
02519                 svm_model model = new svm_model();
02520                 svm_parameter param = new svm_parameter();
02521                 model.param = param;
02522                 model.rho = null;
02523                 model.probA = null;
02524                 model.probB = null;
02525                 model.label = null;
02526                 model.nSV = null;
02527 
02528                 while(true)
02529                 {
02530                         String cmd = fp.readLine();
02531                         String arg = cmd.substring(cmd.indexOf(' ')+1);
02532 
02533                         if(cmd.startsWith("svm_type"))
02534                         {
02535                                 int i;
02536                                 for(i=0;i<svm_type_table.length;i++)
02537                                 {
02538                                         if(arg.indexOf(svm_type_table[i])!=-1)
02539                                         {
02540                                                 param.svm_type=i;
02541                                                 break;
02542                                         }
02543                                 }
02544                                 if(i == svm_type_table.length)
02545                                 {
02546                                         System.err.print("unknown svm type.\n");
02547                                         return null;
02548                                 }
02549                         }
02550                         else if(cmd.startsWith("kernel_type"))
02551                         {
02552                                 int i;
02553                                 for(i=0;i<kernel_type_table.length;i++)
02554                                 {
02555                                         if(arg.indexOf(kernel_type_table[i])!=-1)
02556                                         {
02557                                                 param.kernel_type=i;
02558                                                 break;
02559                                         }
02560                                 }
02561                                 if(i == kernel_type_table.length)
02562                                 {
02563                                         System.err.print("unknown kernel function.\n");
02564                                         return null;
02565                                 }
02566                         }
02567                         else if(cmd.startsWith("degree"))
02568                                 param.degree = atoi(arg);
02569                         else if(cmd.startsWith("gamma"))
02570                                 param.gamma = atof(arg);
02571                         else if(cmd.startsWith("coef0"))
02572                                 param.coef0 = atof(arg);
02573                         else if(cmd.startsWith("nr_class"))
02574                                 model.nr_class = atoi(arg);
02575                         else if(cmd.startsWith("total_sv"))
02576                                 model.l = atoi(arg);
02577                         else if(cmd.startsWith("rho"))
02578                         {
02579                                 int n = model.nr_class * (model.nr_class-1)/2;
02580                                 model.rho = new double[n];
02581                                 StringTokenizer st = new StringTokenizer(arg);
02582                                 for(int i=0;i<n;i++)
02583                                         model.rho[i] = atof(st.nextToken());
02584                         }
02585                         else if(cmd.startsWith("label"))
02586                         {
02587                                 int n = model.nr_class;
02588                                 model.label = new int[n];
02589                                 StringTokenizer st = new StringTokenizer(arg);
02590                                 for(int i=0;i<n;i++)
02591                                         model.label[i] = atoi(st.nextToken());                                  
02592                         }
02593                         else if(cmd.startsWith("probA"))
02594                         {
02595                                 int n = model.nr_class*(model.nr_class-1)/2;
02596                                 model.probA = new double[n];
02597                                 StringTokenizer st = new StringTokenizer(arg);
02598                                 for(int i=0;i<n;i++)
02599                                         model.probA[i] = atof(st.nextToken());                                  
02600                         }
02601                         else if(cmd.startsWith("probB"))
02602                         {
02603                                 int n = model.nr_class*(model.nr_class-1)/2;
02604                                 model.probB = new double[n];
02605                                 StringTokenizer st = new StringTokenizer(arg);
02606                                 for(int i=0;i<n;i++)
02607                                         model.probB[i] = atof(st.nextToken());                                  
02608                         }
02609                         else if(cmd.startsWith("nr_sv"))
02610                         {
02611                                 int n = model.nr_class;
02612                                 model.nSV = new int[n];
02613                                 StringTokenizer st = new StringTokenizer(arg);
02614                                 for(int i=0;i<n;i++)
02615                                         model.nSV[i] = atoi(st.nextToken());
02616                         }
02617                         else if(cmd.startsWith("SV"))
02618                         {
02619                                 break;
02620                         }
02621                         else
02622                         {
02623                                 System.err.print("unknown text in model file: ["+cmd+"]\n");
02624                                 return null;
02625                         }
02626                 }
02627 
02628                 // read sv_coef and SV
02629 
02630                 int m = model.nr_class - 1;
02631                 int l = model.l;
02632                 model.sv_coef = new double[m][l];
02633                 model.SV = new svm_node[l][];
02634 
02635                 for(int i=0;i<l;i++)
02636                 {
02637                         String line = fp.readLine();
02638                         StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
02639 
02640                         for(int k=0;k<m;k++)
02641                                 model.sv_coef[k][i] = atof(st.nextToken());
02642                         int n = st.countTokens()/2;
02643                         model.SV[i] = new svm_node[n];
02644                         for(int j=0;j<n;j++)
02645                         {
02646                                 model.SV[i][j] = new svm_node();
02647                                 model.SV[i][j].index = atoi(st.nextToken());
02648                                 model.SV[i][j].value = atof(st.nextToken());
02649                         }
02650                 }
02651 
02652                 fp.close();
02653                 return model;
02654         }
02655 
02656         public static String svm_check_parameter(svm_problem prob, svm_parameter param)
02657         {
02658                 // svm_type
02659 
02660                 int svm_type = param.svm_type;
02661                 if(svm_type != svm_parameter.C_SVC &&
02662                    svm_type != svm_parameter.NU_SVC &&
02663                    svm_type != svm_parameter.ONE_CLASS &&
02664                    svm_type != svm_parameter.EPSILON_SVR &&
02665                    svm_type != svm_parameter.NU_SVR)
02666                 return "unknown svm type";
02667 
02668                 // kernel_type, degree
02669         
02670                 int kernel_type = param.kernel_type;
02671                 if(kernel_type != svm_parameter.LINEAR &&
02672                    kernel_type != svm_parameter.POLY &&
02673                    kernel_type != svm_parameter.RBF &&
02674                    kernel_type != svm_parameter.SIGMOID &&
02675                    kernel_type != svm_parameter.PRECOMPUTED)
02676                         return "unknown kernel type";
02677 
02678                 if(param.gamma < 0)
02679                         return "gamma < 0";
02680 
02681                 if(param.degree < 0)
02682                         return "degree of polynomial kernel < 0";
02683 
02684                 // cache_size,eps,C,nu,p,shrinking
02685 
02686                 if(param.cache_size <= 0)
02687                         return "cache_size <= 0";
02688 
02689                 if(param.eps <= 0)
02690                         return "eps <= 0";
02691 
02692                 if(svm_type == svm_parameter.C_SVC ||
02693                    svm_type == svm_parameter.EPSILON_SVR ||
02694                    svm_type == svm_parameter.NU_SVR)
02695                         if(param.C <= 0)
02696                                 return "C <= 0";
02697 
02698                 if(svm_type == svm_parameter.NU_SVC ||
02699                    svm_type == svm_parameter.ONE_CLASS ||
02700                    svm_type == svm_parameter.NU_SVR)
02701                         if(param.nu <= 0 || param.nu > 1)
02702                                 return "nu <= 0 or nu > 1";
02703 
02704                 if(svm_type == svm_parameter.EPSILON_SVR)
02705                         if(param.p < 0)
02706                                 return "p < 0";
02707 
02708                 if(param.shrinking != 0 &&
02709                    param.shrinking != 1)
02710                         return "shrinking != 0 and shrinking != 1";
02711 
02712                 if(param.probability != 0 &&
02713                    param.probability != 1)
02714                         return "probability != 0 and probability != 1";
02715 
02716                 if(param.probability == 1 &&
02717                    svm_type == svm_parameter.ONE_CLASS)
02718                         return "one-class SVM probability output not supported yet";
02719                 
02720                 // check whether nu-svc is feasible
02721         
02722                 if(svm_type == svm_parameter.NU_SVC)
02723                 {
02724                         int l = prob.l;
02725                         int max_nr_class = 16;
02726                         int nr_class = 0;
02727                         int[] label = new int[max_nr_class];
02728                         int[] count = new int[max_nr_class];
02729 
02730                         int i;
02731                         for(i=0;i<l;i++)
02732                         {
02733                                 int this_label = (int)prob.y[i];
02734                                 int j;
02735                                 for(j=0;j<nr_class;j++)
02736                                         if(this_label == label[j])
02737                                         {
02738                                                 ++count[j];
02739                                                 break;
02740                                         }
02741 
02742                                 if(j == nr_class)
02743                                 {
02744                                         if(nr_class == max_nr_class)
02745                                         {
02746                                                 max_nr_class *= 2;
02747                                                 int[] new_data = new int[max_nr_class];
02748                                                 System.arraycopy(label,0,new_data,0,label.length);
02749                                                 label = new_data;
02750                                                 
02751                                                 new_data = new int[max_nr_class];
02752                                                 System.arraycopy(count,0,new_data,0,count.length);
02753                                                 count = new_data;
02754                                         }
02755                                         label[nr_class] = this_label;
02756                                         count[nr_class] = 1;
02757                                         ++nr_class;
02758                                 }
02759                         }
02760 
02761                         for(i=0;i<nr_class;i++)
02762                         {
02763                                 int n1 = count[i];
02764                                 for(int j=i+1;j<nr_class;j++)
02765                                 {
02766                                         int n2 = count[j];
02767                                         if(param.nu*(n1+n2)/2 > Math.min(n1,n2))
02768                                                 return "specified nu is infeasible";
02769                                 }
02770                         }
02771                 }
02772 
02773                 return null;
02774         }
02775 
02776         public static int svm_check_probability_model(svm_model model)
02777         {
02778                 if (((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) &&
02779                 model.probA!=null && model.probB!=null) ||
02780                 ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) &&
02781                  model.probA!=null))
02782                         return 1;
02783                 else
02784                         return 0;
02785         }
02786 
02787         public static void svm_set_print_string_function(svm_print_interface print_func)
02788         {
02789                 if (print_func == null)
02790                         svm_print_string = svm_print_stdout;
02791                 else 
02792                         svm_print_string = print_func;
02793         }
02794 }


haf_grasping
Author(s): David Fischinger
autogenerated on Thu Jun 6 2019 18:35:09