Merge branch 'tweaks' into devel
[mldemos:baraks-mldemos.git] / _AlgorithmsPlugins / KernelMethods / regressorKRLS.cpp
1 /*********************************************************************\r
2 MLDemos: A User-Friendly visualization toolkit for machine learning\r
3 Copyright (C) 2010  Basilio Noris\r
4 Contact: mldemos@b4silio.com\r
5 \r
6 This library is free software; you can redistribute it and/or\r
7 modify it under the terms of the GNU Lesser General Public\r
8 License as published by the Free Software Foundation; either\r
9 version 2.1 of the License, or (at your option) any later version.\r
10 \r
11 This library is distributed in the hope that it will be useful,\r
12 but WITHOUT ANY WARRANTY; without even the implied warranty of\r
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU\r
14 Library General Public License for more details.\r
15 \r
16 You should have received a copy of the GNU Lesser General Public\r
17 License along with this library; if not, write to the Free\r
18 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.\r
19 *********************************************************************/\r
20 #include <public.h>\r
21 #include "regressorKRLS.h"\r
22 \r
23 using namespace std;\r
24 \r
25 const char *RegressorKRLS::GetInfoString()\r
26 {\r
27         char *text = new char[255];\r
28         sprintf(text, "Kernel Ridge Least Squares\n");\r
29         sprintf(text, "%sCapacity: %d", text, capacity);\r
30         sprintf(text, "%sKernel: ", text);\r
31         switch(kernelType)\r
32         {\r
33         case 0:\r
34                 sprintf(text, "%s linear", text);\r
35                 break;\r
36         case 1:\r
37                 sprintf(text, "%s polynomial (deg: %d width: %f)", text, kernelDegree, kernelParam);\r
38                 break;\r
39         case 2:\r
40                 sprintf(text, "%s rbf (gamma: %f)", text, kernelParam);\r
41                 break;\r
42         }\r
43         sprintf(text, "%seps: %f\n", text, epsilon);\r
44         sprintf(text, "%sBasis Functions: %lu\n", text, (unsigned long)GetSVs().size());\r
45         return text;\r
46 }\r
47 \r
48 RegressorKRLS::~RegressorKRLS()\r
49 {\r
50         DEL(linTrainer);\r
51         DEL(polTrainer);\r
52         DEL(rbfTrainer);\r
53 }\r
54 \r
55 void RegressorKRLS::Train(std::vector< fvec > _samples, ivec _labels)\r
56 {\r
57         if(capacity == 1) capacity = 2;\r
58         samples.clear();\r
59         labels.clear();\r
60     if(!_samples.size()) return;\r
61     dim = _samples[0].size()-1;\r
62 \r
63         FOR(i, _samples.size())\r
64         {\r
65         reg_sample_type samp(dim);\r
66         FOR(d, dim) samp(d) = _samples[i][d];\r
67         if(outputDim != -1 && outputDim < dim) samp(outputDim) = _samples[i][dim];\r
68                 samples.push_back(samp);\r
69         labels.push_back(_samples[i][outputDim != -1 ? outputDim : dim]);\r
70         }\r
71         randomize_samples(samples, labels);\r
72 \r
73     DEL(linTrainer);\r
74     DEL(polTrainer);\r
75     DEL(rbfTrainer);\r
76     switch(kernelType)\r
77         {\r
78         case 0:\r
79                 {\r
80                         linTrainer = new dlib::krls<reg_lin_kernel>(reg_lin_kernel(),epsilon,capacity ? capacity : 1000000);\r
81                         FOR(i, samples.size())\r
82                         {\r
83                                 linTrainer->train(samples[i], labels[i]);\r
84                         }\r
85                         linFunc = linTrainer->get_decision_function();\r
86                 }\r
87                 break;\r
88         case 1:\r
89                 {\r
90                         polTrainer = new dlib::krls<reg_pol_kernel>(reg_pol_kernel(1./kernelParam,0,kernelDegree),epsilon,capacity ? capacity : 1000000);\r
91                         FOR(i, samples.size())\r
92                         {\r
93                                 polTrainer->train(samples[i], labels[i]);\r
94                         }\r
95                         polFunc = polTrainer->get_decision_function();\r
96                 }\r
97                 break;\r
98         case 2:\r
99                 {\r
100                         rbfTrainer = new dlib::krls<reg_rbf_kernel>(reg_rbf_kernel(1./kernelParam),epsilon,capacity ? capacity : 1000000);\r
101                         FOR(i, samples.size())\r
102                         {\r
103                                 rbfTrainer->train(samples[i], labels[i]);\r
104                         }\r
105                         rbfFunc = rbfTrainer->get_decision_function();\r
106                 }\r
107                 break;\r
108         }\r
109 }\r
110 \r
111 fvec  RegressorKRLS::Test( const fvec &_sample )\r
112 {\r
113         fvec res;\r
114         res.resize(2,0);\r
115     if(!linTrainer && !polTrainer && !rbfTrainer) return res;\r
116     reg_sample_type sample(dim);\r
117     FOR(d, dim) sample(d) = _sample[d];\r
118     if(outputDim != -1 && outputDim < dim) sample(outputDim) = _sample[dim];\r
119     switch(kernelType)\r
120         {\r
121         case 0:\r
122                 res[0] = (*linTrainer)(sample);\r
123                 break;\r
124         case 1:\r
125                 res[0] = (*polTrainer)(sample);\r
126                 break;\r
127         case 2:\r
128                 res[0] = (*rbfTrainer)(sample);\r
129                 break;\r
130         }\r
131         return res;\r
132 }\r
133 \r
134 fVec  RegressorKRLS::Test( const fVec &_sample )\r
135 {\r
136         fVec res;\r
137         reg_sample_type sample;\r
138         sample(0) = _sample._[0];\r
139         switch(kernelType)\r
140         {\r
141         case 0:\r
142                 res[0] = (*linTrainer)(sample);\r
143                 break;\r
144         case 1:\r
145                 res[0] = (*polTrainer)(sample);\r
146                 break;\r
147         case 2:\r
148                 res[0] = (*rbfTrainer)(sample);\r
149                 break;\r
150         }\r
151         return res;\r
152 }\r
153 \r
154 std::vector<fvec> RegressorKRLS::GetSVs()\r
155 {\r
156         vector<fvec> SVs;\r
157         if(kernelType == 0)\r
158         {\r
159                 FOR(i, linFunc.basis_vectors.nr())\r
160                 {\r
161             fvec sv(dim+1,0);\r
162             FOR(d, dim)  sv[d] = linFunc.basis_vectors(i)(d);\r
163             if(outputDim != -1 && outputDim < dim)\r
164             {\r
165                 sv[dim] = sv[outputDim];\r
166                 sv[outputDim] = 0;\r
167             }\r
168             SVs.push_back(sv);\r
169                 }\r
170         }\r
171         else if(kernelType == 1)\r
172         {\r
173                 FOR(i, polFunc.basis_vectors.nr())\r
174                 {\r
175             fvec sv(dim+1,0);\r
176             FOR(d, dim)  sv[d] = polFunc.basis_vectors(i)(d);\r
177             if(outputDim != -1 && outputDim < dim)\r
178             {\r
179                 sv[dim] = sv[outputDim];\r
180                 sv[outputDim] = 0;\r
181             }\r
182             SVs.push_back(sv);\r
183                 }\r
184         }\r
185         else if(kernelType == 2)\r
186         {\r
187                 FOR(i, rbfFunc.basis_vectors.nr())\r
188                 {\r
189             fvec sv(dim+1,0);\r
190             FOR(d, dim)  sv[d] = rbfFunc.basis_vectors(i)(d);\r
191             if(outputDim != -1 && outputDim < dim)\r
192             {\r
193                 sv[dim] = sv[outputDim];\r
194                 sv[outputDim] = 0;\r
195             }\r
196             SVs.push_back(sv);\r
197                 }\r
198         }\r
199 \r
200         FOR(i, SVs.size())\r
201         {\r
202                 int closest = 0;\r
203                 double dist = DBL_MAX;\r
204                 FOR(j, samples.size())\r
205                 {\r
206                         double d = abs(samples[j](0)-SVs[i][0]);\r
207                         if(d < dist)\r
208                         {\r
209                                 dist = d;\r
210                                 closest = j;\r
211                         }\r
212                 }\r
213                 SVs[i][1] = labels[closest];\r
214         }\r
215         return SVs;\r
216 }\r