Merging requests
[mldemos:mldemos.git] / _AlgorithmsPlugins / KernelMethods / interfaceSVMClassifier.cpp
1 /*********************************************************************
2 MLDemos: A User-Friendly visualization toolkit for machine learning
3 Copyright (C) 2010  Basilio Noris
4 Contact: mldemos@b4silio.com
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public License,
8 version 3 as published by the Free Software Foundation.
9
10 This library is distributed in the hope that it will be useful, but
11 WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 Lesser General Public License for more details.
14
15 You should have received a copy of the GNU Lesser General Public
16 License along with this library; if not, write to the Free
17 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18 *********************************************************************/
19 #include "interfaceSVMClassifier.h"
20 #include <QPixmap>
21 #include <QBitmap>
22 #include <QPainter>
23 #include <QDebug>
24
25 using namespace std;
26
27 ClassSVM::ClassSVM()
28 {
29     params = new Ui::Parameters();
30     params->setupUi(widget = new QWidget());
31     ardLabel = 0;
32     connect(params->svmTypeCombo, SIGNAL(currentIndexChanged(int)), this, SLOT(ChangeOptions()));
33     connect(params->kernelTypeCombo, SIGNAL(currentIndexChanged(int)), this, SLOT(ChangeOptions()));
34     connect(params->optimizeCheck, SIGNAL(clicked()), this, SLOT(ChangeOptions()));
35     connect(params->kernelButton, SIGNAL(clicked()), this, SLOT(DisplayARDKernel()));
36     ChangeOptions();
37 }
38
39 ClassSVM::~ClassSVM()
40 {
41     delete params;
42     DEL(ardLabel);
43 }
44
45 void ClassSVM::ChangeOptions()
46 {
47     int C = params->svmCSpin->value();
48     params->maxSVSpin->setVisible(false);
49     params->labelMaxSV->setVisible(false);
50     params->svmCSpin->setRange(0.0001, 1.0);
51     params->svmCSpin->setSingleStep(0.0001);
52     params->svmCSpin->setDecimals(4);
53     params->optimizeCheck->setVisible(true);
54     if(C > 1) params->svmCSpin->setValue(0.001);
55     switch(params->svmTypeCombo->currentIndex())
56     {
57     case 0: // C-SVM
58         params->svmCSpin->setRange(0.1, 9999);
59         params->svmCSpin->setSingleStep(1);
60         params->svmCSpin->setDecimals(1);
61         params->svmCSpin->setValue(C);
62         if(params->svmCSpin->value() < 1) params->svmCSpin->setValue(100);
63         params->svmTypeLabel->setText("C");
64         if(params->kernelTypeCombo->count() < 4) params->kernelTypeCombo->addItem("Sigmoid");
65         break;
66     case 1: // Nu-SVM
67         params->svmTypeLabel->setText("Nu");
68         if(params->kernelTypeCombo->count() < 4) params->kernelTypeCombo->addItem("Sigmoid");
69         break;
70     case 2: // Pegasos
71         params->optimizeCheck->setVisible(false);
72         params->svmTypeLabel->setText("lambda");
73         params->maxSVSpin->setVisible(true);
74         params->labelMaxSV->setVisible(true);
75         if(params->kernelTypeCombo->count() > 3) params->kernelTypeCombo->removeItem(3);
76         break;
77     }
78     switch(params->kernelTypeCombo->currentIndex())
79     {
80     case 0: // linear
81         params->kernelDegSpin->setVisible(false);
82         params->labelDegree->setVisible(false);
83         params->kernelWidthSpin->setVisible(false);
84         params->labelWidth->setVisible(false);
85         break;
86     case 1: // poly
87         params->kernelDegSpin->setVisible(true);
88         params->labelDegree->setVisible(true);
89         params->kernelWidthSpin->setVisible(false);
90         params->labelWidth->setVisible(false);
91         break;
92     case 2: // RBF
93         params->kernelDegSpin->setVisible(false);
94         params->labelDegree->setVisible(false);
95         params->kernelWidthSpin->setVisible(true);
96         params->labelWidth->setVisible(true);
97         break;
98     case 3: // SIGMOID
99         params->kernelDegSpin->setEnabled(false);
100         params->labelDegree->setVisible(false);
101         params->kernelWidthSpin->setEnabled(true);
102         params->labelWidth->setVisible(true);
103         break;
104     }
105     params->kernelButton->setVisible(params->optimizeCheck->isChecked());
106 }
107
108 void ClassSVM::DisplayARDKernel()
109 {
110     if(!ardLabel)
111     {
112         ardLabel = new QLabel();
113         ardLabel->setScaledContents(true);
114     }
115     QPixmap pixmap(200,200);
116     //QBitmap bitmap(pixmap.width(), pixmap.height());
117     //bitmap.clear();
118     //pixmap.setMask(bitmap);
119     pixmap.fill(Qt::transparent);
120     QPainter painter(&pixmap);
121     painter.setRenderHint(QPainter::Antialiasing);
122     if(ardKernel.size())
123     {
124         QPointF center(pixmap.width()/2,pixmap.height()/2);
125
126         int dim = ardKernel.size();
127         float maxVal = -FLT_MAX;
128         FOR(d, dim) maxVal = max(maxVal, ardKernel[d]);
129         float maxRadius = pixmap.width()/2*0.75;
130         painter.setPen(QPen(Qt::black,0.5));
131         FOR(d, dim)
132         {
133             float angle = d*2*M_PI/dim;
134             float radius = pixmap.width()/2*0.8;
135             QPointF point(cosf(angle)*radius, sinf(angle)*radius);
136             painter.drawLine(center, center + point);
137         }
138         QPolygonF poly;
139         FOR(d, dim+1)
140         {
141             float value = ardKernel[d%dim];
142             float angle = d*2*M_PI/dim;
143             float radius = value/maxVal*maxRadius;
144             QPointF point(cosf(angle)*radius, sinf(angle)*radius);
145             poly << center + point;
146         }
147         painter.setBrush(Qt::red);
148         painter.setPen(Qt::NoPen);
149         painter.setOpacity(0.3);
150         painter.drawPolygon(poly);
151         painter.setBrush(Qt::NoBrush);
152         painter.setPen(QPen(Qt::red,2));
153         painter.drawPolygon(poly);
154     }
155     ardLabel->setPixmap(pixmap);
156     ardLabel->show();
157 }
158
159 QString ClassSVM::GetAlgoString()
160 {
161     double C = params->svmCSpin->value();
162     int sv = params->maxSVSpin->value();
163     int kernelType = params->kernelTypeCombo->currentIndex();
164     float kernelGamma = params->kernelWidthSpin->value();
165     float kernelDegree = params->kernelDegSpin->value();
166     bool bOptimize = params->optimizeCheck->isChecked();
167
168     QString algo;
169     switch(params->svmTypeCombo->currentIndex())
170     {
171     case 0: // C-SVM
172         algo += "C-SVM";
173         algo += QString(" %1").arg(C);
174         break;
175     case 1: // Nu-SVM
176         algo += "Nu-SVM";
177         algo += QString(" %1").arg(C);
178         break;
179     case 2: // Pegasos
180         algo += "Pegasos";
181         algo += QString(" %1 %2").arg(C).arg(sv);
182         break;
183     }
184     switch(kernelType)
185     {
186     case 0:
187         algo += " Lin";
188         break;
189     case 1:
190         algo += QString(" Pol %1").arg(kernelDegree);
191         break;
192     case 2:
193         algo += QString(" RBF %1").arg(kernelGamma);
194         break;
195     case 3:
196         algo += QString(" Sig %1").arg(kernelGamma);
197         break;
198     }
199     if(bOptimize) algo += QString(" Opt");
200     return algo;
201 }
202
203 void ClassSVM::SetParams(Classifier *classifier)
204 {
205     if(!classifier) return;
206     int svmType = params->svmTypeCombo->currentIndex();
207     int maxSV = params->maxSVSpin->value();
208     float svmC = params->svmCSpin->value();
209     int kernelType = params->kernelTypeCombo->currentIndex();
210     float kernelGamma = params->kernelWidthSpin->value();
211     float kernelDegree = params->kernelDegSpin->value();
212     bool bOptimize = params->optimizeCheck->isChecked();
213
214     ClassifierPegasos *pegasos = dynamic_cast<ClassifierPegasos *>(classifier);
215     if(pegasos) pegasos->SetParams(svmC, max(2,(int)maxSV), kernelType, kernelGamma, kernelDegree);
216
217     ClassifierSVM *svm = dynamic_cast<ClassifierSVM *>(classifier);
218     if(svm)
219     {
220         switch(svmType)
221         {
222         case 0:
223             svm->param.svm_type = C_SVC;
224             break;
225         case 1:
226             svm->param.svm_type = NU_SVC;
227             break;
228         }
229         switch(kernelType)
230         {
231         case 0:
232             svm->param.kernel_type = LINEAR;
233             break;
234         case 1:
235             svm->param.kernel_type = POLY;
236             break;
237         case 2:
238             svm->param.kernel_type = RBF;
239             break;
240         case 3:
241             svm->param.kernel_type = SIGMOID;
242             break;
243         }
244         svm->param.C = svm->param.nu = svmC;
245         svm->param.gamma = 1 / kernelGamma;
246         svm->param.coef0 = 0;
247         svm->param.degree = kernelDegree;
248         svm->bOptimize = bOptimize;
249     }
250 }
251
252 fvec ClassSVM::GetParams()
253 {
254     int svmType = params->svmTypeCombo->currentIndex();
255     int maxSV = params->maxSVSpin->value();
256     float svmC = params->svmCSpin->value();
257     int kernelType = params->kernelTypeCombo->currentIndex();
258     float kernelGamma = params->kernelWidthSpin->value();
259     float kernelDegree = params->kernelDegSpin->value();
260     bool bOptimize = params->optimizeCheck->isChecked();
261
262     if(svmType == 2) // pegasos
263     {
264         fvec par(5);
265         par[0] = svmC;
266         par[1] = maxSV;
267         par[2] = kernelType;
268         par[3] = kernelGamma;
269         par[4] = kernelDegree;
270         return par;
271     }
272     else
273     {
274         fvec par(6);
275         par[0] = svmType;
276         par[1] = svmC;
277         par[2] = kernelType;
278         par[3] = kernelGamma;
279         par[4] = kernelDegree;
280         par[5] = bOptimize;
281         return par;
282     }
283 }
284
285 void ClassSVM::SetParams(Classifier *classifier, fvec parameters)
286 {
287     if(!classifier) return;
288     int svmType = params->svmTypeCombo->currentIndex();
289     float svmC, kernelGamma;
290     int maxSV, kernelType, kernelDegree;
291     bool bOptimize;
292     if(svmType == 2) // pegasos
293     {
294         svmC = parameters.size() > 0 ? parameters[0] : 1;
295         maxSV = parameters.size() > 1 ? parameters[1] : 0;
296         kernelType = parameters.size() > 2 ? parameters[2] : 0;
297         kernelGamma = parameters.size() > 3 ? parameters[3] : 0.1;
298         kernelDegree = parameters.size() > 4 ? parameters[4] : 0;
299     }
300     else
301     {
302         svmType = parameters.size() > 0 ? parameters[0] : 0;
303         svmC = parameters.size() > 1 ? parameters[1] : 1;
304         kernelType = parameters.size() > 2 ? parameters[2] : 0;
305         kernelGamma = parameters.size() > 3 ? parameters[3] : 0.1;
306         kernelDegree = parameters.size() > 4 ? parameters[4] : 0;
307         bOptimize = parameters.size() > 5 ? parameters[5] : 0;
308     }
309
310     ClassifierPegasos *pegasos = dynamic_cast<ClassifierPegasos *>(classifier);
311     if(pegasos) pegasos->SetParams(svmC, max(2,(int)maxSV), kernelType, kernelGamma, kernelDegree);
312
313     ClassifierSVM *svm = dynamic_cast<ClassifierSVM *>(classifier);
314     if(svm)
315     {
316         switch(svmType)
317         {
318         case 0:
319             svm->param.svm_type = C_SVC;
320             break;
321         case 1:
322             svm->param.svm_type = NU_SVC;
323             break;
324         }
325         switch(kernelType)
326         {
327         case 0:
328             svm->param.kernel_type = LINEAR;
329             break;
330         case 1:
331             svm->param.kernel_type = POLY;
332             break;
333         case 2:
334             svm->param.kernel_type = RBF;
335             break;
336         case 3:
337             svm->param.kernel_type = SIGMOID;
338             break;
339         }
340         svm->param.C = svm->param.nu = svmC;
341         svm->param.gamma = 1. / kernelGamma;
342         svm->param.coef0 = 0;
343         svm->param.degree = kernelDegree;
344         svm->bOptimize = bOptimize;
345     }
346 }
347
348 void ClassSVM::GetParameterList(std::vector<QString> &parameterNames,
349                                 std::vector<QString> &parameterTypes,
350                                 std::vector< std::vector<QString> > &parameterValues)
351 {
352     int svmType = params->svmTypeCombo->currentIndex();
353     if(svmType == 2)
354     {
355         parameterNames.push_back("Penalty (C)");
356         parameterNames.push_back("Max SV");
357         parameterNames.push_back("Kernel Type");
358         parameterNames.push_back("Kernel Width");
359         parameterNames.push_back("Kernel Degree");
360         parameterTypes.push_back("Real");
361         parameterTypes.push_back("Integer");
362         parameterTypes.push_back("List");
363         parameterTypes.push_back("Real");
364         parameterTypes.push_back("Integer");
365         parameterValues.push_back(vector<QString>());
366         parameterValues.back().push_back("0");
367         parameterValues.back().push_back("9999999999999");
368         parameterValues.push_back(vector<QString>());
369         parameterValues.back().push_back("1");
370         parameterValues.back().push_back("999999999");
371         parameterValues.push_back(vector<QString>());
372         parameterValues.back().push_back("Linear");
373         parameterValues.back().push_back("Poly");
374         parameterValues.back().push_back("RBF");
375         parameterValues.push_back(vector<QString>());
376         parameterValues.back().push_back("0.00000001f");
377         parameterValues.back().push_back("9999999");
378         parameterValues.push_back(vector<QString>());
379         parameterValues.back().push_back("1");
380         parameterValues.back().push_back("150");
381     }
382     else
383     {
384         parameterNames.push_back("SVM Type");
385         parameterNames.push_back("Penalty (C) / Nu");
386         parameterNames.push_back("Kernel Type");
387         parameterNames.push_back("Kernel Width");
388         parameterNames.push_back("Kernel Degree");
389         parameterNames.push_back("Optimize Kernel");
390         parameterTypes.push_back("List");
391         parameterTypes.push_back("Real");
392         parameterTypes.push_back("List");
393         parameterTypes.push_back("Real");
394         parameterTypes.push_back("Integer");
395         parameterTypes.push_back("List");
396         parameterValues.push_back(vector<QString>());
397         parameterValues.back().push_back("Epsilon-SVM");
398         parameterValues.back().push_back("Nu-SVM");
399         parameterValues.push_back(vector<QString>());
400         parameterValues.back().push_back("1");
401         parameterValues.back().push_back("999999999");
402         parameterValues.push_back(vector<QString>());
403         parameterValues.back().push_back("Linear");
404         parameterValues.back().push_back("Poly");
405         parameterValues.back().push_back("RBF");
406         parameterValues.push_back(vector<QString>());
407         parameterValues.back().push_back("0.00000001f");
408         parameterValues.back().push_back("9999999");
409         parameterValues.push_back(vector<QString>());
410         parameterValues.back().push_back("1");
411         parameterValues.back().push_back("150");
412         parameterValues.push_back(vector<QString>());
413         parameterValues.back().push_back("False");
414         parameterValues.back().push_back("True");
415     }
416 }
417
418 Classifier *ClassSVM::GetClassifier()
419 {
420     int svmType = params->svmTypeCombo->currentIndex();
421     Classifier *classifier = 0;
422     switch(svmType)
423     {
424     case 2:
425         classifier = new ClassifierPegasos();
426         break;
427     default:
428         classifier = new ClassifierSVM();
429         break;
430     }
431     SetParams(classifier);
432     ardKernel.clear();
433     ardNames.clear();
434     return classifier;
435 }
436
437 void ClassSVM::DrawInfo(Canvas *canvas, QPainter &painter, Classifier *classifier)
438 {
439     ardKernel.clear();
440     ardNames.clear();
441     painter.setRenderHint(QPainter::Antialiasing);
442
443     if(dynamic_cast<ClassifierPegasos*>(classifier))
444     {
445         // we want to draw the support vectors
446         vector<fvec> sv = dynamic_cast<ClassifierPegasos*>(classifier)->GetSVs();
447         int radius = 9;
448         int dim = canvas->data->GetDimCount();
449         FOR(i, sv.size())
450         {
451             fvec sample = sv[i];
452             if(canvas->sourceDims.size())
453             {
454                 fvec newSample(dim, 0);
455                 FOR(d, canvas->sourceDims.size())
456                 {
457                     newSample[canvas->sourceDims[d]] = sample[d];
458                 }
459                 sample = newSample;
460             }
461             QPointF point = canvas->toCanvasCoords(sample);
462             painter.setPen(QPen(Qt::black,6));
463             painter.drawEllipse(point, radius, radius);
464             painter.setPen(QPen(Qt::white,4));
465             painter.drawEllipse(point, radius, radius);
466         }
467     }
468     else if(dynamic_cast<ClassifierSVM*>(classifier))
469     {
470         int dim = canvas->data->GetDimCount();
471         // we want to draw the support vectors
472         svm_model *svm = dynamic_cast<ClassifierSVM*>(classifier)->GetModel();
473         if(svm)
474         {
475             if(svm->param.kernel_type == RBFWEIGH)
476             {
477                 ardKernel.resize(svm->param.kernel_dim,1.f);
478                 FOR(d, svm->param.kernel_dim) ardKernel[d] = svm->param.kernel_weight[d];
479                 ardNames = canvas->dimNames;
480                 if(ardLabel && ardLabel->isVisible()) DisplayARDKernel();
481             }
482             fvec sample(dim,0);
483             FOR(i, svm->l)
484             {
485                 if(canvas->sourceDims.size())
486                 {
487                     FOR(d, canvas->sourceDims.size()) sample[canvas->sourceDims[d]] = svm->SV[i][d].value;
488                 }
489                 else
490                 {
491                     FOR(d, dim) sample[d] = (f32)svm->SV[i][d].value;
492                 }
493                 int radius = 9;
494                 QPointF point = canvas->toCanvasCoords(sample);
495                 if(abs((*svm->sv_coef)[i]) == svm->param.C)
496                 {
497                     painter.setPen(QPen(Qt::black, 6));
498                     painter.drawEllipse(point, radius, radius);
499                     painter.setPen(QPen(Qt::white,4));
500                     painter.drawEllipse(point, radius, radius);
501                 }
502                 else
503                 {
504                     painter.setPen(QPen(Qt::white, 6));
505                     painter.drawEllipse(point, radius, radius);
506                     painter.setPen(QPen(Qt::black,4));
507                     painter.drawEllipse(point, radius, radius);
508                 }
509             }
510         }
511     }
512 }
513
514 void ClassSVM::DrawGL(Canvas *canvas, GLWidget *glw, Classifier *classifier)
515 {
516     int xInd = canvas->xIndex;
517     int yInd = canvas->yIndex;
518     int zInd = canvas->zIndex;
519     GLObject o;
520     o.objectType = "Samples";
521     o.style = "rings,pointsize:24";
522     vector<fvec> svs;
523     int dim = canvas->data->GetDimCount();
524     if(dynamic_cast<ClassifierPegasos*>(classifier))
525     {
526         // we want to draw the support vectors
527         svs = dynamic_cast<ClassifierPegasos*>(classifier)->GetSVs();
528     }
529     else if(dynamic_cast<ClassifierSVM*>(classifier))
530     {
531         // we want to draw the support vectors
532         svm_model *svm = dynamic_cast<ClassifierSVM*>(classifier)->GetModel();
533         if(svm)
534         {
535             fvec sv(dim);
536             FOR(i, svm->l)
537             {
538                 FOR(d, dim) sv[d] = svm->SV[i][d].value;
539                 svs.push_back(sv);
540             }
541         }
542     }
543     FOR(i, svs.size())
544     {
545         o.vertices.append(QVector3D(svs[i][xInd],svs[i][yInd], zInd>=0 && zInd<dim ? svs[i][zInd] : 0));
546         o.colors.append(QVector4D(0,0,0,1));
547     }
548     glw->mutex->lock();
549     glw->AddObject(o);
550     glw->mutex->unlock();
551 }
552
553 void ClassSVM::SaveOptions(QSettings &settings)
554 {
555     settings.setValue("kernelDeg", params->kernelDegSpin->value());
556     settings.setValue("kernelType", params->kernelTypeCombo->currentIndex());
557     settings.setValue("kernelWidth", params->kernelWidthSpin->value());
558     settings.setValue("svmC", params->svmCSpin->value());
559     settings.setValue("svmType", params->svmTypeCombo->currentIndex());
560     settings.setValue("optimizeCheck", params->optimizeCheck->isChecked());
561     settings.setValue("maxSVSpin", params->maxSVSpin->value());
562 }
563
564 bool ClassSVM::LoadOptions(QSettings &settings)
565 {
566     if(settings.contains("kernelDeg")) params->kernelDegSpin->setValue(settings.value("kernelDeg").toFloat());
567     if(settings.contains("kernelType")) params->kernelTypeCombo->setCurrentIndex(settings.value("kernelType").toInt());
568     if(settings.contains("kernelWidth")) params->kernelWidthSpin->setValue(settings.value("kernelWidth").toFloat());
569     if(settings.contains("svmC")) params->svmCSpin->setValue(settings.value("svmC").toFloat());
570     if(settings.contains("svmType")) params->svmTypeCombo->setCurrentIndex(settings.value("svmType").toInt());
571     if(settings.contains("optimizeCheck")) params->optimizeCheck->setChecked(settings.value("optimizeCheck").toInt());
572     if(settings.contains("maxSVSpin")) params->maxSVSpin->setValue(settings.value("maxSVSpin").toInt());
573     ChangeOptions();
574     return true;
575 }
576
577 void ClassSVM::SaveParams(QTextStream &file)
578 {
579     file << "classificationOptions" << ":" << "kernelDeg" << " " << params->kernelDegSpin->value() << "\n";
580     file << "classificationOptions" << ":" << "kernelType" << " " << params->kernelTypeCombo->currentIndex() << "\n";
581     file << "classificationOptions" << ":" << "kernelWidth" << " " << params->kernelWidthSpin->value() << "\n";
582     file << "classificationOptions" << ":" << "svmC" << " " << params->svmCSpin->value() << "\n";
583     file << "classificationOptions" << ":" << "svmType" << " " << params->svmTypeCombo->currentIndex() << "\n";
584     file << "classificationOptions" << ":" << "optimizeCheck" << " " << params->optimizeCheck->isChecked() << "\n";
585     file << "classificationOptions" << ":" << "maxSVSpin" << " " << params->maxSVSpin->value() << "\n";
586 }
587
588 bool ClassSVM::LoadParams(QString name, float value)
589 {
590     if(name.endsWith("kernelDeg")) params->kernelDegSpin->setValue((int)value);
591     if(name.endsWith("kernelType")) params->kernelTypeCombo->setCurrentIndex((int)value);
592     if(name.endsWith("kernelWidth")) params->kernelWidthSpin->setValue(value);
593     if(name.endsWith("svmC")) params->svmCSpin->setValue(value);
594     if(name.endsWith("svmType")) params->svmTypeCombo->setCurrentIndex((int)value);
595     if(name.endsWith("optimizeCheck")) params->optimizeCheck->setChecked((int)value);
596     if(name.endsWith("maxSVSpin")) params->maxSVSpin->setValue((int)value);
597     ChangeOptions();
598     return true;
599 }