Merging requests
[mldemos:mldemos.git] / _AlgorithmsPlugins / KernelMethods / clustererKM.cpp
1 /*********************************************************************\r
2 MLDemos: A User-Friendly visualization toolkit for machine learning\r
3 Copyright (C) 2010  Basilio Noris\r
4 Contact: mldemos@b4silio.com\r
5 \r
6 This library is free software; you can redistribute it and/or\r
7 modify it under the terms of the GNU Lesser General Public\r
8 License as published by the Free Software Foundation; either\r
9 version 2.1 of the License, or (at your option) any later version.\r
10 \r
11 This library is distributed in the hope that it will be useful,\r
12 but WITHOUT ANY WARRANTY; without even the implied warranty of\r
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU\r
14 Library General Public License for more details.\r
15 \r
16 You should have received a copy of the GNU Lesser General Public\r
17 License along with this library; if not, write to the Free\r
18 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.\r
19 *********************************************************************/\r
20 #include "public.h"\r
21 #include "clustererKM.h"\r
22 \r
23 using namespace std;\r
24 \r
25 ClustererKM::~ClustererKM()\r
26 {\r
27     DEL(kmeans);\r
28 }\r
29 \r
30 void ClustererKM::Train(std::vector< fvec > samples)\r
31 {\r
32     if(!samples.size()) return;\r
33     int dim = samples[0].size();\r
34     if(!bIterative)\r
35     {\r
36         DEL(kmeans);\r
37     }\r
38     bool bInit = false;\r
39     if(kmeans && kmeans->GetClusters() != nbClusters) DEL(kmeans);\r
40     if(!kmeans)\r
41     {\r
42         bInit = true;\r
43         kmeans = new KMeansCluster(nbClusters);\r
44         kmeans->AddPoints(samples);\r
45         kmeans->SetPlusPlus(kmeansPlusPlus);\r
46         kmeans->InitClusters();\r
47     }\r
48     kmeans->SetSoft(bSoft);\r
49     kmeans->SetGMM(bGmm);\r
50     kmeans->SetBeta(beta);\r
51     kmeans->SetPower(power);\r
52 \r
53     kmeans->Update(bInit);\r
54 \r
55     if(!bIterative)\r
56     {\r
57         int iterations = 20;\r
58         FOR(i, iterations) kmeans->Update();\r
59     }\r
60 }\r
61 \r
62 fvec ClustererKM::Test( const fvec &sample)\r
63 {\r
64     fvec res;\r
65     res.resize(nbClusters,0);\r
66     if(!kmeans) return res;\r
67     kmeans->Test(sample, res);\r
68     float sum = 0;\r
69     FOR(i, res.size()) sum += res[i];\r
70     FOR(i, res.size()) res[i] /= sum;\r
71     return res;\r
72 }\r
73 \r
74 fvec ClustererKM::Test( const fVec &sample)\r
75 {\r
76     fvec res;\r
77     res.resize(nbClusters,0);\r
78     if(!kmeans) return res;\r
79     kmeans->Test(sample, res);\r
80     float sum = 0;\r
81     FOR(i, res.size()) sum += res[i];\r
82     FOR(i, res.size()) res[i] /= sum;\r
83     return res;\r
84 }\r
85 \r
86 void ClustererKM::SetParams(u32 clusters, int method, float beta, int power, bool kmeansPlusPlus)\r
87 {\r
88 \r
89     this->nbClusters = clusters;\r
90     this->beta = beta;\r
91     this->power = power;\r
92     this->kmeansPlusPlus = kmeansPlusPlus;\r
93 \r
94     switch(method)\r
95     {\r
96     case 0:\r
97         this->bSoft = false;\r
98         this->bGmm = false;\r
99         break;\r
100     case 1:\r
101         this->bSoft = true;\r
102         this->bGmm = false;\r
103         break;\r
104     case 2:\r
105         this->bSoft = false;\r
106         this->bGmm = true;\r
107         break;\r
108     }\r
109 }\r
110 \r
111 \r
112 const char *ClustererKM::GetInfoString()\r
113 {\r
114     char *text = new char[1024];\r
115     sprintf(text, "K-Means\n");\r
116     sprintf(text, "%sClusters: %d\n", text, nbClusters);\r
117     sprintf(text, "%sType:", text);\r
118     if(!bSoft && !bGmm) sprintf(text, "%sK-Means (plusplus: %i)\n", text, kmeansPlusPlus);\r
119     else if(bSoft) sprintf(text, "%sSoft K-Means (beta: %.3f, plusplus: %i)\n", text, beta, kmeansPlusPlus);\r
120     else sprintf(text, "%sGMM\n", text);\r
121     sprintf(text, "%sMetric: ", text);\r
122     switch(power)\r
123     {\r
124     case 0:\r
125         sprintf(text, "%sinfinite norm\n", text);\r
126         break;\r
127     case 1:\r
128         sprintf(text, "%s1-norm (Manhattan)\n", text);\r
129         break;\r
130     case 2:\r
131         sprintf(text, "%s2-norm (Euclidean)\n", text);\r
132         break;\r
133     default:\r
134         sprintf(text, "%s%d-norm\n", text, power);\r
135         break;\r
136     }\r
137     return text;\r
138 }\r