Generic ComputeEM with Ramsay (1975) acceleration
[openmx:openmx.git] / src / omxExpectationBA81.h
1 /*
2   Copyright 2012-2013 Joshua Nathaniel Pritikin and contributors
3
4   This is free software: you can redistribute it and/or modify
5   it under the terms of the GNU General Public License as published by
6   the Free Software Foundation, either version 3 of the License, or
7   (at your option) any later version.
8
9   This program is distributed in the hope that it will be useful,
10   but WITHOUT ANY WARRANTY; without even the implied warranty of
11   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12   GNU General Public License for more details.
13
14   You should have received a copy of the GNU General Public License
15   along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #ifndef _OMX_EXPECTATIONBA81_H_
19 #define _OMX_EXPECTATIONBA81_H_
20
21 #include "omxExpectation.h"
22 #include "omxOpenmpWrap.h"
23
24 enum score_option {
25         SCORES_OMIT,
26         SCORES_UNIQUE,
27         SCORES_FULL
28 };
29
30 enum expectation_type {
31         EXPECTATION_UNINITIALIZED,
32         EXPECTATION_AUGMENTED, // E-M
33         EXPECTATION_OBSERVED,  // regular
34 };
35
36 struct BA81Expect {
37         double LogLargestDouble;       // should be const but need constexpr
38         double LargestDouble;          // should be const but need constexpr
39         double OneOverLargestDouble;   // should be const but need constexpr
40
41         // data characteristics
42         omxData *data;
43         int numUnique;
44         int *numIdentical;        // length numUnique
45         int *rowMap;              // length numUnique, index of first instance of pattern
46
47         // item description related
48         std::vector<const double*> itemSpec;
49         std::vector<int> itemOutcomes;
50         std::vector<int> cumItemOutcomes;
51         int maxDims;
52         int maxAbilities;
53         int numSpecific;
54         int *Sgroup;              // item's specific group 0..numSpecific-1
55         omxMatrix *design;        // items * maxDims
56
57         // quadrature related
58         double Qwidth;
59         double targetQpoints;
60         long quadGridSize;
61         long totalQuadPoints;                 // quadGridSize ^ maxDims
62         long totalPrimaryPoints;              // totalQuadPoints except for specific dim TODO
63         std::vector<double> wherePrep;        // totalQuadPoints * maxDims
64         std::vector<double> whereGram;        // totalQuadPoints * triangleLoc1(maxDims)
65         std::vector<double> Qpoint;           // quadGridSize
66         std::vector<double> priQarea;         // totalPrimaryPoints
67         std::vector<double> speQarea;         // quadGridSize * numSpecific
68
69         // estimation related
70         omxMatrix *customPrior;
71         omxMatrix *itemParam;
72         double *EitemParam;
73         double *patternLik;                   // numUnique
74         double SmallestPatternLik;
75         int excludedPatterns;
76         int totalOutcomes;
77         double *outcomeProb;                  // totalOutcomes * totalQuadPoints
78         double *expected;                     // totalOutcomes * totalQuadPoints
79         int ElatentVersion;
80         std::vector<double> ElatentMean;      // maxAbilities
81         std::vector<double> ElatentCov;       // maxAbilities * maxAbilities
82         omxMatrix *latentMeanOut;
83         omxMatrix *latentCovOut;
84
85         int itemParamVersion;
86         int latentParamVersion;
87         enum expectation_type type;
88         enum score_option scores;
89         bool verbose;
90 };
91
92 extern const struct rpf *rpf_model;
93 extern int rpf_numModels;
94
95 void ba81OutcomeProb(BA81Expect *state, bool estep, bool wantLog);
96
97 OMXINLINE static int
98 triangleLoc1(int diag)
99 {
100         //if (diag < 1) error("Out of domain");
101         return (diag) * (diag+1) / 2;   // 0 1 3 6 10 15 ..
102 }
103
104 OMXINLINE static int
105 triangleLoc0(int diag)
106 {
107         //if (diag < 0) error("Out of domain");
108         return triangleLoc1(diag+1) - 1;  // 0 2 5 9 14 ..
109 }
110
111 OMXINLINE static void
112 pointToWhere(BA81Expect *state, const int *quad, double *where, int upto)
113 {
114         for (int dx=0; dx < upto; dx++) {
115                 where[dx] = state->Qpoint[quad[dx]];
116         }
117 }
118
119 OMXINLINE static void
120 decodeLocation(long qx, const int dims, const long grid, int *quad)
121 {
122         for (int dx=dims-1; dx >= 0; --dx) {
123                 quad[dx] = qx % grid;
124                 qx = qx / grid;
125         }
126 }
127
128 // state->speQarea[sIndex(state, sgroup, sx)]
129 OMXINLINE static
130 int sIndex(BA81Expect *state, int sx, int qx)
131 {
132         //if (sx < 0 || sx >= state->numSpecific) error("Out of domain");
133         //if (qx < 0 || qx >= state->quadGridSize) error("Out of domain");
134         return qx * state->numSpecific + sx;
135 }
136
137 OMXINLINE static double
138 areaProduct(BA81Expect *state, long qx, int sx, const int sg)
139 {
140         if (state->numSpecific == 0) {
141                 return state->priQarea[qx];
142         } else {
143                 if (sx == -1) {
144                         sx = qx % state->quadGridSize;
145                         qx /= state->quadGridSize;
146                 }
147                 return state->priQarea[qx] * state->speQarea[sIndex(state, sg, sx)];
148         }
149 }
150
151 OMXINLINE static void
152 gramProduct(double *vec, size_t len, double *out)
153 {
154         int cell = 0;
155         for (size_t v1=0; v1 < len; ++v1) {
156                 for (size_t v2=0; v2 <= v1; ++v2) {
157                         out[cell] = vec[v1] * vec[v2];
158                         ++cell;
159                 }
160         }
161 }
162
163 OMXINLINE static bool
164 validPatternLik(BA81Expect *state, double pl)
165 {
166         return isfinite(pl) && pl > state->SmallestPatternLik;
167 }
168
169 void ba81SetupQuadrature(omxExpectation* oo);
170 void ba81LikelihoodSlow2(BA81Expect *state, int px, double *out);
171 void cai2010EiEis(BA81Expect *state, int px, double *lxk, double *Eis, double *Ei);
172
173 // debug tools
174 void pda(const double *ar, int rows, int cols);
175 void pia(const int *ar, int rows, int cols);
176
177 #endif