Move allocation out of computeRPF
[openmx:openmx.git] / src / omxExpectationBA81.h
1 /*
2   Copyright 2012-2013 Joshua Nathaniel Pritikin and contributors
3
4   This is free software: you can redistribute it and/or modify
5   it under the terms of the GNU General Public License as published by
6   the Free Software Foundation, either version 3 of the License, or
7   (at your option) any later version.
8
9   This program is distributed in the hope that it will be useful,
10   but WITHOUT ANY WARRANTY; without even the implied warranty of
11   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12   GNU General Public License for more details.
13
14   You should have received a copy of the GNU General Public License
15   along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #ifndef _OMX_EXPECTATIONBA81_H_
19 #define _OMX_EXPECTATIONBA81_H_
20
21 #include "omxExpectation.h"
22 #include "omxOpenmpWrap.h"
23
24 enum score_option {
25         SCORES_OMIT,
26         SCORES_UNIQUE,
27         SCORES_FULL
28 };
29
30 enum expectation_type {
31         EXPECTATION_UNINITIALIZED,
32         EXPECTATION_AUGMENTED, // E-M
33         EXPECTATION_OBSERVED,  // regular
34 };
35
36 typedef struct {
37
38         // data characteristics
39         omxData *data;
40         int numUnique;
41         int *numIdentical;        // length numUnique
42         int *rowMap;              // length numUnique, index of first instance of pattern
43
44         // item description related
45         std::vector<const double*> itemSpec;
46         std::vector<int> itemOutcomes;
47         int maxOutcomes;
48         int maxDims;
49         int maxAbilities;
50         int numSpecific;
51         int *Sgroup;              // item's specific group 0..numSpecific-1
52         omxMatrix *design;        // items * maxDims
53
54         // quadrature related
55         double Qwidth;
56         double targetQpoints;
57         long quadGridSize;
58         long totalQuadPoints;     // quadGridSize ^ maxDims
59         long totalPrimaryPoints;  // totalQuadPoints except for specific dim TODO
60         std::vector<double> Qpoint;           // quadGridSize
61         std::vector<double> priQarea;         // totalPrimaryPoints
62         std::vector<double> speQarea;         // quadGridSize * numSpecific
63
64         // estimation related
65         omxMatrix *customPrior;
66         omxMatrix *itemParam;
67         omxMatrix *EitemParam;    // E step version
68         int cacheLXK;
69         bool LXKcached;
70         double *lxk;              // wo/cache, numUnique * thread
71         double *allElxk;          // numUnique * thread
72         double *Eslxk;            // numUnique * #specific dimensions * thread
73         double *patternLik;       // numUnique
74         int totalOutcomes;
75         double *expected;         // totalOutcomes * totalQuadPoints
76         std::vector<double> ElatentMean;      // maxAbilities
77         std::vector<double> ElatentCov;       // maxAbilities * maxAbilities
78         omxMatrix *latentMeanOut;
79         omxMatrix *latentCovOut;
80
81         int itemParamVersion;
82         int latentParamVersion;
83         enum expectation_type type;
84         enum score_option scores;
85         bool verbose;
86         bool checkedBadData;
87 } BA81Expect;
88
89 extern const struct rpf *rpf_model;
90 extern int rpf_numModels;
91
92 void computeRPF(BA81Expect *state, omxMatrix *itemParam, const int *quad,
93                 const bool wantlog, double *outcomeProb);
94 void cai2010(omxExpectation* oo, const int thrId, int recompute, const int *primaryQuad);
95 double *ba81LikelihoodFast(omxExpectation *oo, const int thrId, int specific, const int *quad);
96
97 OMXINLINE static int
98 triangleLoc1(int diag)
99 {
100         //if (diag < 1) error("Out of domain");
101         return (diag) * (diag+1) / 2;   // 0 1 3 6 10 15 ..
102 }
103
104 OMXINLINE static int
105 triangleLoc0(int diag)
106 {
107         //if (diag < 0) error("Out of domain");
108         return triangleLoc1(diag+1) - 1;  // 0 2 5 9 14 ..
109 }
110
111 OMXINLINE static double *
112 eBase(BA81Expect *state, int thr)
113 {
114         return state->allElxk + thr * state->numUnique;
115 }
116
117 OMXINLINE static double *
118 esBase(BA81Expect *state, int thr)
119 {
120         return state->Eslxk + thr * state->numUnique * state->numSpecific;
121 }
122
123 OMXINLINE static void
124 pointToWhere(BA81Expect *state, const int *quad, double *where, int upto)
125 {
126         for (int dx=0; dx < upto; dx++) {
127                 where[dx] = state->Qpoint[quad[dx]];
128         }
129 }
130
131 OMXINLINE static long
132 encodeLocation(const int dims, const long grid, const int *quad)
133 {
134         long qx = 0;
135         for (int dx=dims-1; dx >= 0; dx--) {
136                 qx = qx * grid;
137                 qx += quad[dx];
138         }
139         return qx;
140 }
141
142 OMXINLINE static void
143 decodeLocation(long qx, const int dims, const long grid, int *quad)
144 {
145         for (int dx=0; dx < dims; dx++) {
146                 quad[dx] = qx % grid;
147                 qx = qx / grid;
148         }
149 }
150
151 OMXINLINE static double
152 areaProduct(BA81Expect *state, const int *quad, const int sg)
153 {
154         int maxDims = state->maxDims;
155         if (state->numSpecific == 0) {
156                 long qloc = encodeLocation(maxDims, state->quadGridSize, quad);
157                 return state->priQarea[qloc];
158         } else {
159                 long priloc = encodeLocation(maxDims-1, state->quadGridSize, quad);
160                 return (state->priQarea[priloc] *
161                         state->speQarea[sg * state->quadGridSize + quad[maxDims - 1]]);
162         }
163 }
164
165 OMXINLINE static void
166 gramProduct(double *vec, size_t len, double *out)
167 {
168         int cell = 0;
169         for (size_t v1=0; v1 < len; ++v1) {
170                 for (size_t v2=0; v2 <= v1; ++v2) {
171                         out[cell] = vec[v1] * vec[v2];
172                         ++cell;
173                 }
174         }
175 }
176
177 // debug tools
178 void pda(const double *ar, int rows, int cols);
179 void pia(const int *ar, int rows, int cols);
180
181 #endif