Merge branch 'tweaks' into devel
[mldemos:baraks-mldemos.git] / _AlgorithmsPlugins / KernelMethods / classifierRVM.cpp
1 /*********************************************************************\r
2 MLDemos: A User-Friendly visualization toolkit for machine learning\r
3 Copyright (C) 2010  Basilio Noris\r
4 Contact: mldemos@b4silio.com\r
5 \r
6 This library is free software; you can redistribute it and/or\r
7 modify it under the terms of the GNU Lesser General Public\r
8 License as published by the Free Software Foundation; either\r
9 version 2.1 of the License, or (at your option) any later version.\r
10 \r
11 This library is distributed in the hope that it will be useful,\r
12 but WITHOUT ANY WARRANTY; without even the implied warranty of\r
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU\r
14 Library General Public License for more details.\r
15 \r
16 You should have received a copy of the GNU Lesser General Public\r
17 License along with this library; if not, write to the Free\r
18 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.\r
19 *********************************************************************/\r
20 #include <public.h>\r
21 #include "classifierRVM.h"\r
22 \r
23 using namespace std;\r
24 using namespace dlib;\r
25 \r
26 ClassifierRVM::~ClassifierRVM()\r
27 {\r
28     if(decFunction)\r
29     {\r
30 #define KILLCASE(a) case a:{KillDim<a>();return;}\r
31         switch(dim)\r
32         {\r
33         KILLCASE(2);\r
34         KILLCASE(3);\r
35         KILLCASE(4);\r
36         KILLCASE(5);\r
37         KILLCASE(6);\r
38         KILLCASE(7);\r
39         KILLCASE(8);\r
40         KILLCASE(9);\r
41         KILLCASE(10);\r
42         KILLCASE(11);\r
43         KILLCASE(12);\r
44         default:\r
45             KillDim<0>();\r
46             return;\r
47         }\r
48     }\r
49 }\r
50 \r
51 const char *ClassifierRVM::GetInfoString() const\r
52 {\r
53         char *text = new char[1024];\r
54         sprintf(text, "Relevance Vector Machine\n");\r
55         sprintf(text, "%sKernel: ", text);\r
56         switch(kernelType)\r
57         {\r
58         case 0:\r
59                 sprintf(text, "%s linear", text);\r
60                 break;\r
61         case 1:\r
62                 sprintf(text, "%s polynomial (deg: %d width: %f)", text, kernelDegree, kernelParam);\r
63                 break;\r
64         case 2:\r
65                 sprintf(text, "%s rbf (gamma: %f)", text, kernelParam);\r
66                 break;\r
67         }\r
68         sprintf(text, "%seps: %f\n", text, epsilon);\r
69         sprintf(text, "%sRelevant Vectors: %lu\n", text, (unsigned long)GetSVs().size());\r
70         return text;\r
71 }\r
72 \r
73 void ClassifierRVM::Train(std::vector< fvec > samples, ivec labels)\r
74 {\r
75     if(!samples.size()) return;\r
76     dim = samples[0].size();\r
77 \r
78     classMap.clear();\r
79     int cnt=0;\r
80     FOR(i, labels.size()) if(!classMap.count(labels[i])) classMap[labels[i]] = cnt++;\r
81     for(map<int,int>::iterator it=classMap.begin(); it != classMap.end(); it++) inverseMap[it->second] = it->first;\r
82 \r
83 #define TRAINCASE(a) case a:{TrainDim<a>(samples, labels);return;}\r
84     switch(dim)\r
85     {\r
86     TRAINCASE(2);\r
87     TRAINCASE(3);\r
88     TRAINCASE(4);\r
89     TRAINCASE(5);\r
90     TRAINCASE(6);\r
91     TRAINCASE(7);\r
92     TRAINCASE(8);\r
93     TRAINCASE(9);\r
94     TRAINCASE(10);\r
95     TRAINCASE(11);\r
96     TRAINCASE(12);\r
97     default:\r
98         TrainDim<0>(samples, labels);\r
99         return;\r
100     }\r
101 }\r
102 \r
103 float ClassifierRVM::Test( const fvec &_sample ) const\r
104 {\r
105 #define TESTCASE(a) case a:{return TestDim<a>(_sample);}\r
106     switch(dim)\r
107     {\r
108     TESTCASE(2);\r
109     TESTCASE(3);\r
110     TESTCASE(4);\r
111     TESTCASE(5);\r
112     TESTCASE(6);\r
113     TESTCASE(7);\r
114     TESTCASE(8);\r
115     TESTCASE(9);\r
116     TESTCASE(10);\r
117     TESTCASE(11);\r
118     TESTCASE(12);\r
119     default:\r
120         return TestDim<0>(_sample);\r
121     }\r
122 }\r
123 \r
124 std::vector<fvec> ClassifierRVM::GetSVs() const\r
125 {\r
126 #define SVCASE(a) case a:{return GetSVsDim<a>();}\r
127     switch(dim)\r
128     {\r
129     SVCASE(2);\r
130     SVCASE(3);\r
131     SVCASE(4);\r
132     SVCASE(5);\r
133     SVCASE(6);\r
134     SVCASE(7);\r
135     SVCASE(8);\r
136     SVCASE(9);\r
137     SVCASE(10);\r
138     SVCASE(11);\r
139     SVCASE(12);\r
140     default:\r
141         return GetSVsDim<0>();\r
142     }\r
143 }\r
144 \r
145 template <int N>\r
146 void ClassifierRVM::KillDim()\r
147 {\r
148     if(!decFunction) return;\r
149     switch(kernelTypeTrained)\r
150     {\r
151     case 0:\r
152         if(decFunction) delete [] (linfunc*)decFunction;\r
153         break;\r
154     case 1:\r
155         if(decFunction) delete [] (polfunc*)decFunction;\r
156         break;\r
157     case 2:\r
158         if(decFunction) delete [] (rbffunc*)decFunction;\r
159         break;\r
160     }\r
161     decFunction = 0;\r
162 }\r
163 \r
164 template <int N>\r
165 void ClassifierRVM::TrainDim(std::vector< fvec > _samples, ivec _labels)\r
166 {\r
167     std::vector<sampletype> samples;\r
168     std::vector<double> labels;\r
169     sampletype samp(dim);\r
170     FOR(i, _samples.size()) { FOR(d, dim) samp(d) = _samples[i][d]; samples.push_back(samp); }\r
171     KillDim<N>();\r
172 \r
173     FOR(i, _samples.size()) labels.push_back(_labels[i] == 1 ? 1 : -1);\r
174 \r
175     randomize_samples(samples, labels);\r
176 \r
177     switch(kernelType)\r
178     {\r
179     case 0:\r
180     {\r
181         rvm_trainer<linkernel> train = rvm_trainer<linkernel>();\r
182         train.set_epsilon(epsilon);\r
183         train.set_kernel(linkernel());\r
184         linfunc *fun = new linfunc[1];\r
185         *fun = train.train(samples, labels);\r
186         decFunction = (void *)fun;\r
187         kernelTypeTrained = 0;\r
188     }\r
189         break;\r
190     case 1:\r
191     {\r
192         rvm_trainer<polkernel> train = rvm_trainer<polkernel>();\r
193         train.set_epsilon(epsilon);\r
194         train.set_kernel(polkernel(1./kernelParam, 0, kernelDegree));\r
195         polfunc *fun = new polfunc[1];\r
196         *fun = train.train(samples, labels);\r
197         decFunction = (void *)fun;\r
198         kernelTypeTrained = 1;\r
199     }\r
200         break;\r
201     case 2:\r
202     {\r
203         rvm_trainer<rbfkernel> train = rvm_trainer<rbfkernel>();\r
204         train.set_epsilon(epsilon);\r
205         train.set_kernel(rbfkernel(1./kernelParam));\r
206         rbffunc *fun = new rbffunc[1];\r
207         *fun = train.train(samples, labels);\r
208         decFunction = (void *)fun;\r
209         kernelTypeTrained = 2;\r
210     }\r
211         break;\r
212     }\r
213 }\r
214 \r
215 template <int N>\r
216 float ClassifierRVM::TestDim(const fvec &_sample) const\r
217 {\r
218     float estimate = 0.f;\r
219 \r
220     sampletype sample(dim);\r
221     FOR(d,dim) sample(d) = _sample[d];\r
222     if(!decFunction) return estimate;\r
223     switch(kernelTypeTrained)\r
224     {\r
225     case 0:\r
226     {\r
227         linfunc fun = *(linfunc*)decFunction;\r
228         estimate = fun(sample);\r
229     }\r
230         break;\r
231     case 1:\r
232     {\r
233         polfunc fun = *(polfunc*)decFunction;\r
234         estimate = fun(sample);\r
235     }\r
236         break;\r
237     case 2:\r
238     {\r
239         rbffunc fun = *(rbffunc*)decFunction;\r
240         estimate = fun(sample);\r
241     }\r
242         break;\r
243     }\r
244     return estimate;\r
245 }\r
246 \r
247 template <int N>\r
248 std::vector<fvec> ClassifierRVM::GetSVsDim() const\r
249 {\r
250     std::vector<fvec> SVs;\r
251     switch(kernelTypeTrained)\r
252     {\r
253     case 0:\r
254     {\r
255 \r
256         FOR(i, (*(linfunc*)decFunction).basis_vectors.nr())\r
257         {\r
258             fvec sv(dim);\r
259             FOR(d, dim) sv[d] = (*(linfunc*)decFunction).basis_vectors(i)(d);\r
260             SVs.push_back(sv);\r
261         }\r
262     }\r
263         break;\r
264     case 1:\r
265     {\r
266 \r
267         FOR(i, (*(polfunc*)decFunction).basis_vectors.nr())\r
268         {\r
269             fvec sv(dim);\r
270             FOR(d, dim) sv[d] = (*(polfunc*)decFunction).basis_vectors(i)(d);\r
271             SVs.push_back(sv);\r
272         }\r
273     }\r
274         break;\r
275     case 2:\r
276     {\r
277 \r
278         FOR(i, (*(rbffunc*)decFunction).basis_vectors.nr())\r
279         {\r
280             fvec sv(dim);\r
281             FOR(d, dim) sv[d] = (*(rbffunc*)decFunction).basis_vectors(i)(d);\r
282             SVs.push_back(sv);\r
283         }\r
284     }\r
285         break;\r
286     }\r
287         return SVs;\r
288 }\r