CHANGED: the base zoom factors to match 2D and 3D zooms
[mldemos:allopens-mldemos.git] / _AlgorithmsPlugins / KernelMethods / interfaceRVMClassifier.cpp
1 /*********************************************************************
2 MLDemos: A User-Friendly visualization toolkit for machine learning
3 Copyright (C) 2010  Basilio Noris
4 Contact: mldemos@b4silio.com
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public License,
8 version 3 as published by the Free Software Foundation.
9
10 This library is distributed in the hope that it will be useful, but
11 WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 Lesser General Public License for more details.
14
15 You should have received a copy of the GNU Lesser General Public
16 License along with this library; if not, write to the Free
17 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18 *********************************************************************/
19 #include "interfaceRVMClassifier.h"
20 #include <QPixmap>
21 #include <QBitmap>
22 #include <QPainter>
23 #include <QDebug>
24
25 using namespace std;
26
27 ClassRVM::ClassRVM()
28 {
29     params = new Ui::ParametersRVM();
30     params->setupUi(widget = new QWidget());
31     connect(params->kernelTypeCombo, SIGNAL(currentIndexChanged(int)), this, SLOT(ChangeOptions()));
32     ChangeOptions();
33 }
34
35 void ClassRVM::ChangeOptions()
36 {
37     int C = params->svmCSpin->value();
38     if(C > 1) params->svmCSpin->setValue(0.001);
39     switch(params->kernelTypeCombo->currentIndex())
40     {
41     case 0: // linear
42         params->kernelDegSpin->setVisible(false);
43         params->labelDegree->setVisible(false);
44         params->kernelWidthSpin->setVisible(false);
45         params->labelWidth->setVisible(false);
46         break;
47     case 1: // poly
48         params->kernelDegSpin->setVisible(true);
49         params->labelDegree->setVisible(true);
50         params->kernelWidthSpin->setVisible(false);
51         params->labelWidth->setVisible(false);
52         break;
53     case 2: // RBF
54         params->kernelDegSpin->setVisible(false);
55         params->labelDegree->setVisible(false);
56         params->kernelWidthSpin->setVisible(true);
57         params->labelWidth->setVisible(true);
58         break;
59     case 3: // SIGMOID
60         params->kernelDegSpin->setEnabled(false);
61         params->labelDegree->setVisible(false);
62         params->kernelWidthSpin->setEnabled(true);
63         params->labelWidth->setVisible(true);
64         break;
65     }
66 }
67
68 QString ClassRVM::GetAlgoString()
69 {
70     double C = params->svmCSpin->value();
71     int kernelType = params->kernelTypeCombo->currentIndex();
72     float kernelGamma = params->kernelWidthSpin->value();
73     float kernelDegree = params->kernelDegSpin->value();
74
75     QString algo = QString("RVM %1").arg(C);
76     switch(kernelType)
77     {
78     case 0:
79         algo += " Lin";
80         break;
81     case 1:
82         algo += QString(" Pol %1").arg(kernelDegree);
83         break;
84     case 2:
85         algo += QString(" RBF %1").arg(kernelGamma);
86         break;
87     case 3:
88         algo += QString(" Sig %1").arg(kernelGamma);
89         break;
90     }
91     return algo;
92 }
93
94 void ClassRVM::SetParams(Classifier *classifier)
95 {
96     if(!classifier) return;
97     SetParams(classifier, GetParams());
98 }
99
100 fvec ClassRVM::GetParams()
101 {
102     float svmC = params->svmCSpin->value();
103     int kernelType = params->kernelTypeCombo->currentIndex();
104     float kernelGamma = params->kernelWidthSpin->value();
105     float kernelDegree = params->kernelDegSpin->value();
106
107     fvec par(4);
108     par[0] = svmC;
109     par[1] = kernelType;
110     par[2] = kernelGamma;
111     par[3] = kernelDegree;
112     return par;
113 }
114
115 void ClassRVM::SetParams(Classifier *classifier, fvec parameters)
116 {
117     if(!classifier) return;
118     float svmC = parameters.size() > 0 ? parameters[0] : 1;
119     int kernelType = parameters.size() > 1 ? parameters[1] : 0;
120     float kernelGamma = parameters.size() > 2 ? parameters[2] : 0;
121     int kernelDegree = parameters.size() > 3 ? parameters[3] : 0;
122
123     ClassifierRVM *rvm = dynamic_cast<ClassifierRVM *>(classifier);
124     if(rvm) rvm->SetParams(svmC, kernelType, kernelGamma, kernelDegree);
125 }
126
127 void ClassRVM::GetParameterList(std::vector<QString> &parameterNames,
128                                 std::vector<QString> &parameterTypes,
129                                 std::vector< std::vector<QString> > &parameterValues)
130 {
131     parameterNames.push_back("Penalty (C)");
132     parameterNames.push_back("Kernel Type");
133     parameterNames.push_back("Kernel Width");
134     parameterNames.push_back("Kernel Degree");
135     parameterTypes.push_back("Real");
136     parameterTypes.push_back("List");
137     parameterTypes.push_back("Real");
138     parameterTypes.push_back("Integer");
139     parameterValues.push_back(vector<QString>());
140     parameterValues.back().push_back("0.00000001f");
141     parameterValues.back().push_back("99999999999999");
142     parameterValues.push_back(vector<QString>());
143     parameterValues.back().push_back("Linear");
144     parameterValues.back().push_back("Poly");
145     parameterValues.back().push_back("RBF");
146     parameterValues.push_back(vector<QString>());
147     parameterValues.back().push_back("0.00000001f");
148     parameterValues.back().push_back("9999999");
149     parameterValues.push_back(vector<QString>());
150     parameterValues.back().push_back("1");
151     parameterValues.back().push_back("150");
152 }
153
154 Classifier *ClassRVM::GetClassifier()
155 {
156     Classifier *classifier = 0;
157     classifier = new ClassifierRVM();
158     SetParams(classifier);
159     return classifier;
160 }
161
162 void ClassRVM::DrawInfo(Canvas *canvas, QPainter &painter, Classifier *classifier)
163 {
164     painter.setRenderHint(QPainter::Antialiasing);
165
166     if(!dynamic_cast<ClassifierRVM*>(classifier)) return;
167     // we want to draw the support vectors
168     vector<fvec> sv = dynamic_cast<ClassifierRVM*>(classifier)->GetSVs();
169     int radius = 9;
170     FOR(i, sv.size())
171     {
172         QPointF point = canvas->toCanvasCoords(sv[i]);
173         painter.setPen(QPen(Qt::black,6));
174         painter.drawEllipse(point, radius, radius);
175         painter.setPen(QPen(Qt::white,4));
176         painter.drawEllipse(point, radius, radius);
177     }
178 }
179
180 void ClassRVM::DrawGL(Canvas *canvas, GLWidget *glw, Classifier *classifier)
181 {
182     int xInd = canvas->xIndex;
183     int yInd = canvas->yIndex;
184     int zInd = canvas->zIndex;
185     if(!dynamic_cast<ClassifierRVM*>(classifier)) return;
186     // we want to draw the support vectors
187     vector<fvec> svs = dynamic_cast<ClassifierRVM*>(classifier)->GetSVs();
188     GLObject o;
189     o.objectType = "Samples";
190     o.style = "rings,pointsize:24";
191     FOR(i, svs.size())
192     {
193         o.vertices.append(QVector3D(svs[i][xInd],svs[i][yInd],svs[i][zInd]));
194         o.colors.append(QVector4D(0,0,0,1));
195     }
196     glw->mutex->lock();
197     glw->AddObject(o);
198     glw->mutex->unlock();
199 }
200
201 void ClassRVM::SaveOptions(QSettings &settings)
202 {
203     settings.setValue("kernelDeg", params->kernelDegSpin->value());
204     settings.setValue("kernelType", params->kernelTypeCombo->currentIndex());
205     settings.setValue("kernelWidth", params->kernelWidthSpin->value());
206     settings.setValue("svmC", params->svmCSpin->value());
207 }
208
209 bool ClassRVM::LoadOptions(QSettings &settings)
210 {
211     if(settings.contains("kernelDeg")) params->kernelDegSpin->setValue(settings.value("kernelDeg").toFloat());
212     if(settings.contains("kernelType")) params->kernelTypeCombo->setCurrentIndex(settings.value("kernelType").toInt());
213     if(settings.contains("kernelWidth")) params->kernelWidthSpin->setValue(settings.value("kernelWidth").toFloat());
214     if(settings.contains("svmC")) params->svmCSpin->setValue(settings.value("svmC").toFloat());
215     ChangeOptions();
216     return true;
217 }
218
219 void ClassRVM::SaveParams(QTextStream &file)
220 {
221     file << "classificationOptions" << ":" << "kernelDeg" << " " << params->kernelDegSpin->value() << "\n";
222     file << "classificationOptions" << ":" << "kernelType" << " " << params->kernelTypeCombo->currentIndex() << "\n";
223     file << "classificationOptions" << ":" << "kernelWidth" << " " << params->kernelWidthSpin->value() << "\n";
224     file << "classificationOptions" << ":" << "svmC" << " " << params->svmCSpin->value() << "\n";
225 }
226
227 bool ClassRVM::LoadParams(QString name, float value)
228 {
229     if(name.endsWith("kernelDeg")) params->kernelDegSpin->setValue((int)value);
230     if(name.endsWith("kernelType")) params->kernelTypeCombo->setCurrentIndex((int)value);
231     if(name.endsWith("kernelWidth")) params->kernelWidthSpin->setValue(value);
232     if(name.endsWith("svmC")) params->svmCSpin->setValue(value);
233     ChangeOptions();
234     return true;
235 }