const-fixing functions
[mldemos:baraks-mldemos.git] / _AlgorithmsPlugins / KernelMethods / classifierRVM.cpp
1 /*********************************************************************\r
2 MLDemos: A User-Friendly visualization toolkit for machine learning\r
3 Copyright (C) 2010  Basilio Noris\r
4 Contact: mldemos@b4silio.com\r
5 \r
6 This library is free software; you can redistribute it and/or\r
7 modify it under the terms of the GNU Lesser General Public\r
8 License as published by the Free Software Foundation; either\r
9 version 2.1 of the License, or (at your option) any later version.\r
10 \r
11 This library is distributed in the hope that it will be useful,\r
12 but WITHOUT ANY WARRANTY; without even the implied warranty of\r
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU\r
14 Library General Public License for more details.\r
15 \r
16 You should have received a copy of the GNU Lesser General Public\r
17 License along with this library; if not, write to the Free\r
18 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.\r
19 *********************************************************************/\r
20 #include <public.h>\r
21 #include "classifierRVM.h"\r
22 \r
23 using namespace std;\r
24 using namespace dlib;\r
25 \r
26 ClassifierRVM::~ClassifierRVM()\r
27 {\r
28     if(decFunction)\r
29     {\r
30 #define KILLCASE(a) case a:{KillDim<a>();return;}\r
31         switch(dim)\r
32         {\r
33         KILLCASE(2);\r
34         KILLCASE(3);\r
35         KILLCASE(4);\r
36         KILLCASE(5);\r
37         KILLCASE(6);\r
38         KILLCASE(7);\r
39         KILLCASE(8);\r
40         KILLCASE(9);\r
41         KILLCASE(10);\r
42         KILLCASE(11);\r
43         KILLCASE(12);\r
44         default:\r
45             KillDim<2>();\r
46             return;\r
47         }\r
48     }\r
49 }\r
50 \r
51 const char *ClassifierRVM::GetInfoString() const\r
52 {\r
53         char *text = new char[1024];\r
54         sprintf(text, "Relevance Vector Machine\n");\r
55         sprintf(text, "%sKernel: ", text);\r
56         switch(kernelType)\r
57         {\r
58         case 0:\r
59                 sprintf(text, "%s linear", text);\r
60                 break;\r
61         case 1:\r
62                 sprintf(text, "%s polynomial (deg: %f %f width: %f)", text, kernelDegree, kernelParam);\r
63                 break;\r
64         case 2:\r
65                 sprintf(text, "%s rbf (gamma: %f)", text, kernelParam);\r
66                 break;\r
67         }\r
68         sprintf(text, "%seps: %f\n", text, epsilon);\r
69         sprintf(text, "%sRelevant Vectors: %d\n", text, GetSVs().size());\r
70         return text;\r
71 }\r
72 \r
73 void ClassifierRVM::Train(std::vector< fvec > samples, ivec labels)\r
74 {\r
75     if(!samples.size()) return;\r
76     dim = samples[0].size();\r
77     if(dim > 12) dim = 12;\r
78 \r
79     classMap.clear();\r
80     int cnt=0;\r
81     FOR(i, labels.size()) if(!classMap.count(labels[i])) classMap[labels[i]] = cnt++;\r
82     for(map<int,int>::iterator it=classMap.begin(); it != classMap.end(); it++) inverseMap[it->second] = it->first;\r
83 \r
84 #define TRAINCASE(a) case a:{TrainDim<a>(samples, labels);return;}\r
85     switch(dim)\r
86     {\r
87     TRAINCASE(2);\r
88     TRAINCASE(3);\r
89     TRAINCASE(4);\r
90     TRAINCASE(5);\r
91     TRAINCASE(6);\r
92     TRAINCASE(7);\r
93     TRAINCASE(8);\r
94     TRAINCASE(9);\r
95     TRAINCASE(10);\r
96     TRAINCASE(11);\r
97     TRAINCASE(12);\r
98     default:\r
99         TrainDim<2>(samples, labels);\r
100         return;\r
101     }\r
102 }\r
103 \r
104 float ClassifierRVM::Test( const fvec &_sample ) const\r
105 {\r
106 #define TESTCASE(a) case a:{return TestDim<a>(_sample);}\r
107     switch(dim)\r
108     {\r
109     TESTCASE(2);\r
110     TESTCASE(3);\r
111     TESTCASE(4);\r
112     TESTCASE(5);\r
113     TESTCASE(6);\r
114     TESTCASE(7);\r
115     TESTCASE(8);\r
116     TESTCASE(9);\r
117     TESTCASE(10);\r
118     TESTCASE(11);\r
119     TESTCASE(12);\r
120     default:\r
121         return TestDim<2>(_sample);\r
122     }\r
123 }\r
124 \r
125 std::vector<fvec> ClassifierRVM::GetSVs() const\r
126 {\r
127 #define SVCASE(a) case a:{return GetSVsDim<a>();}\r
128     switch(dim)\r
129     {\r
130     SVCASE(2);\r
131     SVCASE(3);\r
132     SVCASE(4);\r
133     SVCASE(5);\r
134     SVCASE(6);\r
135     SVCASE(7);\r
136     SVCASE(8);\r
137     SVCASE(9);\r
138     SVCASE(10);\r
139     SVCASE(11);\r
140     SVCASE(12);\r
141     default:\r
142         return GetSVsDim<2>();\r
143     }\r
144 }\r
145 \r
146 template <int N>\r
147 void ClassifierRVM::KillDim()\r
148 {\r
149     if(!decFunction) return;\r
150     switch(kernelTypeTrained)\r
151     {\r
152     case 0:\r
153         if(decFunction) delete [] (linfunc*)decFunction;\r
154         break;\r
155     case 1:\r
156         if(decFunction) delete [] (polfunc*)decFunction;\r
157         break;\r
158     case 2:\r
159         if(decFunction) delete [] (rbffunc*)decFunction;\r
160         break;\r
161     }\r
162     decFunction = 0;\r
163 }\r
164 \r
165 template <int N>\r
166 void ClassifierRVM::TrainDim(std::vector< fvec > _samples, ivec _labels)\r
167 {\r
168     std::vector<sampletype> samples;\r
169     std::vector<double> labels;\r
170     sampletype samp;\r
171     FOR(i, _samples.size()) { FOR(d, dim) samp(d) = _samples[i][d]; samples.push_back(samp); }\r
172     KillDim<N>();\r
173 \r
174     FOR(i, _samples.size()) labels.push_back(_labels[i] == 1 ? 1 : -1);\r
175 \r
176     randomize_samples(samples, labels);\r
177 \r
178     switch(kernelType)\r
179     {\r
180     case 0:\r
181     {\r
182         rvm_trainer<linkernel> train = rvm_trainer<linkernel>();\r
183         train.set_epsilon(epsilon);\r
184         train.set_kernel(linkernel());\r
185         linfunc *fun = new linfunc[1];\r
186         *fun = train.train(samples, labels);\r
187         decFunction = (void *)fun;\r
188         kernelTypeTrained = 0;\r
189     }\r
190         break;\r
191     case 1:\r
192     {\r
193         rvm_trainer<polkernel> train = rvm_trainer<polkernel>();\r
194         train.set_epsilon(epsilon);\r
195         train.set_kernel(polkernel(1./kernelParam, 0, kernelDegree));\r
196         polfunc *fun = new polfunc[1];\r
197         *fun = train.train(samples, labels);\r
198         decFunction = (void *)fun;\r
199         kernelTypeTrained = 1;\r
200     }\r
201         break;\r
202     case 2:\r
203     {\r
204         rvm_trainer<rbfkernel> train = rvm_trainer<rbfkernel>();\r
205         train.set_epsilon(epsilon);\r
206         train.set_kernel(rbfkernel(1./kernelParam));\r
207         rbffunc *fun = new rbffunc[1];\r
208         *fun = train.train(samples, labels);\r
209         decFunction = (void *)fun;\r
210         kernelTypeTrained = 2;\r
211     }\r
212         break;\r
213     }\r
214 }\r
215 \r
216 template <int N>\r
217 float ClassifierRVM::TestDim(const fvec &_sample) const\r
218 {\r
219     float estimate = 0.f;\r
220 \r
221     sampletype sample;\r
222     FOR(d,dim) sample(d) = _sample[d];\r
223     if(!decFunction) return estimate;\r
224     switch(kernelTypeTrained)\r
225     {\r
226     case 0:\r
227     {\r
228         linfunc fun = *(linfunc*)decFunction;\r
229         estimate = fun(sample);\r
230     }\r
231         break;\r
232     case 1:\r
233     {\r
234         polfunc fun = *(polfunc*)decFunction;\r
235         estimate = fun(sample);\r
236     }\r
237         break;\r
238     case 2:\r
239     {\r
240         rbffunc fun = *(rbffunc*)decFunction;\r
241         estimate = fun(sample);\r
242     }\r
243         break;\r
244     }\r
245     return estimate;\r
246 }\r
247 \r
248 template <int N>\r
249 std::vector<fvec> ClassifierRVM::GetSVsDim() const\r
250 {\r
251     std::vector<fvec> SVs;\r
252     switch(kernelTypeTrained)\r
253     {\r
254     case 0:\r
255     {\r
256 \r
257         FOR(i, (*(linfunc*)decFunction).basis_vectors.nr())\r
258         {\r
259             fvec sv;\r
260             sv.push_back((*(linfunc*)decFunction).basis_vectors(i)(0));\r
261             sv.push_back((*(linfunc*)decFunction).basis_vectors(i)(1));\r
262             SVs.push_back(sv);\r
263         }\r
264     }\r
265         break;\r
266     case 1:\r
267     {\r
268 \r
269         FOR(i, (*(polfunc*)decFunction).basis_vectors.nr())\r
270         {\r
271             fvec sv;\r
272             sv.push_back((*(polfunc*)decFunction).basis_vectors(i)(0));\r
273             sv.push_back((*(polfunc*)decFunction).basis_vectors(i)(1));\r
274             SVs.push_back(sv);\r
275         }\r
276     }\r
277         break;\r
278     case 2:\r
279     {\r
280 \r
281         FOR(i, (*(rbffunc*)decFunction).basis_vectors.nr())\r
282         {\r
283             fvec sv;\r
284             sv.push_back((*(rbffunc*)decFunction).basis_vectors(i)(0));\r
285             sv.push_back((*(rbffunc*)decFunction).basis_vectors(i)(1));\r
286             SVs.push_back(sv);\r
287         }\r
288     }\r
289         break;\r
290     }\r
291         return SVs;\r
292 }\r