FIXED: Generator saving/loading of parameters, display of parameters, noise in regres...
[mldemos:mldemos.git] / MLDemos / datagenerator.cpp
1 #include "datagenerator.h"
2 #include "ui_datagenerator.h"
3 #include <QDebug>
4
5 using namespace std;
6
7 DataGenerator::DataGenerator(Canvas *canvas, QMutex *mutex, QWidget *parent) :
8     QDialog(parent),
9     ui(new Ui::DataGenerator)
10 {
11     ui->setupUi(this);
12     this->canvas = canvas;
13     this->mutex = mutex;
14     connect(ui->addButton, SIGNAL(clicked()), this, SLOT(Generate()));
15     connect(ui->generatorCombo, SIGNAL(currentIndexChanged(int)), this, SLOT(OptionsChanged()));
16     OptionsChanged();
17 }
18
19 DataGenerator::~DataGenerator()
20 {
21     delete ui;
22 }
23
24 void DataGenerator::OptionsChanged()
25 {
26     int type = ui->generatorCombo->currentIndex();
27     ui->gridCountLabel->setText("Grid Count");
28     ui->radiusLabel->setText("Radius");
29     ui->classesLabel->setText("Class");
30     ui->dimLabel->setText("Dim");
31     ui->classesCount->setEnabled(true);
32     switch(type)
33     {
34     case 0: // checkerboard
35         ui->radiusLabel->setText("Size");
36         break;
37     case 1: // concentric circles
38         ui->gridCountLabel->setText("Circles");
39         ui->radiusLabel->setText("Radius");
40         break;
41     case 2: // swiss roll
42         ui->gridCountLabel->setText("Swirls");
43         ui->radiusLabel->setText("Radius");
44         break;
45     case 3: // sinc
46         ui->gridCountLabel->setText("Noise");
47         ui->radiusLabel->setText("Width");
48         ui->classesCount->setEnabled(false);
49         break;
50     case 4: // gaussian
51         ui->gridCountLabel->setText("Noise");
52         ui->radiusLabel->setText("Width");
53         ui->classesCount->setEnabled(false);
54         break;
55     case 5: // cosine
56         ui->gridCountLabel->setText("Noise");
57         ui->radiusLabel->setText("Width");
58         ui->classesCount->setEnabled(false);
59         break;
60     }
61 }
62
63 void DataGenerator::Generate()
64 {
65     int count = ui->countSpin->value();
66     int dim = ui->dimSpin->value();
67     int gridCount = ui->gridCountSpin->value();
68     int classesCount = ui->classesCount->value();
69     float radius = ui->radiusSpin->value();
70     int type = ui->generatorCombo->currentIndex();
71     vector<fvec> samples;
72     ivec labels;
73     fvec sample(dim);
74     switch(type)
75     {
76     case 0: // checkerboard
77     {
78         dim = 2;
79         sample.resize(dim);
80         int samplesPerCell = count/(gridCount*gridCount);
81         int label = 0;
82         float xStart=0, xStop=0, yStart=0, yStop=0;
83         FOR(y, gridCount)
84         {
85             yStart = y*radius;
86             yStop = yStart + radius;
87             FOR(x, gridCount)
88             {
89                 xStart = x*radius;
90                 xStop = xStart + radius;
91                 FOR(i, samplesPerCell)
92                 {
93                     sample[0] = drand48()*(xStop-xStart) + xStart;
94                     sample[1] = drand48()*(yStop-yStart) + yStart;
95                     samples.push_back(sample);
96                     labels.push_back(label);
97                 }
98                 label = (label+1)%classesCount;
99             }
100             if(!(gridCount%classesCount)) label = (label+1)%classesCount;
101         }
102     }
103         break;
104     case 1: // concentric circles
105     {
106         int samplesPerCircle = count/(gridCount*classesCount);
107         int cnt = 0;
108         FOR(i, gridCount)
109         {
110             FOR(c, classesCount)
111             {
112                 float radStart = radius*cnt/(float)(gridCount*classesCount);
113                 float radStop = radius*(cnt+1)/(float)(gridCount*classesCount);
114                 FOR(j, samplesPerCircle)
115                 {
116                     float angle = drand48()*2*M_PI;
117                     float rad = drand48()*(radStop-radStart) + radStart;
118                     sample[0] = cos(angle)*rad;
119                     sample[1] = sin(angle)*rad;
120                     samples.push_back(sample);
121                     labels.push_back(c);
122                 }
123                 cnt++;
124             }
125         }
126     }
127         break;
128     case 2: // swiss roll
129     {
130         dim = 2;
131         sample.resize(dim);
132         int samplesPerClass = count / classesCount;
133         FOR(c, classesCount)
134         {
135             FOR(i, samplesPerClass)
136             {
137                 float x = i/(float)samplesPerClass*M_PI*2*gridCount;
138                 sample[0] = x * cosf(x + M_PI*2/classesCount*c)*radius;
139                 sample[1] = x * sinf(x + M_PI*2/classesCount*c)*radius;
140                 samples.push_back(sample);
141                 labels.push_back(c);
142             }
143         }
144     }
145         break;
146     case 3: // sinc
147     {
148         dim = 2;
149         sample.resize(dim);
150         FOR(i, count)
151         {
152             float x = (i/(float)count*2 - 1)*radius*2*M_PI;
153             float y = sinf(M_PI*x) / (x*M_PI);
154             if(gridCount > 1) y += drand48()*((gridCount-1)/(float)32);
155             sample[0] = x;
156             sample[1] = y;
157             samples.push_back(sample);
158             labels.push_back(0);
159         }
160     }
161         break;
162     case 4: // gaussian
163     {
164         dim = 2;
165         sample.resize(dim);
166         FOR(i, count)
167         {
168             float x = (i/(float)count*2 - 1)*10*radius;
169             float y = exp(-0.5f*x*x);
170             if(gridCount > 1) y += drand48()*((gridCount-1)/(float)32);
171             sample[0] = x;
172             sample[1] = y;
173             samples.push_back(sample);
174             labels.push_back(0);
175         }
176     }
177         break;
178     case 5: // cosine
179     {
180         dim = 2;
181         sample.resize(dim);
182         FOR(i, count)
183         {
184             float x = (i/(float)count*2 - 1)*radius*2*M_PI;
185             float y = cos(x);
186             if(gridCount > 1) y += drand48()*((gridCount-1)/(float)32);
187             sample[0] = x;
188             sample[1] = y;
189             samples.push_back(sample);
190             labels.push_back(0);
191         }
192     }
193         break;
194     }
195     mutex->lock();
196     canvas->data->Clear();
197     canvas->ResetSamples();
198     canvas->data->AddSamples(samples, labels);
199     canvas->FitToData();
200     canvas->repaint();
201     mutex->unlock();
202 }