const-fixing functions
[mldemos:baraks-mldemos.git] / _AlgorithmsPlugins / KernelMethods / classifierPegasos.cpp
1 /*********************************************************************\r
2 MLDemos: A User-Friendly visualization toolkit for machine learning\r
3 Copyright (C) 2010  Basilio Noris\r
4 Contact: mldemos@b4silio.com\r
5 \r
6 This library is free software; you can redistribute it and/or\r
7 modify it under the terms of the GNU Lesser General Public\r
8 License as published by the Free Software Foundation; either\r
9 version 2.1 of the License, or (at your option) any later version.\r
10 \r
11 This library is distributed in the hope that it will be useful,\r
12 but WITHOUT ANY WARRANTY; without even the implied warranty of\r
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU\r
14 Library General Public License for more details.\r
15 \r
16 You should have received a copy of the GNU Lesser General Public\r
17 License along with this library; if not, write to the Free\r
18 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.\r
19 *********************************************************************/\r
20 #include <public.h>\r
21 #include "classifierPegasos.h"\r
22 \r
23 using namespace std;\r
24 using namespace dlib;\r
25 \r
26 ClassifierPegasos::~ClassifierPegasos()\r
27 {\r
28     if(decFunction)\r
29     {\r
30 #define KILLCASE(a) case a:{KillDim<a>();return;}\r
31         switch(dim)\r
32         {\r
33         KILLCASE(2);\r
34         KILLCASE(3);\r
35         KILLCASE(4);\r
36         KILLCASE(5);\r
37         KILLCASE(6);\r
38         KILLCASE(7);\r
39         KILLCASE(8);\r
40         KILLCASE(9);\r
41         KILLCASE(10);\r
42         KILLCASE(11);\r
43         KILLCASE(12);\r
44         default:\r
45             KillDim<2>();\r
46             return;\r
47         }\r
48     }\r
49 }\r
50 \r
51 const char *ClassifierPegasos::GetInfoString() const\r
52 {\r
53         char *text = new char[1024];\r
54         sprintf(text, "pegasos SVM\n");\r
55         sprintf(text, "%sKernel: ", text);\r
56         switch(kernelType)\r
57         {\r
58         case 0:\r
59                 sprintf(text, "%s linear", text);\r
60                 break;\r
61         case 1:\r
62                 sprintf(text, "%s polynomial (deg: %f %f width: %f)", text, kernelDegree, kernelParam);\r
63                 break;\r
64         case 2:\r
65                 sprintf(text, "%s rbf (gamma: %f)", text, kernelParam);\r
66                 break;\r
67         }\r
68         sprintf(text, "%slambda: %f\n", text, lambda);\r
69         sprintf(text, "%sSupport Vectors: %d\n", text, GetSVs().size());\r
70         return text;\r
71 }\r
72 \r
73 void ClassifierPegasos::Train(std::vector< fvec > samples, ivec labels)\r
74 {\r
75     if(!samples.size()) return;\r
76     dim = samples[0].size();\r
77     if(dim > 12) dim = 12;\r
78 \r
79     classMap.clear();\r
80     int cnt=0;\r
81     FOR(i, labels.size()) if(!classMap.count(labels[i])) classMap[labels[i]] = cnt++;\r
82     for(map<int,int>::iterator it=classMap.begin(); it != classMap.end(); it++) inverseMap[it->second] = it->first;\r
83 \r
84 #define TRAINCASE(a) case a:{TrainDim<a>(samples, labels);return;}\r
85     switch(dim)\r
86     {\r
87     TRAINCASE(2);\r
88     TRAINCASE(3);\r
89     TRAINCASE(4);\r
90     TRAINCASE(5);\r
91     TRAINCASE(6);\r
92     TRAINCASE(7);\r
93     TRAINCASE(8);\r
94     TRAINCASE(9);\r
95     TRAINCASE(10);\r
96     TRAINCASE(11);\r
97     TRAINCASE(12);\r
98     default:\r
99         TrainDim<2>(samples, labels);\r
100         return;\r
101     }\r
102 }\r
103 \r
104 float ClassifierPegasos::Test( const fvec &_sample ) const\r
105 {\r
106 #define TESTCASE(a) case a:{return TestDim<a>(_sample);}\r
107     switch(dim)\r
108     {\r
109     TESTCASE(2);\r
110     TESTCASE(3);\r
111     TESTCASE(4);\r
112     TESTCASE(5);\r
113     TESTCASE(6);\r
114     TESTCASE(7);\r
115     TESTCASE(8);\r
116     TESTCASE(9);\r
117     TESTCASE(10);\r
118     TESTCASE(11);\r
119     TESTCASE(12);\r
120     default:\r
121         return TestDim<2>(_sample);\r
122     }\r
123 }\r
124 \r
125 std::vector<fvec> ClassifierPegasos::GetSVs() const\r
126 {\r
127 #define SVCASE(a) case a:{return GetSVsDim<a>();}\r
128     switch(dim)\r
129     {\r
130     SVCASE(2);\r
131     SVCASE(3);\r
132     SVCASE(4);\r
133     SVCASE(5);\r
134     SVCASE(6);\r
135     SVCASE(7);\r
136     SVCASE(8);\r
137     SVCASE(9);\r
138     SVCASE(10);\r
139     SVCASE(11);\r
140     SVCASE(12);\r
141     default:\r
142         return GetSVsDim<2>();\r
143     }\r
144 }\r
145 \r
146 \r
147 template <int N>\r
148 void ClassifierPegasos::KillDim()\r
149 {\r
150     if(!decFunction) return;\r
151     switch(kernelTypeTrained)\r
152     {\r
153     case 0:\r
154         if(decFunction) delete [] (linfunc*)decFunction;\r
155         break;\r
156     case 1:\r
157         if(decFunction) delete [] (polfunc*)decFunction;\r
158         break;\r
159     case 2:\r
160         if(decFunction) delete [] (rbffunc*)decFunction;\r
161         break;\r
162     }\r
163     decFunction = 0;\r
164 }\r
165 \r
166 template <int N>\r
167 void ClassifierPegasos::TrainDim(std::vector< fvec > _samples, ivec _labels)\r
168 {\r
169     std::vector<sampletype> samples;\r
170     std::vector<double> labels;\r
171     sampletype samp;\r
172     FOR(i, _samples.size()) { FOR(d, dim) samp(d) = _samples[i][d]; samples.push_back(samp); }\r
173     KillDim<N>();\r
174 \r
175     FOR(i, _samples.size()) labels.push_back(_labels[i] == 1 ? 1 : -1);\r
176 \r
177     randomize_samples(samples, labels);\r
178 \r
179     switch(kernelType)\r
180     {\r
181     case 0:\r
182     {\r
183         svm_pegasos<linkernel> train = svm_pegasos<linkernel>();\r
184         train.set_lambda(lambda);\r
185         train.set_kernel(linkernel());\r
186         train.set_max_num_sv(maxSV);\r
187         linfunc *fun = new linfunc[1];\r
188         *fun = batch_cached(train).train(samples, labels);\r
189         decFunction = (void *)fun;\r
190         kernelTypeTrained = 0;\r
191     }\r
192         break;\r
193     case 1:\r
194     {\r
195         svm_pegasos<polkernel> train = svm_pegasos<polkernel>();\r
196         train.set_lambda(lambda);\r
197         train.set_kernel(polkernel(1./kernelParam, 0, kernelDegree));\r
198         train.set_max_num_sv(maxSV);\r
199         polfunc *fun = new polfunc[1];\r
200         *fun = batch_cached(train).train(samples, labels);\r
201         decFunction = (void *)fun;\r
202         kernelTypeTrained = 1;\r
203     }\r
204         break;\r
205     case 2:\r
206     {\r
207         svm_pegasos<rbfkernel> train = svm_pegasos<rbfkernel>();\r
208         train.set_lambda(lambda);\r
209         train.set_kernel(rbfkernel(1./kernelParam));\r
210         train.set_max_num_sv(maxSV);\r
211         rbffunc *fun = new rbffunc[1];\r
212         *fun = batch_cached(train).train(samples, labels);\r
213         decFunction = (void *)fun;\r
214         kernelTypeTrained = 2;\r
215     }\r
216         break;\r
217     }\r
218 }\r
219 \r
220 template <int N>\r
221 float ClassifierPegasos::TestDim(const fvec &_sample) const\r
222 {\r
223     float estimate = 0.f;\r
224 \r
225     sampletype sample;\r
226     FOR(d,dim) sample(d) = _sample[d];\r
227     if(!decFunction) return estimate;\r
228     switch(kernelTypeTrained)\r
229     {\r
230     case 0:\r
231     {\r
232         linfunc fun = *(linfunc*)decFunction;\r
233         estimate = fun(sample);\r
234     }\r
235         break;\r
236     case 1:\r
237     {\r
238         polfunc fun = *(polfunc*)decFunction;\r
239         estimate = fun(sample);\r
240     }\r
241         break;\r
242     case 2:\r
243     {\r
244         rbffunc fun = *(rbffunc*)decFunction;\r
245         estimate = fun(sample);\r
246     }\r
247         break;\r
248     }\r
249     return estimate;\r
250 }\r
251 \r
252 template <int N>\r
253 std::vector<fvec> ClassifierPegasos::GetSVsDim() const\r
254 {\r
255     std::vector<fvec> SVs;\r
256     switch(kernelTypeTrained)\r
257     {\r
258     case 0:\r
259     {\r
260 \r
261         FOR(i, (*(linfunc*)decFunction).basis_vectors.nr())\r
262         {\r
263             fvec sv;\r
264             sv.push_back((*(linfunc*)decFunction).basis_vectors(i)(0));\r
265             sv.push_back((*(linfunc*)decFunction).basis_vectors(i)(1));\r
266             SVs.push_back(sv);\r
267         }\r
268     }\r
269         break;\r
270     case 1:\r
271     {\r
272 \r
273         FOR(i, (*(polfunc*)decFunction).basis_vectors.nr())\r
274         {\r
275             fvec sv;\r
276             sv.push_back((*(polfunc*)decFunction).basis_vectors(i)(0));\r
277             sv.push_back((*(polfunc*)decFunction).basis_vectors(i)(1));\r
278             SVs.push_back(sv);\r
279         }\r
280     }\r
281         break;\r
282     case 2:\r
283     {\r
284 \r
285         FOR(i, (*(rbffunc*)decFunction).basis_vectors.nr())\r
286         {\r
287             fvec sv;\r
288             sv.push_back((*(rbffunc*)decFunction).basis_vectors(i)(0));\r
289             sv.push_back((*(rbffunc*)decFunction).basis_vectors(i)(1));\r
290             SVs.push_back(sv);\r
291         }\r
292     }\r
293         break;\r
294     }\r
295     return SVs;\r
296 }\r