REMOVED: the DrawModel function from classifierInterface (they all do exactly the...
[mldemos:baraks-mldemos.git] / MLDemos / mlstats.cpp
1 /*********************************************************************
2 MLDemos: A User-Friendly visualization toolkit for machine learning
3 Copyright (C) 2010  Basilio Noris
4 Contact: mldemos@b4silio.com
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public
8 License as published by the Free Software Foundation; either
9 version 2.1 of the License, or (at your option) any later version.
10
11 This library is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 Library General Public License for more details.
15
16 You should have received a copy of the GNU Lesser General Public
17 License along with this library; if not, write to the Free
18 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
19 *********************************************************************/
20 #include "mldemos.h"
21 #include "basicMath.h"
22 #include "classifier.h"
23 #include "regressor.h"
24 #include "clusterer.h"
25 #include <QDebug>
26 #include <fstream>
27 #include <QPixmap>
28 #include <QBitmap>
29 #include <QBoxLayout>
30 #include <QSettings>
31 #include <QFileDialog>
32 #include <vector>
33 #include <algorithm>
34
35 using namespace std;
36
37 void MLDemos::MouseOnRoc(QMouseEvent *event)
38 {
39     int e;
40     switch( event->button())
41     {
42     case Qt::LeftButton:
43         e = EVENT_LBUTTONUP;
44         break;
45     case Qt::RightButton:
46         e = EVENT_RBUTTONUP;
47         break;
48     }
49     //roc_on_mouse(e, event->x(), event->y(), 0, 0);
50     //rocWidget->ShowImage(GetRocImage());
51     //statsDialog->repaint();
52 }
53
54 void MLDemos::ShowRoc()
55 {
56     if(!classifier) return;
57     SetROCInfo();
58     actionShowStats->setChecked(true);
59     showStats->tabWidget->setCurrentWidget(showStats->rocTab);
60     ShowStatsDialog();
61 }
62
63 void MLDemos::StatsChanged()
64 {
65     int tab = showStats->tabWidget->currentIndex();
66     switch(tab)
67     {
68     case 0:
69         UpdateInfo();
70         break;
71     case 1:
72         SetROCInfo();
73         break;
74     case 2:
75         SetCrossValidationInfo();
76         break;
77     }
78 }
79
80 void PaintData(std::vector<float> data, QPixmap &pm)
81 {
82     QPainter painter(&pm);
83     painter.fillRect(pm.rect(), Qt::white);
84
85     int w = pm.width();
86     int h = pm.height();
87     int cnt = data.size();
88     int pad = 10;
89     QPointF oldPoint;
90     double minVal = FLT_MAX;
91     double maxVal = -FLT_MAX;
92     for(int i=0; i< data.size(); i++)
93     {
94         if(minVal > data[i]) minVal = data[i];
95         if(maxVal < data[i]) maxVal = data[i];
96     }
97     if (minVal == maxVal)
98     {
99         minVal = 0;
100     }
101
102     painter.setBrush(Qt::NoBrush);
103     painter.setPen(QPen(QColor(200,200,200), 0.5));
104     int steps = 10;
105     for(int i=0; i<=steps; i++)
106     {
107         painter.drawLine(QPoint(0, i/(float)steps*(h-2*pad) + pad), QPoint(w, i/(float)steps*(h-2*pad) + pad));
108         painter.drawLine(QPoint(i/(float)steps*w, 0), QPoint(i/(float)steps*w, h));
109     }
110     painter.setRenderHint(QPainter::Antialiasing);
111
112     painter.setPen(QPen(Qt::black, 1.5));
113     for(int i=0; i< data.size(); i++)
114     {
115         float value = data[i];
116         if (value != value) continue;
117         float x = i/(float)cnt*w;
118         float y = (1 - (value-minVal)/(maxVal - minVal)) * (float)(h-2*pad) + pad;
119         QPointF point(x, y);
120         if(i) painter.drawLine(oldPoint, point);
121         //painter.drawEllipse(point, 3, 3);
122         oldPoint = point;
123     }
124     painter.setPen(QPen(Qt::black, 0.5));
125     painter.setBrush(QColor(255,255,255,200));
126     painter.drawRect(QRect(w - 100 - 15,h - 55,110,45));
127     painter.setPen(QPen(Qt::black, 1));
128     painter.drawText(QPointF(w - 107, h-57+20), QString("start: %1").arg(data[0], 3));
129     painter.drawText(QPointF(w - 107, h-57+40), QString("end: %1").arg(data[data.size()-1], 3));
130 }
131
132 void MLDemos::SetROCInfo()
133 {
134     QSize size(showStats->rocWidget->width(),showStats->rocWidget->height());
135     if(classifier && bIsRocNew)
136     {
137         QPixmap rocImage = RocImage(classifier->rocdata, classifier->roclabels, size);
138         bIsRocNew = false;
139         //      rocImage.save("roc.png");
140         rocWidget->ShowImage(rocImage);
141     }
142     if(maximizer)
143     {
144         vector<double> history = maximizer->HistoryValue();
145         vector<float> data;data.resize(history.size());
146         FOR(i, data.size()) data[i] = history[i];
147         if(!data.size()) return;
148         QPixmap pixmap(size);
149         PaintData(data, pixmap);
150         rocWidget->ShowImage(pixmap);
151     }
152 }
153
154 void MLDemos::SetCrossValidationInfo()
155 {
156     if(!bIsCrossNew) return;
157     std::vector<fvec> fmeasures;
158     if(classifier) fmeasures = classifier->crossval;
159     else if(regressor) fmeasures = regressor->crossval;
160
161     if(!fmeasures.size()) return;
162     char txt[255];
163     QString text;
164     text += "Cross-Validation\n";
165     float ratios [] = {.1f,.25f,1.f/3.f,.5f,2.f/3.f,.75f,.9f,1.f};
166     int ratioIndex = classifier ? optionsClassify->traintestRatioCombo->currentIndex() : optionsRegress->traintestRatioCombo->currentIndex();
167     float trainRatio = ratios[ratioIndex];
168     //  if(classifier) sprintf(txt, "%d folds\n", optionsClassify->foldCountSpin->value());
169     //  else sprintf(txt, "%d folds\n", optionsRegress->foldCountSpin->value());
170     text += txt;
171     sprintf(txt,"%d train, %d test samples", (int)(canvas->data->GetCount()*trainRatio), canvas->data->GetCount() - (int)(canvas->data->GetCount()*trainRatio));
172     text += txt + QString("\n\n");
173     text += classifier ? QString("Classification Performance:\n\n") : QString("Regression Error:\n\n");
174     FOR(i, fmeasures.size())
175     {
176         fvec meanStd = MeanStd(fmeasures[i]);
177         fvec quartiles = Quartiles(fmeasures[i]);
178         text += !i ? "Training\n" : "Testing\n";
179         sprintf(txt,"%.3f  %.3f", meanStd[0], meanStd[1]);
180         text += txt + QString(" (meanstd)\n");
181         sprintf(txt,"%.3f %.3f %.3f %.3f %.3f", quartiles[0], quartiles[1], quartiles[2], quartiles[3], quartiles[4]);
182         text += txt + QString(" (quartiles)\n");
183         text += "\n\n";
184     }
185     //  showStats->crossvalidText->setText(text);
186     //    QSize boxSize(showStats->crossvalidWidget->width(),showStats->crossvalidWidget->height());
187     //    QPixmap boxplot = BoxPlot(fmeasures, boxSize);
188     //  boxplot.save("boxplot.png");
189     //  bIsCrossNew = false;
190     //    showStats->crossvalidImage->setPixmap(boxplot);
191 }
192
193
194 void MLDemos::UpdateInfo()
195 {
196     // dataset information
197     int count = canvas->data->GetCount();
198     int pcount = 0, ncount = 0;
199     ivec labels = canvas->data->GetLabels();
200     int posClass = optionsClassify->positiveSpin->value();
201     FOR(i, labels.size())
202     {
203         if(labels[i] == posClass) ++pcount;
204         else ++ncount;
205     }
206
207     // min/max, mean/variance
208     vector<fvec> samples = canvas->data->GetSamples();
209     fvec sMin,sMax,sMean,sSigma;
210     sMin.resize(2,FLT_MAX);
211     sMax.resize(2,-FLT_MAX);
212     sMean.resize(2,0);
213     sSigma.resize(4,0);
214     if(samples.size())
215     {
216         FOR(i,samples.size())
217         {
218             sMin[0] = min(sMin[0],samples[i][0]);
219             sMin[1] = min(sMin[1],samples[i][1]);
220             sMax[0] = max(sMax[0],samples[i][0]);
221             sMax[1] = max(sMax[1],samples[i][1]);
222             sMean += samples[i];
223         }
224         sMean /= samples.size();
225         FOR(i, samples.size())
226         {
227             sSigma[0] += (samples[i][0]-sMean[0])*(samples[i][0]-sMean[0]);
228             sSigma[1] += (samples[i][0]-sMean[0])*(samples[i][1]-sMean[1]);
229             sSigma[3] += (samples[i][1]-sMean[1])*(samples[i][1]-sMean[1]);
230         }
231         sSigma[0] = sqrtf(sSigma[0]/samples.size());
232         sSigma[1] = sqrtf(sSigma[1]/samples.size());
233         if(sSigma[1] != sSigma[1]) sSigma[1] = 0;
234         sSigma[2] = sSigma[1];
235         sSigma[3] = sqrtf(sSigma[3]/samples.size());
236     }
237     else
238     {
239         sMin.clear();sMax.clear();
240         sMin.resize(2,0);
241         sMax.resize(2,0);
242     }
243
244     QString information;
245
246     if(classifier)
247     {
248         information += "Classification Performance:\n" + lastTrainingInfo;
249         information += "\nClassifier: " + QString(classifier->GetInfoString());
250         // we also want to generate the confusion matrix
251         if(classifier->IsMultiClass())
252         {
253             QObjectList children = showStats->informationWidget->children();
254             FOR(i, children.size()) delete children[i];
255             if(!showStats->informationWidget->layout())
256             {
257                 QBoxLayout *layout = new QBoxLayout(QBoxLayout::TopToBottom, showStats->informationWidget);
258                 layout->setContentsMargins(0,0,0,0);
259             }
260             QPixmap confusionPixmap(150,150);
261             QPainter painter(&confusionPixmap);
262
263             QLabel *labelTrain = 0;
264             QLabel *labelTest = 0;
265             if(classifier->confusionMatrix[0].size())
266             {
267                 labelTrain = new QLabel();
268                 confusionPixmap.fill(Qt::white);
269                 map< int,map<int,int> > confusion = classifier->confusionMatrix[0];
270                 int classCount = 0;
271                 map<int,int> maxCount;
272                 for(map<int,map<int,int> >::iterator it = confusion.begin();it != confusion.end();it++)
273                 {
274                     classCount = max(classCount, it->first);
275                     for(map<int,int>::iterator it2=it->second.begin(); it2 != it->second.end(); it2++)
276                     {
277                         maxCount[it->first] = max(maxCount[it->first], it2->second);
278                     }
279                 }
280                 classCount++;
281                 int w = max(1,confusionPixmap.width()/classCount);
282                 int h = max(1,confusionPixmap.height()/classCount);
283                 FOR(c, classCount)
284                 {
285                     int maxCnt = maxCount[c];
286                     int y = c *confusionPixmap.height() / classCount;
287                     FOR(c2, classCount)
288                     {
289                         int x = c2 * confusionPixmap.width() / classCount;
290                         float value = confusion[c][c2] / (float)maxCnt;
291                         painter.fillRect(x, y, w, h, QColor(255, (1.f-value)*255, (1.f-value)*255));
292                     }
293                 }
294                 //showStats->informationWidget->layout()->addWidget(new QLabel("Confusion Matrix"));
295                 showStats->informationWidget->layout()->addWidget(labelTrain);
296                 labelTrain->setPixmap(confusionPixmap);
297             }
298             if(classifier->confusionMatrix[1].size())
299             {
300                 labelTest = new QLabel();
301                 confusionPixmap.fill(Qt::white);
302                 map< int,map<int,int> > confusion = classifier->confusionMatrix[1];
303                 int classCount = 0;
304                 map<int,int> maxCount;
305                 for(map<int,map<int,int> >::iterator it = confusion.begin();it != confusion.end();it++)
306                 {
307                     classCount = max(classCount, it->first);
308                     for(map<int,int>::iterator it2=it->second.begin(); it2 != it->second.end(); it2++)
309                     {
310                         maxCount[it->first] = max(maxCount[it->first], it2->second);
311                     }
312                 }
313                 classCount++;
314                 int w = max(1,confusionPixmap.width()/classCount);
315                 int h = max(1,confusionPixmap.height()/classCount);
316                 FOR(c, classCount)
317                 {
318                     int maxCnt = maxCount[c];
319                     int y = c *confusionPixmap.height() / classCount;
320                     FOR(c2, classCount)
321                     {
322                         int x = c2 * confusionPixmap.width() / classCount;
323                         float value = confusion[c][c2] / (float)maxCnt;
324                         painter.fillRect(x, y, w, h, QColor(255, (1.f-value)*255, (1.f-value)*255));
325                     }
326                 }
327                 //showStats->informationWidget->layout()->addWidget(new QLabel("Confusion (Test)"));
328                 showStats->informationWidget->layout()->addWidget(labelTest);
329                 labelTest->setPixmap(confusionPixmap);
330             }
331             if(labelTrain) labelTrain->show();
332             if(labelTest) labelTest->show();
333             showStats->informationWidget->repaint();
334         }
335     }
336     if(regressor)  information += "\nRegressor: "  + QString(regressor->GetInfoString());
337     if(clusterer)  information += "\nClusterer: "  + QString(clusterer->GetInfoString());
338     if(dynamical)  information += "\nDynamical: "  + QString(dynamical->GetInfoString());
339     if(maximizer)  information += "\nMaximizer: "  + QString(maximizer->GetInfoString());
340
341     information += "\nCurrent Dataset:\n";
342     char string[255];
343     sprintf(string, "    %d Samples\n    %d Positives\n    %d Negatives\n\n", count, pcount, ncount);
344     information += QString(string);
345     information +=    "       Min - Max          Mean  ,    Var\n";
346     sprintf(string,   "    %.3f    %.3f      %.3f   ,   %.3f  %.3f\n", sMin[0], sMax[0], sMean[0], sSigma[0], sSigma[1]);
347     sprintf(string, "%s    %.3f    %.3f      %.3f   ,   %.3f  %.3f\n", string, sMin[1], sMax[1], sMean[1], sSigma[2], sSigma[3]);
348     information += string;
349
350     showStats->infoText->setText(information);
351 }