REMOVED: the DrawModel function from classifierInterface (they all do exactly the...
[mldemos:baraks-mldemos.git] / Core / drawSVG.cpp
1 /*********************************************************************\r
2 MLDemos: A User-Friendly visualization toolkit for machine learning\r
3 Copyright (C) 2010  Basilio Noris\r
4 Contact: mldemos@b4silio.com\r
5 \r
6 This library is free software; you can redistribute it and/or\r
7 modify it under the terms of the GNU Lesser General Public\r
8 License as published by the Free Software Foundation; either\r
9 version 2.1 of the License, or (at your option) any later version.\r
10 \r
11 This library is distributed in the hope that it will be useful,\r
12 but WITHOUT ANY WARRANTY; without even the implied warranty of\r
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU\r
14 Library General Public License for more details.\r
15 \r
16 You should have received a copy of the GNU Lesser General Public\r
17 License along with this library; if not, write to the Free\r
18 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.\r
19 *********************************************************************/\r
20 #include <QtSvg>\r
21 #include <QtGui>\r
22 #include <QWidget>\r
23 #include <QSize>\r
24 #include <QPixmap>\r
25 #include <QDebug>\r
26 #include <QMutexLocker>\r
27 \r
28 #include "public.h"\r
29 #include "basicMath.h"\r
30 #include "drawSVG.h"\r
31 \r
32 using namespace std;\r
33 \r
34 DrawSVG::DrawSVG(Canvas *canvas, QMutex *mutex)\r
35         : canvas(canvas),\r
36           classifier(0), regressor(0), dynamical(0), clusterer(0),\r
37       drawClass(0), drawRegr(0), drawDyn(0), drawClust(0), drawProj(0),\r
38           mutex(mutex),\r
39           perm(0), w(0), h(0)\r
40 {\r
41 }\r
42 \r
43 DrawSVG::~DrawSVG()\r
44 {\r
45 \r
46 }\r
47 \r
48 void DrawSVG::Write(QString filename)\r
49 {\r
50         if(!canvas) return;\r
51         QSvgGenerator generator;\r
52         generator.setFileName(filename);\r
53         generator.setSize(QSize(canvas->width(), canvas->height()));\r
54         generator.setTitle("MLDemos screenshot");\r
55         generator.setDescription("Generated with MLDemos");\r
56         QPainter painter;\r
57         painter.begin(&generator);\r
58         // we need to paint the different layers:\r
59         // confidence map\r
60         // samples + trajectories + reward\r
61         canvas->PaintStandard(painter, true);\r
62     if(canvas->bDisplayLearned)\r
63         {\r
64                 // learned model\r
65         painter.setBackgroundMode(Qt::TransparentMode);\r
66         if(classifier) DrawClassificationSamples(canvas, painter, classifier, classifierMulti);\r
67                 if(regressor) drawRegr->DrawModel(canvas, painter, regressor);\r
68                 if(dynamical) drawDyn->DrawModel(canvas, painter, dynamical);\r
69         if(clusterer) drawClust->DrawModel(canvas, painter, clusterer);\r
70         if(projector) drawProj->DrawModel(canvas, painter, projector);\r
71         if(dynamical)\r
72                 {\r
73                         int cnt = 10000; // not too many or it will make unreadable files\r
74                         int steps = 8;\r
75                         VectorsFast(cnt, steps, painter);\r
76                 }\r
77                 if(maximizer)\r
78                 {\r
79                         Maximization(painter);\r
80                 }\r
81         }\r
82 \r
83         if(canvas->bDisplayInfo)\r
84         {\r
85                 // model info\r
86         painter.setBackgroundMode(Qt::TransparentMode);\r
87         if(classifier) drawClass->DrawInfo(canvas, painter, classifier);\r
88                 if(regressor) drawRegr->DrawInfo(canvas, painter, regressor);\r
89                 if(dynamical) drawDyn->DrawInfo(canvas, painter, dynamical);\r
90                 if(clusterer) drawClust->DrawInfo(canvas, painter, clusterer);\r
91         if(projector) drawProj->DrawInfo(canvas, painter, projector);\r
92     }\r
93         painter.end();\r
94 }\r
95 \r
96 void DrawSVG::Vectors(int count, int steps, QPainter &painter)\r
97 {\r
98         if(!dynamical) return;\r
99         float dT = dynamical->dT;// * (dynamical->count/100.f);\r
100         //float dT = 0.02f;\r
101         fvec sample;\r
102         sample.resize(2,0);\r
103         int w = canvas->width();\r
104         int h = canvas->height();\r
105 \r
106         painter.setRenderHint(QPainter::Antialiasing, true);\r
107         painter.setRenderHint(QPainter::HighQualityAntialiasing, true);\r
108         vector<Obstacle> obstacles = canvas->data->GetObstacles();\r
109 \r
110         QPointF oldPoint(-FLT_MAX,-FLT_MAX);\r
111         QPointF oldPointUp(-FLT_MAX,-FLT_MAX);\r
112         QPointF oldPointDown(-FLT_MAX,-FLT_MAX);\r
113         FOR(i, count)\r
114         {\r
115                 QPointF samplePre(rand()/(float)RAND_MAX * w, rand()/(float)RAND_MAX * h);\r
116                 sample = canvas->toSampleCoords(samplePre);\r
117                 float color = (rand()/(float)RAND_MAX*0.7f)*255.f;\r
118                 color = 0;\r
119                 QPointF oldPoint = canvas->toCanvasCoords(sample);\r
120                 FOR(j, steps)\r
121                 {\r
122                         fvec res = dynamical->Test(sample);\r
123                         if(dynamical->avoid)\r
124                         {\r
125                                 dynamical->avoid->SetObstacles(obstacles);\r
126                                 fvec newRes = dynamical->avoid->Avoid(sample, res);\r
127                                 res = newRes;\r
128                         }\r
129                         sample += res*dT;\r
130                         float speed = sqrtf(res[0]*res[0] + res[1]*res[1]);\r
131                         QPointF point = canvas->toCanvasCoords(sample);\r
132                         painter.setOpacity(speed);\r
133                         QColor c(color,color,color);\r
134                         painter.setPen(QPen(c, 0.25));\r
135                         painter.drawLine(point, oldPoint);\r
136                         oldPoint = point;\r
137                 }\r
138         }\r
139 }\r
140 \r
141 void DrawSVG::Maximization(QPainter &painter)\r
142 {\r
143         if(!maximizer) return;\r
144         painter.setRenderHint(QPainter::Antialiasing, true);\r
145         painter.setRenderHint(QPainter::HighQualityAntialiasing, true);\r
146     maximizer->Draw(painter);\r
147 }\r
148 \r
149 void DrawSVG::DrawClassificationSamples(Canvas *canvas, QPainter &painter, Classifier *classifier, std::vector<Classifier*> classifierMulti)\r
150 {\r
151     if(!canvas || !classifier) return;\r
152     int w = canvas->width(), h = canvas->height();\r
153 \r
154     // we draw the samples\r
155     painter.setRenderHint(QPainter::Antialiasing, true);\r
156     FOR(i, canvas->data->GetCount())\r
157     {\r
158         fvec sample = canvas->data->GetSample(i);\r
159         int label = canvas->data->GetLabel(i);\r
160         QPointF point = canvas->toCanvasCoords(canvas->data->GetSample(i));\r
161         fvec res;\r
162         if(classifier->IsMultiClass()) res = classifier->TestMulti(sample);\r
163         else if(classifierMulti.size())\r
164         {\r
165             FOR(c, classifierMulti.size())\r
166             {\r
167                 res.push_back(classifierMulti[c]->Test(sample));\r
168             }\r
169         }\r
170         else res.push_back(classifier->Test(sample));\r
171         if(res.size()==1)\r
172         {\r
173             int posClass = 1;\r
174             float response = res[0];\r
175             if(response > 0)\r
176             {\r
177                 if(classifier->classMap[label] == posClass) Canvas::drawSample(painter, point, 9, 1);\r
178                 else Canvas::drawCross(painter, point, 6, 2);\r
179             }\r
180             else\r
181             {\r
182                 if(classifier->classMap[label] != posClass) Canvas::drawSample(painter, point, 9, 0);\r
183                 else Canvas::drawCross(painter, point, 6, 0);\r
184             }\r
185         }\r
186         else\r
187         {\r
188             int max = 0;\r
189             for(int i=1; i<res.size(); i++) if(res[max] < res[i]) max = i;\r
190             int resp = classifier->inverseMap[max];\r
191             if(label == resp) Canvas::drawSample(painter, point, 9, label);\r
192             else Canvas::drawCross(painter, point, 6, label);\r
193         }\r
194     }\r
195 }\r
196 \r
197 void DrawSVG::VectorsFast(int count, int steps, QPainter &painter)\r
198 {\r
199         if(!dynamical) return;\r
200         QPointF oldPoint(-FLT_MAX,-FLT_MAX);\r
201         QPointF oldPointUp(-FLT_MAX,-FLT_MAX);\r
202         QPointF oldPointDown(-FLT_MAX,-FLT_MAX);\r
203         float dT = dynamical->dT;// * (dynamical->count/100.f);\r
204         //float dT = 0.02f;\r
205         fvec sample;\r
206         sample.resize(2,0);\r
207         int w = canvas->width();\r
208         int h = canvas->height();\r
209 \r
210         painter.setRenderHint(QPainter::Antialiasing, true);\r
211         painter.setRenderHint(QPainter::HighQualityAntialiasing, true);\r
212         vector<Obstacle> obstacles = canvas->data->GetObstacles();\r
213         FOR(i, count)\r
214         {\r
215                 QPointF samplePre(rand()/(float)RAND_MAX * w, rand()/(float)RAND_MAX * h);\r
216                 sample = canvas->toSampleCoords(samplePre);\r
217                 float color = (rand()/(float)RAND_MAX*0.7f)*255.f;\r
218                 color = 0;\r
219                 QPointF oldPoint = canvas->toCanvasCoords(sample);\r
220                 FOR(j, steps)\r
221                 {\r
222                         fvec res = dynamical->Test(sample);\r
223                         if(dynamical->avoid)\r
224                         {\r
225                                 dynamical->avoid->SetObstacles(obstacles);\r
226                                 fvec newRes = dynamical->avoid->Avoid(sample, res);\r
227                                 res = newRes;\r
228                         }\r
229                         sample += res*dT;\r
230                         float speed = sqrtf(res[0]*res[0] + res[1]*res[1]);\r
231                         QPointF point = canvas->toCanvasCoords(sample);\r
232                         painter.setOpacity(speed);\r
233                         QColor c(color,color,color);\r
234                         painter.setPen(QPen(c, 0.25));\r
235                         painter.drawLine(point, oldPoint);\r
236                         oldPoint = point;\r
237                 }\r
238         }\r
239 }\r