- Added maximization features, with corresponding plugins. Added painting feature...
[mldemos:mldemos.git] / _3rdParty / dlib / svm / svm_abstract.h
1 // Copyright (C) 2007  Davis E. King (davis@dlib.net)\r
2 // License: Boost Software License   See LICENSE.txt for the full license.\r
3 #undef DLIB_SVm_ABSTRACT_\r
4 #ifdef DLIB_SVm_ABSTRACT_\r
5 \r
6 #include <cmath>\r
7 #include <limits>\r
8 #include <sstream>\r
9 #include "../matrix/matrix_abstract.h"\r
10 #include "../algs.h"\r
11 #include "../serialize.h"\r
12 #include "function_abstract.h"\r
13 #include "kernel_abstract.h"\r
14 \r
15 namespace dlib\r
16 {\r
17 \r
18 // ----------------------------------------------------------------------------------------\r
19 // ----------------------------------------------------------------------------------------\r
20 // ----------------------------------------------------------------------------------------\r
21 \r
22     class invalid_svm_nu_error : public dlib::error \r
23     { \r
24         /*!\r
25             WHAT THIS OBJECT REPRESENTS\r
26                 This object is an exception class used to indicate that a\r
27                 value of nu used for svm training is incompatible with a\r
28                 particular data set.\r
29 \r
30                 this->nu will be set to the invalid value of nu used.\r
31         !*/\r
32 \r
33     public: \r
34         invalid_svm_nu_error(const std::string& msg, double nu_) : dlib::error(msg), nu(nu_) {};\r
35         const double nu;\r
36     };\r
37 \r
38 // ----------------------------------------------------------------------------------------\r
39 \r
40     template <\r
41         typename T\r
42         >\r
43     typename T::type maximum_nu (\r
44         const T& y\r
45     );\r
46     /*!\r
47         requires\r
48             - T == a matrix object or an object convertible to a matrix via \r
49               vector_to_matrix()\r
50             - y.nc() == 1\r
51             - y.nr() > 1\r
52             - for all valid i:\r
53                 - y(i) == -1 or +1\r
54         ensures\r
55             - returns the maximum valid nu that can be used with the svm_nu_trainer and\r
56               the training set labels from the given y vector.\r
57               (i.e. 2.0*min(number of +1 examples in y, number of -1 examples in y)/y.nr())\r
58     !*/\r
59 \r
60 // ----------------------------------------------------------------------------------------\r
61 \r
62     template <\r
63         typename T,\r
64         typename U\r
65         >\r
66     bool is_binary_classification_problem (\r
67         const T& x,\r
68         const U& x_labels\r
69     );\r
70     /*!\r
71         requires\r
72             - T == a matrix or something convertible to a matrix via vector_to_matrix()\r
73             - U == a matrix or something convertible to a matrix via vector_to_matrix()\r
74         ensures\r
75             - returns true if all of the following are true and false otherwise:\r
76                 - is_col_vector(x) == true\r
77                 - is_col_vector(x_labels) == true\r
78                 - x.size() == x_labels.size() \r
79                 - x.size() > 1\r
80                 - there exists at least one sample from both the +1 and -1 classes.\r
81                   (i.e. all samples can't have the same label)\r
82                 - for all valid i:\r
83                     - x_labels(i) == -1 or +1\r
84     !*/\r
85 \r
86 // ----------------------------------------------------------------------------------------\r
87 // ----------------------------------------------------------------------------------------\r
88 // ----------------------------------------------------------------------------------------\r
89 \r
90     template <\r
91         typename K \r
92         >\r
93     class svm_nu_trainer\r
94     {\r
95         /*!\r
96             REQUIREMENTS ON K \r
97                 is a kernel function object as defined in dlib/svm/kernel_abstract.h \r
98 \r
99             WHAT THIS OBJECT REPRESENTS\r
100                 This object implements a trainer for a nu support vector machine for \r
101                 solving binary classification problems.\r
102 \r
103                 The implementation of the nu-svm training algorithm used by this object is based\r
104                 on the following excellent papers:\r
105                     - Chang and Lin, Training {nu}-Support Vector Classifiers: Theory and Algorithms\r
106                     - Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector \r
107                       machines, 2001. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm\r
108 \r
109         !*/\r
110 \r
111     public:\r
112         typedef K kernel_type;\r
113         typedef typename kernel_type::scalar_type scalar_type;\r
114         typedef typename kernel_type::sample_type sample_type;\r
115         typedef typename kernel_type::mem_manager_type mem_manager_type;\r
116         typedef decision_function<kernel_type> trained_function_type;\r
117 \r
118         svm_nu_trainer (\r
119         );\r
120         /*!\r
121             ensures\r
122                 - This object is properly initialized and ready to be used\r
123                   to train a support vector machine.\r
124                 - #get_nu() == 0.1 \r
125                 - #get_cache_size() == 200\r
126                 - #get_epsilon() == 0.001\r
127         !*/\r
128 \r
129         svm_nu_trainer (\r
130             const kernel_type& kernel, \r
131             const scalar_type& nu\r
132         );\r
133         /*!\r
134             requires\r
135                 - 0 < nu <= 1\r
136             ensures\r
137                 - This object is properly initialized and ready to be used\r
138                   to train a support vector machine.\r
139                 - #get_kernel() == kernel\r
140                 - #get_nu() == nu\r
141                 - #get_cache_size() == 200\r
142                 - #get_epsilon() == 0.001\r
143         !*/\r
144 \r
145         void set_cache_size (\r
146             long cache_size\r
147         );\r
148         /*!\r
149             requires\r
150                 - cache_size > 0\r
151             ensures\r
152                 - #get_cache_size() == cache_size \r
153         !*/\r
154 \r
155         const long get_cache_size (\r
156         ) const;\r
157         /*!\r
158             ensures\r
159                 - returns the number of megabytes of cache this object will use\r
160                   when it performs training via the this->train() function.\r
161                   (bigger values of this may make training go faster but won't affect \r
162                   the result.  However, too big a value will cause you to run out of \r
163                   memory, obviously.)\r
164         !*/\r
165 \r
166         void set_epsilon (\r
167             scalar_type eps\r
168         );\r
169         /*!\r
170             requires\r
171                 - eps > 0\r
172             ensures\r
173                 - #get_epsilon() == eps \r
174         !*/\r
175 \r
176         const scalar_type get_epsilon (\r
177         ) const;\r
178         /*!\r
179             ensures\r
180                 - returns the error epsilon that determines when training should stop.\r
181                   Generally a good value for this is 0.001.  Smaller values may result\r
182                   in a more accurate solution but take longer to execute.\r
183         !*/\r
184 \r
185         void set_kernel (\r
186             const kernel_type& k\r
187         );\r
188         /*!\r
189             ensures\r
190                 - #get_kernel() == k \r
191         !*/\r
192 \r
193         const kernel_type& get_kernel (\r
194         ) const;\r
195         /*!\r
196             ensures\r
197                 - returns a copy of the kernel function in use by this object\r
198         !*/\r
199 \r
200         void set_nu (\r
201             scalar_type nu\r
202         );\r
203         /*!\r
204             requires\r
205                 - 0 < nu <= 1\r
206             ensures\r
207                 - #get_nu() == nu\r
208         !*/\r
209 \r
210         const scalar_type get_nu (\r
211         ) const;\r
212         /*!\r
213             ensures\r
214                 - returns the nu svm parameter.  This is a value between 0 and\r
215                   1.  It is the parameter that determines the trade off between\r
216                   trying to fit the training data exactly or allowing more errors \r
217                   but hopefully improving the generalization ability of the \r
218                   resulting classifier.  Smaller values encourage exact fitting \r
219                   while larger values of nu may encourage better generalization. \r
220                   For more information you should consult the papers referenced \r
221                   above.\r
222         !*/\r
223 \r
224         template <\r
225             typename in_sample_vector_type,\r
226             typename in_scalar_vector_type\r
227             >\r
228         const decision_function<kernel_type> train (\r
229             const in_sample_vector_type& x,\r
230             const in_scalar_vector_type& y\r
231         ) const;\r
232         /*!\r
233             requires\r
234                 - is_binary_classification_problem(x,y) == true\r
235                 - x == a matrix or something convertible to a matrix via vector_to_matrix().\r
236                   Also, x should contain sample_type objects.\r
237                 - y == a matrix or something convertible to a matrix via vector_to_matrix().\r
238                   Also, y should contain scalar_type objects.\r
239             ensures\r
240                 - trains a nu support vector classifier given the training samples in x and \r
241                   labels in y.  Training is done when the error is less than get_epsilon().\r
242                 - returns a decision function F with the following properties:\r
243                     - if (new_x is a sample predicted have +1 label) then\r
244                         - F(new_x) >= 0\r
245                     - else\r
246                         - F(new_x) < 0\r
247             throws\r
248                 - invalid_svm_nu_error\r
249                   This exception is thrown if get_nu() >= maximum_nu(y)\r
250                 - std::bad_alloc\r
251         !*/\r
252 \r
253         void swap (\r
254             svm_nu_trainer& item\r
255         );\r
256         /*!\r
257             ensures\r
258                 - swaps *this and item\r
259         !*/\r
260     }; \r
261 \r
262     template <typename K>\r
263     void swap (\r
264         svm_nu_trainer<K>& a,\r
265         svm_nu_trainer<K>& b\r
266     ) { a.swap(b); }\r
267     /*!\r
268         provides a global swap\r
269     !*/\r
270 \r
271 // ----------------------------------------------------------------------------------------\r
272 \r
273     template <\r
274         typename trainer_type,\r
275         typename in_sample_vector_type,\r
276         typename in_scalar_vector_type\r
277         >\r
278     const probabilistic_decision_function<typename trainer_type::kernel_type> \r
279     train_probabilistic_decision_function (\r
280         const trainer_type& trainer,\r
281         const in_sample_vector_type& x,\r
282         const in_scalar_vector_type& y,\r
283         const long folds\r
284     )\r
285     /*!\r
286         requires\r
287             - 1 < folds <= x.nr()\r
288             - is_binary_classification_problem(x,y) == true\r
289             - trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer)\r
290         ensures\r
291             - trains a nu support vector classifier given the training samples in x and \r
292               labels in y.  \r
293             - returns a probabilistic_decision_function that represents the trained svm.\r
294             - The parameters of the probability model are estimated by performing k-fold \r
295               cross validation. \r
296             - The number of folds used is given by the folds argument.\r
297         throws\r
298             - any exceptions thrown by trainer.train()\r
299             - std::bad_alloc\r
300     !*/\r
301 \r
302 // ----------------------------------------------------------------------------------------\r
303 // ----------------------------------------------------------------------------------------\r
304 //                                  Miscellaneous functions\r
305 // ----------------------------------------------------------------------------------------\r
306 // ----------------------------------------------------------------------------------------\r
307 \r
308     template <\r
309         typename trainer_type,\r
310         typename in_sample_vector_type,\r
311         typename in_scalar_vector_type\r
312         >\r
313     const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type> \r
314     cross_validate_trainer (\r
315         const trainer_type& trainer,\r
316         const in_sample_vector_type& x,\r
317         const in_scalar_vector_type& y,\r
318         const long folds\r
319     );\r
320     /*!\r
321         requires\r
322             - is_binary_classification_problem(x,y) == true\r
323             - 1 < folds <= x.nr()\r
324             - trainer_type == some kind of trainer object (e.g. svm_nu_trainer)\r
325         ensures\r
326             - performs k-fold cross validation by using the given trainer to solve the\r
327               given binary classification problem for the given number of folds.\r
328               Each fold is tested using the output of the trainer and the average \r
329               classification accuracy from all folds is returned.  \r
330             - The average accuracy is computed by running test_binary_decision_function()\r
331               on each fold and its output is averaged and returned.\r
332             - The number of folds used is given by the folds argument.\r
333         throws\r
334             - any exceptions thrown by trainer.train()\r
335             - std::bad_alloc\r
336     !*/\r
337 \r
338 // ----------------------------------------------------------------------------------------\r
339 \r
340     template <\r
341         typename dec_funct_type,\r
342         typename in_sample_vector_type,\r
343         typename in_scalar_vector_type\r
344         >\r
345     const matrix<typename dec_funct_type::scalar_type, 1, 2, typename dec_funct_type::mem_manager_type> \r
346     test_binary_decision_function (\r
347         const dec_funct_type& dec_funct,\r
348         const in_sample_vector_type& x_test,\r
349         const in_scalar_vector_type& y_test\r
350     );\r
351     /*!\r
352         requires\r
353             - is_binary_classification_problem(x_test,y_test) == true\r
354             - dec_funct_type == some kind of decision function object (e.g. decision_function)\r
355         ensures\r
356             - Tests the given decision function by calling it on the x_test and y_test samples.\r
357               The output of dec_funct is interpreted as a prediction for the +1 class\r
358               if its output is >= 0 and as a prediction for the -1 class otherwise.\r
359             - The test accuracy is returned in a row vector, let us call it R.  Both \r
360               quantities in R are numbers between 0 and 1 which represent the fraction \r
361               of examples correctly classified.  R(0) is the fraction of +1 examples \r
362               correctly classified and R(1) is the fraction of -1 examples correctly \r
363               classified.\r
364         throws\r
365             - std::bad_alloc\r
366     !*/\r
367 \r
368 // ----------------------------------------------------------------------------------------\r
369 // ----------------------------------------------------------------------------------------\r
370 \r
371     template <\r
372         typename T,\r
373         typename U\r
374         >\r
375     void randomize_samples (\r
376         T& samples,\r
377         U& labels \r
378     );\r
379     /*!\r
380         requires\r
381             - T == a matrix object or an object compatible with std::vector that contains \r
382               a swappable type.\r
383             - U == a matrix object or an object compatible with std::vector that contains \r
384               a swappable type.\r
385             - if samples or labels are matrix objects then is_vector(samples) == true and\r
386               is_vector(labels) == true\r
387             - samples.size() == labels.size()\r
388         ensures\r
389             - randomizes the order of the samples and labels but preserves\r
390               the pairing between each sample and its label\r
391             - A default initialized random number generator is used to perform the randomizing.\r
392               Note that this means that each call this this function does the same thing.  \r
393               That is, the random number generator always uses the same seed.\r
394             - for all valid i:\r
395                 - let r == the random index samples(i) was moved to.  then:\r
396                     - #labels(r) == labels(i)\r
397     !*/\r
398 \r
399 // ----------------------------------------------------------------------------------------\r
400 \r
401     template <\r
402         typename T,\r
403         typename U,\r
404         typename rand_type\r
405         >\r
406     void randomize_samples (\r
407         T& samples,\r
408         U& labels,\r
409         rand_type& rnd\r
410     );\r
411     /*!\r
412         requires\r
413             - T == a matrix object or an object compatible with std::vector that contains \r
414               a swappable type.\r
415             - U == a matrix object or an object compatible with std::vector that contains \r
416               a swappable type.\r
417             - if samples or labels are matrix objects then is_vector(samples) == true and\r
418               is_vector(labels) == true\r
419             - samples.size() == labels.size()\r
420             - rand_type == a type that implements the dlib/rand/rand_kernel_abstract.h interface\r
421         ensures\r
422             - randomizes the order of the samples and labels but preserves\r
423               the pairing between each sample and its label\r
424             - the given rnd random number generator object is used to do the randomizing\r
425             - for all valid i:\r
426                 - let r == the random index samples(i) was moved to.  then:\r
427                     - #labels(r) == labels(i)\r
428     !*/\r
429 \r
430 // ----------------------------------------------------------------------------------------\r
431 \r
432     template <\r
433         typename T\r
434         >\r
435     void randomize_samples (\r
436         T& samples\r
437     );\r
438     /*!\r
439         requires\r
440             - T == a matrix object or an object compatible with std::vector that contains \r
441               a swappable type.\r
442             - if samples is a matrix then is_vector(samples) == true \r
443         ensures\r
444             - randomizes the order of the elements inside samples \r
445             - A default initialized random number generator is used to perform the randomizing.\r
446               Note that this means that each call this this function does the same thing.  \r
447               That is, the random number generator always uses the same seed.\r
448     !*/\r
449 \r
450 // ----------------------------------------------------------------------------------------\r
451 \r
452     template <\r
453         typename T,\r
454         typename rand_type\r
455         >\r
456     void randomize_samples (\r
457         T& samples,\r
458         rand_type& rnd\r
459     );\r
460     /*!\r
461         requires\r
462             - T == a matrix object or an object compatible with std::vector that contains \r
463               a swappable type.\r
464             - if samples is a matrix then is_vector(samples) == true \r
465         ensures\r
466             - randomizes the order of the elements inside samples \r
467             - the given rnd random number generator object is used to do the randomizing\r
468     !*/\r
469 \r
470 // ----------------------------------------------------------------------------------------\r
471 // ----------------------------------------------------------------------------------------\r
472 \r
473 }\r
474 \r
475 #endif // DLIB_SVm_ABSTRACT_\r
476 \r
477 \r