svm.java
Go to the documentation of this file.
00001 
00002 
00003 
00004 
00005 package libsvm;
00006 import java.io.*;
00007 import java.util.*;
00008 
00009 //
00010 // Kernel Cache
00011 //
00012 // l is the number of total data items
00013 // size is the cache size limit in bytes
00014 //
00015 class Cache {
00016         private final int l;
00017         private long size;
00018         private final class head_t
00019         {
00020                 head_t prev, next;      // a cicular list
00021                 float[] data;
00022                 int len;                // data[0,len) is cached in this entry
00023         }
00024         private final head_t[] head;
00025         private head_t lru_head;
00026 
00027         Cache(int l_, long size_)
00028         {
00029                 l = l_;
00030                 size = size_;
00031                 head = new head_t[l];
00032                 for(int i=0;i<l;i++) head[i] = new head_t();
00033                 size /= 4;
00034                 size -= l * (16/4);     // sizeof(head_t) == 16
00035                 size = Math.max(size, 2* (long) l);  // cache must be large enough for two columns
00036                 lru_head = new head_t();
00037                 lru_head.next = lru_head.prev = lru_head;
00038         }
00039 
00040         private void lru_delete(head_t h)
00041         {
00042                 // delete from current location
00043                 h.prev.next = h.next;
00044                 h.next.prev = h.prev;
00045         }
00046 
00047         private void lru_insert(head_t h)
00048         {
00049                 // insert to last position
00050                 h.next = lru_head;
00051                 h.prev = lru_head.prev;
00052                 h.prev.next = h;
00053                 h.next.prev = h;
00054         }
00055 
00056         // request data [0,len)
00057         // return some position p where [p,len) need to be filled
00058         // (p >= len if nothing needs to be filled)
00059         // java: simulate pointer using single-element array
00060         int get_data(int index, float[][] data, int len)
00061         {
00062                 head_t h = head[index];
00063                 if(h.len > 0) lru_delete(h);
00064                 int more = len - h.len;
00065 
00066                 if(more > 0)
00067                 {
00068                         // free old space
00069                         while(size < more)
00070                         {
00071                                 head_t old = lru_head.next;
00072                                 lru_delete(old);
00073                                 size += old.len;
00074                                 old.data = null;
00075                                 old.len = 0;
00076                         }
00077 
00078                         // allocate new space
00079                         float[] new_data = new float[len];
00080                         if(h.data != null) System.arraycopy(h.data,0,new_data,0,h.len);
00081                         h.data = new_data;
00082                         size -= more;
00083                         do {int _=h.len; h.len=len; len=_;} while(false);
00084                 }
00085 
00086                 lru_insert(h);
00087                 data[0] = h.data;
00088                 return len;
00089         }
00090 
00091         void swap_index(int i, int j)
00092         {
00093                 if(i==j) return;
00094                 
00095                 if(head[i].len > 0) lru_delete(head[i]);
00096                 if(head[j].len > 0) lru_delete(head[j]);
00097                 do {float[] _=head[i].data; head[i].data=head[j].data; head[j].data=_;} while(false);
00098                 do {int _=head[i].len; head[i].len=head[j].len; head[j].len=_;} while(false);
00099                 if(head[i].len > 0) lru_insert(head[i]);
00100                 if(head[j].len > 0) lru_insert(head[j]);
00101 
00102                 if(i>j) do {int _=i; i=j; j=_;} while(false);
00103                 for(head_t h = lru_head.next; h!=lru_head; h=h.next)
00104                 {
00105                         if(h.len > i)
00106                         {
00107                                 if(h.len > j)
00108                                         do {float _=h.data[i]; h.data[i]=h.data[j]; h.data[j]=_;} while(false);
00109                                 else
00110                                 {
00111                                         // give up
00112                                         lru_delete(h);
00113                                         size += h.len;
00114                                         h.data = null;
00115                                         h.len = 0;
00116                                 }
00117                         }
00118                 }
00119         }
00120 }
00121 
00122 //
00123 // Kernel evaluation
00124 //
00125 // the static method k_function is for doing single kernel evaluation
00126 // the constructor of Kernel prepares to calculate the l*l kernel matrix
00127 // the member function get_Q is for getting one column from the Q Matrix
00128 //
00129 abstract class QMatrix {
00130         abstract float[] get_Q(int column, int len);
00131         abstract double[] get_QD();
00132         abstract void swap_index(int i, int j);
00133 };
00134 
00135 abstract class Kernel extends QMatrix {
00136         private svm_node[][] x;
00137         private final double[] x_square;
00138 
00139         // svm_parameter
00140         private final int kernel_type;
00141         private final int degree;
00142         private final double gamma;
00143         private final double coef0;
00144 
00145         abstract float[] get_Q(int column, int len);
00146         abstract double[] get_QD();
00147 
00148         void swap_index(int i, int j)
00149         {
00150                 do {svm_node[] _=x[i]; x[i]=x[j]; x[j]=_;} while(false);
00151                 if(x_square != null) do {double _=x_square[i]; x_square[i]=x_square[j]; x_square[j]=_;} while(false);
00152         }
00153 
00154         private static double powi(double base, int times)
00155         {
00156                 double tmp = base, ret = 1.0;
00157 
00158                 for(int t=times; t>0; t/=2)
00159                 {
00160                         if(t%2==1) ret*=tmp;
00161                         tmp = tmp * tmp;
00162                 }
00163                 return ret;
00164         }
00165 
00166         double kernel_function(int i, int j)
00167         {
00168                 switch(kernel_type)
00169                 {
00170                         case svm_parameter.LINEAR:
00171                                 return dot(x[i],x[j]);
00172                         case svm_parameter.POLY:
00173                                 return powi(gamma*dot(x[i],x[j])+coef0,degree);
00174                         case svm_parameter.RBF:
00175                                 return Math.exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
00176                         case svm_parameter.SIGMOID:
00177                                 return Math.tanh(gamma*dot(x[i],x[j])+coef0);
00178                         case svm_parameter.PRECOMPUTED:
00179                                 return x[i][(int)(x[j][0].value)].value;
00180                         default:
00181                                 return 0;       // java
00182                 }
00183         }
00184 
00185         Kernel(int l, svm_node[][] x_, svm_parameter param)
00186         {
00187                 this.kernel_type = param.kernel_type;
00188                 this.degree = param.degree;
00189                 this.gamma = param.gamma;
00190                 this.coef0 = param.coef0;
00191 
00192                 x = (svm_node[][])x_.clone();
00193 
00194                 if(kernel_type == svm_parameter.RBF)
00195                 {
00196                         x_square = new double[l];
00197                         for(int i=0;i<l;i++)
00198                                 x_square[i] = dot(x[i],x[i]);
00199                 }
00200                 else x_square = null;
00201         }
00202 
00203         static double dot(svm_node[] x, svm_node[] y)
00204         {
00205                 double sum = 0;
00206                 int xlen = x.length;
00207                 int ylen = y.length;
00208                 int i = 0;
00209                 int j = 0;
00210                 while(i < xlen && j < ylen)
00211                 {
00212                         if(x[i].index == y[j].index)
00213                                 sum += x[i++].value * y[j++].value;
00214                         else
00215                         {
00216                                 if(x[i].index > y[j].index)
00217                                         ++j;
00218                                 else
00219                                         ++i;
00220                         }
00221                 }
00222                 return sum;
00223         }
00224 
00225         static double k_function(svm_node[] x, svm_node[] y,
00226                                         svm_parameter param)
00227         {
00228                 switch(param.kernel_type)
00229                 {
00230                         case svm_parameter.LINEAR:
00231                                 return dot(x,y);
00232                         case svm_parameter.POLY:
00233                                 return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
00234                         case svm_parameter.RBF:
00235                         {
00236                                 double sum = 0;
00237                                 int xlen = x.length;
00238                                 int ylen = y.length;
00239                                 int i = 0;
00240                                 int j = 0;
00241                                 while(i < xlen && j < ylen)
00242                                 {
00243                                         if(x[i].index == y[j].index)
00244                                         {
00245                                                 double d = x[i++].value - y[j++].value;
00246                                                 sum += d*d;
00247                                         }
00248                                         else if(x[i].index > y[j].index)
00249                                         {
00250                                                 sum += y[j].value * y[j].value;
00251                                                 ++j;
00252                                         }
00253                                         else
00254                                         {
00255                                                 sum += x[i].value * x[i].value;
00256                                                 ++i;
00257                                         }
00258                                 }
00259 
00260                                 while(i < xlen)
00261                                 {
00262                                         sum += x[i].value * x[i].value;
00263                                         ++i;
00264                                 }
00265 
00266                                 while(j < ylen)
00267                                 {
00268                                         sum += y[j].value * y[j].value;
00269                                         ++j;
00270                                 }
00271 
00272                                 return Math.exp(-param.gamma*sum);
00273                         }
00274                         case svm_parameter.SIGMOID:
00275                                 return Math.tanh(param.gamma*dot(x,y)+param.coef0);
00276                         case svm_parameter.PRECOMPUTED:
00277                                 return  x[(int)(y[0].value)].value;
00278                         default:
00279                                 return 0;       // java
00280                 }
00281         }
00282 }
00283 
00284 // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
00285 // Solves:
00286 //
00287 //      min 0.5(\alpha^T Q \alpha) + p^T \alpha
00288 //
00289 //              y^T \alpha = \delta
00290 //              y_i = +1 or -1
00291 //              0 <= alpha_i <= Cp for y_i = 1
00292 //              0 <= alpha_i <= Cn for y_i = -1
00293 //
00294 // Given:
00295 //
00296 //      Q, p, y, Cp, Cn, and an initial feasible point \alpha
00297 //      l is the size of vectors and matrices
00298 //      eps is the stopping tolerance
00299 //
00300 // solution will be put in \alpha, objective value will be put in obj
00301 //
00302 class Solver {
00303         int active_size;
00304         byte[] y;
00305         double[] G;             // gradient of objective function
00306         static final byte LOWER_BOUND = 0;
00307         static final byte UPPER_BOUND = 1;
00308         static final byte FREE = 2;
00309         byte[] alpha_status;    // LOWER_BOUND, UPPER_BOUND, FREE
00310         double[] alpha;
00311         QMatrix Q;
00312         double[] QD;
00313         double eps;
00314         double Cp,Cn;
00315         double[] p;
00316         int[] active_set;
00317         double[] G_bar;         // gradient, if we treat free variables as 0
00318         int l;
00319         boolean unshrink;       // XXX
00320         
00321         static final double INF = java.lang.Double.POSITIVE_INFINITY;
00322 
00323         double get_C(int i)
00324         {
00325                 return (y[i] > 0)? Cp : Cn;
00326         }
00327         void update_alpha_status(int i)
00328         {
00329                 if(alpha[i] >= get_C(i))
00330                         alpha_status[i] = UPPER_BOUND;
00331                 else if(alpha[i] <= 0)
00332                         alpha_status[i] = LOWER_BOUND;
00333                 else alpha_status[i] = FREE;
00334         }
00335         boolean is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
00336         boolean is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
00337         boolean is_free(int i) {  return alpha_status[i] == FREE; }
00338 
00339         // java: information about solution except alpha,
00340         // because we cannot return multiple values otherwise...
00341         static class SolutionInfo {
00342                 double obj;
00343                 double rho;
00344                 double upper_bound_p;
00345                 double upper_bound_n;
00346                 double r;       // for Solver_NU
00347         }
00348 
00349         void swap_index(int i, int j)
00350         {
00351                 Q.swap_index(i,j);
00352                 do {byte _=y[i]; y[i]=y[j]; y[j]=_;} while(false);
00353                 do {double _=G[i]; G[i]=G[j]; G[j]=_;} while(false);
00354                 do {byte _=alpha_status[i]; alpha_status[i]=alpha_status[j]; alpha_status[j]=_;} while(false);
00355                 do {double _=alpha[i]; alpha[i]=alpha[j]; alpha[j]=_;} while(false);
00356                 do {double _=p[i]; p[i]=p[j]; p[j]=_;} while(false);
00357                 do {int _=active_set[i]; active_set[i]=active_set[j]; active_set[j]=_;} while(false);
00358                 do {double _=G_bar[i]; G_bar[i]=G_bar[j]; G_bar[j]=_;} while(false);
00359         }
00360 
00361         void reconstruct_gradient()
00362         {
00363                 // reconstruct inactive elements of G from G_bar and free variables
00364 
00365                 if(active_size == l) return;
00366 
00367                 int i,j;
00368                 int nr_free = 0;
00369 
00370                 for(j=active_size;j<l;j++)
00371                         G[j] = G_bar[j] + p[j];
00372 
00373                 for(j=0;j<active_size;j++)
00374                         if(is_free(j))
00375                                 nr_free++;
00376 
00377                 if(2*nr_free < active_size)
00378                         svm.info("\nWarning: using -h 0 may be faster\n");
00379 
00380                 if (nr_free*l > 2*active_size*(l-active_size))
00381                 {
00382                         for(i=active_size;i<l;i++)
00383                         {
00384                                 float[] Q_i = Q.get_Q(i,active_size);
00385                                 for(j=0;j<active_size;j++)
00386                                         if(is_free(j))
00387                                                 G[i] += alpha[j] * Q_i[j];
00388                         }       
00389                 }
00390                 else
00391                 {
00392                         for(i=0;i<active_size;i++)
00393                                 if(is_free(i))
00394                                 {
00395                                         float[] Q_i = Q.get_Q(i,l);
00396                                         double alpha_i = alpha[i];
00397                                         for(j=active_size;j<l;j++)
00398                                                 G[j] += alpha_i * Q_i[j];
00399                                 }
00400                 }
00401         }
00402 
00403         void Solve(int l, QMatrix Q, double[] p_, byte[] y_,
00404                    double[] alpha_, double Cp, double Cn, double eps, SolutionInfo si, int shrinking)
00405         {
00406                 this.l = l;
00407                 this.Q = Q;
00408                 QD = Q.get_QD();
00409                 p = (double[])p_.clone();
00410                 y = (byte[])y_.clone();
00411                 alpha = (double[])alpha_.clone();
00412                 this.Cp = Cp;
00413                 this.Cn = Cn;
00414                 this.eps = eps;
00415                 this.unshrink = false;
00416 
00417                 // initialize alpha_status
00418                 {
00419                         alpha_status = new byte[l];
00420                         for(int i=0;i<l;i++)
00421                                 update_alpha_status(i);
00422                 }
00423 
00424                 // initialize active set (for shrinking)
00425                 {
00426                         active_set = new int[l];
00427                         for(int i=0;i<l;i++)
00428                                 active_set[i] = i;
00429                         active_size = l;
00430                 }
00431 
00432                 // initialize gradient
00433                 {
00434                         G = new double[l];
00435                         G_bar = new double[l];
00436                         int i;
00437                         for(i=0;i<l;i++)
00438                         {
00439                                 G[i] = p[i];
00440                                 G_bar[i] = 0;
00441                         }
00442                         for(i=0;i<l;i++)
00443                                 if(!is_lower_bound(i))
00444                                 {
00445                                         float[] Q_i = Q.get_Q(i,l);
00446                                         double alpha_i = alpha[i];
00447                                         int j;
00448                                         for(j=0;j<l;j++)
00449                                                 G[j] += alpha_i*Q_i[j];
00450                                         if(is_upper_bound(i))
00451                                                 for(j=0;j<l;j++)
00452                                                         G_bar[j] += get_C(i) * Q_i[j];
00453                                 }
00454                 }
00455 
00456                 // optimization step
00457 
00458                 int iter = 0;
00459                 int counter = Math.min(l,1000)+1;
00460                 int[] working_set = new int[2];
00461 
00462                 while(true)
00463                 {
00464                         // show progress and do shrinking
00465 
00466                         if(--counter == 0)
00467                         {
00468                                 counter = Math.min(l,1000);
00469                                 if(shrinking!=0) do_shrinking();
00470                                 svm.info(".");
00471                         }
00472 
00473                         if(select_working_set(working_set)!=0)
00474                         {
00475                                 // reconstruct the whole gradient
00476                                 reconstruct_gradient();
00477                                 // reset active set size and check
00478                                 active_size = l;
00479                                 svm.info("*");
00480                                 if(select_working_set(working_set)!=0)
00481                                         break;
00482                                 else
00483                                         counter = 1;    // do shrinking next iteration
00484                         }
00485                         
00486                         int i = working_set[0];
00487                         int j = working_set[1];
00488 
00489                         ++iter;
00490 
00491                         // update alpha[i] and alpha[j], handle bounds carefully
00492 
00493                         float[] Q_i = Q.get_Q(i,active_size);
00494                         float[] Q_j = Q.get_Q(j,active_size);
00495 
00496                         double C_i = get_C(i);
00497                         double C_j = get_C(j);
00498 
00499                         double old_alpha_i = alpha[i];
00500                         double old_alpha_j = alpha[j];
00501 
00502                         if(y[i]!=y[j])
00503                         {
00504                                 double quad_coef = QD[i]+QD[j]+2*Q_i[j];
00505                                 if (quad_coef <= 0)
00506                                         quad_coef = 1e-12;
00507                                 double delta = (-G[i]-G[j])/quad_coef;
00508                                 double diff = alpha[i] - alpha[j];
00509                                 alpha[i] += delta;
00510                                 alpha[j] += delta;
00511                         
00512                                 if(diff > 0)
00513                                 {
00514                                         if(alpha[j] < 0)
00515                                         {
00516                                                 alpha[j] = 0;
00517                                                 alpha[i] = diff;
00518                                         }
00519                                 }
00520                                 else
00521                                 {
00522                                         if(alpha[i] < 0)
00523                                         {
00524                                                 alpha[i] = 0;
00525                                                 alpha[j] = -diff;
00526                                         }
00527                                 }
00528                                 if(diff > C_i - C_j)
00529                                 {
00530                                         if(alpha[i] > C_i)
00531                                         {
00532                                                 alpha[i] = C_i;
00533                                                 alpha[j] = C_i - diff;
00534                                         }
00535                                 }
00536                                 else
00537                                 {
00538                                         if(alpha[j] > C_j)
00539                                         {
00540                                                 alpha[j] = C_j;
00541                                                 alpha[i] = C_j + diff;
00542                                         }
00543                                 }
00544                         }
00545                         else
00546                         {
00547                                 double quad_coef = QD[i]+QD[j]-2*Q_i[j];
00548                                 if (quad_coef <= 0)
00549                                         quad_coef = 1e-12;
00550                                 double delta = (G[i]-G[j])/quad_coef;
00551                                 double sum = alpha[i] + alpha[j];
00552                                 alpha[i] -= delta;
00553                                 alpha[j] += delta;
00554 
00555                                 if(sum > C_i)
00556                                 {
00557                                         if(alpha[i] > C_i)
00558                                         {
00559                                                 alpha[i] = C_i;
00560                                                 alpha[j] = sum - C_i;
00561                                         }
00562                                 }
00563                                 else
00564                                 {
00565                                         if(alpha[j] < 0)
00566                                         {
00567                                                 alpha[j] = 0;
00568                                                 alpha[i] = sum;
00569                                         }
00570                                 }
00571                                 if(sum > C_j)
00572                                 {
00573                                         if(alpha[j] > C_j)
00574                                         {
00575                                                 alpha[j] = C_j;
00576                                                 alpha[i] = sum - C_j;
00577                                         }
00578                                 }
00579                                 else
00580                                 {
00581                                         if(alpha[i] < 0)
00582                                         {
00583                                                 alpha[i] = 0;
00584                                                 alpha[j] = sum;
00585                                         }
00586                                 }
00587                         }
00588 
00589                         // update G
00590 
00591                         double delta_alpha_i = alpha[i] - old_alpha_i;
00592                         double delta_alpha_j = alpha[j] - old_alpha_j;
00593 
00594                         for(int k=0;k<active_size;k++)
00595                         {
00596                                 G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
00597                         }
00598 
00599                         // update alpha_status and G_bar
00600 
00601                         {
00602                                 boolean ui = is_upper_bound(i);
00603                                 boolean uj = is_upper_bound(j);
00604                                 update_alpha_status(i);
00605                                 update_alpha_status(j);
00606                                 int k;
00607                                 if(ui != is_upper_bound(i))
00608                                 {
00609                                         Q_i = Q.get_Q(i,l);
00610                                         if(ui)
00611                                                 for(k=0;k<l;k++)
00612                                                         G_bar[k] -= C_i * Q_i[k];
00613                                         else
00614                                                 for(k=0;k<l;k++)
00615                                                         G_bar[k] += C_i * Q_i[k];
00616                                 }
00617 
00618                                 if(uj != is_upper_bound(j))
00619                                 {
00620                                         Q_j = Q.get_Q(j,l);
00621                                         if(uj)
00622                                                 for(k=0;k<l;k++)
00623                                                         G_bar[k] -= C_j * Q_j[k];
00624                                         else
00625                                                 for(k=0;k<l;k++)
00626                                                         G_bar[k] += C_j * Q_j[k];
00627                                 }
00628                         }
00629 
00630                 }
00631 
00632                 // calculate rho
00633 
00634                 si.rho = calculate_rho();
00635 
00636                 // calculate objective value
00637                 {
00638                         double v = 0;
00639                         int i;
00640                         for(i=0;i<l;i++)
00641                                 v += alpha[i] * (G[i] + p[i]);
00642 
00643                         si.obj = v/2;
00644                 }
00645 
00646                 // put back the solution
00647                 {
00648                         for(int i=0;i<l;i++)
00649                                 alpha_[active_set[i]] = alpha[i];
00650                 }
00651 
00652                 si.upper_bound_p = Cp;
00653                 si.upper_bound_n = Cn;
00654 
00655                 svm.info("\noptimization finished, #iter = "+iter+"\n");
00656         }
00657 
00658         // return 1 if already optimal, return 0 otherwise
00659         int select_working_set(int[] working_set)
00660         {
00661                 // return i,j such that
00662                 // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
00663                 // j: mimimizes the decrease of obj value
00664                 //    (if quadratic coefficeint <= 0, replace it with tau)
00665                 //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
00666                 
00667                 double Gmax = -INF;
00668                 double Gmax2 = -INF;
00669                 int Gmax_idx = -1;
00670                 int Gmin_idx = -1;
00671                 double obj_diff_min = INF;
00672         
00673                 for(int t=0;t<active_size;t++)
00674                         if(y[t]==+1)    
00675                         {
00676                                 if(!is_upper_bound(t))
00677                                         if(-G[t] >= Gmax)
00678                                         {
00679                                                 Gmax = -G[t];
00680                                                 Gmax_idx = t;
00681                                         }
00682                         }
00683                         else
00684                         {
00685                                 if(!is_lower_bound(t))
00686                                         if(G[t] >= Gmax)
00687                                         {
00688                                                 Gmax = G[t];
00689                                                 Gmax_idx = t;
00690                                         }
00691                         }
00692         
00693                 int i = Gmax_idx;
00694                 float[] Q_i = null;
00695                 if(i != -1) // null Q_i not accessed: Gmax=-INF if i=-1
00696                         Q_i = Q.get_Q(i,active_size);
00697         
00698                 for(int j=0;j<active_size;j++)
00699                 {
00700                         if(y[j]==+1)
00701                         {
00702                                 if (!is_lower_bound(j))
00703                                 {
00704                                         double grad_diff=Gmax+G[j];
00705                                         if (G[j] >= Gmax2)
00706                                                 Gmax2 = G[j];
00707                                         if (grad_diff > 0)
00708                                         {
00709                                                 double obj_diff; 
00710                                                 double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];
00711                                                 if (quad_coef > 0)
00712                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00713                                                 else
00714                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00715         
00716                                                 if (obj_diff <= obj_diff_min)
00717                                                 {
00718                                                         Gmin_idx=j;
00719                                                         obj_diff_min = obj_diff;
00720                                                 }
00721                                         }
00722                                 }
00723                         }
00724                         else
00725                         {
00726                                 if (!is_upper_bound(j))
00727                                 {
00728                                         double grad_diff= Gmax-G[j];
00729                                         if (-G[j] >= Gmax2)
00730                                                 Gmax2 = -G[j];
00731                                         if (grad_diff > 0)
00732                                         {
00733                                                 double obj_diff; 
00734                                                 double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];
00735                                                 if (quad_coef > 0)
00736                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00737                                                 else
00738                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00739         
00740                                                 if (obj_diff <= obj_diff_min)
00741                                                 {
00742                                                         Gmin_idx=j;
00743                                                         obj_diff_min = obj_diff;
00744                                                 }
00745                                         }
00746                                 }
00747                         }
00748                 }
00749 
00750                 if(Gmax+Gmax2 < eps)
00751                         return 1;
00752 
00753                 working_set[0] = Gmax_idx;
00754                 working_set[1] = Gmin_idx;
00755                 return 0;
00756         }
00757 
00758         private boolean be_shrunk(int i, double Gmax1, double Gmax2)
00759         {       
00760                 if(is_upper_bound(i))
00761                 {
00762                         if(y[i]==+1)
00763                                 return(-G[i] > Gmax1);
00764                         else
00765                                 return(-G[i] > Gmax2);
00766                 }
00767                 else if(is_lower_bound(i))
00768                 {
00769                         if(y[i]==+1)
00770                                 return(G[i] > Gmax2);
00771                         else    
00772                                 return(G[i] > Gmax1);
00773                 }
00774                 else
00775                         return(false);
00776         }
00777 
00778         void do_shrinking()
00779         {
00780                 int i;
00781                 double Gmax1 = -INF;            // max { -y_i * grad(f)_i | i in I_up(\alpha) }
00782                 double Gmax2 = -INF;            // max { y_i * grad(f)_i | i in I_low(\alpha) }
00783 
00784                 // find maximal violating pair first
00785                 for(i=0;i<active_size;i++)
00786                 {
00787                         if(y[i]==+1)
00788                         {
00789                                 if(!is_upper_bound(i))  
00790                                 {
00791                                         if(-G[i] >= Gmax1)
00792                                                 Gmax1 = -G[i];
00793                                 }
00794                                 if(!is_lower_bound(i))
00795                                 {
00796                                         if(G[i] >= Gmax2)
00797                                                 Gmax2 = G[i];
00798                                 }
00799                         }
00800                         else            
00801                         {
00802                                 if(!is_upper_bound(i))  
00803                                 {
00804                                         if(-G[i] >= Gmax2)
00805                                                 Gmax2 = -G[i];
00806                                 }
00807                                 if(!is_lower_bound(i))  
00808                                 {
00809                                         if(G[i] >= Gmax1)
00810                                                 Gmax1 = G[i];
00811                                 }
00812                         }
00813                 }
00814 
00815                 if(unshrink == false && Gmax1 + Gmax2 <= eps*10) 
00816                 {
00817                         unshrink = true;
00818                         reconstruct_gradient();
00819                         active_size = l;
00820                 }
00821 
00822                 for(i=0;i<active_size;i++)
00823                         if (be_shrunk(i, Gmax1, Gmax2))
00824                         {
00825                                 active_size--;
00826                                 while (active_size > i)
00827                                 {
00828                                         if (!be_shrunk(active_size, Gmax1, Gmax2))
00829                                         {
00830                                                 swap_index(i,active_size);
00831                                                 break;
00832                                         }
00833                                         active_size--;
00834                                 }
00835                         }
00836         }
00837 
00838         double calculate_rho()
00839         {
00840                 double r;
00841                 int nr_free = 0;
00842                 double ub = INF, lb = -INF, sum_free = 0;
00843                 for(int i=0;i<active_size;i++)
00844                 {
00845                         double yG = y[i]*G[i];
00846 
00847                         if(is_lower_bound(i))
00848                         {
00849                                 if(y[i] > 0)
00850                                         ub = Math.min(ub,yG);
00851                                 else
00852                                         lb = Math.max(lb,yG);
00853                         }
00854                         else if(is_upper_bound(i))
00855                         {
00856                                 if(y[i] < 0)
00857                                         ub = Math.min(ub,yG);
00858                                 else
00859                                         lb = Math.max(lb,yG);
00860                         }
00861                         else
00862                         {
00863                                 ++nr_free;
00864                                 sum_free += yG;
00865                         }
00866                 }
00867 
00868                 if(nr_free>0)
00869                         r = sum_free/nr_free;
00870                 else
00871                         r = (ub+lb)/2;
00872 
00873                 return r;
00874         }
00875 
00876 }
00877 
00878 //
00879 // Solver for nu-svm classification and regression
00880 //
00881 // additional constraint: e^T \alpha = constant
00882 //
00883 final class Solver_NU extends Solver
00884 {
00885         private SolutionInfo si;
00886 
00887         void Solve(int l, QMatrix Q, double[] p, byte[] y,
00888                    double[] alpha, double Cp, double Cn, double eps,
00889                    SolutionInfo si, int shrinking)
00890         {
00891                 this.si = si;
00892                 super.Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking);
00893         }
00894 
00895         // return 1 if already optimal, return 0 otherwise
00896         int select_working_set(int[] working_set)
00897         {
00898                 // return i,j such that y_i = y_j and
00899                 // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
00900                 // j: minimizes the decrease of obj value
00901                 //    (if quadratic coefficeint <= 0, replace it with tau)
00902                 //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
00903         
00904                 double Gmaxp = -INF;
00905                 double Gmaxp2 = -INF;
00906                 int Gmaxp_idx = -1;
00907         
00908                 double Gmaxn = -INF;
00909                 double Gmaxn2 = -INF;
00910                 int Gmaxn_idx = -1;
00911         
00912                 int Gmin_idx = -1;
00913                 double obj_diff_min = INF;
00914         
00915                 for(int t=0;t<active_size;t++)
00916                         if(y[t]==+1)
00917                         {
00918                                 if(!is_upper_bound(t))
00919                                         if(-G[t] >= Gmaxp)
00920                                         {
00921                                                 Gmaxp = -G[t];
00922                                                 Gmaxp_idx = t;
00923                                         }
00924                         }
00925                         else
00926                         {
00927                                 if(!is_lower_bound(t))
00928                                         if(G[t] >= Gmaxn)
00929                                         {
00930                                                 Gmaxn = G[t];
00931                                                 Gmaxn_idx = t;
00932                                         }
00933                         }
00934         
00935                 int ip = Gmaxp_idx;
00936                 int in = Gmaxn_idx;
00937                 float[] Q_ip = null;
00938                 float[] Q_in = null;
00939                 if(ip != -1) // null Q_ip not accessed: Gmaxp=-INF if ip=-1
00940                         Q_ip = Q.get_Q(ip,active_size);
00941                 if(in != -1)
00942                         Q_in = Q.get_Q(in,active_size);
00943         
00944                 for(int j=0;j<active_size;j++)
00945                 {
00946                         if(y[j]==+1)
00947                         {
00948                                 if (!is_lower_bound(j)) 
00949                                 {
00950                                         double grad_diff=Gmaxp+G[j];
00951                                         if (G[j] >= Gmaxp2)
00952                                                 Gmaxp2 = G[j];
00953                                         if (grad_diff > 0)
00954                                         {
00955                                                 double obj_diff; 
00956                                                 double quad_coef = QD[ip]+QD[j]-2*Q_ip[j];
00957                                                 if (quad_coef > 0)
00958                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00959                                                 else
00960                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00961         
00962                                                 if (obj_diff <= obj_diff_min)
00963                                                 {
00964                                                         Gmin_idx=j;
00965                                                         obj_diff_min = obj_diff;
00966                                                 }
00967                                         }
00968                                 }
00969                         }
00970                         else
00971                         {
00972                                 if (!is_upper_bound(j))
00973                                 {
00974                                         double grad_diff=Gmaxn-G[j];
00975                                         if (-G[j] >= Gmaxn2)
00976                                                 Gmaxn2 = -G[j];
00977                                         if (grad_diff > 0)
00978                                         {
00979                                                 double obj_diff; 
00980                                                 double quad_coef = QD[in]+QD[j]-2*Q_in[j];
00981                                                 if (quad_coef > 0)
00982                                                         obj_diff = -(grad_diff*grad_diff)/quad_coef;
00983                                                 else
00984                                                         obj_diff = -(grad_diff*grad_diff)/1e-12;
00985         
00986                                                 if (obj_diff <= obj_diff_min)
00987                                                 {
00988                                                         Gmin_idx=j;
00989                                                         obj_diff_min = obj_diff;
00990                                                 }
00991                                         }
00992                                 }
00993                         }
00994                 }
00995 
00996                 if(Math.max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps)
00997                         return 1;
00998         
00999                 if(y[Gmin_idx] == +1)
01000                         working_set[0] = Gmaxp_idx;
01001                 else
01002                         working_set[0] = Gmaxn_idx;
01003                 working_set[1] = Gmin_idx;
01004         
01005                 return 0;
01006         }
01007 
01008         private boolean be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4)
01009         {
01010                 if(is_upper_bound(i))
01011                 {
01012                         if(y[i]==+1)
01013                                 return(-G[i] > Gmax1);
01014                         else    
01015                                 return(-G[i] > Gmax4);
01016                 }
01017                 else if(is_lower_bound(i))
01018                 {
01019                         if(y[i]==+1)
01020                                 return(G[i] > Gmax2);
01021                         else    
01022                                 return(G[i] > Gmax3);
01023                 }
01024                 else
01025                         return(false);
01026         }
01027 
01028         void do_shrinking()
01029         {
01030                 double Gmax1 = -INF;    // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) }
01031                 double Gmax2 = -INF;    // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) }
01032                 double Gmax3 = -INF;    // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) }
01033                 double Gmax4 = -INF;    // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) }
01034  
01035                 // find maximal violating pair first
01036                 int i;
01037                 for(i=0;i<active_size;i++)
01038                 {
01039                         if(!is_upper_bound(i))
01040                         {
01041                                 if(y[i]==+1)
01042                                 {
01043                                         if(-G[i] > Gmax1) Gmax1 = -G[i];
01044                                 }
01045                                 else    if(-G[i] > Gmax4) Gmax4 = -G[i];
01046                         }
01047                         if(!is_lower_bound(i))
01048                         {
01049                                 if(y[i]==+1)
01050                                 {       
01051                                         if(G[i] > Gmax2) Gmax2 = G[i];
01052                                 }
01053                                 else    if(G[i] > Gmax3) Gmax3 = G[i];
01054                         }
01055                 }
01056 
01057                 if(unshrink == false && Math.max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) 
01058                 {
01059                         unshrink = true;
01060                         reconstruct_gradient();
01061                         active_size = l;
01062                 }
01063 
01064                 for(i=0;i<active_size;i++)
01065                         if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4))
01066                         {
01067                                 active_size--;
01068                                 while (active_size > i)
01069                                 {
01070                                         if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
01071                                         {
01072                                                 swap_index(i,active_size);
01073                                                 break;
01074                                         }
01075                                         active_size--;
01076                                 }
01077                         }
01078         }
01079         
01080         double calculate_rho()
01081         {
01082                 int nr_free1 = 0,nr_free2 = 0;
01083                 double ub1 = INF, ub2 = INF;
01084                 double lb1 = -INF, lb2 = -INF;
01085                 double sum_free1 = 0, sum_free2 = 0;
01086 
01087                 for(int i=0;i<active_size;i++)
01088                 {
01089                         if(y[i]==+1)
01090                         {
01091                                 if(is_lower_bound(i))
01092                                         ub1 = Math.min(ub1,G[i]);
01093                                 else if(is_upper_bound(i))
01094                                         lb1 = Math.max(lb1,G[i]);
01095                                 else
01096                                 {
01097                                         ++nr_free1;
01098                                         sum_free1 += G[i];
01099                                 }
01100                         }
01101                         else
01102                         {
01103                                 if(is_lower_bound(i))
01104                                         ub2 = Math.min(ub2,G[i]);
01105                                 else if(is_upper_bound(i))
01106                                         lb2 = Math.max(lb2,G[i]);
01107                                 else
01108                                 {
01109                                         ++nr_free2;
01110                                         sum_free2 += G[i];
01111                                 }
01112                         }
01113                 }
01114 
01115                 double r1,r2;
01116                 if(nr_free1 > 0)
01117                         r1 = sum_free1/nr_free1;
01118                 else
01119                         r1 = (ub1+lb1)/2;
01120 
01121                 if(nr_free2 > 0)
01122                         r2 = sum_free2/nr_free2;
01123                 else
01124                         r2 = (ub2+lb2)/2;
01125 
01126                 si.r = (r1+r2)/2;
01127                 return (r1-r2)/2;
01128         }
01129 }
01130 
01131 //
01132 // Q matrices for various formulations
01133 //
01134 class SVC_Q extends Kernel
01135 {
01136         private final byte[] y;
01137         private final Cache cache;
01138         private final double[] QD;
01139 
01140         SVC_Q(svm_problem prob, svm_parameter param, byte[] y_)
01141         {
01142                 super(prob.l, prob.x, param);
01143                 y = (byte[])y_.clone();
01144                 cache = new Cache(prob.l,(long)(param.cache_size*(1<<20)));
01145                 QD = new double[prob.l];
01146                 for(int i=0;i<prob.l;i++)
01147                         QD[i] = kernel_function(i,i);
01148         }
01149 
01150         float[] get_Q(int i, int len)
01151         {
01152                 float[][] data = new float[1][];
01153                 int start, j;
01154                 if((start = cache.get_data(i,data,len)) < len)
01155                 {
01156                         for(j=start;j<len;j++)
01157                                 data[0][j] = (float)(y[i]*y[j]*kernel_function(i,j));
01158                 }
01159                 return data[0];
01160         }
01161 
01162         double[] get_QD()
01163         {
01164                 return QD;
01165         }
01166 
01167         void swap_index(int i, int j)
01168         {
01169                 cache.swap_index(i,j);
01170                 super.swap_index(i,j);
01171                 do {byte _=y[i]; y[i]=y[j]; y[j]=_;} while(false);
01172                 do {double _=QD[i]; QD[i]=QD[j]; QD[j]=_;} while(false);
01173         }
01174 }
01175 
01176 class ONE_CLASS_Q extends Kernel
01177 {
01178         private final Cache cache;
01179         private final double[] QD;
01180 
01181         ONE_CLASS_Q(svm_problem prob, svm_parameter param)
01182         {
01183                 super(prob.l, prob.x, param);
01184                 cache = new Cache(prob.l,(long)(param.cache_size*(1<<20)));
01185                 QD = new double[prob.l];
01186                 for(int i=0;i<prob.l;i++)
01187                         QD[i] = kernel_function(i,i);
01188         }
01189 
01190         float[] get_Q(int i, int len)
01191         {
01192                 float[][] data = new float[1][];
01193                 int start, j;
01194                 if((start = cache.get_data(i,data,len)) < len)
01195                 {
01196                         for(j=start;j<len;j++)
01197                                 data[0][j] = (float)kernel_function(i,j);
01198                 }
01199                 return data[0];
01200         }
01201 
01202         double[] get_QD()
01203         {
01204                 return QD;
01205         }
01206 
01207         void swap_index(int i, int j)
01208         {
01209                 cache.swap_index(i,j);
01210                 super.swap_index(i,j);
01211                 do {double _=QD[i]; QD[i]=QD[j]; QD[j]=_;} while(false);
01212         }
01213 }
01214 
01215 class SVR_Q extends Kernel
01216 {
01217         private final int l;
01218         private final Cache cache;
01219         private final byte[] sign;
01220         private final int[] index;
01221         private int next_buffer;
01222         private float[][] buffer;
01223         private final double[] QD;
01224 
01225         SVR_Q(svm_problem prob, svm_parameter param)
01226         {
01227                 super(prob.l, prob.x, param);
01228                 l = prob.l;
01229                 cache = new Cache(l,(long)(param.cache_size*(1<<20)));
01230                 QD = new double[2*l];
01231                 sign = new byte[2*l];
01232                 index = new int[2*l];
01233                 for(int k=0;k<l;k++)
01234                 {
01235                         sign[k] = 1;
01236                         sign[k+l] = -1;
01237                         index[k] = k;
01238                         index[k+l] = k;
01239                         QD[k] = kernel_function(k,k);
01240                         QD[k+l] = QD[k];
01241                 }
01242                 buffer = new float[2][2*l];
01243                 next_buffer = 0;
01244         }
01245 
01246         void swap_index(int i, int j)
01247         {
01248                 do {byte _=sign[i]; sign[i]=sign[j]; sign[j]=_;} while(false);
01249                 do {int _=index[i]; index[i]=index[j]; index[j]=_;} while(false);
01250                 do {double _=QD[i]; QD[i]=QD[j]; QD[j]=_;} while(false);
01251         }
01252 
01253         float[] get_Q(int i, int len)
01254         {
01255                 float[][] data = new float[1][];
01256                 int j, real_i = index[i];
01257                 if(cache.get_data(real_i,data,l) < l)
01258                 {
01259                         for(j=0;j<l;j++)
01260                                 data[0][j] = (float)kernel_function(real_i,j);
01261                 }
01262 
01263                 // reorder and copy
01264                 float buf[] = buffer[next_buffer];
01265                 next_buffer = 1 - next_buffer;
01266                 byte si = sign[i];
01267                 for(j=0;j<len;j++)
01268                         buf[j] = (float) si * sign[j] * data[0][index[j]];
01269                 return buf;
01270         }
01271 
01272         double[] get_QD()
01273         {
01274                 return QD;
01275         }
01276 }
01277 
01278 public class svm {
01279         //
01280         // construct and solve various formulations
01281         //
01282         public static final int LIBSVM_VERSION=300; 
01283 
01284         private static svm_print_interface svm_print_stdout = new svm_print_interface()
01285         {
01286                 public void print(String s)
01287                 {
01288                         System.out.print(s);
01289                         System.out.flush();
01290                 }
01291         };
01292 
01293         private static svm_print_interface svm_print_string = svm_print_stdout;
01294 
01295         static void info(String s) 
01296         {
01297                 svm_print_string.print(s);
01298         }
01299 
01300         private static void solve_c_svc(svm_problem prob, svm_parameter param,
01301                                         double[] alpha, Solver.SolutionInfo si,
01302                                         double Cp, double Cn)
01303         {
01304                 int l = prob.l;
01305                 double[] minus_ones = new double[l];
01306                 byte[] y = new byte[l];
01307 
01308                 int i;
01309 
01310                 for(i=0;i<l;i++)
01311                 {
01312                         alpha[i] = 0;
01313                         minus_ones[i] = -1;
01314                         if(prob.y[i] > 0) y[i] = +1; else y[i] = -1;
01315                 }
01316 
01317                 Solver s = new Solver();
01318                 s.Solve(l, new SVC_Q(prob,param,y), minus_ones, y,
01319                         alpha, Cp, Cn, param.eps, si, param.shrinking);
01320 
01321                 double sum_alpha=0;
01322                 for(i=0;i<l;i++)
01323                         sum_alpha += alpha[i];
01324 
01325                 if (Cp==Cn)
01326                         svm.info("nu = "+sum_alpha/(Cp*prob.l)+"\n");
01327 
01328                 for(i=0;i<l;i++)
01329                         alpha[i] *= y[i];
01330         }
01331 
01332         private static void solve_nu_svc(svm_problem prob, svm_parameter param,
01333                                         double[] alpha, Solver.SolutionInfo si)
01334         {
01335                 int i;
01336                 int l = prob.l;
01337                 double nu = param.nu;
01338 
01339                 byte[] y = new byte[l];
01340 
01341                 for(i=0;i<l;i++)
01342                         if(prob.y[i]>0)
01343                                 y[i] = +1;
01344                         else
01345                                 y[i] = -1;
01346 
01347                 double sum_pos = nu*l/2;
01348                 double sum_neg = nu*l/2;
01349 
01350                 for(i=0;i<l;i++)
01351                         if(y[i] == +1)
01352                         {
01353                                 alpha[i] = Math.min(1.0,sum_pos);
01354                                 sum_pos -= alpha[i];
01355                         }
01356                         else
01357                         {
01358                                 alpha[i] = Math.min(1.0,sum_neg);
01359                                 sum_neg -= alpha[i];
01360                         }
01361 
01362                 double[] zeros = new double[l];
01363 
01364                 for(i=0;i<l;i++)
01365                         zeros[i] = 0;
01366 
01367                 Solver_NU s = new Solver_NU();
01368                 s.Solve(l, new SVC_Q(prob,param,y), zeros, y,
01369                         alpha, 1.0, 1.0, param.eps, si, param.shrinking);
01370                 double r = si.r;
01371 
01372                 svm.info("C = "+1/r+"\n");
01373 
01374                 for(i=0;i<l;i++)
01375                         alpha[i] *= y[i]/r;
01376 
01377                 si.rho /= r;
01378                 si.obj /= (r*r);
01379                 si.upper_bound_p = 1/r;
01380                 si.upper_bound_n = 1/r;
01381         }
01382 
01383         private static void solve_one_class(svm_problem prob, svm_parameter param,
01384                                         double[] alpha, Solver.SolutionInfo si)
01385         {
01386                 int l = prob.l;
01387                 double[] zeros = new double[l];
01388                 byte[] ones = new byte[l];
01389                 int i;
01390 
01391                 int n = (int)(param.nu*prob.l); // # of alpha's at upper bound
01392 
01393                 for(i=0;i<n;i++)
01394                         alpha[i] = 1;
01395                 if(n<prob.l)
01396                         alpha[n] = param.nu * prob.l - n;
01397                 for(i=n+1;i<l;i++)
01398                         alpha[i] = 0;
01399 
01400                 for(i=0;i<l;i++)
01401                 {
01402                         zeros[i] = 0;
01403                         ones[i] = 1;
01404                 }
01405 
01406                 Solver s = new Solver();
01407                 s.Solve(l, new ONE_CLASS_Q(prob,param), zeros, ones,
01408                         alpha, 1.0, 1.0, param.eps, si, param.shrinking);
01409         }
01410 
01411         private static void solve_epsilon_svr(svm_problem prob, svm_parameter param,
01412                                         double[] alpha, Solver.SolutionInfo si)
01413         {
01414                 int l = prob.l;
01415                 double[] alpha2 = new double[2*l];
01416                 double[] linear_term = new double[2*l];
01417                 byte[] y = new byte[2*l];
01418                 int i;
01419 
01420                 for(i=0;i<l;i++)
01421                 {
01422                         alpha2[i] = 0;
01423                         linear_term[i] = param.p - prob.y[i];
01424                         y[i] = 1;
01425 
01426                         alpha2[i+l] = 0;
01427                         linear_term[i+l] = param.p + prob.y[i];
01428                         y[i+l] = -1;
01429                 }
01430 
01431                 Solver s = new Solver();
01432                 s.Solve(2*l, new SVR_Q(prob,param), linear_term, y,
01433                         alpha2, param.C, param.C, param.eps, si, param.shrinking);
01434 
01435                 double sum_alpha = 0;
01436                 for(i=0;i<l;i++)
01437                 {
01438                         alpha[i] = alpha2[i] - alpha2[i+l];
01439                         sum_alpha += Math.abs(alpha[i]);
01440                 }
01441                 svm.info("nu = "+sum_alpha/(param.C*l)+"\n");
01442         }
01443 
01444         private static void solve_nu_svr(svm_problem prob, svm_parameter param,
01445                                         double[] alpha, Solver.SolutionInfo si)
01446         {
01447                 int l = prob.l;
01448                 double C = param.C;
01449                 double[] alpha2 = new double[2*l];
01450                 double[] linear_term = new double[2*l];
01451                 byte[] y = new byte[2*l];
01452                 int i;
01453 
01454                 double sum = C * param.nu * l / 2;
01455                 for(i=0;i<l;i++)
01456                 {
01457                         alpha2[i] = alpha2[i+l] = Math.min(sum,C);
01458                         sum -= alpha2[i];
01459                         
01460                         linear_term[i] = - prob.y[i];
01461                         y[i] = 1;
01462 
01463                         linear_term[i+l] = prob.y[i];
01464                         y[i+l] = -1;
01465                 }
01466 
01467                 Solver_NU s = new Solver_NU();
01468                 s.Solve(2*l, new SVR_Q(prob,param), linear_term, y,
01469                         alpha2, C, C, param.eps, si, param.shrinking);
01470 
01471                 svm.info("epsilon = "+(-si.r)+"\n");
01472                 
01473                 for(i=0;i<l;i++)
01474                         alpha[i] = alpha2[i] - alpha2[i+l];
01475         }
01476 
01477         //
01478         // decision_function
01479         //
01480         static class decision_function
01481         {
01482                 double[] alpha;
01483                 double rho;     
01484         };
01485 
01486         static decision_function svm_train_one(
01487                 svm_problem prob, svm_parameter param,
01488                 double Cp, double Cn)
01489         {
01490                 double[] alpha = new double[prob.l];
01491                 Solver.SolutionInfo si = new Solver.SolutionInfo();
01492                 switch(param.svm_type)
01493                 {
01494                         case svm_parameter.C_SVC:
01495                                 solve_c_svc(prob,param,alpha,si,Cp,Cn);
01496                                 break;
01497                         case svm_parameter.NU_SVC:
01498                                 solve_nu_svc(prob,param,alpha,si);
01499                                 break;
01500                         case svm_parameter.ONE_CLASS:
01501                                 solve_one_class(prob,param,alpha,si);
01502                                 break;
01503                         case svm_parameter.EPSILON_SVR:
01504                                 solve_epsilon_svr(prob,param,alpha,si);
01505                                 break;
01506                         case svm_parameter.NU_SVR:
01507                                 solve_nu_svr(prob,param,alpha,si);
01508                                 break;
01509                 }
01510 
01511                 svm.info("obj = "+si.obj+", rho = "+si.rho+"\n");
01512 
01513                 // output SVs
01514 
01515                 int nSV = 0;
01516                 int nBSV = 0;
01517                 for(int i=0;i<prob.l;i++)
01518                 {
01519                         if(Math.abs(alpha[i]) > 0)
01520                         {
01521                                 ++nSV;
01522                                 if(prob.y[i] > 0)
01523                                 {
01524                                         if(Math.abs(alpha[i]) >= si.upper_bound_p)
01525                                         ++nBSV;
01526                                 }
01527                                 else
01528                                 {
01529                                         if(Math.abs(alpha[i]) >= si.upper_bound_n)
01530                                                 ++nBSV;
01531                                 }
01532                         }
01533                 }
01534 
01535                 svm.info("nSV = "+nSV+", nBSV = "+nBSV+"\n");
01536 
01537                 decision_function f = new decision_function();
01538                 f.alpha = alpha;
01539                 f.rho = si.rho;
01540                 return f;
01541         }
01542 
01543         // Platt's binary SVM Probablistic Output: an improvement from Lin et al.
01544         private static void sigmoid_train(int l, double[] dec_values, double[] labels, 
01545                                   double[] probAB)
01546         {
01547                 double A, B;
01548                 double prior1=0, prior0 = 0;
01549                 int i;
01550 
01551                 for (i=0;i<l;i++)
01552                         if (labels[i] > 0) prior1+=1;
01553                         else prior0+=1;
01554         
01555                 int max_iter=100;       // Maximal number of iterations
01556                 double min_step=1e-10;  // Minimal step taken in line search
01557                 double sigma=1e-12;     // For numerically strict PD of Hessian
01558                 double eps=1e-5;
01559                 double hiTarget=(prior1+1.0)/(prior1+2.0);
01560                 double loTarget=1/(prior0+2.0);
01561                 double[] t= new double[l];
01562                 double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
01563                 double newA,newB,newf,d1,d2;
01564                 int iter; 
01565         
01566                 // Initial Point and Initial Fun Value
01567                 A=0.0; B=Math.log((prior0+1.0)/(prior1+1.0));
01568                 double fval = 0.0;
01569 
01570                 for (i=0;i<l;i++)
01571                 {
01572                         if (labels[i]>0) t[i]=hiTarget;
01573                         else t[i]=loTarget;
01574                         fApB = dec_values[i]*A+B;
01575                         if (fApB>=0)
01576                                 fval += t[i]*fApB + Math.log(1+Math.exp(-fApB));
01577                         else
01578                                 fval += (t[i] - 1)*fApB +Math.log(1+Math.exp(fApB));
01579                 }
01580                 for (iter=0;iter<max_iter;iter++)
01581                 {
01582                         // Update Gradient and Hessian (use H' = H + sigma I)
01583                         h11=sigma; // numerically ensures strict PD
01584                         h22=sigma;
01585                         h21=0.0;g1=0.0;g2=0.0;
01586                         for (i=0;i<l;i++)
01587                         {
01588                                 fApB = dec_values[i]*A+B;
01589                                 if (fApB >= 0)
01590                                 {
01591                                         p=Math.exp(-fApB)/(1.0+Math.exp(-fApB));
01592                                         q=1.0/(1.0+Math.exp(-fApB));
01593                                 }
01594                                 else
01595                                 {
01596                                         p=1.0/(1.0+Math.exp(fApB));
01597                                         q=Math.exp(fApB)/(1.0+Math.exp(fApB));
01598                                 }
01599                                 d2=p*q;
01600                                 h11+=dec_values[i]*dec_values[i]*d2;
01601                                 h22+=d2;
01602                                 h21+=dec_values[i]*d2;
01603                                 d1=t[i]-p;
01604                                 g1+=dec_values[i]*d1;
01605                                 g2+=d1;
01606                         }
01607 
01608                         // Stopping Criteria
01609                         if (Math.abs(g1)<eps && Math.abs(g2)<eps)
01610                                 break;
01611                         
01612                         // Finding Newton direction: -inv(H') * g
01613                         det=h11*h22-h21*h21;
01614                         dA=-(h22*g1 - h21 * g2) / det;
01615                         dB=-(-h21*g1+ h11 * g2) / det;
01616                         gd=g1*dA+g2*dB;
01617 
01618 
01619                         stepsize = 1;           // Line Search
01620                         while (stepsize >= min_step)
01621                         {
01622                                 newA = A + stepsize * dA;
01623                                 newB = B + stepsize * dB;
01624 
01625                                 // New function value
01626                                 newf = 0.0;
01627                                 for (i=0;i<l;i++)
01628                                 {
01629                                         fApB = dec_values[i]*newA+newB;
01630                                         if (fApB >= 0)
01631                                                 newf += t[i]*fApB + Math.log(1+Math.exp(-fApB));
01632                                         else
01633                                                 newf += (t[i] - 1)*fApB +Math.log(1+Math.exp(fApB));
01634                                 }
01635                                 // Check sufficient decrease
01636                                 if (newf<fval+0.0001*stepsize*gd)
01637                                 {
01638                                         A=newA;B=newB;fval=newf;
01639                                         break;
01640                                 }
01641                                 else
01642                                         stepsize = stepsize / 2.0;
01643                         }
01644                         
01645                         if (stepsize < min_step)
01646                         {
01647                                 svm.info("Line search fails in two-class probability estimates\n");
01648                                 break;
01649                         }
01650                 }
01651                 
01652                 if (iter>=max_iter)
01653                         svm.info("Reaching maximal iterations in two-class probability estimates\n");
01654                 probAB[0]=A;probAB[1]=B;
01655         }
01656 
01657         private static double sigmoid_predict(double decision_value, double A, double B)
01658         {
01659                 double fApB = decision_value*A+B;
01660                 if (fApB >= 0)
01661                         return Math.exp(-fApB)/(1.0+Math.exp(-fApB));
01662                 else
01663                         return 1.0/(1+Math.exp(fApB)) ;
01664         }
01665 
01666         // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
01667         private static void multiclass_probability(int k, double[][] r, double[] p)
01668         {
01669                 int t,j;
01670                 int iter = 0, max_iter=Math.max(100,k);
01671                 double[][] Q=new double[k][k];
01672                 double[] Qp=new double[k];
01673                 double pQp, eps=0.005/k;
01674         
01675                 for (t=0;t<k;t++)
01676                 {
01677                         p[t]=1.0/k;  // Valid if k = 1
01678                         Q[t][t]=0;
01679                         for (j=0;j<t;j++)
01680                         {
01681                                 Q[t][t]+=r[j][t]*r[j][t];
01682                                 Q[t][j]=Q[j][t];
01683                         }
01684                         for (j=t+1;j<k;j++)
01685                         {
01686                                 Q[t][t]+=r[j][t]*r[j][t];
01687                                 Q[t][j]=-r[j][t]*r[t][j];
01688                         }
01689                 }
01690                 for (iter=0;iter<max_iter;iter++)
01691                 {
01692                         // stopping condition, recalculate QP,pQP for numerical accuracy
01693                         pQp=0;
01694                         for (t=0;t<k;t++)
01695                         {
01696                                 Qp[t]=0;
01697                                 for (j=0;j<k;j++)
01698                                         Qp[t]+=Q[t][j]*p[j];
01699                                 pQp+=p[t]*Qp[t];
01700                         }
01701                         double max_error=0;
01702                         for (t=0;t<k;t++)
01703                         {
01704                                 double error=Math.abs(Qp[t]-pQp);
01705                                 if (error>max_error)
01706                                         max_error=error;
01707                         }
01708                         if (max_error<eps) break;
01709                 
01710                         for (t=0;t<k;t++)
01711                         {
01712                                 double diff=(-Qp[t]+pQp)/Q[t][t];
01713                                 p[t]+=diff;
01714                                 pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff);
01715                                 for (j=0;j<k;j++)
01716                                 {
01717                                         Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff);
01718                                         p[j]/=(1+diff);
01719                                 }
01720                         }
01721                 }
01722                 if (iter>=max_iter)
01723                         svm.info("Exceeds max_iter in multiclass_prob\n");
01724         }
01725 
01726         // Cross-validation decision values for probability estimates
01727         private static void svm_binary_svc_probability(svm_problem prob, svm_parameter param, double Cp, double Cn, double[] probAB)
01728         {
01729                 int i;
01730                 int nr_fold = 5;
01731                 int[] perm = new int[prob.l];
01732                 double[] dec_values = new double[prob.l];
01733 
01734                 // random shuffle
01735                 for(i=0;i<prob.l;i++) perm[i]=i;
01736                 for(i=0;i<prob.l;i++)
01737                 {
01738                         int j = i+(int)(Math.random()*(prob.l-i));
01739                         do {int _=perm[i]; perm[i]=perm[j]; perm[j]=_;} while(false);
01740                 }
01741                 for(i=0;i<nr_fold;i++)
01742                 {
01743                         int begin = i*prob.l/nr_fold;
01744                         int end = (i+1)*prob.l/nr_fold;
01745                         int j,k;
01746                         svm_problem subprob = new svm_problem();
01747 
01748                         subprob.l = prob.l-(end-begin);
01749                         subprob.x = new svm_node[subprob.l][];
01750                         subprob.y = new double[subprob.l];
01751                         
01752                         k=0;
01753                         for(j=0;j<begin;j++)
01754                         {
01755                                 subprob.x[k] = prob.x[perm[j]];
01756                                 subprob.y[k] = prob.y[perm[j]];
01757                                 ++k;
01758                         }
01759                         for(j=end;j<prob.l;j++)
01760                         {
01761                                 subprob.x[k] = prob.x[perm[j]];
01762                                 subprob.y[k] = prob.y[perm[j]];
01763                                 ++k;
01764                         }
01765                         int p_count=0,n_count=0;
01766                         for(j=0;j<k;j++)
01767                                 if(subprob.y[j]>0)
01768                                         p_count++;
01769                                 else
01770                                         n_count++;
01771                         
01772                         if(p_count==0 && n_count==0)
01773                                 for(j=begin;j<end;j++)
01774                                         dec_values[perm[j]] = 0;
01775                         else if(p_count > 0 && n_count == 0)
01776                                 for(j=begin;j<end;j++)
01777                                         dec_values[perm[j]] = 1;
01778                         else if(p_count == 0 && n_count > 0)
01779                                 for(j=begin;j<end;j++)
01780                                         dec_values[perm[j]] = -1;
01781                         else
01782                         {
01783                                 svm_parameter subparam = (svm_parameter)param.clone();
01784                                 subparam.probability=0;
01785                                 subparam.C=1.0;
01786                                 subparam.nr_weight=2;
01787                                 subparam.weight_label = new int[2];
01788                                 subparam.weight = new double[2];
01789                                 subparam.weight_label[0]=+1;
01790                                 subparam.weight_label[1]=-1;
01791                                 subparam.weight[0]=Cp;
01792                                 subparam.weight[1]=Cn;
01793                                 svm_model submodel = svm_train(subprob,subparam);
01794                                 for(j=begin;j<end;j++)
01795                                 {
01796                                         double[] dec_value=new double[1];
01797                                         svm_predict_values(submodel,prob.x[perm[j]],dec_value);
01798                                         dec_values[perm[j]]=dec_value[0];
01799                                         // ensure +1 -1 order; reason not using CV subroutine
01800                                         dec_values[perm[j]] *= submodel.label[0];
01801                                 }               
01802                         }
01803                 }               
01804                 sigmoid_train(prob.l,dec_values,prob.y,probAB);
01805         }
01806 
01807         // Return parameter of a Laplace distribution 
01808         private static double svm_svr_probability(svm_problem prob, svm_parameter param)
01809         {
01810                 int i;
01811                 int nr_fold = 5;
01812                 double[] ymv = new double[prob.l];
01813                 double mae = 0;
01814 
01815                 svm_parameter newparam = (svm_parameter)param.clone();
01816                 newparam.probability = 0;
01817                 svm_cross_validation(prob,newparam,nr_fold,ymv);
01818                 for(i=0;i<prob.l;i++)
01819                 {
01820                         ymv[i]=prob.y[i]-ymv[i];
01821                         mae += Math.abs(ymv[i]);
01822                 }               
01823                 mae /= prob.l;
01824                 double std=Math.sqrt(2*mae*mae);
01825                 int count=0;
01826                 mae=0;
01827                 for(i=0;i<prob.l;i++)
01828                         if (Math.abs(ymv[i]) > 5*std) 
01829                                 count=count+1;
01830                         else 
01831                                 mae+=Math.abs(ymv[i]);
01832                 mae /= (prob.l-count);
01833                 svm.info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma="+mae+"\n");
01834                 return mae;
01835         }
01836 
01837         // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
01838         // perm, length l, must be allocated before calling this subroutine
01839         private static void svm_group_classes(svm_problem prob, int[] nr_class_ret, int[][] label_ret, int[][] start_ret, int[][] count_ret, int[] perm)
01840         {
01841                 int l = prob.l;
01842                 int max_nr_class = 16;
01843                 int nr_class = 0;
01844                 int[] label = new int[max_nr_class];
01845                 int[] count = new int[max_nr_class];
01846                 int[] data_label = new int[l];
01847                 int i;
01848 
01849                 for(i=0;i<l;i++)
01850                 {
01851                         int this_label = (int)(prob.y[i]);
01852                         int j;
01853                         for(j=0;j<nr_class;j++)
01854                         {
01855                                 if(this_label == label[j])
01856                                 {
01857                                         ++count[j];
01858                                         break;
01859                                 }
01860                         }
01861                         data_label[i] = j;
01862                         if(j == nr_class)
01863                         {
01864                                 if(nr_class == max_nr_class)
01865                                 {
01866                                         max_nr_class *= 2;
01867                                         int[] new_data = new int[max_nr_class];
01868                                         System.arraycopy(label,0,new_data,0,label.length);
01869                                         label = new_data;
01870                                         new_data = new int[max_nr_class];
01871                                         System.arraycopy(count,0,new_data,0,count.length);
01872                                         count = new_data;                                       
01873                                 }
01874                                 label[nr_class] = this_label;
01875                                 count[nr_class] = 1;
01876                                 ++nr_class;
01877                         }
01878                 }
01879 
01880                 int[] start = new int[nr_class];
01881                 start[0] = 0;
01882                 for(i=1;i<nr_class;i++)
01883                         start[i] = start[i-1]+count[i-1];
01884                 for(i=0;i<l;i++)
01885                 {
01886                         perm[start[data_label[i]]] = i;
01887                         ++start[data_label[i]];
01888                 }
01889                 start[0] = 0;
01890                 for(i=1;i<nr_class;i++)
01891                         start[i] = start[i-1]+count[i-1];
01892 
01893                 nr_class_ret[0] = nr_class;
01894                 label_ret[0] = label;
01895                 start_ret[0] = start;
01896                 count_ret[0] = count;
01897         }
01898 
01899         //
01900         // Interface functions
01901         //
01902         public static svm_model svm_train(svm_problem prob, svm_parameter param)
01903         {
01904                 svm_model model = new svm_model();
01905                 model.param = param;
01906 
01907                 if(param.svm_type == svm_parameter.ONE_CLASS ||
01908                    param.svm_type == svm_parameter.EPSILON_SVR ||
01909                    param.svm_type == svm_parameter.NU_SVR)
01910                 {
01911                         // regression or one-class-svm
01912                         model.nr_class = 2;
01913                         model.label = null;
01914                         model.nSV = null;
01915                         model.probA = null; model.probB = null;
01916                         model.sv_coef = new double[1][];
01917 
01918                         if(param.probability == 1 &&
01919                            (param.svm_type == svm_parameter.EPSILON_SVR ||
01920                             param.svm_type == svm_parameter.NU_SVR))
01921                         {
01922                                 model.probA = new double[1];
01923                                 model.probA[0] = svm_svr_probability(prob,param);
01924                         }
01925 
01926                         decision_function f = svm_train_one(prob,param,0,0);
01927                         model.rho = new double[1];
01928                         model.rho[0] = f.rho;
01929 
01930                         int nSV = 0;
01931                         int i;
01932                         for(i=0;i<prob.l;i++)
01933                                 if(Math.abs(f.alpha[i]) > 0) ++nSV;
01934                         model.l = nSV;
01935                         model.SV = new svm_node[nSV][];
01936                         model.sv_coef[0] = new double[nSV];
01937                         int j = 0;
01938                         for(i=0;i<prob.l;i++)
01939                                 if(Math.abs(f.alpha[i]) > 0)
01940                                 {
01941                                         model.SV[j] = prob.x[i];
01942                                         model.sv_coef[0][j] = f.alpha[i];
01943                                         ++j;
01944                                 }
01945                 }
01946                 else
01947                 {
01948                         // classification
01949                         int l = prob.l;
01950                         int[] tmp_nr_class = new int[1];
01951                         int[][] tmp_label = new int[1][];
01952                         int[][] tmp_start = new int[1][];
01953                         int[][] tmp_count = new int[1][];                       
01954                         int[] perm = new int[l];
01955 
01956                         // group training data of the same class
01957                         svm_group_classes(prob,tmp_nr_class,tmp_label,tmp_start,tmp_count,perm);
01958                         int nr_class = tmp_nr_class[0];                 
01959                         int[] label = tmp_label[0];
01960                         int[] start = tmp_start[0];
01961                         int[] count = tmp_count[0];
01962                         svm_node[][] x = new svm_node[l][];
01963                         int i;
01964                         for(i=0;i<l;i++)
01965                                 x[i] = prob.x[perm[i]];
01966 
01967                         // calculate weighted C
01968 
01969                         double[] weighted_C = new double[nr_class];
01970                         for(i=0;i<nr_class;i++)
01971                                 weighted_C[i] = param.C;
01972                         for(i=0;i<param.nr_weight;i++)
01973                         {
01974                                 int j;
01975                                 for(j=0;j<nr_class;j++)
01976                                         if(param.weight_label[i] == label[j])
01977                                                 break;
01978                                 if(j == nr_class)
01979                                         System.err.print("warning: class label "+param.weight_label[i]+" specified in weight is not found\n");
01980                                 else
01981                                         weighted_C[j] *= param.weight[i];
01982                         }
01983 
01984                         // train k*(k-1)/2 models
01985 
01986                         boolean[] nonzero = new boolean[l];
01987                         for(i=0;i<l;i++)
01988                                 nonzero[i] = false;
01989                         decision_function[] f = new decision_function[nr_class*(nr_class-1)/2];
01990 
01991                         double[] probA=null,probB=null;
01992                         if (param.probability == 1)
01993                         {
01994                                 probA=new double[nr_class*(nr_class-1)/2];
01995                                 probB=new double[nr_class*(nr_class-1)/2];
01996                         }
01997 
01998                         int p = 0;
01999                         for(i=0;i<nr_class;i++)
02000                                 for(int j=i+1;j<nr_class;j++)
02001                                 {
02002                                         svm_problem sub_prob = new svm_problem();
02003                                         int si = start[i], sj = start[j];
02004                                         int ci = count[i], cj = count[j];
02005                                         sub_prob.l = ci+cj;
02006                                         sub_prob.x = new svm_node[sub_prob.l][];
02007                                         sub_prob.y = new double[sub_prob.l];
02008                                         int k;
02009                                         for(k=0;k<ci;k++)
02010                                         {
02011                                                 sub_prob.x[k] = x[si+k];
02012                                                 sub_prob.y[k] = +1;
02013                                         }
02014                                         for(k=0;k<cj;k++)
02015                                         {
02016                                                 sub_prob.x[ci+k] = x[sj+k];
02017                                                 sub_prob.y[ci+k] = -1;
02018                                         }
02019 
02020                                         if(param.probability == 1)
02021                                         {
02022                                                 double[] probAB=new double[2];
02023                                                 svm_binary_svc_probability(sub_prob,param,weighted_C[i],weighted_C[j],probAB);
02024                                                 probA[p]=probAB[0];
02025                                                 probB[p]=probAB[1];
02026                                         }
02027 
02028                                         f[p] = svm_train_one(sub_prob,param,weighted_C[i],weighted_C[j]);
02029                                         for(k=0;k<ci;k++)
02030                                                 if(!nonzero[si+k] && Math.abs(f[p].alpha[k]) > 0)
02031                                                         nonzero[si+k] = true;
02032                                         for(k=0;k<cj;k++)
02033                                                 if(!nonzero[sj+k] && Math.abs(f[p].alpha[ci+k]) > 0)
02034                                                         nonzero[sj+k] = true;
02035                                         ++p;
02036                                 }
02037 
02038                         // build output
02039 
02040                         model.nr_class = nr_class;
02041 
02042                         model.label = new int[nr_class];
02043                         for(i=0;i<nr_class;i++)
02044                                 model.label[i] = label[i];
02045 
02046                         model.rho = new double[nr_class*(nr_class-1)/2];
02047                         for(i=0;i<nr_class*(nr_class-1)/2;i++)
02048                                 model.rho[i] = f[i].rho;
02049 
02050                         if(param.probability == 1)
02051                         {
02052                                 model.probA = new double[nr_class*(nr_class-1)/2];
02053                                 model.probB = new double[nr_class*(nr_class-1)/2];
02054                                 for(i=0;i<nr_class*(nr_class-1)/2;i++)
02055                                 {
02056                                         model.probA[i] = probA[i];
02057                                         model.probB[i] = probB[i];
02058                                 }
02059                         }
02060                         else
02061                         {
02062                                 model.probA=null;
02063                                 model.probB=null;
02064                         }
02065 
02066                         int nnz = 0;
02067                         int[] nz_count = new int[nr_class];
02068                         model.nSV = new int[nr_class];
02069                         for(i=0;i<nr_class;i++)
02070                         {
02071                                 int nSV = 0;
02072                                 for(int j=0;j<count[i];j++)
02073                                         if(nonzero[start[i]+j])
02074                                         {
02075                                                 ++nSV;
02076                                                 ++nnz;
02077                                         }
02078                                 model.nSV[i] = nSV;
02079                                 nz_count[i] = nSV;
02080                         }
02081 
02082                         svm.info("Total nSV = "+nnz+"\n");
02083 
02084                         model.l = nnz;
02085                         model.SV = new svm_node[nnz][];
02086                         p = 0;
02087                         for(i=0;i<l;i++)
02088                                 if(nonzero[i]) model.SV[p++] = x[i];
02089 
02090                         int[] nz_start = new int[nr_class];
02091                         nz_start[0] = 0;
02092                         for(i=1;i<nr_class;i++)
02093                                 nz_start[i] = nz_start[i-1]+nz_count[i-1];
02094 
02095                         model.sv_coef = new double[nr_class-1][];
02096                         for(i=0;i<nr_class-1;i++)
02097                                 model.sv_coef[i] = new double[nnz];
02098 
02099                         p = 0;
02100                         for(i=0;i<nr_class;i++)
02101                                 for(int j=i+1;j<nr_class;j++)
02102                                 {
02103                                         // classifier (i,j): coefficients with
02104                                         // i are in sv_coef[j-1][nz_start[i]...],
02105                                         // j are in sv_coef[i][nz_start[j]...]
02106 
02107                                         int si = start[i];
02108                                         int sj = start[j];
02109                                         int ci = count[i];
02110                                         int cj = count[j];
02111 
02112                                         int q = nz_start[i];
02113                                         int k;
02114                                         for(k=0;k<ci;k++)
02115                                                 if(nonzero[si+k])
02116                                                         model.sv_coef[j-1][q++] = f[p].alpha[k];
02117                                         q = nz_start[j];
02118                                         for(k=0;k<cj;k++)
02119                                                 if(nonzero[sj+k])
02120                                                         model.sv_coef[i][q++] = f[p].alpha[ci+k];
02121                                         ++p;
02122                                 }
02123                 }
02124                 return model;
02125         }
02126         
02127         // Stratified cross validation
02128         public static void svm_cross_validation(svm_problem prob, svm_parameter param, int nr_fold, double[] target)
02129         {
02130                 int i;
02131                 int[] fold_start = new int[nr_fold+1];
02132                 int l = prob.l;
02133                 int[] perm = new int[l];
02134                 
02135                 // stratified cv may not give leave-one-out rate
02136                 // Each class to l folds -> some folds may have zero elements
02137                 if((param.svm_type == svm_parameter.C_SVC ||
02138                     param.svm_type == svm_parameter.NU_SVC) && nr_fold < l)
02139                 {
02140                         int[] tmp_nr_class = new int[1];
02141                         int[][] tmp_label = new int[1][];
02142                         int[][] tmp_start = new int[1][];
02143                         int[][] tmp_count = new int[1][];
02144 
02145                         svm_group_classes(prob,tmp_nr_class,tmp_label,tmp_start,tmp_count,perm);
02146 
02147                         int nr_class = tmp_nr_class[0];
02148                         int[] start = tmp_start[0];
02149                         int[] count = tmp_count[0];             
02150 
02151                         // random shuffle and then data grouped by fold using the array perm
02152                         int[] fold_count = new int[nr_fold];
02153                         int c;
02154                         int[] index = new int[l];
02155                         for(i=0;i<l;i++)
02156                                 index[i]=perm[i];
02157                         for (c=0; c<nr_class; c++)
02158                                 for(i=0;i<count[c];i++)
02159                                 {
02160                                         int j = i+(int)(Math.random()*(count[c]-i));
02161                                         do {int _=index[start[c]+j]; index[start[c]+j]=index[start[c]+i]; index[start[c]+i]=_;} while(false);
02162                                 }
02163                         for(i=0;i<nr_fold;i++)
02164                         {
02165                                 fold_count[i] = 0;
02166                                 for (c=0; c<nr_class;c++)
02167                                         fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold;
02168                         }
02169                         fold_start[0]=0;
02170                         for (i=1;i<=nr_fold;i++)
02171                                 fold_start[i] = fold_start[i-1]+fold_count[i-1];
02172                         for (c=0; c<nr_class;c++)
02173                                 for(i=0;i<nr_fold;i++)
02174                                 {
02175                                         int begin = start[c]+i*count[c]/nr_fold;
02176                                         int end = start[c]+(i+1)*count[c]/nr_fold;
02177                                         for(int j=begin;j<end;j++)
02178                                         {
02179                                                 perm[fold_start[i]] = index[j];
02180                                                 fold_start[i]++;
02181                                         }
02182                                 }
02183                         fold_start[0]=0;
02184                         for (i=1;i<=nr_fold;i++)
02185                                 fold_start[i] = fold_start[i-1]+fold_count[i-1];
02186                 }
02187                 else
02188                 {
02189                         for(i=0;i<l;i++) perm[i]=i;
02190                         for(i=0;i<l;i++)
02191                         {
02192                                 int j = i+(int)(Math.random()*(l-i));
02193                                 do {int _=perm[i]; perm[i]=perm[j]; perm[j]=_;} while(false);
02194                         }
02195                         for(i=0;i<=nr_fold;i++)
02196                                 fold_start[i]=i*l/nr_fold;
02197                 }
02198 
02199                 for(i=0;i<nr_fold;i++)
02200                 {
02201                         int begin = fold_start[i];
02202                         int end = fold_start[i+1];
02203                         int j,k;
02204                         svm_problem subprob = new svm_problem();
02205 
02206                         subprob.l = l-(end-begin);
02207                         subprob.x = new svm_node[subprob.l][];
02208                         subprob.y = new double[subprob.l];
02209 
02210                         k=0;
02211                         for(j=0;j<begin;j++)
02212                         {
02213                                 subprob.x[k] = prob.x[perm[j]];
02214                                 subprob.y[k] = prob.y[perm[j]];
02215                                 ++k;
02216                         }
02217                         for(j=end;j<l;j++)
02218                         {
02219                                 subprob.x[k] = prob.x[perm[j]];
02220                                 subprob.y[k] = prob.y[perm[j]];
02221                                 ++k;
02222                         }
02223                         svm_model submodel = svm_train(subprob,param);
02224                         if(param.probability==1 &&
02225                            (param.svm_type == svm_parameter.C_SVC ||
02226                             param.svm_type == svm_parameter.NU_SVC))
02227                         {
02228                                 double[] prob_estimates= new double[svm_get_nr_class(submodel)];
02229                                 for(j=begin;j<end;j++)
02230                                         target[perm[j]] = svm_predict_probability(submodel,prob.x[perm[j]],prob_estimates);
02231                         }
02232                         else
02233                                 for(j=begin;j<end;j++)
02234                                         target[perm[j]] = svm_predict(submodel,prob.x[perm[j]]);
02235                 }
02236         }
02237 
02238         public static int svm_get_svm_type(svm_model model)
02239         {
02240                 return model.param.svm_type;
02241         }
02242 
02243         public static int svm_get_nr_class(svm_model model)
02244         {
02245                 return model.nr_class;
02246         }
02247 
02248         public static void svm_get_labels(svm_model model, int[] label)
02249         {
02250                 if (model.label != null)
02251                         for(int i=0;i<model.nr_class;i++)
02252                                 label[i] = model.label[i];
02253         }
02254 
02255         public static double svm_get_svr_probability(svm_model model)
02256         {
02257                 if ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) &&
02258                     model.probA!=null)
02259                 return model.probA[0];
02260                 else
02261                 {
02262                         System.err.print("Model doesn't contain information for SVR probability inference\n");
02263                         return 0;
02264                 }
02265         }
02266 
02267         public static double svm_predict_values(svm_model model, svm_node[] x, double[] dec_values)
02268         {
02269                 if(model.param.svm_type == svm_parameter.ONE_CLASS ||
02270                    model.param.svm_type == svm_parameter.EPSILON_SVR ||
02271                    model.param.svm_type == svm_parameter.NU_SVR)
02272                 {
02273                         double[] sv_coef = model.sv_coef[0];
02274                         double sum = 0;
02275                         for(int i=0;i<model.l;i++)
02276                                 sum += sv_coef[i] * Kernel.k_function(x,model.SV[i],model.param);
02277                         sum -= model.rho[0];
02278                         dec_values[0] = sum;
02279 
02280                         if(model.param.svm_type == svm_parameter.ONE_CLASS)
02281                                 return (sum>0)?1:-1;
02282                         else
02283                                 return sum;
02284                 }
02285                 else
02286                 {
02287                         int i;
02288                         int nr_class = model.nr_class;
02289                         int l = model.l;
02290                 
02291                         double[] kvalue = new double[l];
02292                         for(i=0;i<l;i++)
02293                                 kvalue[i] = Kernel.k_function(x,model.SV[i],model.param);
02294 
02295                         int[] start = new int[nr_class];
02296                         start[0] = 0;
02297                         for(i=1;i<nr_class;i++)
02298                                 start[i] = start[i-1]+model.nSV[i-1];
02299 
02300                         int[] vote = new int[nr_class];
02301                         for(i=0;i<nr_class;i++)
02302                                 vote[i] = 0;
02303 
02304                         int p=0;
02305                         for(i=0;i<nr_class;i++)
02306                                 for(int j=i+1;j<nr_class;j++)
02307                                 {
02308                                         double sum = 0;
02309                                         int si = start[i];
02310                                         int sj = start[j];
02311                                         int ci = model.nSV[i];
02312                                         int cj = model.nSV[j];
02313                                 
02314                                         int k;
02315                                         double[] coef1 = model.sv_coef[j-1];
02316                                         double[] coef2 = model.sv_coef[i];
02317                                         for(k=0;k<ci;k++)
02318                                                 sum += coef1[si+k] * kvalue[si+k];
02319                                         for(k=0;k<cj;k++)
02320                                                 sum += coef2[sj+k] * kvalue[sj+k];
02321                                         sum -= model.rho[p];
02322                                         dec_values[p] = sum;                                    
02323 
02324                                         if(dec_values[p] > 0)
02325                                                 ++vote[i];
02326                                         else
02327                                                 ++vote[j];
02328                                         p++;
02329                                 }
02330 
02331                         int vote_max_idx = 0;
02332                         for(i=1;i<nr_class;i++)
02333                                 if(vote[i] > vote[vote_max_idx])
02334                                         vote_max_idx = i;
02335 
02336                         return model.label[vote_max_idx];
02337                 }
02338         }
02339 
02340         public static double svm_predict(svm_model model, svm_node[] x)
02341         {
02342                 int nr_class = model.nr_class;
02343                 double[] dec_values;
02344                 if(model.param.svm_type == svm_parameter.ONE_CLASS ||
02345                                 model.param.svm_type == svm_parameter.EPSILON_SVR ||
02346                                 model.param.svm_type == svm_parameter.NU_SVR)
02347                         dec_values = new double[1];
02348                 else
02349                         dec_values = new double[nr_class*(nr_class-1)/2];
02350                 double pred_result = svm_predict_values(model, x, dec_values);
02351                 return pred_result;
02352         }
02353 
02354         public static double svm_predict_probability(svm_model model, svm_node[] x, double[] prob_estimates)
02355         {
02356                 if ((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) &&
02357                     model.probA!=null && model.probB!=null)
02358                 {
02359                         int i;
02360                         int nr_class = model.nr_class;
02361                         double[] dec_values = new double[nr_class*(nr_class-1)/2];
02362                         svm_predict_values(model, x, dec_values);
02363 
02364                         double min_prob=1e-7;
02365                         double[][] pairwise_prob=new double[nr_class][nr_class];
02366                         
02367                         int k=0;
02368                         for(i=0;i<nr_class;i++)
02369                                 for(int j=i+1;j<nr_class;j++)
02370                                 {
02371                                         pairwise_prob[i][j]=Math.min(Math.max(sigmoid_predict(dec_values[k],model.probA[k],model.probB[k]),min_prob),1-min_prob);
02372                                         pairwise_prob[j][i]=1-pairwise_prob[i][j];
02373                                         k++;
02374                                 }
02375                         multiclass_probability(nr_class,pairwise_prob,prob_estimates);
02376 
02377                         int prob_max_idx = 0;
02378                         for(i=1;i<nr_class;i++)
02379                                 if(prob_estimates[i] > prob_estimates[prob_max_idx])
02380                                         prob_max_idx = i;
02381                         return model.label[prob_max_idx];
02382                 }
02383                 else 
02384                         return svm_predict(model, x);
02385         }
02386 
02387         static final String svm_type_table[] =
02388         {
02389                 "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",
02390         };
02391 
02392         static final String kernel_type_table[]=
02393         {
02394                 "linear","polynomial","rbf","sigmoid","precomputed"
02395         };
02396 
02397         public static void svm_save_model(String model_file_name, svm_model model) throws IOException
02398         {
02399                 DataOutputStream fp = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(model_file_name)));
02400 
02401                 svm_parameter param = model.param;
02402 
02403                 fp.writeBytes("svm_type "+svm_type_table[param.svm_type]+"\n");
02404                 fp.writeBytes("kernel_type "+kernel_type_table[param.kernel_type]+"\n");
02405 
02406                 if(param.kernel_type == svm_parameter.POLY)
02407                         fp.writeBytes("degree "+param.degree+"\n");
02408 
02409                 if(param.kernel_type == svm_parameter.POLY ||
02410                    param.kernel_type == svm_parameter.RBF ||
02411                    param.kernel_type == svm_parameter.SIGMOID)
02412                         fp.writeBytes("gamma "+param.gamma+"\n");
02413 
02414                 if(param.kernel_type == svm_parameter.POLY ||
02415                    param.kernel_type == svm_parameter.SIGMOID)
02416                         fp.writeBytes("coef0 "+param.coef0+"\n");
02417 
02418                 int nr_class = model.nr_class;
02419                 int l = model.l;
02420                 fp.writeBytes("nr_class "+nr_class+"\n");
02421                 fp.writeBytes("total_sv "+l+"\n");
02422         
02423                 {
02424                         fp.writeBytes("rho");
02425                         for(int i=0;i<nr_class*(nr_class-1)/2;i++)
02426                                 fp.writeBytes(" "+model.rho[i]);
02427                         fp.writeBytes("\n");
02428                 }
02429         
02430                 if(model.label != null)
02431                 {
02432                         fp.writeBytes("label");
02433                         for(int i=0;i<nr_class;i++)
02434                                 fp.writeBytes(" "+model.label[i]);
02435                         fp.writeBytes("\n");
02436                 }
02437 
02438                 if(model.probA != null) // regression has probA only
02439                 {
02440                         fp.writeBytes("probA");
02441                         for(int i=0;i<nr_class*(nr_class-1)/2;i++)
02442                                 fp.writeBytes(" "+model.probA[i]);
02443                         fp.writeBytes("\n");
02444                 }
02445                 if(model.probB != null) 
02446                 {
02447                         fp.writeBytes("probB");
02448                         for(int i=0;i<nr_class*(nr_class-1)/2;i++)
02449                                 fp.writeBytes(" "+model.probB[i]);
02450                         fp.writeBytes("\n");
02451                 }
02452 
02453                 if(model.nSV != null)
02454                 {
02455                         fp.writeBytes("nr_sv");
02456                         for(int i=0;i<nr_class;i++)
02457                                 fp.writeBytes(" "+model.nSV[i]);
02458                         fp.writeBytes("\n");
02459                 }
02460 
02461                 fp.writeBytes("SV\n");
02462                 double[][] sv_coef = model.sv_coef;
02463                 svm_node[][] SV = model.SV;
02464 
02465                 for(int i=0;i<l;i++)
02466                 {
02467                         for(int j=0;j<nr_class-1;j++)
02468                                 fp.writeBytes(sv_coef[j][i]+" ");
02469 
02470                         svm_node[] p = SV[i];
02471                         if(param.kernel_type == svm_parameter.PRECOMPUTED)
02472                                 fp.writeBytes("0:"+(int)(p[0].value));
02473                         else    
02474                                 for(int j=0;j<p.length;j++)
02475                                         fp.writeBytes(p[j].index+":"+p[j].value+" ");
02476                         fp.writeBytes("\n");
02477                 }
02478 
02479                 fp.close();
02480         }
02481 
02482         private static double atof(String s)
02483         {
02484                 return Double.valueOf(s).doubleValue();
02485         }
02486 
02487         private static int atoi(String s)
02488         {
02489                 return Integer.parseInt(s);
02490         }
02491 
02492         public static svm_model svm_load_model(String model_file_name) throws IOException
02493         {
02494                 return svm_load_model(new BufferedReader(new FileReader(model_file_name)));
02495         }
02496 
02497         public static svm_model svm_load_model(BufferedReader fp) throws IOException
02498         {
02499                 // read parameters
02500 
02501                 svm_model model = new svm_model();
02502                 svm_parameter param = new svm_parameter();
02503                 model.param = param;
02504                 model.rho = null;
02505                 model.probA = null;
02506                 model.probB = null;
02507                 model.label = null;
02508                 model.nSV = null;
02509 
02510                 while(true)
02511                 {
02512                         String cmd = fp.readLine();
02513                         String arg = cmd.substring(cmd.indexOf(' ')+1);
02514 
02515                         if(cmd.startsWith("svm_type"))
02516                         {
02517                                 int i;
02518                                 for(i=0;i<svm_type_table.length;i++)
02519                                 {
02520                                         if(arg.indexOf(svm_type_table[i])!=-1)
02521                                         {
02522                                                 param.svm_type=i;
02523                                                 break;
02524                                         }
02525                                 }
02526                                 if(i == svm_type_table.length)
02527                                 {
02528                                         System.err.print("unknown svm type.\n");
02529                                         return null;
02530                                 }
02531                         }
02532                         else if(cmd.startsWith("kernel_type"))
02533                         {
02534                                 int i;
02535                                 for(i=0;i<kernel_type_table.length;i++)
02536                                 {
02537                                         if(arg.indexOf(kernel_type_table[i])!=-1)
02538                                         {
02539                                                 param.kernel_type=i;
02540                                                 break;
02541                                         }
02542                                 }
02543                                 if(i == kernel_type_table.length)
02544                                 {
02545                                         System.err.print("unknown kernel function.\n");
02546                                         return null;
02547                                 }
02548                         }
02549                         else if(cmd.startsWith("degree"))
02550                                 param.degree = atoi(arg);
02551                         else if(cmd.startsWith("gamma"))
02552                                 param.gamma = atof(arg);
02553                         else if(cmd.startsWith("coef0"))
02554                                 param.coef0 = atof(arg);
02555                         else if(cmd.startsWith("nr_class"))
02556                                 model.nr_class = atoi(arg);
02557                         else if(cmd.startsWith("total_sv"))
02558                                 model.l = atoi(arg);
02559                         else if(cmd.startsWith("rho"))
02560                         {
02561                                 int n = model.nr_class * (model.nr_class-1)/2;
02562                                 model.rho = new double[n];
02563                                 StringTokenizer st = new StringTokenizer(arg);
02564                                 for(int i=0;i<n;i++)
02565                                         model.rho[i] = atof(st.nextToken());
02566                         }
02567                         else if(cmd.startsWith("label"))
02568                         {
02569                                 int n = model.nr_class;
02570                                 model.label = new int[n];
02571                                 StringTokenizer st = new StringTokenizer(arg);
02572                                 for(int i=0;i<n;i++)
02573                                         model.label[i] = atoi(st.nextToken());                                  
02574                         }
02575                         else if(cmd.startsWith("probA"))
02576                         {
02577                                 int n = model.nr_class*(model.nr_class-1)/2;
02578                                 model.probA = new double[n];
02579                                 StringTokenizer st = new StringTokenizer(arg);
02580                                 for(int i=0;i<n;i++)
02581                                         model.probA[i] = atof(st.nextToken());                                  
02582                         }
02583                         else if(cmd.startsWith("probB"))
02584                         {
02585                                 int n = model.nr_class*(model.nr_class-1)/2;
02586                                 model.probB = new double[n];
02587                                 StringTokenizer st = new StringTokenizer(arg);
02588                                 for(int i=0;i<n;i++)
02589                                         model.probB[i] = atof(st.nextToken());                                  
02590                         }
02591                         else if(cmd.startsWith("nr_sv"))
02592                         {
02593                                 int n = model.nr_class;
02594                                 model.nSV = new int[n];
02595                                 StringTokenizer st = new StringTokenizer(arg);
02596                                 for(int i=0;i<n;i++)
02597                                         model.nSV[i] = atoi(st.nextToken());
02598                         }
02599                         else if(cmd.startsWith("SV"))
02600                         {
02601                                 break;
02602                         }
02603                         else
02604                         {
02605                                 System.err.print("unknown text in model file: ["+cmd+"]\n");
02606                                 return null;
02607                         }
02608                 }
02609 
02610                 // read sv_coef and SV
02611 
02612                 int m = model.nr_class - 1;
02613                 int l = model.l;
02614                 model.sv_coef = new double[m][l];
02615                 model.SV = new svm_node[l][];
02616 
02617                 for(int i=0;i<l;i++)
02618                 {
02619                         String line = fp.readLine();
02620                         StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
02621 
02622                         for(int k=0;k<m;k++)
02623                                 model.sv_coef[k][i] = atof(st.nextToken());
02624                         int n = st.countTokens()/2;
02625                         model.SV[i] = new svm_node[n];
02626                         for(int j=0;j<n;j++)
02627                         {
02628                                 model.SV[i][j] = new svm_node();
02629                                 model.SV[i][j].index = atoi(st.nextToken());
02630                                 model.SV[i][j].value = atof(st.nextToken());
02631                         }
02632                 }
02633 
02634                 fp.close();
02635                 return model;
02636         }
02637 
02638         public static String svm_check_parameter(svm_problem prob, svm_parameter param)
02639         {
02640                 // svm_type
02641 
02642                 int svm_type = param.svm_type;
02643                 if(svm_type != svm_parameter.C_SVC &&
02644                    svm_type != svm_parameter.NU_SVC &&
02645                    svm_type != svm_parameter.ONE_CLASS &&
02646                    svm_type != svm_parameter.EPSILON_SVR &&
02647                    svm_type != svm_parameter.NU_SVR)
02648                 return "unknown svm type";
02649 
02650                 // kernel_type, degree
02651         
02652                 int kernel_type = param.kernel_type;
02653                 if(kernel_type != svm_parameter.LINEAR &&
02654                    kernel_type != svm_parameter.POLY &&
02655                    kernel_type != svm_parameter.RBF &&
02656                    kernel_type != svm_parameter.SIGMOID &&
02657                    kernel_type != svm_parameter.PRECOMPUTED)
02658                         return "unknown kernel type";
02659 
02660                 if(param.gamma < 0)
02661                         return "gamma < 0";
02662 
02663                 if(param.degree < 0)
02664                         return "degree of polynomial kernel < 0";
02665 
02666                 // cache_size,eps,C,nu,p,shrinking
02667 
02668                 if(param.cache_size <= 0)
02669                         return "cache_size <= 0";
02670 
02671                 if(param.eps <= 0)
02672                         return "eps <= 0";
02673 
02674                 if(svm_type == svm_parameter.C_SVC ||
02675                    svm_type == svm_parameter.EPSILON_SVR ||
02676                    svm_type == svm_parameter.NU_SVR)
02677                         if(param.C <= 0)
02678                                 return "C <= 0";
02679 
02680                 if(svm_type == svm_parameter.NU_SVC ||
02681                    svm_type == svm_parameter.ONE_CLASS ||
02682                    svm_type == svm_parameter.NU_SVR)
02683                         if(param.nu <= 0 || param.nu > 1)
02684                                 return "nu <= 0 or nu > 1";
02685 
02686                 if(svm_type == svm_parameter.EPSILON_SVR)
02687                         if(param.p < 0)
02688                                 return "p < 0";
02689 
02690                 if(param.shrinking != 0 &&
02691                    param.shrinking != 1)
02692                         return "shrinking != 0 and shrinking != 1";
02693 
02694                 if(param.probability != 0 &&
02695                    param.probability != 1)
02696                         return "probability != 0 and probability != 1";
02697 
02698                 if(param.probability == 1 &&
02699                    svm_type == svm_parameter.ONE_CLASS)
02700                         return "one-class SVM probability output not supported yet";
02701                 
02702                 // check whether nu-svc is feasible
02703         
02704                 if(svm_type == svm_parameter.NU_SVC)
02705                 {
02706                         int l = prob.l;
02707                         int max_nr_class = 16;
02708                         int nr_class = 0;
02709                         int[] label = new int[max_nr_class];
02710                         int[] count = new int[max_nr_class];
02711 
02712                         int i;
02713                         for(i=0;i<l;i++)
02714                         {
02715                                 int this_label = (int)prob.y[i];
02716                                 int j;
02717                                 for(j=0;j<nr_class;j++)
02718                                         if(this_label == label[j])
02719                                         {
02720                                                 ++count[j];
02721                                                 break;
02722                                         }
02723 
02724                                 if(j == nr_class)
02725                                 {
02726                                         if(nr_class == max_nr_class)
02727                                         {
02728                                                 max_nr_class *= 2;
02729                                                 int[] new_data = new int[max_nr_class];
02730                                                 System.arraycopy(label,0,new_data,0,label.length);
02731                                                 label = new_data;
02732                                                 
02733                                                 new_data = new int[max_nr_class];
02734                                                 System.arraycopy(count,0,new_data,0,count.length);
02735                                                 count = new_data;
02736                                         }
02737                                         label[nr_class] = this_label;
02738                                         count[nr_class] = 1;
02739                                         ++nr_class;
02740                                 }
02741                         }
02742 
02743                         for(i=0;i<nr_class;i++)
02744                         {
02745                                 int n1 = count[i];
02746                                 for(int j=i+1;j<nr_class;j++)
02747                                 {
02748                                         int n2 = count[j];
02749                                         if(param.nu*(n1+n2)/2 > Math.min(n1,n2))
02750                                                 return "specified nu is infeasible";
02751                                 }
02752                         }
02753                 }
02754 
02755                 return null;
02756         }
02757 
02758         public static int svm_check_probability_model(svm_model model)
02759         {
02760                 if (((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) &&
02761                 model.probA!=null && model.probB!=null) ||
02762                 ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) &&
02763                  model.probA!=null))
02764                         return 1;
02765                 else
02766                         return 0;
02767         }
02768 
02769         public static void svm_set_print_string_function(svm_print_interface print_func)
02770         {
02771                 if (print_func == null)
02772                         svm_print_string = svm_print_stdout;
02773                 else 
02774                         svm_print_string = print_func;
02775         }
02776 }


libsvm3
Author(s): various
autogenerated on Wed Nov 27 2013 11:36:23