Avoid exp(); more than doubles performance; no detectable loss in accuracy
[openmx:openmx.git] / src / omxExpectationBA81.h
1 /*
2   Copyright 2012-2013 Joshua Nathaniel Pritikin and contributors
3
4   This is free software: you can redistribute it and/or modify
5   it under the terms of the GNU General Public License as published by
6   the Free Software Foundation, either version 3 of the License, or
7   (at your option) any later version.
8
9   This program is distributed in the hope that it will be useful,
10   but WITHOUT ANY WARRANTY; without even the implied warranty of
11   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12   GNU General Public License for more details.
13
14   You should have received a copy of the GNU General Public License
15   along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #ifndef _OMX_EXPECTATIONBA81_H_
19 #define _OMX_EXPECTATIONBA81_H_
20
21 #include "omxExpectation.h"
22 #include "omxOpenmpWrap.h"
23
24 enum score_option {
25         SCORES_OMIT,
26         SCORES_UNIQUE,
27         SCORES_FULL
28 };
29
30 enum expectation_type {
31         EXPECTATION_UNINITIALIZED,
32         EXPECTATION_AUGMENTED, // E-M
33         EXPECTATION_OBSERVED,  // regular
34 };
35
36 typedef struct {
37
38         // data characteristics
39         omxData *data;
40         int numUnique;
41         int *numIdentical;        // length numUnique
42         double *logNumIdentical;  // length numUnique
43         int *rowMap;              // length numUnique, index of first instance of pattern
44
45         // item description related
46         std::vector<const double*> itemSpec;
47         int maxOutcomes;
48         int maxDims;
49         int maxAbilities;
50         int numSpecific;
51         int *Sgroup;              // item's specific group 0..numSpecific-1
52         omxMatrix *design;        // items * maxDims
53
54         // quadrature related
55         double Qwidth;
56         double targetQpoints;
57         long quadGridSize;
58         long totalQuadPoints;     // quadGridSize ^ maxDims
59         long totalPrimaryPoints;  // totalQuadPoints except for specific dim TODO
60         std::vector<double> Qpoint;           // quadGridSize
61         std::vector<double> priLogQarea;      // totalPrimaryPoints
62         std::vector<double> speLogQarea;      // quadGridSize * numSpecific
63         std::vector<double> priQarea;         // totalPrimaryPoints
64         std::vector<double> speQarea;         // quadGridSize * numSpecific
65
66         // estimation related
67         omxMatrix *customPrior;
68         omxMatrix *itemParam;
69         omxMatrix *EitemParam;    // E step version
70         int cacheLXK;
71         bool LXKcached;
72         double *lxk;              // wo/cache, numUnique * thread
73         std::vector<double> allElxk;          // numUnique * thread
74         std::vector<double> Eslxk;            // numUnique * #specific dimensions * thread
75         double *patternLik;       // numUnique
76         double *_logPatternLik;   // numUnique
77         int totalOutcomes;
78         double *expected;         // totalOutcomes * totalQuadPoints
79         double *ElatentMean;      // maxAbilities * numUnique
80         double *ElatentCov;       // maxAbilities * maxAbilities * numUnique ; only lower triangle is used
81         omxMatrix *latentMeanOut;
82         omxMatrix *latentCovOut;
83
84         enum expectation_type type;
85         enum score_option scores;
86 } BA81Expect;
87
88 extern const struct rpf *rpf_model;
89 extern int rpf_numModels;
90
91 void ba81buildLXKcache(omxExpectation *oo);
92 double *computeRPF(BA81Expect *state, omxMatrix *itemParam, const int *quad, const bool wantlog);
93 void ba81SetupQuadrature(omxExpectation* oo, int gridsize, int flat);
94 void ba81Estep1(omxExpectation *oo);
95 void cai2010(omxExpectation* oo, const int thrId, int recompute, const int *primaryQuad);
96 double *ba81LikelihoodFast(omxExpectation *oo, const int thrId, int specific, const int *quad);
97 double *getLogPatternLik(omxExpectation* oo);
98
99 OMXINLINE static int
100 triangleLoc1(int diag)
101 {
102         //if (diag < 1) error("Out of domain");
103         return (diag) * (diag+1) / 2;   // 0 1 3 6 10 15 ..
104 }
105
106 OMXINLINE static int
107 triangleLoc0(int diag)
108 {
109         //if (diag < 0) error("Out of domain");
110         return triangleLoc1(diag+1) - 1;  // 0 2 5 9 14 ..
111 }
112
113 // state->allElxk[eIndex(state, px)]
114 OMXINLINE static int
115 eIndex(BA81Expect *state, int thr, int px)
116 {
117         return thr * state->numUnique + px;
118 }
119
120 // state->Eslxk[esIndex(state, sx, px)]
121 OMXINLINE static int
122 esIndex(BA81Expect *state, int thr, int sx, int px)
123 {
124         return (thr * state->numUnique * state->numSpecific +
125                 state->numUnique * sx + px);
126 }
127
128 OMXINLINE static void
129 pointToWhere(BA81Expect *state, const int *quad, double *where, int upto)
130 {
131         for (int dx=0; dx < upto; dx++) {
132                 where[dx] = state->Qpoint[quad[dx]];
133         }
134 }
135
136 OMXINLINE static long
137 encodeLocation(const int dims, const long grid, const int *quad)
138 {
139         long qx = 0;
140         for (int dx=dims-1; dx >= 0; dx--) {
141                 qx = qx * grid;
142                 qx += quad[dx];
143         }
144         return qx;
145 }
146
147 OMXINLINE static void
148 decodeLocation(long qx, const int dims, const long grid, int *quad)
149 {
150         for (int dx=0; dx < dims; dx++) {
151                 quad[dx] = qx % grid;
152                 qx = qx / grid;
153         }
154 }
155
156 OMXINLINE static double
157 logAreaProduct(BA81Expect *state, const int *quad, const int sg) // remove? TODO
158 {
159         int maxDims = state->maxDims;
160         if (state->numSpecific == 0) {
161                 long qloc = encodeLocation(maxDims, state->quadGridSize, quad);
162                 return state->priLogQarea[qloc];
163         } else {
164                 long priloc = encodeLocation(maxDims-1, state->quadGridSize, quad);
165                 return (state->priLogQarea[priloc] +
166                         state->speLogQarea[sg * state->quadGridSize + quad[maxDims - 1]]);
167         }
168 }
169
170 OMXINLINE static double
171 areaProduct(BA81Expect *state, const int *quad, const int sg)
172 {
173         int maxDims = state->maxDims;
174         if (state->numSpecific == 0) {
175                 long qloc = encodeLocation(maxDims, state->quadGridSize, quad);
176                 return state->priQarea[qloc];
177         } else {
178                 long priloc = encodeLocation(maxDims-1, state->quadGridSize, quad);
179                 return (state->priQarea[priloc] *
180                         state->speQarea[sg * state->quadGridSize + quad[maxDims - 1]]);
181         }
182 }
183
184 // debug tools
185 void pda(const double *ar, int rows, int cols);
186 void pia(const int *ar, int rows, int cols);
187
188 #endif