Teh first one
[mldemos:kalians-mldemos.git] / MLDemos / sampleManager.cpp
1 /*********************************************************************\r
2 MLDemos: A User-Friendly visualization toolkit for machine learning\r
3 Copyright (C) 2010  Basilio Noris\r
4 Contact: mldemos@b4silio.com\r
5 \r
6 This library is free software; you can redistribute it and/or\r
7 modify it under the terms of the GNU Lesser General Public\r
8 License as published by the Free Software Foundation; either\r
9 version 2.1 of the License, or (at your option) any later version.\r
10 \r
11 This library is distributed in the hope that it will be useful,\r
12 but WITHOUT ANY WARRANTY; without even the implied warranty of\r
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU\r
14 Library General Public License for more details.\r
15 \r
16 You should have received a copy of the GNU Lesser General Public\r
17 License along with this library; if not, write to the Free\r
18 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.\r
19 *********************************************************************/\r
20 #include "public.h"\r
21 #include "basicMath.h"\r
22 #include "basicOpenCV.h"\r
23 #include "sampleManager.h"\r
24 \r
25 \r
26 using namespace std;\r
27 \r
28 u32 SampleManager::IDCount = 0;\r
29 \r
30 SampleManager::SampleManager(CvSize resolution)\r
31 : size(resolution)\r
32 {\r
33         ID = IDCount++;\r
34         display = NULL;\r
35         perm = NULL;\r
36         bShowing = false;\r
37 }\r
38 \r
39 SampleManager::~SampleManager()\r
40 {\r
41         IMKILL(display);\r
42         Clear();\r
43 }\r
44 \r
45 void SampleManager::CreateSampleImage(IplImage **image, bool bShowLabels, float ratio)\r
46 {\r
47         int cnt = samples.size();\r
48 \r
49         int gridH = (int)(sqrtf(cnt/ratio) + 0.5f);\r
50         int gridW = (cnt / gridH) + (cnt%gridH ? 1 : 0);\r
51 \r
52         CvSize imSize = cvSize(size.width*gridW, size.height*gridH);\r
53         if (!(*image) || (*image)->width != imSize.width || (*image)->height != imSize.height)\r
54         {\r
55                 if((*image)) cvReleaseImage(image);\r
56                 (*image) = cvCreateImage(imSize, 8, 3);\r
57         }\r
58         cvZero((*image));\r
59 \r
60         FOR(i, cnt)\r
61         {\r
62                 if (!samples.at(i)) continue;\r
63                 CvRect rect = cvRect((i%gridW) * size.width, (i/gridW) * size.height, size.width, size.height);\r
64                 cvSetImageROI((*image), rect);\r
65                 cvCopy(samples.at(i), (*image));\r
66                 cvResetImageROI((*image));\r
67                 if(bShowLabels && flags[i] == TEST)\r
68                 {\r
69                         cvSetImageROI((*image), rect);\r
70                         IplImage *black = cvCreateImage(size, 8, 3);\r
71                         cvZero(black);\r
72                         cvAddWeighted((*image),0.5,black,1,0,(*image));\r
73                         IMKILL(black);\r
74                         cvResetImageROI((*image));\r
75                 }\r
76                 if(bShowLabels && labels[i])\r
77                 {\r
78                         cvRectangle((*image), cvPoint(rect.x, rect.y), cvPoint(rect.x+rect.width-1, rect.y+rect.height-1), CV::color[labels[i]%CV::colorCnt], 2, CV_AA); \r
79                 }\r
80         }\r
81 }\r
82 \r
83 void sm_on_mouse( int event, int x, int y, int flags, void* param );\r
84 int params[5];\r
85 void SampleManager::Show()\r
86 {\r
87         int cnt = samples.size();\r
88         if(!cnt) return;\r
89 \r
90         CreateSampleImage(&display, true);\r
91 \r
92 \r
93         params[0] = (intptr_t)((void *)&display);\r
94         params[1] = (intptr_t)((void *)&samples);\r
95         params[2] = (intptr_t)((void *)&labels);\r
96         params[3] = (intptr_t)((void *)&flags);\r
97         params[4] = (intptr_t)((void *)&size);\r
98         char name[255];\r
99         sprintf(name, "collected samples %d", ID);\r
100         cvNamedWindow(name);\r
101         cvShowImage(name, display);\r
102         cvSetMouseCallback(name, sm_on_mouse, (void *) &params);\r
103         bShowing = true;\r
104 }\r
105 \r
106 IplImage *SampleManager::GetSampleImage()\r
107 {\r
108         IplImage *image = NULL;\r
109         if(!GetCount()) return image;\r
110         CreateSampleImage(&image, true);\r
111         return image;\r
112 }\r
113 \r
114 \r
115 void SampleManager::Hide()\r
116 {\r
117         char name[255];\r
118         sprintf(name, "collected samples %d", ID);\r
119         cvDestroyWindow(name);\r
120         bShowing = false;\r
121 }\r
122 \r
123 int SampleManager::GetIndexAt(int x, int y)\r
124 {\r
125         if(display)\r
126         {\r
127                 int gridX = (display->width / size.width);\r
128                 int gridY = (display->height / size.height);\r
129                 int i = (int)(x / (float)display->width * gridX);\r
130                 int j = (int)(y / (float)display->height * gridY);\r
131                 int index = j*gridX + i;\r
132                 return index;\r
133         }\r
134         else\r
135         {\r
136                 float ratio = 1.f;\r
137                 int cnt = samples.size();\r
138                 int gridH = (int)(sqrtf(cnt/ratio) + 0.5f);\r
139                 int gridW = (cnt / gridH) + (cnt%gridH ? 1 : 0);\r
140                 CvSize imSize = cvSize(size.width*gridW, size.height*gridH);\r
141 \r
142                 int gridX = (imSize.width / size.width);\r
143                 int gridY = (imSize.height / size.height);\r
144                 int i = (int)(x / (float)imSize.width * gridX);\r
145                 int j = (int)(y / (float)imSize.height * gridY);\r
146                 int index = j*gridW + i;\r
147                 return index;\r
148         }\r
149 }\r
150 \r
151 void sm_on_mouse( int event, int x, int y, int flags, void* param )\r
152 {\r
153         IplImage *image = (*(IplImage **)(((int *)param)[0]));\r
154         std::vector<IplImage *> *samples = (std::vector<IplImage *> *)(((int *)param)[1]);\r
155         ivec *labels = (ivec *)(((int *)param)[2]);\r
156         std::vector<smFlags> *smflags = (std::vector<smFlags> *)(((int *)param)[3]);\r
157         CvSize size = (*(CvSize *)(((int *)param)[4]));\r
158     if( !image )\r
159         return;\r
160 \r
161     if( image->origin )\r
162         y = image->height - y;\r
163 \r
164         x = x >= (0xffff>>1) ? -(0xffff - x + 1) : x;\r
165         y = y >= (0xffff>>1) ? -(0xffff - y + 1) : y;\r
166 \r
167         if (x < 0) x = 0;\r
168         if (y < 0) y = 0;\r
169 \r
170         unsigned int gridX = (image->width / size.width);\r
171         unsigned int gridY = (image->height / size.height);\r
172         unsigned int i = (int)(x / (float)image->width * gridX);\r
173         unsigned int j = (int)(y / (float)image->height * gridY);\r
174         unsigned int index = j*gridX + i;\r
175 \r
176         if(event == CV_EVENT_LBUTTONDOWN)\r
177     {\r
178         }\r
179         else if(event == CV_EVENT_LBUTTONUP)\r
180         {\r
181                 if(flags & CV_EVENT_FLAG_CTRLKEY)\r
182                 {\r
183                         if(samples->size() == 1)\r
184                         {\r
185                                 IMKILL((*samples)[0]);\r
186                                 samples->clear();\r
187                                 labels->clear();\r
188                                 cvZero(image);\r
189                         }\r
190                         else if(index < samples->size())\r
191                         {\r
192                                 IMKILL((*samples)[index]);\r
193                                 while(index < samples->size()-1)\r
194                                 {\r
195                                         (*samples)[index] = (*samples)[index+1];\r
196                                         (*labels)[index] = (*labels)[index+1];\r
197                                         index++;\r
198                                 }\r
199                                 samples->pop_back();\r
200                                 labels->pop_back();\r
201                         }\r
202                 }\r
203                 else if(flags & CV_EVENT_FLAG_ALTKEY)\r
204                 {\r
205                         if(index < smflags->size())\r
206                         {\r
207                                 if((*smflags)[index] == UNUSED)\r
208                                 {\r
209                                         (*smflags)[index] = TEST;\r
210                                 }\r
211                                 else if ((*smflags)[index] == TEST)\r
212                                 {\r
213                                         (*smflags)[index] = UNUSED;\r
214                                 }\r
215                         }\r
216                 }\r
217                 else\r
218                 {\r
219                         if(index < labels->size())\r
220                         {\r
221                                 if (flags & CV_EVENT_FLAG_SHIFTKEY)\r
222                                 {\r
223                                         u32 newLabel = ((*labels)[index]+1) % 256;\r
224                                         for (u32 i=index; i<labels->size(); i++) (*labels)[i] = newLabel;\r
225                                 }\r
226                                 else (*labels)[index] = ((*labels)[index]+1) % 256;\r
227                         }\r
228                 }\r
229     }\r
230         else if(event == CV_EVENT_RBUTTONUP)\r
231         {\r
232                 if(index < labels->size())\r
233                 {\r
234                                 if (flags & CV_EVENT_FLAG_SHIFTKEY)\r
235                                 {\r
236                                         u32 newLabel = (*labels)[index] ? (*labels)[index]-1 : 255;\r
237                                         for (u32 i=index; i<labels->size(); i++) (*labels)[i] = newLabel;\r
238                                 }\r
239                                 else (*labels)[index] = (*labels)[index] ? (*labels)[index]-1 : 255;\r
240                 }\r
241         }\r
242 }\r
243 \r
244 \r
245 void SampleManager::Clear()\r
246 {\r
247         FOR(i, samples.size())\r
248         {\r
249                 IMKILL(samples[i]);\r
250         }\r
251         samples.clear();\r
252         flags.clear();\r
253         labels.clear();\r
254         KILL(perm);\r
255         if(display) cvZero(display);\r
256 }\r
257 \r
258 void SampleManager::AddSample(IplImage *image, unsigned int label)\r
259 {\r
260         if (!image) return;\r
261 \r
262         IplImage *img = cvCreateImage(size, 8, 3);\r
263         if(image->nChannels == 3) cvResize(image, img, CV_INTER_CUBIC);\r
264         else\r
265         {\r
266                 IplImage *tmp = cvCreateImage(cvGetSize(image), 8, 3);\r
267                 cvCvtColor(image, tmp, CV_GRAY2BGR);\r
268                 cvResize(tmp, img);\r
269                 IMKILL(tmp);\r
270         }\r
271         samples.push_back(img);\r
272         flags.push_back(UNUSED);\r
273         labels.push_back(label);\r
274         KILL(perm);\r
275         perm = randPerm(samples.size());\r
276 }\r
277 \r
278 void SampleManager::AddSample(IplImage *image, CvRect selection, unsigned int label)\r
279 {\r
280         if (!image) return;\r
281         if (selection.x < 0 || selection.y < 0 || !selection.width || !selection.height) return;\r
282         if (selection.x + selection.width > image->width || selection.y + selection.height > image->height) return;\r
283 \r
284         ROI(image, selection);\r
285         IplImage *img = cvCreateImage(size, 8, 3);\r
286         cvResize(image, img, CV_INTER_CUBIC);\r
287         unROI(image);\r
288         samples.push_back(img);\r
289         flags.push_back(UNUSED);\r
290         labels.push_back(label);\r
291         KILL(perm);\r
292         perm = randPerm(samples.size());\r
293 }\r
294 \r
295 \r
296 void SampleManager::AddSamples(std::vector<IplImage *>images)\r
297 {\r
298         FOR(i, images.size())\r
299         {\r
300                 if(images[i])\r
301                 {\r
302                         IplImage *sample = cvCreateImage(size, 8, 3);\r
303                         if(images[i]->width == size.width && images[i]->height == size.height)\r
304                         {\r
305                                 if(images[i]->nChannels == 3) cvCopy(images[i], sample);\r
306                                 else cvCvtColor(images[i], sample, CV_GRAY2BGR);\r
307                         }\r
308                         else\r
309                         {\r
310                                 if(images[i]->nChannels == 3) cvResize(images[i], sample, CV_INTER_CUBIC);\r
311                                 else\r
312                                 {\r
313                                         IplImage *tmp = cvCreateImage(size, 8, 1);\r
314                                         cvResize(images[i], tmp, CV_INTER_CUBIC);\r
315                                         cvCvtColor(tmp, sample, CV_GRAY2BGR);\r
316                                         IMKILL(tmp);\r
317                                 }\r
318                         }\r
319                         samples.push_back(sample);\r
320                         flags.push_back(UNUSED);\r
321                         labels.push_back(0);\r
322                 }\r
323         }\r
324         KILL(perm);\r
325         perm = randPerm(samples.size());\r
326 }\r
327 \r
328 void SampleManager::AddSamples(SampleManager newSamples)\r
329 {\r
330         FOR(i, newSamples.GetSamples().size())\r
331         {\r
332                 samples.push_back(newSamples.GetSample(i));\r
333                 flags.push_back(newSamples.GetFlag(i));\r
334                 labels.push_back(newSamples.GetLabel(i));\r
335         }\r
336         KILL(perm);\r
337         perm = randPerm(samples.size());\r
338 }\r
339 \r
340 void SampleManager::RemoveSample(unsigned int index)\r
341 {\r
342         if(index >= samples.size()) return;\r
343         if(samples.size() == 1)\r
344         {\r
345                 Clear();\r
346                 return;\r
347         }\r
348         IMKILL(samples[index]);\r
349         for (unsigned int i = index; i < samples.size()-1; i++)\r
350         {\r
351                 samples[i] = samples[i+1];\r
352                 labels[i] = labels[i+1];\r
353         }\r
354         samples.pop_back();\r
355         labels.pop_back();\r
356 }\r
357 \r
358 // we compare the current sample with all the ones in the dataset\r
359 // and return the smallest distance\r
360 f32 SampleManager::Compare(IplImage *sample)\r
361 {\r
362         if(!sample) return 1.0f;\r
363         IplImage *s = cvCreateImage(size, 8, 3);\r
364         if(sample->width == size.width && sample->height == size.height)\r
365         {\r
366                 if(sample->nChannels == 3) cvCopy(sample, s);\r
367                 else cvCvtColor(sample, s, CV_GRAY2BGR);\r
368         }\r
369         else if(sample->nChannels == s->nChannels)\r
370         {\r
371                 cvResize(sample, s, CV_INTER_CUBIC);\r
372         }\r
373         else\r
374         {\r
375                 IplImage *tmp = cvCreateImage(cvGetSize(sample), 8, 3);\r
376                 cvCvtColor(sample, tmp, CV_GRAY2BGR);\r
377                 cvResize(tmp, s, CV_INTER_CUBIC);\r
378                 IMKILL(tmp);\r
379         }\r
380 \r
381         // now compute the differences\r
382         f32 minDist = 1.0f;\r
383         u32 index = 0;\r
384         IplImage *diff = cvCloneImage(s);\r
385         FOR(i, samples.size())\r
386         {\r
387                 cvAbsDiff(s, samples[i], diff);\r
388                 f32 dist = (f32)cvSum(diff).val[0] / (f32)(size.width*size.height) / 255.f;\r
389                 if(minDist > dist)\r
390                 {\r
391                         index = i;\r
392                         minDist = dist;\r
393                 }\r
394         }\r
395         IMKILL(diff);\r
396         IMKILL(s);\r
397         return minDist;\r
398 }\r
399 \r
400 void SampleManager::Randomize(int seed)\r
401 {\r
402         KILL(perm);\r
403         if(samples.size()) perm = randPerm(samples.size(), seed);\r
404 }\r
405 \r
406 void SampleManager::ResetFlags()\r
407 {\r
408         FOR(i, samples.size()) flags[i] = UNUSED;\r
409 }\r
410 \r
411 \r
412 std::vector<IplImage *> SampleManager::GetSamples(u32 count, smFlags flag, smFlags replaceWith)\r
413 {\r
414         std::vector<IplImage *> selected;\r
415         if (!samples.size() || !perm) return selected;\r
416 \r
417         if (!count)\r
418         {\r
419                 FOR(i, samples.size())\r
420                 {\r
421                         if ( flags[perm[i]] == flag)\r
422                         {\r
423                                 selected.push_back(samples[perm[i]]);\r
424                                 flags[perm[i]] = replaceWith;\r
425                         }\r
426                 }\r
427                 return selected;\r
428         }\r
429 \r
430         for ( u32 i=0, cnt=0; i < samples.size() && cnt < count; i++ )\r
431         {\r
432                 if ( flags[perm[i]] == flag )\r
433                 {\r
434                         selected.push_back(samples[perm[i]]);\r
435                         flags[perm[i]] = replaceWith;\r
436                         cnt++;\r
437                 }\r
438         }\r
439 \r
440         return selected;\r
441 }\r
442 \r
443 void SampleManager::Save(const char *filename)\r
444 {\r
445         if(!samples.size()) return;\r
446         IplImage *image = NULL;\r
447         u32 sampleCnt = samples.size();\r
448 \r
449         IplImage *labelImg = cvCreateImage(size, 8, 3);\r
450         u32 passes = 1 + (sampleCnt+2) / (size.width*size.height*3);\r
451         u32 cnt = min(size.width*size.height*3, (int)sampleCnt);\r
452         cvZero(labelImg); // we want at least one empty label\r
453         samples.push_back(labelImg);\r
454         FOR(i, passes)\r
455         {\r
456                 cnt = min(size.width*size.height*3, (int)sampleCnt - (int)i*size.width*size.height*3);\r
457                 labelImg = cvCreateImage(size, 8, 3);\r
458                 cvZero(labelImg);\r
459                 FOR(j, cnt)\r
460                 {\r
461                         labelImg->imageData[j] = labels[i*(size.width*size.height*3) + j];\r
462                 }\r
463                 samples.push_back(labelImg);\r
464         }\r
465 \r
466         CreateSampleImage(&image);\r
467         \r
468         // we write down the size of the samples in the last pixel of the image\r
469         cvSet2D(image,image->width-1,image->height-1,CV_RGB(255, size.height, size.width));\r
470 \r
471 \r
472         FOR(i, passes+1)\r
473         {\r
474                 IMKILL(samples[samples.size()-1]);\r
475                 samples.pop_back();\r
476         }\r
477 \r
478         cvSaveImage(filename, image);\r
479         IMKILL(image);\r
480 }\r
481 \r
482 bool SampleManager::Load(const char *filename, CvSize resolution)\r
483 {\r
484         IplImage *image = cvLoadImage(filename);\r
485         if(!image || image->width < resolution.width || image->height < resolution.height) return false;\r
486 \r
487         Clear();\r
488 \r
489         // we try to get the resolution off the image itself\r
490         int last = (image->height-1)*image->widthStep + (image->width-1)*3;\r
491         if(image->imageData[last] == -1) // we have the information!\r
492         {\r
493                 size.width = image->imageData[last-2];\r
494                 size.height = image->imageData[last-1];\r
495         }\r
496         else size = resolution;\r
497 \r
498         int gridW = image->width / size.width;\r
499         int gridH = image->height / size.height;\r
500         int cnt = gridW*gridH;\r
501         bool bDone = false;\r
502         FOR(i, cnt)\r
503         {\r
504                 IplImage *sample = cvCreateImage(size, 8, 3);\r
505                 ROI(image, cvRect((i%gridW) * size.width, (i/gridW) * size.height, size.width, size.height));\r
506                 cvCopy(image, sample);\r
507                 unROI(image);\r
508 \r
509                 if(bDone)\r
510                 {\r
511                         if(labels.size() == samples.size()) // we added all the labels already\r
512                         {\r
513                                 IMKILL(sample);\r
514                                 break;\r
515                         }\r
516                         u32 labelCnt = min((u32)size.width*size.height*3, (u32)samples.size() - (u32)labels.size());\r
517                         FOR(j, labelCnt)\r
518                         {\r
519                                 labels.push_back((u8)sample->imageData[j]);\r
520                         }\r
521                         IMKILL(sample);\r
522                         continue;\r
523                 }\r
524 \r
525                 if (cvSumPixels(sample) == 0)\r
526                 {\r
527                         IMKILL(sample);\r
528                         bDone = true;\r
529                 }\r
530                 else\r
531                 {\r
532                         samples.push_back(sample);\r
533                         flags.push_back(UNUSED);\r
534                         //labels.push_back(0);\r
535                 }\r
536         }\r
537         while(labels.size() < samples.size()) labels.push_back(0);\r
538         KILL(perm);\r
539         perm = randPerm(samples.size());\r
540         return samples.size() > 0;\r
541 }\r
542 \r
543 u32 SampleManager::GetClassCount(ivec classes)\r
544 {\r
545         u32 *counts = new u32[256];\r
546         memset(counts, 0, 256*sizeof(u32));\r
547         FOR(i, classes.size()) counts[classes[i]]++;\r
548         u32 result = 0;\r
549         for (u32 i=1; i<256; i++) result += counts[i] > 0 ? 1 : 0;\r
550         return result;\r
551 }\r
552 \r
553 std::vector<bool> SampleManager::GetFreeFlags()\r
554 {\r
555         std::vector<bool> res;\r
556         FOR(i, flags.size()) res.push_back(flags[i] == UNUSED);\r
557         return res;\r
558 }\r
559 \r
560 float SampleManager::GetTestRatio()\r
561 {\r
562         float ratio = 0;\r
563         FOR(i, flags.size())\r
564         {\r
565                 ratio += flags[i] == TEST ? 1 : 0;\r
566         }\r
567         return ratio /= flags.size();\r
568 }\r
569 \r
570 void SampleManager::RandomTestSet(float ratio, bool bEnsureOnePerClass)\r
571 {\r
572         float minSamples = 5;\r
573         // we want to have at least minSamples samples in order not to get weird results\r
574         ratio = min(ratio, (GetCount()-minSamples)/GetCount());\r
575         u32 *perm = randPerm(GetCount());\r
576         FOR(i, GetCount())\r
577         {\r
578                 SetFlag(perm[i], i < GetCount()*ratio ? TEST : UNUSED);\r
579         }\r
580         delete [] perm;\r
581 \r
582         if(!bEnsureOnePerClass) return;\r
583         // we count how many samples of the positive and negative classes we have\r
584         vector< pair<u32, u32> > counts;\r
585         FOR(i, GetCount())\r
586         {\r
587                 u32 label = GetLabel(i);\r
588                 bool bExists = false;\r
589                 u32 index = 0;\r
590                 FOR(j, counts.size())\r
591                 {\r
592                         if(label == counts[j].first)\r
593                         {\r
594                                 index = j;\r
595                                 bExists = true;\r
596                                 break;\r
597                         }\r
598                 }\r
599                 if(!bExists)\r
600                 {\r
601                         counts.push_back(pair<u32,u32>(label, 0));\r
602                 }\r
603                 else\r
604                 {\r
605                         if(GetFlag(i) == UNUSED)\r
606                         counts[index].second++;\r
607                 }\r
608         }\r
609         // if we don't have any, we set at least one sample for ever class\r
610         perm = randPerm(GetCount());\r
611         FOR(j, counts.size())\r
612         {\r
613                 if(counts[j].second) continue;\r
614                 FOR(i, GetCount())\r
615                 {\r
616                         if(GetLabel(perm[i]) == counts[j].first)\r
617                         {\r
618                                 SetFlag(perm[i], UNUSED);\r
619                                 break;\r
620                         }\r
621                 }\r
622         }\r
623         delete [] perm;\r
624 }\r