ADDED: display of ASVM support vector (alpha + gamma)
[mldemos:baraks-mldemos.git] / _AlgorithmsPlugins / ASVM / interfaceASVMDynamic.cpp
1 /*********************************************************************
2 MLDemos: A User-Friendly visualization toolkit for machine learning
3 Copyright (C) 2010  Basilio Noris
4 Contact: mldemos@b4silio.com
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public License,
8 version 3 as published by the Free Software Foundation.
9
10 This library is distributed in the hope that it will be useful, but
11 WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 Lesser General Public License for more details.
14
15 You should have received a copy of the GNU Lesser General Public
16 License along with this library; if not, write to the Free
17 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18 *********************************************************************/
19 #include "interfaceASVMDynamic.h"
20 #include "drawUtils.h"
21 #include <QPixmap>
22 #include <QBitmap>
23 #include <QPainter>
24 #include <QDebug>
25 #include <qcontour.h>
26
27 using namespace std;
28
29 DynamicASVM::DynamicASVM()
30 {
31     params = new Ui::ParametersASVM();
32     params->setupUi(widget = new QWidget());
33 }
34
35 void DynamicASVM::SetParams(Dynamical *dynamical)
36 {
37     if(!dynamical) return;
38
39     int clusters = params->gmmCount->value();
40     double alphaTol = params->alphaTolSpin->value();
41     double betaTol = params->betaTolSpin->value();
42     double betaRelax = params->betaRelaxSpin->value();
43     double Cparam = params->CSpin->value();
44     double kernelWidth = params->kernelWidthSpin->value();
45     double epsilon = params->epsilonSpin->value();
46     int maxIteration = params->iterationCount->value();
47
48     DynamicalASVM *asvm = dynamic_cast<DynamicalASVM *>(dynamical);
49     if(!asvm) return;
50
51     asvm->SetParams(clusters, kernelWidth, Cparam, alphaTol, betaTol, betaRelax, epsilon, maxIteration);
52 }
53
54 Dynamical *DynamicASVM::GetDynamical()
55 {
56     DynamicalASVM *dynamical = new DynamicalASVM();
57     SetParams(dynamical);
58     return dynamical;
59 }
60
61 void DynamicASVM::DrawInfo(Canvas *canvas, QPainter &painter, Dynamical *dynamical)
62 {
63     if(!canvas || !dynamical) return;
64     painter.setRenderHint(QPainter::Antialiasing);
65
66     DynamicalASVM *asvm = dynamic_cast<DynamicalASVM*>(dynamical);
67     if(!asvm) return;
68
69     /*
70     // we display the gmms
71     vector<Gmm*> gmms = asvm->gmms;
72     FOR(c, gmms.size())
73     {
74         Gmm gmm = *(gmms[c]);
75
76         int xIndex = canvas->xIndex;
77         int yIndex = canvas->yIndex;
78         int dim = gmm.dim;
79         float mean[2];
80         float sigma[3];
81         painter.setBrush(Qt::NoBrush);
82         FOR(i, gmm.nstates)
83         {
84             float* bigSigma = new float[dim*dim];
85             float* bigMean = new float[dim];
86             gmm.getCovariance(i, bigSigma, false);
87             sigma[0] = bigSigma[xIndex*dim + xIndex];
88             sigma[1] = bigSigma[yIndex*dim + xIndex];
89             sigma[2] = bigSigma[yIndex*dim + yIndex];
90             gmm.getMean(i, bigMean);
91             mean[0] = bigMean[xIndex];
92             mean[1] = bigMean[yIndex];
93             delete [] bigSigma;
94             delete [] bigMean;
95
96             painter.setPen(QPen(Qt::black, 1));
97             DrawEllipse(mean, sigma, 1, &painter, canvas);
98             painter.setPen(QPen(Qt::black, 0.5));
99             DrawEllipse(mean, sigma, 2, &painter, canvas);
100             QPointF point = canvas->toCanvasCoords(mean[0],mean[1]);
101             painter.setPen(QPen(Qt::black, 4));
102             painter.drawEllipse(point, 2, 2);
103             painter.setPen(QPen(Qt::white, 2));
104             painter.drawEllipse(point, 2, 2);
105         }
106     }
107     */
108
109     // we display the support vectors
110     painter.setPen(QPen(Qt::black, 1.5));
111     FOR(i, asvm->asvms.size())
112     {
113         FOR(j, asvm->asvms[i].numAlpha)
114         {
115             double *sv = asvm->asvms[i].svalpha[j];
116             fvec sample(asvm->asvms[i].dim);
117             FOR(d,asvm->asvms[i].dim) sample[d] = sv[d];
118             QPointF point = canvas->toCanvasCoords(sample);
119             painter.drawLine(point-QPointF(-5,-5),point-QPointF(5,-5));
120             painter.drawLine(point-QPointF(5,-5),point-QPointF(5,5));
121             painter.drawLine(point-QPointF(5,5),point-QPointF(-5,5));
122             painter.drawLine(point-QPointF(-5,5),point-QPointF(-5,-5));
123         }
124         FOR(j, asvm->asvms[i].numBeta)
125         {
126             double *sv = asvm->asvms[i].svbeta[j];
127             fvec sample(asvm->asvms[i].dim);
128             FOR(d,asvm->asvms[i].dim) sample[d] = sv[d];
129             QPointF point = canvas->toCanvasCoords(sample);
130             painter.drawLine(point-QPointF(-5,-5),point-QPointF(5,-5));
131             painter.drawLine(point-QPointF(5,-5),point-QPointF(0,5));
132             painter.drawLine(point-QPointF(-5,-5),point-QPointF(0,5));
133         }
134     }
135
136     // we display the contour lines of the svm classifier
137     int W = painter.viewport().width();
138     int H = painter.viewport().height();
139     int w = 129;
140     int h = 129;
141     int classCount = asvm->classCount;
142
143     // we draw the contours of the classification function
144     double **valueList = new double*[classCount];
145     FOR(c, classCount)
146     {
147         valueList[c] = new double[w*h];
148         FOR(i, w)
149         {
150             FOR(j, h)
151             {
152                 valueList[c][j*w+i] = 0.;
153             }
154         }
155     }
156
157     FOR(i, w)
158     {
159         FOR(j, h)
160         {
161             int x = i*W/w;
162             int y = j*H/h;
163             fvec sample = canvas->fromCanvas(x,y);
164             fvec res = asvm->Classify(sample);
165             int c = res[0];
166             double value = res[1];
167             // to avoid some numerical weird stuff
168             value = value*1000.;
169             if(c < classCount && c >= 0)
170             {
171                 valueList[c][j*w + i] = value;
172             }
173         }
174     }
175
176     FOR(k, classCount)
177     {
178         QContour contour(valueList[k], w, h);
179         contour.bDrawColorbar = false;
180         int classColor = asvm->inverseMap.count(k) ? asvm->inverseMap[k] : k;
181         contour.plotColor = classColor ? SampleColor[(classColor)%SampleColorCnt] : Qt::black;
182         contour.plotThickness = 4;
183         contour.style = Qt::DotLine;
184
185         double vmin, vmax;
186         contour.GetLimits(vmin, vmax);
187         vmin += (vmax - vmin)/30; // we take out the smallest levels to avoid numerical issues
188         contour.SetLimits(vmin, vmax);
189         contour.Paint(painter, 10);
190         delete [] valueList[k];
191         valueList[k] = 0;
192     }
193     delete [] valueList;
194 }
195
196 void DynamicASVM::SaveModel(QString filename, Dynamical *dynamical)
197 {
198     DynamicalASVM *asvm = dynamic_cast<DynamicalASVM*>(dynamical);
199     if(!asvm) return;
200     asvm->SaveModel(filename.toStdString());
201 }
202
203 bool DynamicASVM::LoadModel(QString filename, Dynamical *dynamical)
204 {
205     DynamicalASVM *asvm = dynamic_cast<DynamicalASVM*>(dynamical);
206     if(!asvm) return false;
207     return asvm->LoadModel(filename.toStdString());
208 }
209
210 void DynamicASVM::SaveOptions(QSettings &settings)
211 {
212     settings.setValue("gmmCount", params->gmmCount->value());
213     settings.setValue("alphaTol", params->alphaTolSpin->value());
214     settings.setValue("betaTol", params->betaTolSpin->value());
215     settings.setValue("betaRelax", params->betaRelaxSpin->value());
216     settings.setValue("Cparam", params->CSpin->value());
217     settings.setValue("kernelWidth", params->kernelWidthSpin->value());
218     settings.setValue("epsilon", params->epsilonSpin->value());
219     settings.setValue("iterationCount", params->iterationCount->value());
220 }
221
222 bool DynamicASVM::LoadOptions(QSettings &settings)
223 {
224     if(settings.contains("gmmCount")) params->gmmCount->setValue(settings.value("gmmCount").toInt());
225     if(settings.contains("alphaTol")) params->alphaTolSpin->setValue(settings.value("alphaTol").toDouble());
226     if(settings.contains("betaTol")) params->betaTolSpin->setValue(settings.value("betaTol").toDouble());
227     if(settings.contains("betaRelax")) params->betaRelaxSpin->setValue(settings.value("betaRelax").toDouble());
228     if(settings.contains("Cparam")) params->CSpin->setValue(settings.value("Cparam").toDouble());
229     if(settings.contains("kernelWidth")) params->kernelWidthSpin->setValue(settings.value("kernelWidth").toDouble());
230     if(settings.contains("epsilon")) params->epsilonSpin->setValue(settings.value("epsilon").toDouble());
231     if(settings.contains("iterationCount")) params->iterationCount->setValue(settings.value("iterationCount").toInt());
232     return true;
233 }
234
235 void DynamicASVM::SaveParams(QTextStream &file)
236 {
237     file << "dynamicalOptions" << ":" << "gmmCount" << " " << params->gmmCount->value() << "\n";
238     file << "dynamicalOptions" << ":" << "alphaTol" << " " << params->alphaTolSpin->value() << "\n";
239     file << "dynamicalOptions" << ":" << "betaTol" << " " << params->betaTolSpin->value() << "\n";
240     file << "dynamicalOptions" << ":" << "betaRelax" << " " << params->betaRelaxSpin->value() << "\n";
241     file << "dynamicalOptions" << ":" << "Cparam" << " " << params->CSpin->value() << "\n";
242     file << "dynamicalOptions" << ":" << "kernelWidth" << " " << params->kernelWidthSpin->value() << "\n";
243     file << "dynamicalOptions" << ":" << "epsilon" << " " << params->epsilonSpin->value() << "\n";
244     file << "dynamicalOptions" << ":" << "iterationCount" << " " << params->iterationCount->value() << "\n";
245 }
246
247 bool DynamicASVM::LoadParams(QString name, float value)
248 {
249     if(name.endsWith("gmmCount")) params->gmmCount->setValue((int)value);
250     if(name.endsWith("alphaTol")) params->alphaTolSpin->setValue((double)value);
251     if(name.endsWith("betaTol")) params->betaTolSpin->setValue((double)value);
252     if(name.endsWith("betaRelax")) params->betaRelaxSpin->setValue((double)value);
253     if(name.endsWith("Cparam")) params->CSpin->setValue((double)value);
254     if(name.endsWith("kernelWidth")) params->kernelWidthSpin->setValue((double)value);
255     if(name.endsWith("epsilon")) params->epsilonSpin->setValue((double)value);
256     if(name.endsWith("iterationCount")) params->iterationCount->setValue((int)value);
257     return true;
258 }
259
260 Q_EXPORT_PLUGIN2(mld_ASVM, DynamicASVM)