fixed log-likelihood function for gmm
[mldemos:baraks-mldemos.git] / _AlgorithmsPlugins / GMM / clustererGMM.cpp
1 /*********************************************************************\r
2 MLDemos: A User-Friendly visualization toolkit for machine learning\r
3 Copyright (C) 2010  Basilio Noris\r
4 Contact: mldemos@b4silio.com\r
5 \r
6 This library is free software; you can redistribute it and/or\r
7 modify it under the terms of the GNU Lesser General Public\r
8 License as published by the Free Software Foundation; either\r
9 version 2.1 of the License, or (at your option) any later version.\r
10 \r
11 This library is distributed in the hope that it will be useful,\r
12 but WITHOUT ANY WARRANTY; without even the implied warranty of\r
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU\r
14 Library General Public License for more details.\r
15 \r
16 You should have received a copy of the GNU Lesser General Public\r
17 License along with this library; if not, write to the Free\r
18 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.\r
19 *********************************************************************/\r
20 #include "public.h"\r
21 #include "clustererGMM.h"\r
22 \r
23 using namespace std;\r
24 \r
25 ClustererGMM::~ClustererGMM()\r
26 {\r
27     DEL(gmm);\r
28 }\r
29 \r
30 void ClustererGMM::Train(std::vector< fvec > samples)\r
31 {\r
32         if(!samples.size()) return;\r
33     dim = samples[0].size();\r
34         DEL(gmm);\r
35         gmm = new Gmm(nbClusters, dim);\r
36         KILL(data);\r
37         data = new float[samples.size()*dim];\r
38         FOR(i, samples.size())\r
39         {\r
40                 FOR(j, dim) data[i*dim + j] = samples[i][j];\r
41         }\r
42         gmm->init(data, samples.size(), initType);\r
43         gmm->em(data, samples.size(),-1e4,(COVARIANCE_TYPE)covarianceType);\r
44 //      FOR(i, nbClusters) gmm->SetPrior(i, 1.f/nbClusters);\r
45 }\r
46 \r
47 fvec ClustererGMM::Test( const fvec &sample)\r
48 {\r
49         fvec res;\r
50         res.resize(nbClusters,0);\r
51         if(!gmm) return res;\r
52         float estimate;\r
53         float sigma;\r
54         FOR(i, nbClusters) res[i] = gmm->pdf(&sample[0], i);\r
55         float sum = 0;\r
56         FOR(i, nbClusters) sum += res[i];\r
57         if(sum > FLT_MIN*3) FOR(i, nbClusters) res[i] /= sum;\r
58         return res;\r
59 }\r
60 \r
61 fvec ClustererGMM::Test( const fVec &sample)\r
62 {\r
63         fvec res;\r
64         res.resize(nbClusters,0);\r
65         if(!gmm) return res;\r
66         float estimate;\r
67         float sigma;\r
68         FOR(i, nbClusters) res[i] = gmm->pdf(sample._, i);\r
69         float sum = 0;\r
70         FOR(i, nbClusters) sum += res[i];\r
71         if(sum > FLT_MIN*3) FOR(i, nbClusters) res[i] /= sum;\r
72         return res;\r
73 }\r
74 \r
75 float ClustererGMM::GetLogLikelihood(std::vector<fvec> samples)\r
76 {\r
77     float *weights = new float[nbClusters];\r
78     float logLik = 0;\r
79     FOR(i, samples.size())\r
80     {\r
81         gmm->pdf(&samples[i][0], weights);\r
82         float likelihood = 0;\r
83         FOR(j, nbClusters) likelihood += weights[j];\r
84         logLik += logf(likelihood);\r
85     }\r
86     delete [] weights;\r
87     return logLik;\r
88 }\r
89 \r
90 float ClustererGMM::GetParameterCount()\r
91 {\r
92     switch ( covarianceType ) {\r
93     case 0: // spherical\r
94         return nbClusters*(dim+1);\r
95         break;\r
96     case 1: // diagonal\r
97         return nbClusters*(2*dim);\r
98         break;\r
99     case 2: // full\r
100         return nbClusters*(dim + dim*(dim+1)/2);\r
101         break;\r
102     }\r
103     return nbClusters;\r
104 }\r
105 \r
106 void ClustererGMM::SetParams(u32 nbClusters, u32 covarianceType, u32 initType)\r
107 {\r
108         this->nbClusters = nbClusters;\r
109         this->covarianceType = covarianceType;\r
110         this->initType = initType;\r
111 }\r
112 \r
113 const char *ClustererGMM::GetInfoString()\r
114 {\r
115         char *text = new char[1024];\r
116         sprintf(text, "GMM\n");\r
117         sprintf(text, "%sClusters: %d\n", text, nbClusters);\r
118         sprintf(text, "%sCovariance Type: ", text);\r
119         switch(covarianceType)\r
120         {\r
121         case 0:\r
122                 sprintf(text, "%sSpherical\n", text);\r
123                 break;\r
124         case 1:\r
125                 sprintf(text, "%sDiagonal\n", text);\r
126                 break;\r
127         case 2:\r
128                 sprintf(text, "%sFull\n", text);\r
129                 break;\r
130         }\r
131         sprintf(text, "%sInitialization Type: ", text);\r
132         switch(initType)\r
133         {\r
134         case 0:\r
135                 sprintf(text, "%sRandom\n", text);\r
136                 break;\r
137         case 1:\r
138                 sprintf(text, "%sUniform\n", text);\r
139                 break;\r
140         case 2:\r
141                 sprintf(text, "%sK-Means\n", text);\r
142                 break;\r
143         }\r
144         return text;\r
145 }\r