ADDED: SVM-ARD kernel relevance determination
[mldemos:mldemos.git] / _AlgorithmsPlugins / KernelMethods / svm.cpp
1 /*********************************************************************
2 MLDemos: A User-Friendly visualization toolkit for machine learning
3 Copyright (C) 2010  Basilio Noris
4 Contact: mldemos@b4silio.com
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public
8 License as published by the Free Software Foundation; either
9 version 2.1 of the License, or (at your option) any later version.
10
11 This library is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 Library General Public License for more details.
15
16 You should have received a copy of the GNU Lesser General Public
17 License along with this library; if not, write to the Free
18 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
19 *********************************************************************/
20 #include <math.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <ctype.h>
24 #include <float.h>
25 #include <string.h>
26 #include <stdarg.h>
27 #include "svm.h"
28 #ifdef WIN32
29 #pragma warning(disable : 4996)
30 #endif
31
32 #ifndef min
33 template <class T> inline T min(T x,T y) { return (x<y)?x:y; }
34 #endif
35 #ifndef max
36 template <class T> inline T max(T x,T y) { return (x>y)?x:y; }
37 #endif
38 template <class T> inline void swap(T& x, T& y) { T t=x; x=y; y=t; }
39 template <class S, class T> inline void clone(T*& dst, S* src, int n)
40 {
41         dst = new T[n];
42         memcpy((void *)dst,(void *)src,sizeof(T)*n);
43 }
44 inline double powi(double base, int times)
45 {
46         double tmp = base, ret = 1.0;
47
48         for(int t=times; t>0; t/=2)
49         {
50                 if(t%2==1) ret*=tmp;
51                 tmp = tmp * tmp;
52         }
53         return ret;
54 }
55 #define INF HUGE_VAL
56 #define TAU 1e-12
57 #if 0
58 void info(const char *fmt,...)
59 {
60         va_list ap;
61         va_start(ap,fmt);
62         vprintf(fmt,ap);
63         va_end(ap);
64 }
65 void info_flush()
66 {
67         fflush(stdout);
68 }
69 #else
70 void info(const char *fmt,...) {}
71 void info_flush() {}
72 #endif
73
74 Cache::Cache(int l_,long int size_):l(l_),size(size_)
75 {
76         head = (head_t *)calloc(l,sizeof(head_t));      // initialized to 0
77         size /= sizeof(Qfloat);
78         size -= l * sizeof(head_t) / sizeof(Qfloat);
79         size = max(size, 2 * (long int) l);     // cache must be large enough for two columns
80         lru_head.next = lru_head.prev = &lru_head;
81 }
82
83 Cache::~Cache()
84 {
85         for(head_t *h = lru_head.next; h != &lru_head; h=h->next)
86         delete [] h->data;
87     delete [] head;
88 }
89
90 void Cache::lru_delete(head_t *h)
91 {
92         // delete from current location
93         h->prev->next = h->next;
94         h->next->prev = h->prev;
95 }
96
97 void Cache::lru_insert(head_t *h)
98 {
99         // insert to last position
100         h->next = &lru_head;
101         h->prev = lru_head.prev;
102         h->prev->next = h;
103         h->next->prev = h;
104 }
105
106 int Cache::get_data(const int index, Qfloat **data, int len)
107 {
108         head_t *h = &head[index];
109         if(h->len) lru_delete(h);
110         int more = len - h->len;
111
112         if(more > 0)
113         {
114                 // free old space
115                 while(size < more)
116                 {
117                         head_t *old = lru_head.next;
118                         lru_delete(old);
119             delete [] old->data;
120                         size += old->len;
121                         old->data = 0;
122                         old->len = 0;
123                 }
124
125                 // allocate new space
126                 h->data = (Qfloat *)realloc(h->data,sizeof(Qfloat)*len);
127                 size -= more;
128                 swap(h->len,len);
129         }
130
131         lru_insert(h);
132         *data = h->data;
133         return len;
134 }
135
136 void Cache::swap_index(int i, int j)
137 {
138         if(i==j) return;
139
140         if(head[i].len) lru_delete(&head[i]);
141         if(head[j].len) lru_delete(&head[j]);
142         swap(head[i].data,head[j].data);
143         swap(head[i].len,head[j].len);
144         if(head[i].len) lru_insert(&head[i]);
145         if(head[j].len) lru_insert(&head[j]);
146
147         if(i>j) swap(i,j);
148         for(head_t *h = lru_head.next; h!=&lru_head; h=h->next)
149         {
150                 if(h->len > i)
151                 {
152                         if(h->len > j)
153                                 swap(h->data[i],h->data[j]);
154                         else
155                         {
156                                 // give up
157                                 lru_delete(h);
158                 delete [] h->data;
159                                 size += h->len;
160                                 h->data = 0;
161                                 h->len = 0;
162                         }
163                 }
164         }
165 }
166
167 Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param)
168 :kernel_type(param.kernel_type), degree(param.degree),
169  gamma(param.gamma), coef0(param.coef0), kernel_weight(param.kernel_weight), kernel_norm(param.kernel_norm)
170 {
171         switch(kernel_type)
172         {
173                 case LINEAR:
174                         kernel_function = &Kernel::kernel_linear;
175                         break;
176                 case POLY:
177                         kernel_function = &Kernel::kernel_poly;
178                         break;
179                 case RBF:
180                         kernel_function = &Kernel::kernel_rbf;
181                         break;
182                 case RBFWEIGH:
183                         kernel_function = &Kernel::kernel_rbf_weight;
184                         break;
185                 case RBFWMATRIX:
186                         kernel_function = &Kernel::kernel_rbf_w;
187                         break;
188                 case SIGMOID:
189                         kernel_function = &Kernel::kernel_sigmoid;
190                         break;
191                 case PRECOMPUTED:
192                         kernel_function = &Kernel::kernel_precomputed;
193                         break;
194         }
195
196         clone(x,x_,l);
197
198         dim = param.kernel_dim;
199         if(!dim)
200         {
201                 while(x[0][dim].index != -1) dim++;
202         }
203
204         if(kernel_type == RBF)
205         {
206                 x_square = new double[l];
207                 for(int i=0;i<l;i++)
208                         x_square[i] = dot(x[i],x[i]);
209         }
210         else if(kernel_type == RBFWEIGH)
211         {
212                 x_square = new double[l];
213                 for(int i=0;i<l;i++)
214                         x_square[i] = dot(x[i],x[i], kernel_weight);
215         }
216         else
217                 x_square = 0;
218 }
219
220 Kernel::~Kernel()
221 {
222     delete [] x;
223     delete [] x_square;
224 }
225
226 void Kernel::swap_index(int i, int j) const     // no so const...
227 {
228         swap(x[i],x[j]);
229         if(x_square) swap(x_square[i],x_square[j]);
230 }
231
232 double Kernel::kernel_linear(const int i, const int j) const
233 {
234         if(kernel_norm != 1.) return kernel_norm*dot(x[i],x[j]);
235         return dot(x[i],x[j]);
236 }
237 double Kernel::kernel_poly(const int i, const int j) const
238 {
239         if(kernel_norm != 1.) return kernel_norm*powi(gamma*dot(x[i],x[j])+coef0,degree);
240         return powi(gamma*dot(x[i],x[j])+coef0,degree);
241 }
242 double Kernel::kernel_rbf(const int i, const int j) const
243 {
244         if(kernel_norm != 1.) return kernel_norm*exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
245         return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
246 }
247 double Kernel::kernel_rbf_weight(const int i, const int j) const
248 {
249     double sum = 0;
250     const svm_node *px = x[i];
251     const svm_node *py = x[j];
252     while(px->index != -1 && py->index != -1)
253     {
254         if(px->index == py->index)
255         {
256             sum += (px->value - py->value) * (px->value - py->value) * kernel_weight[px->index-1];
257             ++px;
258             ++py;
259         }
260         else
261         {
262             if(px->index > py->index)
263                 ++py;
264             else
265                 ++px;
266         }
267     }
268     if(kernel_norm != 1.) return kernel_norm*exp(-gamma*sum);
269     return exp(-gamma*sum);
270 }
271 double Kernel::kernel_rbf_w(const int i, const int j) const
272 {
273         if(kernel_norm != 1.) return kernel_norm*exp(-gamma*matrix(x[i], x[j], kernel_weight, dim));
274         return exp(-gamma*matrix(x[i], x[j], kernel_weight, dim));
275 }
276 double Kernel::kernel_sigmoid(const int i, const int j) const
277 {
278         return tanh(gamma*dot(x[i],x[j])+coef0);
279 }
280 double Kernel::kernel_precomputed(const int i, const int j) const
281 {
282         return x[i][(int)(x[j][0].value)].value;
283 }
284
285
286
287 double Kernel::dot(const svm_node *px, const svm_node *py)
288 {
289         double sum = 0;
290         while(px->index != -1 && py->index != -1)
291         {
292                 if(px->index == py->index)
293                 {
294                         sum += px->value * py->value;
295                         ++px;
296                         ++py;
297                 }
298                 else
299                 {
300                         if(px->index > py->index)
301                                 ++py;
302                         else
303                                 ++px;
304                 }                       
305         }
306         return sum;
307 }
308
309 double Kernel::dot(const svm_node *px, const svm_node *py, const double *weight)
310 {
311         double sum = 0;
312         while(px->index != -1 && py->index != -1)
313         {
314                 if(px->index == py->index)
315                 {
316                         sum += px->value * py->value * weight[px->index];
317                         ++px;
318                         ++py;
319                 }
320                 else
321                 {
322                         if(px->index > py->index)
323                                 ++py;
324                         else
325                                 ++px;
326                 }                       
327         }
328         return sum;
329 }
330
331 double Kernel::matrix(const svm_node *px, const svm_node *py, const double *W, int dim)
332 {
333         double sum = 0;
334         double *xW = new double[dim];
335         for (int i=0; i<dim; i++)
336         {
337                 xW[i] = 0;
338                 for (int j=0; j<dim; j++)
339                 {
340                         xW[i] += (px[j].value - py[j].value)*W[j*dim + i];
341                 }
342         }
343         for (int i=0; i<dim; i++)
344         {
345                 sum += (px[i].value - py[i].value)*xW[i];
346         }
347         delete [] xW;
348         return sum;
349 }
350
351
352 double Kernel::k_function(const svm_node *x, const svm_node *y,
353                           const svm_parameter& param)
354 {
355         switch(param.kernel_type)
356         {
357                 case LINEAR:
358                         return dot(x,y);
359                 case POLY:
360                         return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
361                 case RBF:
362                 {
363                         double sum = 0;
364                         while(x->index != -1 && y->index !=-1)
365                         {
366                                 if(x->index == y->index)
367                                 {
368                                         double d = x->value - y->value;
369                                         sum += d*d;
370                                         ++x;
371                                         ++y;
372                                 }
373                                 else
374                                 {
375                                         if(x->index > y->index)
376                                         {       
377                                                 sum += y->value * y->value;
378                                                 ++y;
379                                         }
380                                         else
381                                         {
382                                                 sum += x->value * x->value;
383                                                 ++x;
384                                         }
385                                 }
386                         }
387
388                         while(x->index != -1)
389                         {
390                                 sum += x->value * x->value;
391                                 ++x;
392                         }
393
394                         while(y->index != -1)
395                         {
396                                 sum += y->value * y->value;
397                                 ++y;
398                         }
399                         if(param.normalizeKernel) return param.kernel_norm*exp(-param.gamma*sum);
400                         return exp(-param.gamma*sum);
401                 }
402                 case RBFWEIGH:
403                         {
404                                 double sum = 0;
405                                 while(x->index != -1 && y->index !=-1)
406                                 {
407                                         if(x->index == y->index)
408                                         {
409                                                 double d = x->value - y->value;
410                         sum += d*d*param.kernel_weight[x->index-1];
411                                                 ++x;
412                                                 ++y;
413                                         }
414                                         else
415                                         {
416                                                 if(x->index > y->index)
417                                                 {       
418                             sum += y->value * y->value * param.kernel_weight[y->index-1];
419                                                         ++y;
420                                                 }
421                                                 else
422                                                 {
423                             sum += x->value * x->value * param.kernel_weight[x->index-1];
424                                                         ++x;
425                                                 }
426                                         }
427                                 }
428
429                                 while(x->index != -1)
430                                 {
431                     sum += x->value * x->value * param.kernel_weight[x->index-1];
432                                         ++x;
433                                 }
434
435                                 while(y->index != -1)
436                                 {
437                     sum += y->value * y->value * param.kernel_weight[y->index-1];
438                                         ++y;
439                                 }
440                                 if(param.normalizeKernel) return param.kernel_norm*exp(-param.gamma*sum);
441                                 return exp(-param.gamma*sum);
442                         }
443                 case RBFWMATRIX:
444                         {
445                                 int l = param.kernel_dim;
446                                 if(!l)
447                                 {
448                                         while(x[l].index != -1) l++;
449                                 }
450                                 double sum = matrix(x, y, param.kernel_weight, l);
451                                 if(param.normalizeKernel) return param.kernel_norm*exp(-param.gamma*sum);
452                                 return exp(-param.gamma*sum);
453                         }
454                 case SIGMOID:
455                         return tanh(param.gamma*dot(x,y)+param.coef0);
456                 case PRECOMPUTED:  //x: test (validation), y: SV
457                         return x[(int)(y->value)].value;
458                 default:
459                         return 0;  // Unreachable 
460         }
461 }
462
463 // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
464 // Solves:
465 //
466 //      min 0.5(\alpha^T Q \alpha) + p^T \alpha
467 //
468 //              y^T \alpha = \delta
469 //              y_i = +1 or -1
470 //              0 <= alpha_i <= Cp for y_i = 1
471 //              0 <= alpha_i <= Cn for y_i = -1
472 //
473 // Given:
474 //
475 //      Q, p, y, Cp, Cn, and an initial feasible point \alpha
476 //      l is the size of vectors and matrices
477 //      eps is the stopping tolerance
478 //
479 // solution will be put in \alpha, objective value will be put in obj
480 //
481 class Solver {
482 public:
483     Solver() {}
484     virtual ~Solver() {}
485
486         struct SolutionInfo {
487                 double obj;
488                 double rho;
489                 double upper_bound_p;
490                 double upper_bound_n;
491                 double r;       // for Solver_NU
492         };
493
494         void Solve(int l, const Q_Matrix& Q, const double *p_, const schar *y_,
495                    double *alpha_, double Cp, double Cn, double eps,
496                    SolutionInfo* si, int shrinking);
497 protected:
498         int active_size;
499         schar *y;
500         double *G;              // gradient of objective function
501         enum { LOWER_BOUND, UPPER_BOUND, FREE };
502         char *alpha_status;     // LOWER_BOUND, UPPER_BOUND, FREE
503         double *alpha;
504         const Q_Matrix *Q;
505         const Qfloat *QD;
506         double eps;
507         double Cp,Cn;
508         double *p;
509         int *active_set;
510         double *G_bar;          // gradient, if we treat free variables as 0
511         int l;
512         bool unshrinked;        // XXX
513
514         double get_C(int i)
515         {
516                 return (y[i] > 0)? Cp : Cn;
517         }
518         void update_alpha_status(int i)
519         {
520                 if(alpha[i] >= get_C(i))
521                         alpha_status[i] = UPPER_BOUND;
522                 else if(alpha[i] <= 0)
523                         alpha_status[i] = LOWER_BOUND;
524                 else alpha_status[i] = FREE;
525         }
526         bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
527         bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
528     bool is_free(int i) { return alpha_status[i] == FREE; }
529         void swap_index(int i, int j);
530         void reconstruct_gradient();
531         virtual int select_working_set(int &i, int &j);
532         virtual double calculate_rho();
533         virtual void do_shrinking();
534 private:
535         bool be_shrunken(int i, double Gmax1, double Gmax2);    
536 };
537
538 void Solver::swap_index(int i, int j)
539 {
540         Q->swap_index(i,j);
541         swap(y[i],y[j]);
542         swap(G[i],G[j]);
543         swap(alpha_status[i],alpha_status[j]);
544         swap(alpha[i],alpha[j]);
545         swap(p[i],p[j]);
546         swap(active_set[i],active_set[j]);
547         swap(G_bar[i],G_bar[j]);
548 }
549
550 void Solver::reconstruct_gradient()
551 {
552         // reconstruct inactive elements of G from G_bar and free variables
553
554         if(active_size == l) return;
555
556         int i;
557         for(i=active_size;i<l;i++)
558                 G[i] = G_bar[i] + p[i];
559         
560         for(i=0;i<active_size;i++)
561         if(is_free(i))
562                 {
563                         const Qfloat *Q_i = Q->get_Q(i,l);
564                         double alpha_i = alpha[i];
565                         for(int j=active_size;j<l;j++)
566                                 G[j] += alpha_i * Q_i[j];
567                 }
568 }
569
570 void Solver::Solve(int l, const Q_Matrix& Q, const double *p_, const schar *y_,
571                    double *alpha_, double Cp, double Cn, double eps,
572                    SolutionInfo* si, int shrinking)
573 {
574         this->l = l;
575         this->Q = &Q;
576         QD=Q.get_QD();
577         clone(p, p_,l);
578         clone(y, y_,l);
579         clone(alpha,alpha_,l);
580         this->Cp = Cp;
581         this->Cn = Cn;
582         this->eps = eps;
583         unshrinked = false;
584
585         // initialize alpha_status
586         {
587                 alpha_status = new char[l];
588                 for(int i=0;i<l;i++)
589                         update_alpha_status(i);
590         }
591
592         // initialize active set (for shrinking)
593         {
594                 active_set = new int[l];
595                 for(int i=0;i<l;i++)
596                         active_set[i] = i;
597                 active_size = l;
598         }
599
600         // initialize gradient
601         {
602                 G = new double[l];
603                 G_bar = new double[l];
604                 int i;
605                 for(i=0;i<l;i++)
606                 {
607                         G[i] = p[i];
608                         G_bar[i] = 0;
609                 }
610                 for(i=0;i<l;i++)
611                         if(!is_lower_bound(i))
612                         {
613                                 const Qfloat *Q_i = Q.get_Q(i,l);
614                                 double alpha_i = alpha[i];
615                                 int j;
616                                 for(j=0;j<l;j++)
617                                         G[j] += alpha_i*Q_i[j];
618                                 if(is_upper_bound(i))
619                                         for(j=0;j<l;j++)
620                                                 G_bar[j] += get_C(i) * Q_i[j];
621                         }
622         }
623
624         // optimization step
625
626         int iter = 0;
627         int counter = min(l,1000)+1;
628         while(1)
629         {
630                 // show progress and do shrinking
631
632                 if (iter > 10000)
633                 {
634                         // reconstruct the whole gradient
635                         reconstruct_gradient();
636                         info("O"); info_flush();
637                         break;
638                 }
639
640                 if(--counter == 0)
641                 {
642                         info("."); info_flush();
643                         counter = min(l,1000);
644                         if(shrinking) do_shrinking();
645                 }
646
647                 int i,j;
648                 if(select_working_set(i,j)!=0)
649                 {
650                         info("*"); info_flush();
651                         // reconstruct the whole gradient
652                         reconstruct_gradient();
653                         // reset active set size and check
654                         active_size = l;
655                         if(select_working_set(i,j)!=0)
656                                 break;
657                         else
658                                 counter = 1;    // do shrinking next iteration
659                 }
660                 if(i==-1 || j==-1)
661                 {
662                         reconstruct_gradient();
663                         info("O"); info_flush();
664                         break;
665                 }
666                 
667                 ++iter;
668
669                 // update alpha[i] and alpha[j], handle bounds carefully
670                 const Qfloat *Q_i = Q.get_Q(i,active_size);
671                 const Qfloat *Q_j = Q.get_Q(j,active_size);
672
673                 double C_i = get_C(i);
674                 double C_j = get_C(j);
675
676                 double old_alpha_i = alpha[i];
677                 double old_alpha_j = alpha[j];
678
679                 if(y[i]!=y[j])
680                 {
681                         double quad_coef = Q_i[i]+Q_j[j]+2*Q_i[j];
682                         if (quad_coef <= 0)
683                                 quad_coef = TAU;
684                         double delta = (-G[i]-G[j])/quad_coef;
685                         double diff = alpha[i] - alpha[j];
686                         alpha[i] += delta;
687                         alpha[j] += delta;
688                         
689                         if(diff > 0)
690                         {
691                                 if(alpha[j] < 0)
692                                 {
693                                         alpha[j] = 0;
694                                         alpha[i] = diff;
695                                 }
696                         }
697                         else
698                         {
699                                 if(alpha[i] < 0)
700                                 {
701                                         alpha[i] = 0;
702                                         alpha[j] = -diff;
703                                 }
704                         }
705                         if(diff > C_i - C_j)
706                         {
707                                 if(alpha[i] > C_i)
708                                 {
709                                         alpha[i] = C_i;
710                                         alpha[j] = C_i - diff;
711                                 }
712                         }
713                         else
714                         {
715                                 if(alpha[j] > C_j)
716                                 {
717                                         alpha[j] = C_j;
718                                         alpha[i] = C_j + diff;
719                                 }
720                         }
721                 }
722                 else
723                 {
724                         double quad_coef = Q_i[i]+Q_j[j]-2*Q_i[j];
725                         if (quad_coef <= 0)
726                                 quad_coef = TAU;
727                         double delta = (G[i]-G[j])/quad_coef;
728                         double sum = alpha[i] + alpha[j];
729                         alpha[i] -= delta;
730                         alpha[j] += delta;
731
732                         if(sum > C_i)
733                         {
734                                 if(alpha[i] > C_i)
735                                 {
736                                         alpha[i] = C_i;
737                                         alpha[j] = sum - C_i;
738                                 }
739                         }
740                         else
741                         {
742                                 if(alpha[j] < 0)
743                                 {
744                                         alpha[j] = 0;
745                                         alpha[i] = sum;
746                                 }
747                         }
748                         if(sum > C_j)
749                         {
750                                 if(alpha[j] > C_j)
751                                 {
752                                         alpha[j] = C_j;
753                                         alpha[i] = sum - C_j;
754                                 }
755                         }
756                         else
757                         {
758                                 if(alpha[i] < 0)
759                                 {
760                                         alpha[i] = 0;
761                                         alpha[j] = sum;
762                                 }
763                         }
764                 }
765
766                 // update G
767
768                 double delta_alpha_i = alpha[i] - old_alpha_i;
769                 double delta_alpha_j = alpha[j] - old_alpha_j;
770                 
771                 for(int k=0;k<active_size;k++)
772                 {
773                         G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
774                 }
775
776                 // update alpha_status and G_bar
777                 {
778                         bool ui = is_upper_bound(i);
779                         bool uj = is_upper_bound(j);
780                         update_alpha_status(i);
781                         update_alpha_status(j);
782                         int k;
783                         if(ui != is_upper_bound(i))
784                         {
785                                 Q_i = Q.get_Q(i,l);
786                                 if(ui)
787                                         for(k=0;k<l;k++)
788                                                 G_bar[k] -= C_i * Q_i[k];
789                                 else
790                                         for(k=0;k<l;k++)
791                                                 G_bar[k] += C_i * Q_i[k];
792                         }
793
794                         if(uj != is_upper_bound(j))
795                         {
796                                 Q_j = Q.get_Q(j,l);
797                                 if(uj)
798                                         for(k=0;k<l;k++)
799                                                 G_bar[k] -= C_j * Q_j[k];
800                                 else
801                                         for(k=0;k<l;k++)
802                                                 G_bar[k] += C_j * Q_j[k];
803                         }
804                 }
805         }
806
807         // calculate rho
808
809         si->rho = calculate_rho();
810
811         // calculate objective value
812         {
813                 double v = 0;
814                 int i;
815                 for(i=0;i<l;i++)
816                         v += alpha[i] * (G[i] + p[i]);
817
818                 si->obj = v/2;
819         }
820
821         // put back the solution
822         {
823                 for(int i=0;i<l;i++)
824                         alpha_[active_set[i]] = alpha[i];
825         }
826
827         // juggle everything back
828         /*{
829                 for(int i=0;i<l;i++)
830                         while(active_set[i] != i)
831                                 swap_index(i,active_set[i]);
832                                 // or Q.swap_index(i,active_set[i]);
833         }*/
834
835         si->upper_bound_p = Cp;
836         si->upper_bound_n = Cn;
837
838     info("\noptimization finished, #iter = %d\n",iter);
839
840     delete [] p;
841     delete [] y;
842     delete [] alpha;
843     delete [] alpha_status;
844     delete [] active_set;
845     delete [] G;
846     delete [] G_bar;
847 }
848
849 // return 1 if already optimal, return 0 otherwise
850 int Solver::select_working_set(int &out_i, int &out_j)
851 {
852         // return i,j such that
853         // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
854         // j: minimizes the decrease of obj value
855         //    (if quadratic coefficeint <= 0, replace it with tau)
856         //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
857         
858         double Gmax = -INF;
859         double Gmax2 = -INF;
860         int Gmax_idx = -1;
861         int Gmin_idx = -1;
862         double obj_diff_min = INF;
863
864         for(int t=0;t<active_size;t++)
865                 if(y[t]==+1)    
866                 {
867                         if(!is_upper_bound(t))
868                                 if(-G[t] >= Gmax)
869                                 {
870                                         Gmax = -G[t];
871                                         Gmax_idx = t;
872                                 }
873                 }
874                 else
875                 {
876                         if(!is_lower_bound(t))
877                                 if(G[t] >= Gmax)
878                                 {
879                                         Gmax = G[t];
880                                         Gmax_idx = t;
881                                 }
882                 }
883
884         int i = Gmax_idx;
885         const Qfloat *Q_i = NULL;
886         if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1
887                 Q_i = Q->get_Q(i,active_size);
888
889         for(int j=0;j<active_size;j++)
890         {
891                 if(y[j]==+1)
892                 {
893                         if (!is_lower_bound(j))
894                         {
895                                 double grad_diff=Gmax+G[j];
896                                 if (G[j] >= Gmax2)
897                                         Gmax2 = G[j];
898                                 if (grad_diff > 0)
899                                 {
900                                         double obj_diff; 
901                                         double quad_coef=Q_i[i]+QD[j]-2.0*y[i]*Q_i[j];
902                                         if (quad_coef > 0)
903                                                 obj_diff = -(grad_diff*grad_diff)/quad_coef;
904                                         else
905                                                 obj_diff = -(grad_diff*grad_diff)/TAU;
906
907                                         if (obj_diff <= obj_diff_min)
908                                         {
909                                                 Gmin_idx=j;
910                                                 obj_diff_min = obj_diff;
911                                         }
912                                 }
913                         }
914                 }
915                 else
916                 {
917                         if (!is_upper_bound(j))
918                         {
919                                 double grad_diff= Gmax-G[j];
920                                 if (-G[j] >= Gmax2)
921                                         Gmax2 = -G[j];
922                                 if (grad_diff > 0)
923                                 {
924                                         double obj_diff; 
925                                         double quad_coef=Q_i[i]+QD[j]+2.0*y[i]*Q_i[j];
926                                         if (quad_coef > 0)
927                                                 obj_diff = -(grad_diff*grad_diff)/quad_coef;
928                                         else
929                                                 obj_diff = -(grad_diff*grad_diff)/TAU;
930
931                                         if (obj_diff <= obj_diff_min)
932                                         {
933                                                 Gmin_idx=j;
934                                                 obj_diff_min = obj_diff;
935                                         }
936                                 }
937                         }
938                 }
939         }
940
941         if(Gmax+Gmax2 < eps)
942                 return 1;
943
944         out_i = Gmax_idx;
945         out_j = Gmin_idx;
946         return 0;
947 }
948
949 bool Solver::be_shrunken(int i, double Gmax1, double Gmax2)
950 {
951         if(is_upper_bound(i))
952         {
953                 if(y[i]==+1)
954                         return(-G[i] > Gmax1);
955                 else
956                         return(-G[i] > Gmax2);
957         }
958         else if(is_lower_bound(i))
959         {
960                 if(y[i]==+1)
961                         return(G[i] > Gmax2);
962                 else    
963                         return(G[i] > Gmax1);
964         }
965         else
966                 return(false);
967 }
968
969 void Solver::do_shrinking()
970 {
971         int i;
972         double Gmax1 = -INF;            // max { -y_i * grad(f)_i | i in I_up(\alpha) }
973         double Gmax2 = -INF;            // max { y_i * grad(f)_i | i in I_low(\alpha) }
974
975         // find maximal violating pair first
976         for(i=0;i<active_size;i++)
977         {
978                 if(y[i]==+1)    
979                 {
980                         if(!is_upper_bound(i))  
981                         {
982                                 if(-G[i] >= Gmax1)
983                                         Gmax1 = -G[i];
984                         }
985                         if(!is_lower_bound(i))  
986                         {
987                                 if(G[i] >= Gmax2)
988                                         Gmax2 = G[i];
989                         }
990                 }
991                 else    
992                 {
993                         if(!is_upper_bound(i))  
994                         {
995                                 if(-G[i] >= Gmax2)
996                                         Gmax2 = -G[i];
997                         }
998                         if(!is_lower_bound(i))  
999                         {
1000                                 if(G[i] >= Gmax1)
1001                                         Gmax1 = G[i];
1002                         }
1003                 }
1004         }
1005
1006         // shrink
1007
1008         for(i=0;i<active_size;i++)
1009                 if (be_shrunken(i, Gmax1, Gmax2))
1010                 {
1011                         active_size--;
1012                         while (active_size > i)
1013                         {
1014                                 if (!be_shrunken(active_size, Gmax1, Gmax2))
1015                                 {
1016                                         swap_index(i,active_size);
1017                                         break;
1018                                 }
1019                                 active_size--;
1020                         }
1021                 }
1022
1023         // unshrink, check all variables again before final iterations
1024
1025         if(unshrinked || Gmax1 + Gmax2 > eps*10) return;
1026         
1027         unshrinked = true;
1028         reconstruct_gradient();
1029
1030         for(i=l-1;i>=active_size;i--)
1031                 if (!be_shrunken(i, Gmax1, Gmax2))
1032                 {
1033                         while (active_size < i)
1034                         {
1035                                 if (be_shrunken(active_size, Gmax1, Gmax2))
1036                                 {
1037                                         swap_index(i,active_size);
1038                                         break;
1039                                 }
1040                                 active_size++;
1041                         }
1042                         active_size++;
1043                 }
1044 }
1045
1046 double Solver::calculate_rho()
1047 {
1048         double r;
1049         int nr_free = 0;
1050         double ub = INF, lb = -INF, sum_free = 0;
1051         for(int i=0;i<active_size;i++)
1052         {
1053                 double yG = y[i]*G[i];
1054
1055                 if(is_upper_bound(i))
1056                 {
1057                         if(y[i]==-1)
1058                                 ub = min(ub,yG);
1059                         else
1060                                 lb = max(lb,yG);
1061                 }
1062                 else if(is_lower_bound(i))
1063                 {
1064                         if(y[i]==+1)
1065                                 ub = min(ub,yG);
1066                         else
1067                                 lb = max(lb,yG);
1068                 }
1069                 else
1070                 {
1071                         ++nr_free;
1072                         sum_free += yG;
1073                 }
1074         }
1075
1076         if(nr_free>0)
1077                 r = sum_free/nr_free;
1078         else
1079                 r = (ub+lb)/2;
1080
1081         return r;
1082 }
1083
1084 //
1085 // Solver for nu-svm classification and regression
1086 //
1087 // additional constraint: e^T \alpha = constant
1088 //
1089 class Solver_NU : public Solver
1090 {
1091 public:
1092         Solver_NU() {}
1093         void Solve(int l, const Q_Matrix& Q, const double *p, const schar *y,
1094                    double *alpha, double Cp, double Cn, double eps,
1095                    SolutionInfo* si, int shrinking)
1096         {
1097                 this->si = si;
1098                 Solver::Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking);
1099         }
1100 private:
1101         SolutionInfo *si;
1102         int select_working_set(int &i, int &j);
1103         double calculate_rho();
1104         bool be_shrunken(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4);
1105         void do_shrinking();
1106 };
1107
1108 // return 1 if already optimal, return 0 otherwise
1109 int Solver_NU::select_working_set(int &out_i, int &out_j)
1110 {
1111         // return i,j such that y_i = y_j and
1112         // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
1113         // j: minimizes the decrease of obj value
1114         //    (if quadratic coefficeint <= 0, replace it with tau)
1115         //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
1116
1117         double Gmaxp = -INF;
1118         double Gmaxp2 = -INF;
1119         int Gmaxp_idx = -1;
1120
1121         double Gmaxn = -INF;
1122         double Gmaxn2 = -INF;
1123         int Gmaxn_idx = -1;
1124
1125         int Gmin_idx = -1;
1126         double obj_diff_min = INF;
1127
1128         for(int t=0;t<active_size;t++)
1129                 if(y[t]==+1)
1130                 {
1131                         if(!is_upper_bound(t))
1132                                 if(-G[t] >= Gmaxp)
1133                                 {
1134                                         Gmaxp = -G[t];
1135                                         Gmaxp_idx = t;
1136                                 }
1137                 }
1138                 else
1139                 {
1140                         if(!is_lower_bound(t))
1141                                 if(G[t] >= Gmaxn)
1142                                 {
1143                                         Gmaxn = G[t];
1144                                         Gmaxn_idx = t;
1145                                 }
1146                 }
1147
1148         int ip = Gmaxp_idx;
1149         int in = Gmaxn_idx;
1150         const Qfloat *Q_ip = NULL;
1151         const Qfloat *Q_in = NULL;
1152         if(ip != -1) // NULL Q_ip not accessed: Gmaxp=-INF if ip=-1
1153                 Q_ip = Q->get_Q(ip,active_size);
1154         if(in != -1)
1155                 Q_in = Q->get_Q(in,active_size);
1156
1157         for(int j=0;j<active_size;j++)
1158         {
1159                 if(y[j]==+1)
1160                 {
1161                         if (!is_lower_bound(j)) 
1162                         {
1163                                 double grad_diff=Gmaxp+G[j];
1164                                 if (G[j] >= Gmaxp2)
1165                                         Gmaxp2 = G[j];
1166                                 if (grad_diff > 0)
1167                                 {
1168                                         double obj_diff; 
1169                                         double quad_coef = Q_ip[ip]+QD[j]-2*Q_ip[j];
1170                                         if (quad_coef > 0)
1171                                                 obj_diff = -(grad_diff*grad_diff)/quad_coef;
1172                                         else
1173                                                 obj_diff = -(grad_diff*grad_diff)/TAU;
1174
1175                                         if (obj_diff <= obj_diff_min)
1176                                         {
1177                                                 Gmin_idx=j;
1178                                                 obj_diff_min = obj_diff;
1179                                         }
1180                                 }
1181                         }
1182                 }
1183                 else
1184                 {
1185                         if (!is_upper_bound(j))
1186                         {
1187                                 double grad_diff=Gmaxn-G[j];
1188                                 if (-G[j] >= Gmaxn2)
1189                                         Gmaxn2 = -G[j];
1190                                 if (grad_diff > 0)
1191                                 {
1192                                         double obj_diff; 
1193                                         double quad_coef = Q_in[in]+QD[j]-2*Q_in[j];
1194                                         if (quad_coef > 0)
1195                                                 obj_diff = -(grad_diff*grad_diff)/quad_coef;
1196                                         else
1197                                                 obj_diff = -(grad_diff*grad_diff)/TAU;
1198
1199                                         if (obj_diff <= obj_diff_min)
1200                                         {
1201                                                 Gmin_idx=j;
1202                                                 obj_diff_min = obj_diff;
1203                                         }
1204                                 }
1205                         }
1206                 }
1207         }
1208
1209         if(max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps)
1210                 return 1;
1211
1212         if (y[Gmin_idx] == +1)
1213                 out_i = Gmaxp_idx;
1214         else
1215                 out_i = Gmaxn_idx;
1216         out_j = Gmin_idx;
1217
1218         return 0;
1219 }
1220
1221 bool Solver_NU::be_shrunken(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4)
1222 {
1223         if(is_upper_bound(i))
1224         {
1225                 if(y[i]==+1)
1226                         return(-G[i] > Gmax1);
1227                 else    
1228                         return(-G[i] > Gmax4);
1229         }
1230         else if(is_lower_bound(i))
1231         {
1232                 if(y[i]==+1)
1233                         return(G[i] > Gmax2);
1234                 else    
1235                         return(G[i] > Gmax3);
1236         }
1237         else
1238                 return(false);
1239 }
1240
1241 void Solver_NU::do_shrinking()
1242 {
1243         double Gmax1 = -INF;    // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) }
1244         double Gmax2 = -INF;    // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) }
1245         double Gmax3 = -INF;    // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) }
1246         double Gmax4 = -INF;    // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) }
1247
1248         // find maximal violating pair first
1249         int i;
1250         for(i=0;i<active_size;i++)
1251         {
1252                 if(!is_upper_bound(i))
1253                 {
1254                         if(y[i]==+1)
1255                         {
1256                                 if(-G[i] > Gmax1) Gmax1 = -G[i];
1257                         }
1258                         else    if(-G[i] > Gmax4) Gmax4 = -G[i];
1259                 }
1260                 if(!is_lower_bound(i))
1261                 {
1262                         if(y[i]==+1)
1263                         {       
1264                                 if(G[i] > Gmax2) Gmax2 = G[i];
1265                         }
1266                         else    if(G[i] > Gmax3) Gmax3 = G[i];
1267                 }
1268         }
1269
1270         // shrinking
1271
1272         for(i=0;i<active_size;i++)
1273                 if (be_shrunken(i, Gmax1, Gmax2, Gmax3, Gmax4))
1274                 {
1275                         active_size--;
1276                         while (active_size > i)
1277                         {
1278                                 if (!be_shrunken(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
1279                                 {
1280                                         swap_index(i,active_size);
1281                                         break;
1282                                 }
1283                                 active_size--;
1284                         }
1285                 }
1286
1287         // unshrink, check all variables again before final iterations
1288
1289         if(unshrinked || max(Gmax1+Gmax2,Gmax3+Gmax4) > eps*10) return;
1290         
1291         unshrinked = true;
1292         reconstruct_gradient();
1293
1294         for(i=l-1;i>=active_size;i--)
1295                 if (!be_shrunken(i, Gmax1, Gmax2, Gmax3, Gmax4))
1296                 {
1297                         while (active_size < i)
1298                         {
1299                                 if (be_shrunken(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
1300                                 {
1301                                         swap_index(i,active_size);
1302                                         break;
1303                                 }
1304                                 active_size++;
1305                         }
1306                         active_size++;
1307                 }
1308 }
1309
1310 double Solver_NU::calculate_rho()
1311 {
1312         int nr_free1 = 0,nr_free2 = 0;
1313         double ub1 = INF, ub2 = INF;
1314         double lb1 = -INF, lb2 = -INF;
1315         double sum_free1 = 0, sum_free2 = 0;
1316
1317         for(int i=0;i<active_size;i++)
1318         {
1319                 if(y[i]==+1)
1320                 {
1321                         if(is_upper_bound(i))
1322                                 lb1 = max(lb1,G[i]);
1323                         else if(is_lower_bound(i))
1324                                 ub1 = min(ub1,G[i]);
1325                         else
1326                         {
1327                                 ++nr_free1;
1328                                 sum_free1 += G[i];
1329                         }
1330                 }
1331                 else
1332                 {
1333                         if(is_upper_bound(i))
1334                                 lb2 = max(lb2,G[i]);
1335                         else if(is_lower_bound(i))
1336                                 ub2 = min(ub2,G[i]);
1337                         else
1338                         {
1339                                 ++nr_free2;
1340                                 sum_free2 += G[i];
1341                         }
1342                 }
1343         }
1344
1345         double r1,r2;
1346         if(nr_free1 > 0)
1347                 r1 = sum_free1/nr_free1;
1348         else
1349                 r1 = (ub1+lb1)/2;
1350         
1351         if(nr_free2 > 0)
1352                 r2 = sum_free2/nr_free2;
1353         else
1354                 r2 = (ub2+lb2)/2;
1355         
1356         si->r = (r1+r2)/2;
1357         return (r1-r2)/2;
1358 }
1359
1360 //
1361 // Q matrices for various formulations
1362 //
1363 class SVC_Q: public Kernel
1364
1365 public:
1366         SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_)
1367         :Kernel(prob.l, prob.x, param)
1368         {
1369                 clone(y,y_,prob.l);
1370                 cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1371                 QD = new Qfloat[prob.l];
1372                 for(int i=0;i<prob.l;i++)
1373                         QD[i]= (Qfloat)(this->*kernel_function)(i,i);
1374         }
1375         
1376         Qfloat *get_Q(int i, int len) const
1377         {
1378                 Qfloat *data;
1379                 int start;
1380                 if((start = cache->get_data(i,&data,len)) < len)
1381                 {
1382                         for(int j=start;j<len;j++)
1383                                 data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j));
1384                 }
1385                 return data;
1386         }
1387
1388         Qfloat *get_QD() const
1389         {
1390                 return QD;
1391         }
1392
1393         void swap_index(int i, int j) const
1394         {
1395                 cache->swap_index(i,j);
1396                 Kernel::swap_index(i,j);
1397                 swap(y[i],y[j]);
1398                 swap(QD[i],QD[j]);
1399         }
1400
1401         ~SVC_Q()
1402         {
1403         delete [] y;
1404                 delete cache;
1405         delete [] QD;
1406         }
1407 private:
1408         schar *y;
1409         Cache *cache;
1410         Qfloat *QD;
1411 };
1412
1413 class ONE_CLASS_Q: public Kernel
1414 {
1415 public:
1416         ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param)
1417         :Kernel(prob.l, prob.x, param)
1418         {
1419                 cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1420                 QD = new Qfloat[prob.l];
1421                 for(int i=0;i<prob.l;i++)
1422                         QD[i]= (Qfloat)(this->*kernel_function)(i,i);
1423         }
1424         
1425         Qfloat *get_Q(int i, int len) const
1426         {
1427                 Qfloat *data;
1428                 int start;
1429                 if((start = cache->get_data(i,&data,len)) < len)
1430                 {
1431                         for(int j=start;j<len;j++)
1432                                 data[j] = (Qfloat)(this->*kernel_function)(i,j);
1433                 }
1434                 return data;
1435         }
1436
1437         Qfloat *get_QD() const
1438         {
1439                 return QD;
1440         }
1441
1442         void swap_index(int i, int j) const
1443         {
1444                 cache->swap_index(i,j);
1445                 Kernel::swap_index(i,j);
1446                 swap(QD[i],QD[j]);
1447         }
1448
1449         ~ONE_CLASS_Q()
1450         {
1451                 delete cache;
1452         delete [] QD;
1453         }
1454 private:
1455         Cache *cache;
1456         Qfloat *QD;
1457 };
1458
1459 class SVR_Q: public Kernel
1460
1461 public:
1462         SVR_Q(const svm_problem& prob, const svm_parameter& param)
1463         :Kernel(prob.l, prob.x, param)
1464         {
1465                 l = prob.l;
1466                 cache = new Cache(l,(long int)(param.cache_size*(1<<20)));
1467                 QD = new Qfloat[2*l];
1468                 sign = new schar[2*l];
1469                 index = new int[2*l];
1470                 for(int k=0;k<l;k++)
1471                 {
1472                         sign[k] = 1;
1473                         sign[k+l] = -1;
1474                         index[k] = k;
1475                         index[k+l] = k;
1476                         QD[k]= (Qfloat)(this->*kernel_function)(k,k);
1477                         QD[k+l]=QD[k];
1478                 }
1479                 buffer[0] = new Qfloat[2*l];
1480                 buffer[1] = new Qfloat[2*l];
1481                 next_buffer = 0;
1482         }
1483
1484         void swap_index(int i, int j) const
1485         {
1486                 swap(sign[i],sign[j]);
1487                 swap(index[i],index[j]);
1488                 swap(QD[i],QD[j]);
1489         }
1490         
1491         Qfloat *get_Q(int i, int len) const
1492         {
1493                 Qfloat *data;
1494                 int real_i = index[i];
1495                 if(cache->get_data(real_i,&data,l) < l)
1496                 {
1497                         for(int j=0;j<l;j++)
1498                                 data[j] = (Qfloat)(this->*kernel_function)(real_i,j);
1499                 }
1500
1501                 // reorder and copy
1502                 Qfloat *buf = buffer[next_buffer];
1503                 next_buffer = 1 - next_buffer;
1504                 schar si = sign[i];
1505                 for(int j=0;j<len;j++)
1506                         buf[j] = si * sign[j] * data[index[j]];
1507                 return buf;
1508         }
1509
1510         Qfloat *get_QD() const
1511         {
1512                 return QD;
1513         }
1514
1515         ~SVR_Q()
1516         {
1517                 delete cache;
1518         delete [] sign;
1519         delete [] index;
1520         delete [] buffer[0];
1521         delete [] buffer[1];
1522         delete [] QD;
1523         }
1524 private:
1525         int l;
1526         Cache *cache;
1527         schar *sign;
1528         int *index;
1529         mutable int next_buffer;
1530         Qfloat *buffer[2];
1531         Qfloat *QD;
1532 };
1533
1534 svm_parameter& svm_parameter::operator= (const svm_parameter &param) {
1535     if (this == &param) return *this;
1536     svm_type = param.svm_type;
1537     kernel_type = param.kernel_type;
1538     degree = param.degree;
1539     gamma = param.gamma;
1540     coef0 = param.coef0;
1541     kernel_dim = param.kernel_dim;
1542     if(kernel_dim)
1543     {
1544         if(kernel_weight) delete [] kernel_weight;
1545         kernel_weight = new double[kernel_dim];
1546         memcpy(kernel_weight, param.kernel_weight, kernel_dim*sizeof(double));
1547     }
1548     normalizeKernel = param.normalizeKernel;
1549     kernel_norm = param.kernel_norm;
1550     cache_size = param.cache_size;
1551     eps = param.eps;
1552     C = param.C;
1553     nr_weight = param.nr_weight;
1554     if(nr_weight)
1555     {
1556         if(weight) delete [] weight;
1557         if(weight_label) delete [] weight_label;
1558         weight = new double[nr_weight];
1559         weight_label = new int[nr_weight];
1560     }
1561     nu = param.nu;
1562     p = param.p;
1563     shrinking = param.shrinking;
1564     probability = param.probability;
1565     return *this;
1566 }
1567
1568 //
1569 // construct and solve various formulations
1570 //
1571 static void solve_c_svc(
1572         const svm_problem *prob, const svm_parameter* param,
1573         double *alpha, Solver::SolutionInfo* si, double Cp, double Cn)
1574 {
1575         int l = prob->l;
1576         double *minus_ones = new double[l];
1577         schar *y = new schar[l];
1578
1579         int i;
1580
1581         for(i=0;i<l;i++)
1582         {
1583                 alpha[i] = 0;
1584                 minus_ones[i] = -1;
1585                 if(prob->y[i] > 0) y[i] = +1; else y[i]=-1;
1586         }
1587
1588         Solver s;
1589         s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y,
1590                 alpha, Cp, Cn, param->eps, si, param->shrinking);
1591         double sum_alpha=0;
1592         for(i=0;i<l;i++)
1593                 sum_alpha += alpha[i];
1594
1595         if (Cp==Cn)
1596                 info("nu = %f\n", sum_alpha/(Cp*prob->l));
1597
1598         for(i=0;i<l;i++)
1599                 alpha[i] *= y[i];
1600
1601     delete [] minus_ones;
1602     delete [] y;
1603 }
1604
1605 static void solve_nu_svc(
1606         const svm_problem *prob, const svm_parameter *param,
1607         double *alpha, Solver::SolutionInfo* si)
1608 {
1609         int i;
1610         int l = prob->l;
1611         double nu = param->nu;
1612
1613         schar *y = new schar[l];
1614
1615         for(i=0;i<l;i++)
1616                 if(prob->y[i]>0)
1617                         y[i] = +1;
1618                 else
1619                         y[i] = -1;
1620
1621         double sum_pos = nu*l/2;
1622         double sum_neg = nu*l/2;
1623
1624         for(i=0;i<l;i++)
1625                 if(y[i] == +1)
1626                 {
1627                         alpha[i] = min(1.0,sum_pos);
1628                         sum_pos -= alpha[i];
1629                 }
1630                 else
1631                 {
1632                         alpha[i] = min(1.0,sum_neg);
1633                         sum_neg -= alpha[i];
1634                 }
1635
1636         double *zeros = new double[l];
1637
1638         for(i=0;i<l;i++)
1639                 zeros[i] = 0;
1640
1641         Solver_NU s;
1642         s.Solve(l, SVC_Q(*prob,*param,y), zeros, y,
1643                 alpha, 1.0, 1.0, param->eps, si,  param->shrinking);
1644         double r = si->r;
1645
1646         info("C = %f\n",1/r);
1647
1648         for(i=0;i<l;i++)
1649                 alpha[i] *= y[i]/r;
1650
1651         si->rho /= r;
1652         si->obj /= (r*r);
1653         si->upper_bound_p = 1/r;
1654         si->upper_bound_n = 1/r;
1655
1656     delete [] y;
1657     delete [] zeros;
1658 }
1659
1660 static void solve_one_class(
1661         const svm_problem *prob, const svm_parameter *param,
1662         double *alpha, Solver::SolutionInfo* si)
1663 {
1664         int l = prob->l;
1665         double *zeros = new double[l];
1666         schar *ones = new schar[l];
1667         int i;
1668
1669         int n = (int)(param->nu*prob->l);       // # of alpha's at upper bound
1670
1671         for(i=0;i<n;i++)
1672                 alpha[i] = 1;
1673         if(n<prob->l)
1674                 alpha[n] = param->nu * prob->l - n;
1675         for(i=n+1;i<l;i++)
1676                 alpha[i] = 0;
1677
1678         for(i=0;i<l;i++)
1679         {
1680                 zeros[i] = 0;
1681                 ones[i] = 1;
1682         }
1683
1684         Solver s;
1685         s.Solve(l, ONE_CLASS_Q(*prob,*param), zeros, ones,
1686                 alpha, 1.0, 1.0, param->eps, si, param->shrinking);
1687
1688     delete [] zeros;
1689     delete [] ones;
1690 }
1691
1692 static void solve_epsilon_svr(
1693         const svm_problem *prob, const svm_parameter *param,
1694         double *alpha, Solver::SolutionInfo* si)
1695 {
1696         int l = prob->l;
1697         double *alpha2 = new double[2*l];
1698         double *linear_term = new double[2*l];
1699         schar *y = new schar[2*l];
1700         int i;
1701
1702         for(i=0;i<l;i++)
1703         {
1704                 alpha2[i] = 0;
1705                 linear_term[i] = param->p - prob->y[i];
1706                 y[i] = 1;
1707
1708                 alpha2[i+l] = 0;
1709                 linear_term[i+l] = param->p + prob->y[i];
1710                 y[i+l] = -1;
1711         }
1712
1713         Solver s;
1714         s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
1715                 alpha2, param->C, param->C, param->eps, si, param->shrinking);
1716
1717         double sum_alpha = 0;
1718         for(i=0;i<l;i++)
1719         {
1720                 alpha[i] = alpha2[i] - alpha2[i+l];
1721                 sum_alpha += fabs(alpha[i]);
1722         }
1723         info("nu = %f\n",sum_alpha/(param->C*l));
1724
1725     delete [] alpha2;
1726     delete [] linear_term;
1727     delete [] y;
1728 }
1729
1730 static void solve_nu_svr(
1731         const svm_problem *prob, const svm_parameter *param,
1732         double *alpha, Solver::SolutionInfo* si)
1733 {
1734         int l = prob->l;
1735         double C = param->C;
1736         double *alpha2 = new double[2*l];
1737         double *linear_term = new double[2*l];
1738         schar *y = new schar[2*l];
1739         int i;
1740
1741         double sum = C * param->nu * l / 2;
1742         for(i=0;i<l;i++)
1743         {
1744                 alpha2[i] = alpha2[i+l] = min(sum,C);
1745                 sum -= alpha2[i];
1746
1747                 linear_term[i] = - prob->y[i];
1748                 y[i] = 1;
1749
1750                 linear_term[i+l] = prob->y[i];
1751                 y[i+l] = -1;
1752         }
1753
1754         Solver_NU s;
1755         s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
1756                 alpha2, C, C, param->eps, si, param->shrinking);
1757
1758         info("epsilon = %f\n",-si->r);
1759
1760         for(i=0;i<l;i++)
1761                 alpha[i] = alpha2[i] - alpha2[i+l];
1762
1763     delete [] alpha2;
1764     delete [] linear_term;
1765     delete [] y;
1766 }
1767
1768 //
1769 // decision_function
1770 //
1771 struct decision_function
1772 {
1773         double *alpha;
1774         double rho;
1775         double eps;
1776 };
1777
1778 decision_function svm_train_one(
1779         const svm_problem *prob, const svm_parameter *param,
1780         double Cp, double Cn)
1781 {
1782     double *alpha = new double[prob->l];
1783         Solver::SolutionInfo si;
1784         switch(param->svm_type)
1785         {
1786                 case C_SVC:
1787                         solve_c_svc(prob,param,alpha,&si,Cp,Cn);
1788                         break;
1789                 case NU_SVC:
1790                         solve_nu_svc(prob,param,alpha,&si);
1791                         break;
1792                 case ONE_CLASS:
1793                         solve_one_class(prob,param,alpha,&si);
1794                         break;
1795                 case EPSILON_SVR:
1796                         solve_epsilon_svr(prob,param,alpha,&si);
1797                         break;
1798                 case NU_SVR:
1799                         solve_nu_svr(prob,param,alpha,&si);
1800                         break;
1801         }
1802
1803         info("obj = %f, rho = %f\n",si.obj,si.rho);
1804
1805         // output SVs
1806
1807         int nSV = 0;
1808         int nBSV = 0;
1809         for(int i=0;i<prob->l;i++)
1810         {
1811                 if(fabs(alpha[i]) > 0)
1812                 {
1813                         ++nSV;
1814                         if(prob->y[i] > 0)
1815                         {
1816                                 if(fabs(alpha[i]) >= si.upper_bound_p)
1817                                         ++nBSV;
1818                         }
1819                         else
1820                         {
1821                                 if(fabs(alpha[i]) >= si.upper_bound_n)
1822                                         ++nBSV;
1823                         }
1824                 }
1825         }
1826
1827         info("nSV = %d, nBSV = %d\n",nSV,nBSV);
1828
1829         decision_function f;
1830         f.alpha = alpha;
1831         f.rho = si.rho;
1832         f.eps = si.r;
1833         return f;
1834 }
1835
1836
1837 // Platt's binary SVM Probablistic Output: an improvement from Lin et al.
1838 void sigmoid_train(
1839         int l, const double *dec_values, const double *labels, 
1840         double& A, double& B)
1841 {
1842         double prior1=0, prior0 = 0;
1843         int i;
1844
1845         for (i=0;i<l;i++)
1846                 if (labels[i] > 0) prior1+=1;
1847                 else prior0+=1;
1848         
1849         int max_iter=100;       // Maximal number of iterations
1850         double min_step=1e-10;  // Minimal step taken in line search
1851         double sigma=1e-12;     // For numerically strict PD of Hessian
1852         double eps=1e-5;
1853         double hiTarget=(prior1+1.0)/(prior1+2.0);
1854         double loTarget=1/(prior0+2.0);
1855     double *t=new double[l];
1856         double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
1857         double newA,newB,newf,d1,d2;
1858         int iter; 
1859         
1860         // Initial Point and Initial Fun Value
1861         A=0.0; B=log((prior0+1.0)/(prior1+1.0));
1862         double fval = 0.0;
1863
1864         for (i=0;i<l;i++)
1865         {
1866                 if (labels[i]>0) t[i]=hiTarget;
1867                 else t[i]=loTarget;
1868                 fApB = dec_values[i]*A+B;
1869                 if (fApB>=0)
1870                         fval += t[i]*fApB + log(1+exp(-fApB));
1871                 else
1872                         fval += (t[i] - 1)*fApB +log(1+exp(fApB));
1873         }
1874         for (iter=0;iter<max_iter;iter++)
1875         {
1876                 // Update Gradient and Hessian (use H' = H + sigma I)
1877                 h11=sigma; // numerically ensures strict PD
1878                 h22=sigma;
1879                 h21=0.0;g1=0.0;g2=0.0;
1880                 for (i=0;i<l;i++)
1881                 {
1882                         fApB = dec_values[i]*A+B;
1883                         if (fApB >= 0)
1884                         {
1885                                 p=exp(-fApB)/(1.0+exp(-fApB));
1886                                 q=1.0/(1.0+exp(-fApB));
1887                         }
1888                         else
1889                         {
1890                                 p=1.0/(1.0+exp(fApB));
1891                                 q=exp(fApB)/(1.0+exp(fApB));
1892                         }
1893                         d2=p*q;
1894                         h11+=dec_values[i]*dec_values[i]*d2;
1895                         h22+=d2;
1896                         h21+=dec_values[i]*d2;
1897                         d1=t[i]-p;
1898                         g1+=dec_values[i]*d1;
1899                         g2+=d1;
1900                 }
1901
1902                 // Stopping Criteria
1903                 if (fabs(g1)<eps && fabs(g2)<eps)
1904                         break;
1905
1906                 // Finding Newton direction: -inv(H') * g
1907                 det=h11*h22-h21*h21;
1908                 dA=-(h22*g1 - h21 * g2) / det;
1909                 dB=-(-h21*g1+ h11 * g2) / det;
1910                 gd=g1*dA+g2*dB;
1911
1912
1913                 stepsize = 1;           // Line Search
1914                 while (stepsize >= min_step)
1915                 {
1916                         newA = A + stepsize * dA;
1917                         newB = B + stepsize * dB;
1918
1919                         // New function value
1920                         newf = 0.0;
1921                         for (i=0;i<l;i++)
1922                         {
1923                                 fApB = dec_values[i]*newA+newB;
1924                                 if (fApB >= 0)
1925                                         newf += t[i]*fApB + log(1+exp(-fApB));
1926                                 else
1927                                         newf += (t[i] - 1)*fApB +log(1+exp(fApB));
1928                         }
1929                         // Check sufficient decrease
1930                         if (newf<fval+0.0001*stepsize*gd)
1931                         {
1932                                 A=newA;B=newB;fval=newf;
1933                                 break;
1934                         }
1935                         else
1936                                 stepsize = stepsize / 2.0;
1937                 }
1938
1939                 if (stepsize < min_step)
1940                 {
1941                         info("Line search fails in two-class probability estimates\n");
1942                         break;
1943                 }
1944         }
1945
1946         if (iter>=max_iter)
1947                 info("Reaching maximal iterations in two-class probability estimates\n");
1948     delete [] t;
1949 }
1950
1951 double sigmoid_predict(double decision_value, double A, double B)
1952 {
1953         double fApB = decision_value*A+B;
1954         if (fApB >= 0)
1955                 return exp(-fApB)/(1.0+exp(-fApB));
1956         else
1957                 return 1.0/(1+exp(fApB)) ;
1958 }
1959
1960 // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
1961 void multiclass_probability(int k, double **r, double *p)
1962 {
1963         int t,j;
1964         int iter = 0, max_iter=max(100,k);
1965     double **Q=new double*[k];
1966     double *Qp=new double[k];
1967         double pQp, eps=0.005/k;
1968         
1969         for (t=0;t<k;t++)
1970         {
1971                 p[t]=1.0/k;  // Valid if k = 1
1972         Q[t]=new double[k];
1973                 Q[t][t]=0;
1974                 for (j=0;j<t;j++)
1975                 {
1976                         Q[t][t]+=r[j][t]*r[j][t];
1977                         Q[t][j]=Q[j][t];
1978                 }
1979                 for (j=t+1;j<k;j++)
1980                 {
1981                         Q[t][t]+=r[j][t]*r[j][t];
1982                         Q[t][j]=-r[j][t]*r[t][j];
1983                 }
1984         }
1985         for (iter=0;iter<max_iter;iter++)
1986         {
1987                 // stopping condition, recalculate QP,pQP for numerical accuracy
1988                 pQp=0;
1989                 for (t=0;t<k;t++)
1990                 {
1991                         Qp[t]=0;
1992                         for (j=0;j<k;j++)
1993                                 Qp[t]+=Q[t][j]*p[j];
1994                         pQp+=p[t]*Qp[t];
1995                 }
1996                 double max_error=0;
1997                 for (t=0;t<k;t++)
1998                 {
1999                         double error=fabs(Qp[t]-pQp);
2000                         if (error>max_error)
2001                                 max_error=error;
2002                 }
2003                 if (max_error<eps) break;
2004                 
2005                 for (t=0;t<k;t++)
2006                 {
2007                         double diff=(-Qp[t]+pQp)/Q[t][t];
2008                         p[t]+=diff;
2009                         pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff);
2010                         for (j=0;j<k;j++)
2011                         {
2012                                 Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff);
2013                                 p[j]/=(1+diff);
2014                         }
2015                 }
2016         }
2017         if (iter>=max_iter)
2018                 info("Exceeds max_iter in multiclass_prob\n");
2019     for(t=0;t<k;t++) delete [] Q[t];
2020     delete [] Q;
2021     delete [] Qp;
2022 }
2023
2024 // Cross-validation decision values for probability estimates
2025 void svm_binary_svc_probability(
2026         const svm_problem *prob, const svm_parameter *param,
2027         double Cp, double Cn, double& probA, double& probB)
2028 {
2029         int i;
2030         int nr_fold = 5;
2031     int *perm = new int[prob->l];
2032     double *dec_values = new double[prob->l];
2033
2034         // random shuffle
2035         for(i=0;i<prob->l;i++) perm[i]=i;
2036         for(i=0;i<prob->l;i++)
2037         {
2038                 int j = i+rand()%(prob->l-i);
2039                 swap(perm[i],perm[j]);
2040         }
2041         for(i=0;i<nr_fold;i++)
2042         {
2043                 int begin = i*prob->l/nr_fold;
2044                 int end = (i+1)*prob->l/nr_fold;
2045                 int j,k;
2046                 struct svm_problem subprob;
2047
2048                 subprob.l = prob->l-(end-begin);
2049         subprob.x = new struct svm_node*[subprob.l];
2050         subprob.y = new double[subprob.l];
2051                         
2052                 k=0;
2053                 for(j=0;j<begin;j++)
2054                 {
2055                         subprob.x[k] = prob->x[perm[j]];
2056                         subprob.y[k] = prob->y[perm[j]];
2057                         ++k;
2058                 }
2059                 for(j=end;j<prob->l;j++)
2060                 {
2061                         subprob.x[k] = prob->x[perm[j]];
2062                         subprob.y[k] = prob->y[perm[j]];
2063                         ++k;
2064                 }
2065                 int p_count=0,n_count=0;
2066                 for(j=0;j<k;j++)
2067                         if(subprob.y[j]>0)
2068                                 p_count++;
2069                         else
2070                                 n_count++;
2071
2072                 if(p_count==0 && n_count==0)
2073                         for(j=begin;j<end;j++)
2074                                 dec_values[perm[j]] = 0;
2075                 else if(p_count > 0 && n_count == 0)
2076                         for(j=begin;j<end;j++)
2077                                 dec_values[perm[j]] = 1;
2078                 else if(p_count == 0 && n_count > 0)
2079                         for(j=begin;j<end;j++)
2080                                 dec_values[perm[j]] = -1;
2081                 else
2082                 {
2083                         svm_parameter subparam = *param;
2084                         subparam.probability=0;
2085                         subparam.C=1.0;
2086                         subparam.nr_weight=2;
2087             subparam.weight_label = new int[2];
2088             subparam.weight = new double[2];
2089                         subparam.weight_label[0]=+1;
2090                         subparam.weight_label[1]=-1;
2091                         subparam.weight[0]=Cp;
2092                         subparam.weight[1]=Cn;
2093                         struct svm_model *submodel = svm_train(&subprob,&subparam);
2094                         for(j=begin;j<end;j++)
2095                         {
2096                                 svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]])); 
2097                                 // ensure +1 -1 order; reason not using CV subroutine
2098                                 dec_values[perm[j]] *= submodel->label[0];
2099                         }               
2100                         svm_destroy_model(submodel);
2101                         svm_destroy_param(&subparam);
2102                 }
2103         delete [] subprob.x;
2104         delete [] subprob.y;
2105         }               
2106         sigmoid_train(prob->l,dec_values,prob->y,probA,probB);
2107     delete [] dec_values;
2108     delete [] perm;
2109 }
2110
2111 // Return parameter of a Laplace distribution 
2112 double svm_svr_probability(
2113         const svm_problem *prob, const svm_parameter *param)
2114 {
2115         int i;
2116         int nr_fold = 5;
2117     double *ymv = new double[prob->l];
2118         double mae = 0;
2119
2120         svm_parameter newparam = *param;
2121         newparam.probability = 0;
2122         svm_cross_validation(prob,&newparam,nr_fold,ymv);
2123         for(i=0;i<prob->l;i++)
2124         {
2125                 ymv[i]=prob->y[i]-ymv[i];
2126                 mae += fabs(ymv[i]);
2127         }               
2128         mae /= prob->l;
2129         double std=sqrt(2*mae*mae);
2130         int count=0;
2131         mae=0;
2132         for(i=0;i<prob->l;i++)
2133                 if (fabs(ymv[i]) > 5*std) 
2134                         count=count+1;
2135                 else 
2136                         mae+=fabs(ymv[i]);
2137         mae /= (prob->l-count);
2138         info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae);
2139     delete [] ymv;
2140         return mae;
2141 }
2142
2143
2144 // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
2145 // perm, length l, must be allocated before calling this subroutine
2146 void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)
2147 {
2148         int l = prob->l;
2149         int max_nr_class = 16;
2150         int nr_class = 0;
2151     int *label = new int[max_nr_class];
2152     int *count = new int[max_nr_class];
2153     int *data_label = new int[l];
2154         int i;
2155
2156         for(i=0;i<l;i++)
2157         {
2158                 int this_label = (int)prob->y[i];
2159                 int j;
2160                 for(j=0;j<nr_class;j++)
2161                 {
2162                         if(this_label == label[j])
2163                         {
2164                                 ++count[j];
2165                                 break;
2166                         }
2167                 }
2168                 data_label[i] = j;
2169                 if(j == nr_class)
2170                 {
2171                         if(nr_class == max_nr_class)
2172                         {
2173                                 max_nr_class *= 2;
2174                                 label = (int *)realloc(label,max_nr_class*sizeof(int));
2175                                 count = (int *)realloc(count,max_nr_class*sizeof(int));
2176                         }
2177                         label[nr_class] = this_label;
2178                         count[nr_class] = 1;
2179                         ++nr_class;
2180                 }
2181         }
2182
2183     int *start = new int[nr_class];
2184         start[0] = 0;
2185         for(i=1;i<nr_class;i++)
2186                 start[i] = start[i-1]+count[i-1];
2187         for(i=0;i<l;i++)
2188         {
2189                 perm[start[data_label[i]]] = i;
2190                 ++start[data_label[i]];
2191         }
2192         start[0] = 0;
2193         for(i=1;i<nr_class;i++)
2194                 start[i] = start[i-1]+count[i-1];
2195
2196         *nr_class_ret = nr_class;
2197         *label_ret = label;
2198         *start_ret = start;
2199         *count_ret = count;
2200     delete [] data_label;
2201 }
2202
2203 //
2204 // Interface functions
2205 //
2206 svm_model *svm_train(const svm_problem *prob, const svm_parameter *param)
2207 {
2208     svm_model *model = new svm_model;
2209         model->param = *param;
2210         model->free_sv = 0;     // XXX
2211
2212         if(param->svm_type == ONE_CLASS ||
2213            param->svm_type == EPSILON_SVR ||
2214            param->svm_type == NU_SVR)
2215         {
2216                 // regression or one-class-svm
2217                 model->nr_class = 2;
2218                 model->label = NULL;
2219                 model->nSV = NULL;
2220                 model->probA = NULL; model->probB = NULL;
2221         model->sv_coef = new double*[1];
2222
2223                 if(param->probability && 
2224                    (param->svm_type == EPSILON_SVR ||
2225                     param->svm_type == NU_SVR))
2226                 {
2227             model->probA = new double[1];
2228                         model->probA[0] = svm_svr_probability(prob,param);
2229                 }
2230                 decision_function f = svm_train_one(prob,param,0,0);
2231         model->rho = new double[1];
2232                 model->rho[0] = f.rho;
2233         model->eps = new double[1];
2234                 model->eps[0] = f.eps;
2235
2236                 int nSV = 0;
2237                 int i;
2238                 for(i=0;i<prob->l;i++)
2239                         if(fabs(f.alpha[i]) > 0) ++nSV;
2240                 model->l = nSV;
2241         model->SV = new svm_node*[nSV];
2242         model->sv_coef[0] = new double[nSV];
2243                 int j = 0;
2244                 for(i=0;i<prob->l;i++)
2245                         if(fabs(f.alpha[i]) > 0)
2246                         {
2247                                 model->SV[j] = prob->x[i];
2248                                 model->sv_coef[0][j] = f.alpha[i];
2249                                 ++j;
2250                         }               
2251
2252         delete [] f.alpha;
2253         }
2254         else
2255         {
2256                 // classification
2257                 int l = prob->l;
2258                 int nr_class;
2259                 int *label = NULL;
2260                 int *start = NULL;
2261                 int *count = NULL;
2262         int *perm = new int[l];
2263
2264                 // group training data of the same class
2265                 svm_group_classes(prob,&nr_class,&label,&start,&count,perm);            
2266         svm_node **x = new svm_node*[l];
2267                 int i;
2268                 for(i=0;i<l;i++)
2269                         x[i] = prob->x[perm[i]];
2270
2271                 // calculate weighted C
2272
2273         double *weighted_C = new double[nr_class];
2274                 for(i=0;i<nr_class;i++)
2275                         weighted_C[i] = param->C;
2276                 for(i=0;i<param->nr_weight;i++)
2277                 {       
2278                         int j;
2279                         for(j=0;j<nr_class;j++)
2280                                 if(param->weight_label[i] == label[j])
2281                                         break;
2282                         if(j == nr_class)
2283                                 fprintf(stderr,"warning: class label %d specified in weight is not found\n", param->weight_label[i]);
2284                         else
2285                                 weighted_C[j] *= param->weight[i];
2286                 }
2287
2288                 // train k*(k-1)/2 models
2289                 
2290         bool *nonzero = new bool[l];
2291                 for(i=0;i<l;i++)
2292                         nonzero[i] = false;
2293         decision_function *f = new decision_function[nr_class*(nr_class-1)/2];
2294
2295                 double *probA=NULL,*probB=NULL;
2296                 if (param->probability)
2297                 {
2298             probA = new double[nr_class*(nr_class-1)/2];
2299             probB = new double[nr_class*(nr_class-1)/2];
2300                 }
2301
2302                 int p = 0;
2303                 for(i=0;i<nr_class;i++)
2304                         for(int j=i+1;j<nr_class;j++)
2305                         {
2306                                 svm_problem sub_prob;
2307                                 int si = start[i], sj = start[j];
2308                                 int ci = count[i], cj = count[j];
2309                                 sub_prob.l = ci+cj;
2310                 sub_prob.x = new svm_node*[sub_prob.l];
2311                 sub_prob.y = new double[sub_prob.l];
2312                                 int k;
2313                                 for(k=0;k<ci;k++)
2314                                 {
2315                                         sub_prob.x[k] = x[si+k];
2316                                         sub_prob.y[k] = +1;
2317                                 }
2318                                 for(k=0;k<cj;k++)
2319                                 {
2320                                         sub_prob.x[ci+k] = x[sj+k];
2321                                         sub_prob.y[ci+k] = -1;
2322                                 }
2323
2324                                 if(param->probability)
2325                                         svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]);
2326                                 f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]);
2327                                 for(k=0;k<ci;k++)
2328                                         if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0)
2329                                                 nonzero[si+k] = true;
2330                                 for(k=0;k<cj;k++)
2331                                         if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0)
2332                                                 nonzero[sj+k] = true;
2333                 delete [] sub_prob.x;
2334                 delete [] sub_prob.y;
2335                                 ++p;
2336                         }
2337
2338                 // build output
2339
2340                 model->nr_class = nr_class;
2341                 
2342         model->label = new int[nr_class];
2343                 for(i=0;i<nr_class;i++)
2344                         model->label[i] = label[i];
2345                 
2346         model->rho = new double[nr_class*(nr_class-1)/2];
2347                 for(i=0;i<nr_class*(nr_class-1)/2;i++)
2348                         model->rho[i] = f[i].rho;
2349
2350                 if(param->probability)
2351                 {
2352             model->probA = new double [nr_class*(nr_class-1)/2];
2353             model->probB = new double [nr_class*(nr_class-1)/2];
2354                         for(i=0;i<nr_class*(nr_class-1)/2;i++)
2355                         {
2356                                 model->probA[i] = probA[i];
2357                                 model->probB[i] = probB[i];
2358                         }
2359                 }
2360                 else
2361                 {
2362                         model->probA=NULL;
2363                         model->probB=NULL;
2364                 }
2365
2366                 int total_sv = 0;
2367         int *nz_count = new int[nr_class];
2368         model->nSV = new int[nr_class];
2369                 for(i=0;i<nr_class;i++)
2370                 {
2371                         int nSV = 0;
2372                         for(int j=0;j<count[i];j++)
2373                                 if(nonzero[start[i]+j])
2374                                 {       
2375                                         ++nSV;
2376                                         ++total_sv;
2377                                 }
2378                         model->nSV[i] = nSV;
2379                         nz_count[i] = nSV;
2380                 }
2381                 
2382                 info("Total nSV = %d\n",total_sv);
2383
2384                 model->l = total_sv;
2385         model->SV = new svm_node *[total_sv];
2386                 p = 0;
2387                 for(i=0;i<l;i++)
2388                         if(nonzero[i]) model->SV[p++] = x[i];
2389
2390         int *nz_start = new int[nr_class];
2391                 nz_start[0] = 0;
2392                 for(i=1;i<nr_class;i++)
2393                         nz_start[i] = nz_start[i-1]+nz_count[i-1];
2394
2395         model->sv_coef = new double * [nr_class-1];
2396                 for(i=0;i<nr_class-1;i++)
2397             model->sv_coef[i] = new double [total_sv];
2398
2399                 p = 0;
2400                 for(i=0;i<nr_class;i++)
2401                         for(int j=i+1;j<nr_class;j++)
2402                         {
2403                                 // classifier (i,j): coefficients with
2404                                 // i are in sv_coef[j-1][nz_start[i]...],
2405                                 // j are in sv_coef[i][nz_start[j]...]
2406
2407                                 int si = start[i];
2408                                 int sj = start[j];
2409                                 int ci = count[i];
2410                                 int cj = count[j];
2411                                 
2412                                 int q = nz_start[i];
2413                                 int k;
2414                                 for(k=0;k<ci;k++)
2415                                         if(nonzero[si+k])
2416                                                 model->sv_coef[j-1][q++] = f[p].alpha[k];
2417                                 q = nz_start[j];
2418                                 for(k=0;k<cj;k++)
2419                                         if(nonzero[sj+k])
2420                                                 model->sv_coef[i][q++] = f[p].alpha[ci+k];
2421                                 ++p;
2422                         }
2423                 
2424         delete [] label;
2425         delete [] probA;
2426         delete [] probB;
2427         delete [] count;
2428         delete [] perm;
2429         delete [] start;
2430         delete [] x;
2431         delete [] weighted_C;
2432         delete [] nonzero;
2433                 for(i=0;i<nr_class*(nr_class-1)/2;i++)
2434             delete [] f[i].alpha;
2435         delete [] f;
2436         delete [] nz_count;
2437         delete [] nz_start;
2438         }
2439         return model;
2440 }
2441
2442 // Stratified cross validation
2443 void svm_cross_validation(const svm_problem *prob, const svm_parameter *param, int nr_fold, double *target)
2444 {
2445         int i;
2446     int *fold_start = new int [nr_fold+1];
2447         int l = prob->l;
2448     int *perm = new int[l];
2449         int nr_class;
2450
2451         // stratified cv may not give leave-one-out rate
2452         // Each class to l folds -> some folds may have zero elements
2453         if((param->svm_type == C_SVC ||
2454             param->svm_type == NU_SVC) && nr_fold < l)
2455         {
2456                 int *start = NULL;
2457                 int *label = NULL;
2458                 int *count = NULL;
2459                 svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
2460
2461                 // random shuffle and then data grouped by fold using the array perm
2462         int *fold_count = new int[nr_fold];
2463                 int c;
2464         int *index = new int[l];
2465                 for(i=0;i<l;i++)
2466                         index[i]=perm[i];
2467                 for (c=0; c<nr_class; c++) 
2468                         for(i=0;i<count[c];i++)
2469                         {
2470                                 int j = i+rand()%(count[c]-i);
2471                                 swap(index[start[c]+j],index[start[c]+i]);
2472                         }
2473                 for(i=0;i<nr_fold;i++)
2474                 {
2475                         fold_count[i] = 0;
2476                         for (c=0; c<nr_class;c++)
2477                                 fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold;
2478                 }
2479                 fold_start[0]=0;
2480                 for (i=1;i<=nr_fold;i++)
2481                         fold_start[i] = fold_start[i-1]+fold_count[i-1];
2482                 for (c=0; c<nr_class;c++)
2483                         for(i=0;i<nr_fold;i++)
2484                         {
2485                                 int begin = start[c]+i*count[c]/nr_fold;
2486                                 int end = start[c]+(i+1)*count[c]/nr_fold;
2487                                 for(int j=begin;j<end;j++)
2488                                 {
2489                                         perm[fold_start[i]] = index[j];
2490                                         fold_start[i]++;
2491                                 }
2492                         }
2493                 fold_start[0]=0;
2494                 for (i=1;i<=nr_fold;i++)
2495                         fold_start[i] = fold_start[i-1]+fold_count[i-1];
2496         delete [] start;
2497         delete [] label;
2498         delete [] count;
2499         delete [] index;
2500         delete [] fold_count;
2501         }
2502         else
2503         {
2504                 for(i=0;i<l;i++) perm[i]=i;
2505                 for(i=0;i<l;i++)
2506                 {
2507                         int j = i+rand()%(l-i);
2508                         swap(perm[i],perm[j]);
2509                 }
2510                 for(i=0;i<=nr_fold;i++)
2511                         fold_start[i]=i*l/nr_fold;
2512         }
2513
2514         for(i=0;i<nr_fold;i++)
2515         {
2516                 int begin = fold_start[i];
2517                 int end = fold_start[i+1];
2518                 int j,k;
2519                 struct svm_problem subprob;
2520
2521                 subprob.l = l-(end-begin);
2522         subprob.x = new struct svm_node*[subprob.l];
2523         subprob.y = new double[subprob.l];