REMOVED: the DrawModel function from classifierInterface (they all do exactly the...
[mldemos:baraks-mldemos.git] / _AlgorithmsPlugins / KernelMethods / interfaceSVMClassifier.cpp
1 /*********************************************************************
2 MLDemos: A User-Friendly visualization toolkit for machine learning
3 Copyright (C) 2010  Basilio Noris
4 Contact: mldemos@b4silio.com
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public License,
8 version 3 as published by the Free Software Foundation.
9
10 This library is distributed in the hope that it will be useful, but
11 WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 Lesser General Public License for more details.
14
15 You should have received a copy of the GNU Lesser General Public
16 License along with this library; if not, write to the Free
17 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18 *********************************************************************/
19 #include "interfaceSVMClassifier.h"
20 #include <QPixmap>
21 #include <QBitmap>
22 #include <QPainter>
23 #include <QDebug>
24
25 using namespace std;
26
27 ClassSVM::ClassSVM()
28 {
29     params = new Ui::Parameters();
30     params->setupUi(widget = new QWidget());
31     connect(params->svmTypeCombo, SIGNAL(currentIndexChanged(int)), this, SLOT(ChangeOptions()));
32     connect(params->kernelTypeCombo, SIGNAL(currentIndexChanged(int)), this, SLOT(ChangeOptions()));
33     ChangeOptions();
34 }
35
36 void ClassSVM::ChangeOptions()
37 {
38     int C = params->svmCSpin->value();
39     params->maxSVSpin->setVisible(false);
40     params->labelMaxSV->setVisible(false);
41     params->svmCSpin->setRange(0.0001, 1.0);
42     params->svmCSpin->setSingleStep(0.0001);
43     params->svmCSpin->setDecimals(4);
44     params->optimizeCheck->setVisible(true);
45     if(C > 1) params->svmCSpin->setValue(0.001);
46     switch(params->svmTypeCombo->currentIndex())
47     {
48     case 0: // C-SVM
49         params->svmCSpin->setRange(0.1, 9999);
50         params->svmCSpin->setSingleStep(1);
51         params->svmCSpin->setDecimals(1);
52         params->svmCSpin->setValue(C);
53         if(params->svmCSpin->value() < 1) params->svmCSpin->setValue(100);
54         params->svmTypeLabel->setText("C");
55         if(params->kernelTypeCombo->count() < 4) params->kernelTypeCombo->addItem("Sigmoid");
56         break;
57     case 1: // Nu-SVM
58         params->svmTypeLabel->setText("Nu");
59         if(params->kernelTypeCombo->count() < 4) params->kernelTypeCombo->addItem("Sigmoid");
60         break;
61     case 2: // Pegasos
62         params->optimizeCheck->setVisible(false);
63         params->svmTypeLabel->setText("lambda");
64         params->maxSVSpin->setVisible(true);
65         params->labelMaxSV->setVisible(true);
66         if(params->kernelTypeCombo->count() > 3) params->kernelTypeCombo->removeItem(3);
67         break;
68     }
69     switch(params->kernelTypeCombo->currentIndex())
70     {
71     case 0: // linear
72         params->kernelDegSpin->setVisible(false);
73         params->labelDegree->setVisible(false);
74         params->kernelWidthSpin->setVisible(false);
75         params->labelWidth->setVisible(false);
76         break;
77     case 1: // poly
78         params->kernelDegSpin->setVisible(true);
79         params->labelDegree->setVisible(true);
80         params->kernelWidthSpin->setVisible(false);
81         params->labelWidth->setVisible(false);
82         break;
83     case 2: // RBF
84         params->kernelDegSpin->setVisible(false);
85         params->labelDegree->setVisible(false);
86         params->kernelWidthSpin->setVisible(true);
87         params->labelWidth->setVisible(true);
88         break;
89     case 3: // SIGMOID
90         params->kernelDegSpin->setEnabled(false);
91         params->labelDegree->setVisible(false);
92         params->kernelWidthSpin->setEnabled(true);
93         params->labelWidth->setVisible(true);
94         break;
95     }
96 }
97
98 QString ClassSVM::GetAlgoString()
99 {
100     double C = params->svmCSpin->value();
101     int sv = params->maxSVSpin->value();
102     int kernelType = params->kernelTypeCombo->currentIndex();
103     float kernelGamma = params->kernelWidthSpin->value();
104     float kernelDegree = params->kernelDegSpin->value();
105     bool bOptimize = params->optimizeCheck->isChecked();
106
107     QString algo;
108     switch(params->svmTypeCombo->currentIndex())
109     {
110     case 0: // C-SVM
111         algo += "C-SVM";
112         algo += QString(" %1").arg(C);
113         break;
114     case 1: // Nu-SVM
115         algo += "Nu-SVM";
116         algo += QString(" %1").arg(C);
117         break;
118     case 2: // Pegasos
119         algo += "Pegasos";
120         algo += QString(" %1 %2").arg(C).arg(sv);
121         break;
122     }
123     switch(kernelType)
124     {
125     case 0:
126         algo += " Lin";
127         break;
128     case 1:
129         algo += QString(" Pol %1").arg(kernelDegree);
130         break;
131     case 2:
132         algo += QString(" RBF %1").arg(kernelGamma);
133         break;
134     case 3:
135         algo += QString(" Sig %1").arg(kernelGamma);
136         break;
137     }
138     if(bOptimize) algo += QString(" Opt");
139     return algo;
140 }
141
142 void ClassSVM::SetParams(Classifier *classifier)
143 {
144     if(!classifier) return;
145     int svmType = params->svmTypeCombo->currentIndex();
146     int maxSV = params->maxSVSpin->value();
147     float svmC = params->svmCSpin->value();
148     int kernelType = params->kernelTypeCombo->currentIndex();
149     float kernelGamma = params->kernelWidthSpin->value();
150     float kernelDegree = params->kernelDegSpin->value();
151     bool bOptimize = params->optimizeCheck->isChecked();
152
153     ClassifierPegasos *pegasos = dynamic_cast<ClassifierPegasos *>(classifier);
154     if(pegasos) pegasos->SetParams(svmC, max(2,(int)maxSV), kernelType, kernelGamma, kernelDegree);
155
156     ClassifierSVM *svm = dynamic_cast<ClassifierSVM *>(classifier);
157     if(svm)
158     {
159         switch(svmType)
160         {
161         case 0:
162             svm->param.svm_type = C_SVC;
163             break;
164         case 1:
165             svm->param.svm_type = NU_SVC;
166             break;
167         }
168         switch(kernelType)
169         {
170         case 0:
171             svm->param.kernel_type = LINEAR;
172             break;
173         case 1:
174             svm->param.kernel_type = POLY;
175             break;
176         case 2:
177             svm->param.kernel_type = RBF;
178             break;
179         case 3:
180             svm->param.kernel_type = SIGMOID;
181             break;
182         }
183         svm->param.C = svm->param.nu = svmC;
184         svm->param.gamma = 1 / kernelGamma;
185         svm->param.coef0 = 0;
186         svm->param.degree = kernelDegree;
187         svm->bOptimize = bOptimize;
188     }
189 }
190
191 fvec ClassSVM::GetParams()
192 {
193     int svmType = params->svmTypeCombo->currentIndex();
194     int maxSV = params->maxSVSpin->value();
195     float svmC = params->svmCSpin->value();
196     int kernelType = params->kernelTypeCombo->currentIndex();
197     float kernelGamma = params->kernelWidthSpin->value();
198     float kernelDegree = params->kernelDegSpin->value();
199     bool bOptimize = params->optimizeCheck->isChecked();
200
201     if(svmType == 2) // pegasos
202     {
203         fvec par(5);
204         par[0] = svmC;
205         par[1] = maxSV;
206         par[2] = kernelType;
207         par[3] = kernelGamma;
208         par[4] = kernelDegree;
209         return par;
210     }
211     else
212     {
213         fvec par(6);
214         par[0] = svmType;
215         par[1] = svmC;
216         par[2] = kernelType;
217         par[3] = kernelGamma;
218         par[4] = kernelDegree;
219         par[5] = bOptimize;
220         return par;
221     }
222 }
223
224 void ClassSVM::SetParams(Classifier *classifier, fvec parameters)
225 {
226     if(!classifier) return;
227     int svmType = params->svmTypeCombo->currentIndex();
228     float svmC, kernelGamma;
229     int maxSV, kernelType, kernelDegree;
230     bool bOptimize;
231     if(svmType == 2) // pegasos
232     {
233         svmC = parameters.size() > 0 ? parameters[0] : 1;
234         maxSV = parameters.size() > 1 ? parameters[1] : 0;
235         kernelType = parameters.size() > 2 ? parameters[2] : 0;
236         kernelGamma = parameters.size() > 3 ? parameters[3] : 0;
237         kernelDegree = parameters.size() > 4 ? parameters[4] : 0;
238     }
239     else
240     {
241         svmType = parameters.size() > 0 ? parameters[0] : 0;
242         svmC = parameters.size() > 1 ? parameters[1] : 1;
243         kernelType = parameters.size() > 2 ? parameters[2] : 0;
244         kernelGamma = parameters.size() > 3 ? parameters[3] : 0;
245         kernelDegree = parameters.size() > 4 ? parameters[4] : 0;
246         bOptimize = parameters.size() > 5 ? parameters[5] : 0;
247     }
248
249     ClassifierPegasos *pegasos = dynamic_cast<ClassifierPegasos *>(classifier);
250     if(pegasos) pegasos->SetParams(svmC, max(2,(int)maxSV), kernelType, kernelGamma, kernelDegree);
251
252     ClassifierSVM *svm = dynamic_cast<ClassifierSVM *>(classifier);
253     if(svm)
254     {
255         switch(svmType)
256         {
257         case 0:
258             svm->param.svm_type = C_SVC;
259             break;
260         case 1:
261             svm->param.svm_type = NU_SVC;
262             break;
263         }
264         switch(kernelType)
265         {
266         case 0:
267             svm->param.kernel_type = LINEAR;
268             break;
269         case 1:
270             svm->param.kernel_type = POLY;
271             break;
272         case 2:
273             svm->param.kernel_type = RBF;
274             break;
275         case 3:
276             svm->param.kernel_type = SIGMOID;
277             break;
278         }
279         svm->param.C = svm->param.nu = svmC;
280         svm->param.gamma = 1 / kernelGamma;
281         svm->param.coef0 = 0;
282         svm->param.degree = kernelDegree;
283         svm->bOptimize = bOptimize;
284     }
285 }
286
287 void ClassSVM::GetParameterList(std::vector<QString> &parameterNames,
288                                 std::vector<QString> &parameterTypes,
289                                 std::vector< std::vector<QString> > &parameterValues)
290 {
291     int svmType = params->svmTypeCombo->currentIndex();
292     if(svmType == 2)
293     {
294         parameterNames.push_back("Penalty (C)");
295         parameterNames.push_back("Max SV");
296         parameterNames.push_back("Kernel Type");
297         parameterNames.push_back("Kernel Width");
298         parameterNames.push_back("Kernel Degree");
299         parameterTypes.push_back("Real");
300         parameterTypes.push_back("Integer");
301         parameterTypes.push_back("List");
302         parameterTypes.push_back("Real");
303         parameterTypes.push_back("Integer");
304         parameterValues.push_back(vector<QString>());
305         parameterValues.back().push_back("0");
306         parameterValues.back().push_back("9999999999999");
307         parameterValues.push_back(vector<QString>());
308         parameterValues.back().push_back("1");
309         parameterValues.back().push_back("999999999");
310         parameterValues.push_back(vector<QString>());
311         parameterValues.back().push_back("Linear");
312         parameterValues.back().push_back("Poly");
313         parameterValues.back().push_back("RBF");
314         parameterValues.push_back(vector<QString>());
315         parameterValues.back().push_back("0.00000001f");
316         parameterValues.back().push_back("9999999");
317         parameterValues.push_back(vector<QString>());
318         parameterValues.back().push_back("1");
319         parameterValues.back().push_back("150");
320     }
321     else
322     {
323         parameterNames.push_back("SVM Type");
324         parameterNames.push_back("Penalty (C) / Nu");
325         parameterNames.push_back("Kernel Type");
326         parameterNames.push_back("Kernel Width");
327         parameterNames.push_back("Kernel Degree");
328         parameterNames.push_back("Optimize Kernel");
329         parameterTypes.push_back("List");
330         parameterTypes.push_back("Real");
331         parameterTypes.push_back("List");
332         parameterTypes.push_back("Real");
333         parameterTypes.push_back("Integer");
334         parameterTypes.push_back("List");
335         parameterValues.push_back(vector<QString>());
336         parameterValues.back().push_back("Epsilon-SVM");
337         parameterValues.back().push_back("Nu-SVM");
338         parameterValues.push_back(vector<QString>());
339         parameterValues.back().push_back("1");
340         parameterValues.back().push_back("999999999");
341         parameterValues.push_back(vector<QString>());
342         parameterValues.back().push_back("Linear");
343         parameterValues.back().push_back("Poly");
344         parameterValues.back().push_back("RBF");
345         parameterValues.push_back(vector<QString>());
346         parameterValues.back().push_back("0.00000001f");
347         parameterValues.back().push_back("9999999");
348         parameterValues.push_back(vector<QString>());
349         parameterValues.back().push_back("1");
350         parameterValues.back().push_back("150");
351         parameterValues.push_back(vector<QString>());
352         parameterValues.back().push_back("False");
353         parameterValues.back().push_back("True");
354     }
355 }
356
357 Classifier *ClassSVM::GetClassifier()
358 {
359     int svmType = params->svmTypeCombo->currentIndex();
360     Classifier *classifier = 0;
361     switch(svmType)
362     {
363     case 2:
364         classifier = new ClassifierPegasos();
365         break;
366     default:
367         classifier = new ClassifierSVM();
368         break;
369     }
370     SetParams(classifier);
371     return classifier;
372 }
373
374 void ClassSVM::DrawInfo(Canvas *canvas, QPainter &painter, Classifier *classifier)
375 {
376     painter.setRenderHint(QPainter::Antialiasing);
377
378     if(dynamic_cast<ClassifierPegasos*>(classifier))
379     {
380         // we want to draw the support vectors
381         vector<fvec> sv = dynamic_cast<ClassifierPegasos*>(classifier)->GetSVs();
382         int radius = 9;
383         FOR(i, sv.size())
384         {
385             QPointF point = canvas->toCanvasCoords(sv[i]);
386             painter.setPen(QPen(Qt::black,6));
387             painter.drawEllipse(point, radius, radius);
388             painter.setPen(QPen(Qt::white,4));
389             painter.drawEllipse(point, radius, radius);
390         }
391     }
392     else if(dynamic_cast<ClassifierSVM*>(classifier))
393     {
394         int dim = canvas->data->GetDimCount();
395         int xIndex = canvas->xIndex, yIndex = canvas->yIndex;
396         // we want to draw the support vectors
397         svm_model *svm = dynamic_cast<ClassifierSVM*>(classifier)->GetModel();
398         if(svm)
399         {
400             f32 sv[2];
401             FOR(i, svm->l)
402             {
403                 FOR(j, 2)
404                 {
405                     sv[j] = (f32)svm->SV[i][j].value;
406                 }
407                 int radius = 9;
408                 QPointF point = canvas->toCanvasCoords(sv[xIndex], sv[yIndex]);
409                 if(abs((*svm->sv_coef)[i]) == svm->param.C)
410                 {
411                     painter.setPen(QPen(Qt::black, 6));
412                     painter.drawEllipse(point, radius, radius);
413                     painter.setPen(QPen(Qt::white,4));
414                     painter.drawEllipse(point, radius, radius);
415                 }
416                 else
417                 {
418                     painter.setPen(QPen(Qt::white, 6));
419                     painter.drawEllipse(point, radius, radius);
420                     painter.setPen(QPen(Qt::black,4));
421                     painter.drawEllipse(point, radius, radius);
422                 }
423             }
424         }
425     }
426 }
427
428 void ClassSVM::DrawGL(Canvas *canvas, GLWidget *glw, Classifier *classifier)
429 {
430     int xInd = canvas->xIndex;
431     int yInd = canvas->yIndex;
432     int zInd = canvas->zIndex;
433     GLObject o;
434     o.objectType = "Samples";
435     o.style = "rings,pointsize:24";
436     vector<fvec> svs;
437     if(dynamic_cast<ClassifierPegasos*>(classifier))
438     {
439         // we want to draw the support vectors
440         svs = dynamic_cast<ClassifierPegasos*>(classifier)->GetSVs();
441     }
442     else if(dynamic_cast<ClassifierSVM*>(classifier))
443     {
444         int dim = canvas->data->GetDimCount();
445         // we want to draw the support vectors
446         svm_model *svm = dynamic_cast<ClassifierSVM*>(classifier)->GetModel();
447         if(svm)
448         {
449             fvec sv(dim);
450             FOR(i, svm->l)
451             {
452                 FOR(d, dim) sv[d] = svm->SV[i][d].value;
453                 svs.push_back(sv);
454             }
455         }
456     }
457     FOR(i, svs.size())
458     {
459         o.vertices.append(QVector3D(svs[i][xInd],svs[i][yInd],svs[i][zInd]));
460         o.colors.append(QVector4D(0,0,0,1));
461     }
462     glw->mutex->lock();
463     glw->objects.push_back(o);
464     glw->mutex->unlock();
465 }
466
467 void ClassSVM::SaveOptions(QSettings &settings)
468 {
469     settings.setValue("kernelDeg", params->kernelDegSpin->value());
470     settings.setValue("kernelType", params->kernelTypeCombo->currentIndex());
471     settings.setValue("kernelWidth", params->kernelWidthSpin->value());
472     settings.setValue("svmC", params->svmCSpin->value());
473     settings.setValue("svmType", params->svmTypeCombo->currentIndex());
474     settings.setValue("optimizeCheck", params->optimizeCheck->isChecked());
475     settings.setValue("maxSVSpin", params->maxSVSpin->value());
476 }
477
478 bool ClassSVM::LoadOptions(QSettings &settings)
479 {
480     if(settings.contains("kernelDeg")) params->kernelDegSpin->setValue(settings.value("kernelDeg").toFloat());
481     if(settings.contains("kernelType")) params->kernelTypeCombo->setCurrentIndex(settings.value("kernelType").toInt());
482     if(settings.contains("kernelWidth")) params->kernelWidthSpin->setValue(settings.value("kernelWidth").toFloat());
483     if(settings.contains("svmC")) params->svmCSpin->setValue(settings.value("svmC").toFloat());
484     if(settings.contains("svmType")) params->svmTypeCombo->setCurrentIndex(settings.value("svmType").toInt());
485     if(settings.contains("optimizeCheck")) params->optimizeCheck->setChecked(settings.value("optimizeCheck").toInt());
486     if(settings.contains("maxSVSpin")) params->maxSVSpin->setValue(settings.value("maxSVSpin").toInt());
487     ChangeOptions();
488     return true;
489 }
490
491 void ClassSVM::SaveParams(QTextStream &file)
492 {
493     file << "classificationOptions" << ":" << "kernelDeg" << " " << params->kernelDegSpin->value() << "\n";
494     file << "classificationOptions" << ":" << "kernelType" << " " << params->kernelTypeCombo->currentIndex() << "\n";
495     file << "classificationOptions" << ":" << "kernelWidth" << " " << params->kernelWidthSpin->value() << "\n";
496     file << "classificationOptions" << ":" << "svmC" << " " << params->svmCSpin->value() << "\n";
497     file << "classificationOptions" << ":" << "svmType" << " " << params->svmTypeCombo->currentIndex() << "\n";
498     file << "classificationOptions" << ":" << "optimizeCheck" << " " << params->optimizeCheck->isChecked() << "\n";
499     file << "classificationOptions" << ":" << "maxSVSpin" << " " << params->maxSVSpin->value() << "\n";
500 }
501
502 bool ClassSVM::LoadParams(QString name, float value)
503 {
504     if(name.endsWith("kernelDeg")) params->kernelDegSpin->setValue((int)value);
505     if(name.endsWith("kernelType")) params->kernelTypeCombo->setCurrentIndex((int)value);
506     if(name.endsWith("kernelWidth")) params->kernelWidthSpin->setValue(value);
507     if(name.endsWith("svmC")) params->svmCSpin->setValue(value);
508     if(name.endsWith("svmType")) params->svmTypeCombo->setCurrentIndex((int)value);
509     if(name.endsWith("optimizeCheck")) params->optimizeCheck->setChecked((int)value);
510     if(name.endsWith("maxSVSpin")) params->maxSVSpin->setValue((int)value);
511     ChangeOptions();
512     return true;
513 }