- Integrated the file import into the main interface (not as a plugin anymore)
[mldemos:mldemos.git] / _AlgorithmsPlugins / KernelMethods / interfaceSVMRegress.cpp
1 /*********************************************************************
2 MLDemos: A User-Friendly visualization toolkit for machine learning
3 Copyright (C) 2010  Basilio Noris
4 Contact: mldemos@b4silio.com
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public License,
8 version 3 as published by the Free Software Foundation.
9
10 This library is distributed in the hope that it will be useful, but
11 WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 Lesser General Public License for more details.
14
15 You should have received a copy of the GNU Lesser General Public
16 License along with this library; if not, write to the Free
17 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18 *********************************************************************/
19 #include "interfaceSVMRegress.h"
20 #include <QPixmap>
21 #include <QBitmap>
22 #include <QPainter>
23
24 using namespace std;
25
26 RegrSVM::RegrSVM()
27 {
28         params = new Ui::ParametersRegr();
29         params->setupUi(widget = new QWidget());
30     connect(params->svmTypeCombo, SIGNAL(currentIndexChanged(int)), this, SLOT(ChangeOptions()));
31     connect(params->kernelTypeCombo, SIGNAL(currentIndexChanged(int)), this, SLOT(ChangeOptions()));
32 }
33
34 void RegrSVM::ChangeOptions()
35 {
36         params->svmCLabel->setText("C");
37     params->svmPSpin->setRange(0.0001, 1.0);
38         params->svmPSpin->setSingleStep(0.01);
39         params->svmPSpin->setDecimals(4);
40         params->svmCSpin->setEnabled(true);
41         params->svmCSpin->setRange(0.1, 9999.9);
42         params->svmCSpin->setDecimals(1);
43         switch(params->svmTypeCombo->currentIndex())
44         {
45         case 0: // C-SVM
46                 params->svmEpsLabel->setText("eps");
47         params->svmPSpin->setRange(0.0001, 100.0);
48         break;
49         case 1: // Nu-SVM
50                 params->svmEpsLabel->setText("Nu");
51                 break;
52         case 2: // RVM
53                 params->svmCSpin->setEnabled(false);
54                 params->svmEpsLabel->setText("eps");
55                 break;
56         case 3: // SOGP
57                 params->svmEpsLabel->setText("Noise");
58                 params->svmCLabel->setText("Capacity");
59                 params->svmCSpin->setRange(-1, 500);
60                 params->svmCSpin->setDecimals(0);
61                 params->svmPSpin->setRange(0.001, 1.0);
62                 params->svmPSpin->setSingleStep(0.01);
63                 params->svmPSpin->setDecimals(3);
64                 break;
65         case 4:
66                 params->svmEpsLabel->setText("Tolerance");
67                 params->svmCLabel->setText("Capacity");
68                 params->svmCSpin->setRange(0, 1000);
69                 params->svmCSpin->setDecimals(0);
70                 params->svmPSpin->setRange(0.0001, 1.0);
71                 params->svmPSpin->setSingleStep(0.001);
72                 params->svmPSpin->setDecimals(4);
73                 break;
74         }
75     switch(params->kernelTypeCombo->currentIndex())
76     {
77     case 0: // linear
78         params->kernelDegSpin->setEnabled(false);
79         break;
80     case 1: // poly
81         params->kernelDegSpin->setEnabled(true);
82         params->kernelWidthSpin->setEnabled(true);
83         break;
84     case 2: // RBF
85         params->kernelDegSpin->setEnabled(false);
86         params->kernelWidthSpin->setEnabled(true);
87         break;
88     }
89 }
90
91 void RegrSVM::SetParams(Regressor *regressor)
92 {
93         if(!regressor) return;
94         int kernelMethod = params->svmTypeCombo->currentIndex();
95         float svmC = params->svmCSpin->value();
96         int kernelType = params->kernelTypeCombo->currentIndex();
97         float kernelGamma = params->kernelWidthSpin->value();
98         float kernelDegree = params->kernelDegSpin->value();
99         float svmP = params->svmPSpin->value();
100
101         if(kernelMethod == 2) // rvm
102         {
103                 RegressorRVM *rvm = (RegressorRVM*)regressor;
104                 rvm->SetParams(svmP, kernelType, kernelGamma, kernelDegree);
105         }
106         else if(kernelMethod == 3) // sogp
107         {
108                 RegressorGPR *gpr = (RegressorGPR*)regressor;
109                 int capacity = svmC;
110                 double kernelNoise = svmP;
111                 gpr->SetParams(kernelGamma, kernelNoise, capacity, kernelType, kernelDegree);
112         }
113         else if(kernelMethod == 4 ) // KRLS
114         {
115                 RegressorKRLS *krls = (RegressorKRLS*)regressor;
116                 int capacity = svmC;
117                 double epsilon = svmP;
118                 krls->SetParams(epsilon, capacity, kernelType, kernelGamma, kernelDegree);
119         }
120         else
121         {
122                 RegressorSVR *svm = (RegressorSVR*)regressor;
123                 switch(kernelMethod)
124                 {
125                 case 0:
126                         svm->param.svm_type = EPSILON_SVR;
127                         break;
128                 case 1:
129                         svm->param.svm_type = NU_SVR;
130                         break;
131                 }
132                 switch(kernelType)
133                 {
134                 case 0:
135                         svm->param.kernel_type = LINEAR;
136                         break;
137                 case 1:
138                         svm->param.kernel_type = POLY;
139                         break;
140                 case 2:
141                         svm->param.kernel_type = RBF;
142                         break;
143                 }
144                 svm->param.C = svmC;
145                 svm->param.nu = svmP;
146                 svm->param.p = svmP;
147                 svm->param.gamma = 1 / kernelGamma;
148                 svm->param.degree = kernelDegree;
149         }
150 }
151
152 QString RegrSVM::GetAlgoString()
153 {
154         int kernelMethod = params->svmTypeCombo->currentIndex();
155         float svmC = params->svmCSpin->value();
156         int kernelType = params->kernelTypeCombo->currentIndex();
157         float kernelGamma = params->kernelWidthSpin->value();
158         float kernelDegree = params->kernelDegSpin->value();
159         float svmP = params->svmPSpin->value();
160
161         QString algo;
162         switch(kernelMethod)
163         {
164         case 0:
165                 algo += "eps-SVM";
166                 algo += QString(" %1 %2").arg(svmC).arg(svmP);
167                 break;
168         case 1:
169                 algo += "nu-SVM";
170                 algo += QString(" %1 %2").arg(svmC).arg(svmP);
171                 break;
172         case 2:
173                 algo += "RVM";
174                 algo += QString(" %1").arg(svmP);
175                 break;
176         case 3:
177                 algo += "SOGP";
178                 algo += QString(" %1 %2").arg(svmC).arg(svmP);
179                 break;
180         case 4:
181                 algo += "KRLS";
182                 algo += QString(" %1 %2").arg(svmC).arg(svmP);
183                 break;
184         }
185         switch(kernelType)
186         {
187         case 0:
188                 algo += " L";
189                 break;
190         case 1:
191                 algo += QString(" P %1").arg(kernelDegree);
192                 break;
193         case 2:
194                 algo += QString(" R %1").arg(kernelGamma);
195                 break;
196         }
197         return algo;
198 }
199
200 Regressor *RegrSVM::GetRegressor()
201 {
202         int svmType = params->svmTypeCombo->currentIndex();
203         Regressor *regressor = 0;
204         switch(svmType)
205         {
206         case 2:
207                 regressor = new RegressorRVM();
208                 break;
209         case 3:
210                 regressor = new RegressorGPR();
211                 break;
212         case 4:
213                 regressor = new RegressorKRLS();
214                 break;
215         default:
216                 regressor = new RegressorSVR();
217                 break;
218         }
219         SetParams(regressor);
220         return regressor;
221 }
222
223 void DrawArrow( const QPointF &ppt, const QPointF &pt, double sze, QPainter &painter)
224 {
225         QPointF pd, pa, pb;
226         double tangent;
227
228         pd = ppt - pt;
229         if (pd.x() == 0 && pd.y() == 0)
230                 return;
231         tangent = atan2 ((double) pd.y(), (double) pd.x());
232         pa.setX(sze * cos (tangent + M_PI / 7.f) + pt.x());
233         pa.setY(sze * sin (tangent + M_PI / 7.f) + pt.y());
234         pb.setX(sze * cos (tangent - M_PI / 7.f) + pt.x());
235         pb.setY(sze * sin (tangent - M_PI / 7.f) + pt.y());
236         //-- connect the dots...
237         painter.drawLine(pt, ppt);
238         painter.drawLine(pt, pa);
239         painter.drawLine(pt, pb);
240 }
241
242 void RegrSVM::DrawInfo(Canvas *canvas, QPainter &painter, Regressor *regressor)
243 {
244         painter.setRenderHint(QPainter::Antialiasing);
245     int xIndex = canvas->xIndex;
246     int yIndex = canvas->yIndex;
247     if(regressor->type == REGR_RVM || regressor->type == REGR_KRLS)
248         {
249                 vector<fvec> sv = (regressor->type == REGR_KRLS) ?
250                                 ((RegressorKRLS*)regressor)->GetSVs() :
251                                 ((RegressorRVM*)regressor)->GetSVs();
252                 int radius = 9;
253                 painter.setBrush(Qt::NoBrush);
254                 FOR(i, sv.size())
255                 {
256                         QPointF point = canvas->toCanvasCoords(sv[i]);
257                         painter.setPen(QPen(Qt::black,6));
258                         painter.drawEllipse(point, radius, radius);
259                         painter.setPen(QPen(Qt::white,3));
260                         painter.drawEllipse(point, radius, radius);
261                 }
262         }
263         else if(regressor->type == REGR_SVR)
264         {
265                 // we want to draw the support vectors
266                 svm_model *svm = ((RegressorSVR*)regressor)->GetModel();
267                 if(svm)
268                 {
269                         painter.setBrush(Qt::NoBrush);
270                         std::vector<fvec> samples = canvas->data->GetSamples();
271             int dim = canvas->data->GetDimCount();
272             fvec sv(2,0);
273                         FOR(i, svm->l)
274                         {
275                 sv[0] = (f32)svm->SV[i][xIndex].value;
276                                 FOR(j, samples.size())
277                                 {
278                     if(sv[0] == samples[j][xIndex])
279                                         {
280                         sv[1] = samples[j][yIndex];
281                                                 break;
282                                         }
283                                 }
284                                 int radius = 7;
285                 QPointF point = canvas->toCanvasCoords(sv[0],sv[1]);
286                                 if(abs((*svm->sv_coef)[i]) == svm->param.C)
287                                 {
288                                         painter.setPen(QPen(Qt::black, 4));
289                                         painter.drawEllipse(point, radius, radius);
290                                         painter.setPen(Qt::white);
291                                         painter.drawEllipse(point, radius, radius);
292                                 }
293                                 else
294                                 {
295                                         painter.setPen(Qt::black);
296                                         painter.drawEllipse(point, radius, radius);
297                                 }
298                         }
299                 }
300         }
301         else if(regressor->type == REGR_GPR)
302         {
303                 RegressorGPR * gpr = (RegressorGPR*)regressor;
304                 int radius = 8;
305         int dim = canvas->data->GetDimCount()-1;
306                 painter.setBrush(Qt::NoBrush);
307                 FOR(i, gpr->GetBasisCount())
308                 {
309                         fvec basis = gpr->GetBasisVector(i);
310             fvec testBasis(dim+1);
311             FOR(d, dim) testBasis[d] = basis[d];
312             fvec res = gpr->Test(testBasis);
313             QPointF pt1 = canvas->toCanvasCoords(basis[xIndex],res[0]);
314             QPointF pt2 = pt1 + QPointF(0,(basis[dim + xIndex]>0 ? 1 : -1)*radius);
315             QPointF pt3 = pt2 + QPointF(0,(basis[dim + xIndex]>0 ? 1 : -1)*50);
316                         painter.setPen(QPen(Qt::red,3));
317                         painter.drawEllipse(pt1, radius, radius);
318             painter.setPen(QPen(Qt::red,min(4.f,max(fabs(basis[dim + xIndex])/5,0.5f))));
319                         DrawArrow(pt2,pt3,10,painter);
320                 }
321         }
322 }
323
324 void RegrSVM::DrawConfidence(Canvas *canvas, Regressor *regressor)
325 {
326         if(regressor->type == REGR_GPR)
327         {
328                 RegressorGPR *gpr = (RegressorGPR *)regressor;
329                 if(gpr->sogp)
330                 {
331             int w = canvas->width();
332                         int h = canvas->height();
333             int dim = canvas->data->GetDimCount()-1;
334             int outputDim = regressor->outputDim;
335             int xIndex = canvas->xIndex;
336             int yIndex = canvas->yIndex;
337             Matrix _testout;
338             ColumnVector _testin(dim);
339                         QImage density(QSize(256,256), QImage::Format_RGB32);
340                         density.fill(0);
341                         // we draw a density map for the probability
342                         for (int i=0; i < density.width(); i++)
343                         {
344                                 fvec sampleIn = canvas->toSampleCoords(i*w/density.width(),0);
345                 FOR(d, dim) _testin(d+1) = sampleIn[d];
346                 if(outputDim != -1 && outputDim < dim) _testin(outputDim+1) = sampleIn[dim];
347                                 double sigma;
348                                 _testout = gpr->sogp->predict(_testin, sigma);
349                                 sigma = sigma*sigma;
350                                 float testout = _testout(1,1);
351                                 for (int j=0; j< density.height(); j++)
352                                 {
353                     fvec sampleOut = canvas->toSampleCoords(i*w/density.width(),j*h/density.height());
354                     float val = gpr->GetLikelihood(testout, sigma, sampleOut[yIndex]);
355                                         int color = min(255,(int)(128 + val*20));
356                                         density.setPixel(i,j, qRgb(color,color,color));
357                                 }
358                         }
359                         canvas->maps.confidence = QPixmap::fromImage(density.scaled(QSize(w,h),Qt::IgnoreAspectRatio, Qt::SmoothTransformation));
360                 }
361                 else canvas->maps.confidence = QPixmap();
362         }
363         else canvas->maps.confidence = QPixmap();
364 }
365
366 void RegrSVM::DrawModel(Canvas *canvas, QPainter &painter, Regressor *regressor)
367 {
368         painter.setRenderHint(QPainter::Antialiasing, true);
369         int w = canvas->width();
370         int h = canvas->height();
371     int xIndex = canvas->xIndex;
372     fvec sample = canvas->toSampleCoords(0,0);
373     int dim = sample.size();
374     if(dim > 2) return;
375         if(regressor->type == REGR_KRLS || regressor->type == REGR_RVM)
376         {
377                 canvas->maps.confidence = QPixmap();
378                 int steps = w;
379                 QPointF oldPoint(-FLT_MAX,-FLT_MAX);
380                 FOR(x, steps)
381                 {
382                         sample = canvas->toSampleCoords(x,0);
383                         fvec res = regressor->Test(sample);
384                         if(res[0] != res[0]) continue;
385             QPointF point = canvas->toCanvasCoords(sample[xIndex], res[0]);
386                         if(x)
387                         {
388                                 painter.setPen(QPen(Qt::black, 1));
389                                 painter.drawLine(point, oldPoint);
390                                 painter.setPen(QPen(Qt::black, 0.5));
391                                 //                              painter.drawLine(point+QPointF(0,eps*h), oldPoint+QPointF(0,eps*h));
392                                 //                              painter.drawLine(point-QPointF(0,eps*h), oldPoint-QPointF(0,eps*h));
393                         }
394                         oldPoint = point;
395                 }
396         }
397         else if(regressor->type == REGR_SVR)
398         {
399                 canvas->maps.confidence = QPixmap();
400                 svm_parameter params = ((RegressorSVR *)regressor)->param;
401
402                 float eps = params.p;
403                 if(params.svm_type == NU_SVR) eps = ((RegressorSVR *)regressor)->GetModel()->eps[0];
404                 eps = fabs((canvas->toCanvasCoords(eps,0) - canvas->toCanvasCoords(0,0)).x());
405
406                 int steps = w;
407                 QPointF oldPoint(-FLT_MAX,-FLT_MAX);
408                 FOR(x, steps)
409                 {
410                         sample = canvas->toSampleCoords(x,0);
411             int dim = sample.size();
412                         fvec res = regressor->Test(sample);
413                         if(res[0] != res[0]) continue;
414             QPointF point = canvas->toCanvasCoords(sample[xIndex], res[0]);
415                         if(x)
416                         {
417                                 painter.setPen(QPen(Qt::black, 1));
418                                 painter.drawLine(point, oldPoint);
419                                 painter.setPen(QPen(Qt::black, 0.5));
420                                 painter.drawLine(point+QPointF(0,eps), oldPoint+QPointF(0,eps));
421                                 painter.drawLine(point-QPointF(0,eps), oldPoint-QPointF(0,eps));
422                         }
423                         oldPoint = point;
424                 }
425         }
426         else if(regressor->type == REGR_GPR)
427         {
428                 RegressorGPR *gpr = (RegressorGPR *)regressor;
429                 int steps = w;
430                 QPointF oldPoint(-FLT_MAX,-FLT_MAX);
431                 QPointF oldPointUp(-FLT_MAX,-FLT_MAX);
432                 QPointF oldPointDown(-FLT_MAX,-FLT_MAX);
433                 FOR(x, steps)
434                 {
435                         sample = canvas->toSampleCoords(x,0);
436                         fvec res = regressor->Test(sample);
437                         if(res[0] != res[0] || res[1] != res[1]) continue;
438             QPointF point = canvas->toCanvasCoords(sample[xIndex], res[0]);
439             QPointF pointUp = canvas->toCanvasCoords(sample[xIndex],res[0] + res[1]);
440             QPointF pointDown = canvas->toCanvasCoords(sample[xIndex],res[0] - res[1]);
441                         if(x)
442                         {
443                                 painter.setPen(QPen(Qt::black, 1));
444                                 painter.drawLine(point, oldPoint);
445                                 painter.setPen(QPen(Qt::black, 0.5));
446                                 painter.drawLine(pointUp, oldPointUp);
447                                 painter.drawLine(pointDown, oldPointDown);
448                         }
449                         oldPoint = point;
450                         oldPointUp = pointUp;
451                         oldPointDown = pointDown;
452                 }
453         }
454 }
455
456 void RegrSVM::SaveOptions(QSettings &settings)
457 {
458         settings.setValue("kernelDeg", params->kernelDegSpin->value());
459         settings.setValue("kernelType", params->kernelTypeCombo->currentIndex());
460         settings.setValue("kernelWidth", params->kernelWidthSpin->value());
461         settings.setValue("svmC", params->svmCSpin->value());
462         settings.setValue("svmP", params->svmPSpin->value());
463         settings.setValue("svmType", params->svmTypeCombo->currentIndex());
464 }
465
466 bool RegrSVM::LoadOptions(QSettings &settings)
467 {
468         if(settings.contains("kernelDeg")) params->kernelDegSpin->setValue(settings.value("kernelDeg").toFloat());
469         if(settings.contains("kernelType")) params->kernelTypeCombo->setCurrentIndex(settings.value("kernelType").toInt());
470         if(settings.contains("kernelWidth")) params->kernelWidthSpin->setValue(settings.value("kernelWidth").toFloat());
471         if(settings.contains("svmC")) params->svmCSpin->setValue(settings.value("svmC").toFloat());
472         if(settings.contains("svmP")) params->svmPSpin->setValue(settings.value("svmP").toFloat());
473         if(settings.contains("svmType")) params->svmTypeCombo->setCurrentIndex(settings.value("svmType").toInt());
474         return true;
475 }
476
477 void RegrSVM::SaveParams(QTextStream &file)
478 {
479         file << "regressionOptions" << ":" << "kernelDeg" << " " << params->kernelDegSpin->value() << "\n";
480         file << "regressionOptions" << ":" << "kernelType" << " " << params->kernelTypeCombo->currentIndex() << "\n";
481         file << "regressionOptions" << ":" << "kernelWidth" << " " << params->kernelWidthSpin->value() << "\n";
482         file << "regressionOptions" << ":" << "svmC" << " " << params->svmCSpin->value() << "\n";
483         file << "regressionOptions" << ":" << "svmP" << " " << params->svmPSpin->value() << "\n";
484         file << "regressionOptions" << ":" << "svmType" << " " << params->svmTypeCombo->currentIndex() << "\n";
485 }
486
487 bool RegrSVM::LoadParams(QString name, float value)
488 {
489         if(name.endsWith("kernelDeg")) params->kernelDegSpin->setValue((int)value);
490         if(name.endsWith("kernelType")) params->kernelTypeCombo->setCurrentIndex((int)value);
491         if(name.endsWith("kernelWidth")) params->kernelWidthSpin->setValue(value);
492         if(name.endsWith("svmC")) params->svmCSpin->setValue(value);
493         if(name.endsWith("svmP")) params->svmPSpin->setValue(value);
494         if(name.endsWith("svmType")) params->svmTypeCombo->setCurrentIndex((int)value);
495         return true;
496 }