Reorg E step 1/9
[openmx:openmx.git] / src / omxExpectationBA81.cpp
1 /*
2   Copyright 2012-2013 Joshua Nathaniel Pritikin and contributors
3
4   This is free software: you can redistribute it and/or modify
5   it under the terms of the GNU General Public License as published by
6   the Free Software Foundation, either version 3 of the License, or
7   (at your option) any later version.
8
9   This program is distributed in the hope that it will be useful,
10   but WITHOUT ANY WARRANTY; without even the implied warranty of
11   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12   GNU General Public License for more details.
13
14   You should have received a copy of the GNU General Public License
15   along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #include <valarray>
19 #include <Rmath.h>
20
21 #include "omxExpectationBA81.h"
22 #include "glue.h"
23 #include "libifa-rpf.h"
24 #include "dmvnorm.h"
25
26 const struct rpf *rpf_model = NULL;
27 int rpf_numModels;
28
29 void pda(const double *ar, int rows, int cols)
30 {
31         std::string buf;
32         for (int rx=0; rx < rows; rx++) {   // column major order
33                 for (int cx=0; cx < cols; cx++) {
34                         buf += string_snprintf("%.6g, ", ar[cx * rows + rx]);
35                 }
36                 buf += "\n";
37         }
38         mxLogBig(buf);
39 }
40
41 void pia(const int *ar, int rows, int cols)
42 {
43         std::string buf;
44         for (int rx=0; rx < rows; rx++) {   // column major order
45                 for (int cx=0; cx < cols; cx++) {
46                         buf += string_snprintf("%d, ", ar[cx * rows + rx]);
47                 }
48                 buf += "\n";
49         }
50         mxLogBig(buf);
51 }
52
53 // state->speQarea[sIndex(state, sx, qx)]
54 OMXINLINE static
55 int sIndex(BA81Expect *state, int sx, int qx)
56 {
57         //if (sx < 0 || sx >= state->numSpecific) error("Out of domain");
58         //if (qx < 0 || qx >= state->quadGridSize) error("Out of domain");
59         return sx * state->quadGridSize + qx;
60 }
61
62 // Depends on item parameters, but not latent distribution
63 void computeRPF(BA81Expect *state, omxMatrix *itemParam, const int *quad,
64                 const bool wantlog, double *out)
65 {
66         omxMatrix *design = state->design;
67         int maxDims = state->maxDims;
68         size_t numItems = state->itemSpec.size();
69
70         double theta[maxDims];
71         pointToWhere(state, quad, theta, maxDims);
72
73         for (size_t ix=0; ix < numItems; ix++) {
74                 const double *spec = state->itemSpec[ix];
75                 int id = spec[RPF_ISpecID];
76                 int dims = spec[RPF_ISpecDims];
77                 double ptheta[dims];
78
79                 for (int dx=0; dx < dims; dx++) {
80                         int ability = (int)omxMatrixElement(design, dx, ix) - 1;
81                         if (ability >= maxDims) ability = maxDims-1;
82                         ptheta[dx] = theta[ability];
83                 }
84
85                 double *iparam = omxMatrixColumn(itemParam, ix);
86                 if (wantlog) {
87                         (*rpf_model[id].logprob)(spec, iparam, ptheta, out);
88                 } else {
89                         (*rpf_model[id].prob)(spec, iparam, ptheta, out);
90                 }
91 #if 0
92                 for (int ox=0; ox < state->itemOutcomes[ix]; ox++) {
93                         if (!isfinite(out[ox]) || out[ox] > 0) {
94                                 mxLog("item param");
95                                 pda(iparam, itemParam->rows, 1);
96                                 mxLog("where");
97                                 pda(ptheta, dims, 1);
98                                 error("RPF returned %20.20f", out[ox]);
99                         }
100                 }
101 #endif
102                 out += state->itemOutcomes[ix];
103         }
104 }
105
106 OMXINLINE static double *
107 getLXKcache(BA81Expect *state, const long qx, const int specific)
108 {
109         long ordinate;
110         if (state->numSpecific == 0) {
111                 ordinate = qx;
112         } else {
113                 ordinate = specific * state->totalQuadPoints + qx;
114         }
115         return state->lxk + state->numUnique * ordinate;
116 }
117
118 OMXINLINE static double *
119 ba81Likelihood(omxExpectation *oo, const int thrId, int specific, const long qx)
120 {
121         BA81Expect *state = (BA81Expect*) oo->argStruct;
122         int numUnique = state->numUnique;
123         std::vector<int> &itemOutcomes = state->itemOutcomes;
124         omxData *data = state->data;
125         size_t numItems = state->itemSpec.size();
126         int *Sgroup = state->Sgroup;
127         double *lxk;
128
129         if (!state->cacheLXK) {
130                 lxk = state->lxk + numUnique * thrId;
131         } else {
132                 lxk = getLXKcache(state, qx, specific);
133         }
134
135         const int *rowMap = state->rowMap;
136         for (int px=0; px < numUnique; px++) {
137                 double lxk1 = 1;
138                 const double *oProb = state->outcomeProb + qx * state->totalOutcomes;
139                 for (size_t ix=0; ix < numItems; ix++) {
140                         int pick = omxIntDataElementUnsafe(data, rowMap[px], ix);
141                         if (specific == Sgroup[ix] && pick != NA_INTEGER) {
142                                 double piece = oProb[pick-1];  // move -1 elsewhere TODO
143                                 lxk1 *= piece;
144                                 //mxLog("%d pick %d piece %.7f", ix, pick-1, piece);
145                         }
146                         oProb += itemOutcomes[ix];
147                 }
148                 lxk[px] = lxk1;
149         }
150
151         return lxk;
152 }
153
154 double *ba81LikelihoodFast(omxExpectation *oo, const int thrId, int specific, const long qx)
155 {
156         BA81Expect *state = (BA81Expect*) oo->argStruct;
157         if (!state->cacheLXK) {
158                 double *ret = ba81Likelihood(oo, thrId, specific, qx);
159                 return ret;
160         } else {
161                 return getLXKcache(state, qx, specific);
162         }
163
164 }
165
166 OMXINLINE static void
167 mapLatentSpace(BA81Expect *state, int sgroup, double piece, const double *where,
168                const double *whereGram, double *latentDist)
169 {
170         int maxDims = state->maxDims;
171         int maxAbilities = state->maxAbilities;
172         int pmax = maxDims;
173         if (state->numSpecific) pmax -= 1;
174
175         if (sgroup == 0) {
176                 int gx = 0;
177                 int cx = maxAbilities;
178                 for (int d1=0; d1 < pmax; d1++) {
179                         double piece_w1 = piece * where[d1];
180                         latentDist[d1] += piece_w1;
181                         for (int d2=0; d2 <= d1; d2++) {
182                                 double piece_cov = piece * whereGram[gx];
183                                 latentDist[cx] += piece_cov;
184                                 ++cx; ++gx;
185                         }
186                 }
187         }
188
189         if (state->numSpecific) {
190                 int sdim = pmax + sgroup;
191                 double piece_w1 = piece * where[pmax];
192                 latentDist[sdim] += piece_w1;
193
194                 double piece_var = piece * whereGram[triangleLoc0(pmax)];
195                 int to = maxAbilities + triangleLoc0(sdim);
196                 latentDist[to] += piece_var;
197         }
198 }
199
200 // Eslxk, allElxk (Ei, Eis) depend on the ordinate of the primary dimensions
201 void cai2010(omxExpectation* oo, const int thrId, int recompute, const long primaryQ)
202 {
203         BA81Expect *state = (BA81Expect*) oo->argStruct;
204         int numUnique = state->numUnique;
205         int numSpecific = state->numSpecific;
206         int quadGridSize = state->quadGridSize;
207         double *allElxk = eBase(state, thrId);
208         double *Eslxk = esBase(state, thrId);
209
210         for (int px=0; px < numUnique; px++) {
211                 allElxk[px] = 1;
212                 for (int sx=0; sx < numSpecific; sx++) {
213                         Eslxk[sx * numUnique + px] = 0;
214                 }
215         }
216
217         if (!state->cacheLXK) recompute = TRUE;
218
219         for (int sx=0; sx < quadGridSize; sx++) {
220                 long qloc = primaryQ * quadGridSize + sx;
221
222                 for (int sgroup=0; sgroup < numSpecific; sgroup++) {
223                         double *myEslxk = Eslxk + sgroup * numUnique;
224                         double *lxk;     // a.k.a. "L_is"
225                         if (recompute) {
226                                 lxk = ba81Likelihood(oo, thrId, sgroup, qloc);
227                         } else {
228                                 lxk = getLXKcache(state, qloc, sgroup);
229                         }
230
231                         for (int ix=0; ix < numUnique; ix++) {
232                                 double area = state->speQarea[sIndex(state, sgroup, sx)];
233                                 double piece = lxk[ix] * area;
234                                 //mxLog("E.is(%d) (%ld,%d) %.6f + %.6f = %.6f",
235                                 //  sgroup, primaryQ, sx, lxk[ix], area, piece);
236                                 myEslxk[ix] += piece;
237                         }
238                 }
239         }
240
241         for (int sx=0; sx < numSpecific; sx++) {
242                 for (int px=0; px < numUnique; px++) {
243                         //mxLog("E.is(%d) at (%ld) %.6f", sx, primaryQ,
244                         //  Eslxk[sx * numUnique + px]);
245                         allElxk[px] *= Eslxk[sx * numUnique + px];  // allSlxk a.k.a. "E_i"
246                 }
247         }
248 }
249
250 static void ba81OutcomeProb(BA81Expect *state)
251 {
252         int maxDims = state->maxDims;
253         double *qProb = state->outcomeProb =
254                 Realloc(state->outcomeProb, state->totalOutcomes * state->totalQuadPoints, double);
255         for (long qx=0; qx < state->totalQuadPoints; qx++) {
256                 int quad[maxDims];
257                 decodeLocation(qx, maxDims, state->quadGridSize, quad);
258                 double where[maxDims];
259                 pointToWhere(state, quad, where, maxDims);
260                 computeRPF(state, state->EitemParam, quad, FALSE, qProb);
261                 qProb += state->totalOutcomes;
262         }
263 }
264
265 static void ba81Estep1(omxExpectation *oo)
266 {
267         if(OMX_DEBUG) {mxLog("Beginning %s Computation.", oo->name);}
268
269         BA81Expect *state = (BA81Expect*) oo->argStruct;
270         if (state->verbose) {
271                 mxLog("%s: lxk(%d) patternLik ElatentMean ElatentCov",
272                       oo->name, omxGetMatrixVersion(state->EitemParam));
273         }
274
275         int numUnique = state->numUnique;
276         size_t numItems = state->itemSpec.size();
277         int numSpecific = state->numSpecific;
278         int maxDims = state->maxDims;
279         int maxAbilities = state->maxAbilities;
280         int primaryDims = maxDims;
281         int totalOutcomes = state->totalOutcomes;
282         omxData *data = state->data;
283         int *numIdentical = state->numIdentical;
284         long totalQuadPoints = state->totalQuadPoints;
285
286         state->patternLik = Realloc(state->patternLik, numUnique, double);
287         double *patternLik = state->patternLik;
288         OMXZERO(patternLik, numUnique); // remove TODO
289         std::vector<double> thrExpected(totalOutcomes * totalQuadPoints * Global->numThreads);
290
291         int numLatents = maxAbilities + triangleLoc1(maxAbilities);
292         int numLatentsPerThread = numUnique * numLatents;
293         double *latentDist = Calloc(numUnique * numLatents * Global->numThreads, double);
294
295         const int *rowMap = state->rowMap;
296         std::vector<int> &itemOutcomes = state->itemOutcomes;
297
298         int whereChunk = maxDims + triangleLoc1(maxDims);
299         std::vector<double> wherePrep(totalQuadPoints * whereChunk);
300         for (long qx=0; qx < totalQuadPoints; qx++) {
301                 double *wh = wherePrep.data() + qx * whereChunk;
302                 int quad[maxDims];
303                 decodeLocation(qx, maxDims, state->quadGridSize, quad);
304                 pointToWhere(state, quad, wh, maxDims);
305                 gramProduct(wh, maxDims, wh + maxDims);
306         }
307
308         if (numSpecific == 0) {
309 #pragma omp parallel for num_threads(Global->numThreads) schedule(static,32)
310                 for (int px=0; px < numUnique; px++) {
311                         int thrId = omx_absolute_thread_num();
312                         double *thrLatentDist = latentDist + thrId * numLatentsPerThread; // reshape matrix TODO
313                         double *myExpected = thrExpected.data() + thrId * totalOutcomes * totalQuadPoints;
314
315                         std::valarray<double> lxk(1, totalQuadPoints);
316                         int outcomeBase = -itemOutcomes[0];
317
318                         for (size_t ix=0; ix < numItems; ix++) {
319                                 outcomeBase += itemOutcomes[ix];
320                                 int pick = omxIntDataElementUnsafe(data, rowMap[px], ix);
321                                 if (pick == NA_INTEGER) continue;
322                                 pick -= 1;
323
324                                 double *oProb = state->outcomeProb + outcomeBase;
325                                 for (long qx=0; qx < totalQuadPoints; ++qx) {
326                                         lxk[qx] *= oProb[pick];
327                                         oProb += totalOutcomes;
328                                 }
329                         }
330                         
331                         double *lxkCache = state->cacheLXK? state->lxk + px : NULL;
332                         double patternLik1 = 0;
333                         double *wh = wherePrep.data();
334                         for (long qx=0; qx < totalQuadPoints; qx++) {
335                                 double area = state->priQarea[qx];
336                                 double tmp = lxk[qx] * area;
337                                 patternLik1 += tmp;
338                                 mapLatentSpace(state, 0, tmp, wh, wh + maxDims,
339                                                thrLatentDist + px * numLatents);
340
341                                 if (lxkCache) {
342                                         *lxkCache = lxk[qx];
343                                         lxkCache += numUnique;
344                                 }
345                                 wh += whereChunk;
346                         }
347
348                         patternLik[px] = patternLik1;
349                         double weight = numIdentical[px] / patternLik1;
350
351                         outcomeBase = -itemOutcomes[0];
352                         for (size_t ix=0; ix < numItems; ix++) {
353                                 outcomeBase += itemOutcomes[ix];
354                                 int pick = omxIntDataElementUnsafe(data, rowMap[px], ix);
355                                 if (pick == NA_INTEGER) continue;
356                                 pick -= 1;
357
358                                 double *out = myExpected + outcomeBase;
359                                 for (long qx=0; qx < totalQuadPoints; ++qx) {
360                                         out[pick] += weight * lxk[qx];
361                                         out += totalOutcomes;
362                                 }
363                         }
364                 }
365         } else {
366                 primaryDims -= 1;
367                 long totalPrimaryPoints = state->totalPrimaryPoints;
368                 long specificPoints = state->quadGridSize;
369
370 #pragma omp parallel for num_threads(Global->numThreads) schedule(static,32)
371                 for (int px=0; px < numUnique; px++) {
372                         int thrId = omx_absolute_thread_num();
373                         double *thrLatentDist = latentDist + thrId * numLatentsPerThread; // reshape matrix TODO
374                         double *myExpected = thrExpected.data() + thrId * totalOutcomes * totalQuadPoints;
375
376                         std::valarray<double> lxk(1, totalQuadPoints * numSpecific);
377                         int outcomeBase = -itemOutcomes[0];
378
379                         for (size_t ix=0; ix < numItems; ix++) {
380                                 outcomeBase += itemOutcomes[ix];
381                                 int pick = omxIntDataElementUnsafe(data, rowMap[px], ix);
382                                 if (pick == NA_INTEGER) continue;
383                                 pick -= 1;
384                                 int Sbase = state->Sgroup[ix] * totalQuadPoints;
385                                 double *oProb = state->outcomeProb + outcomeBase;
386                                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
387                                         lxk[Sbase + qx] *= oProb[pick];
388                                         oProb += totalOutcomes;
389                                 }
390                         }
391
392                         std::valarray<double> Eis(0.0, totalPrimaryPoints * numSpecific);
393                         std::valarray<double> Ei(1.0, totalPrimaryPoints);
394                         for (int sgroup=0; sgroup < numSpecific; ++sgroup) {
395                                 int Sbase = sgroup * totalQuadPoints;
396                                 for (long qx=0; qx < totalPrimaryPoints; qx++) {
397                                         for (long sx=0; sx < specificPoints; sx++) {
398                                                 long qloc = qx * specificPoints + sx; // change to ++
399                                                 double area = state->speQarea[sIndex(state, sgroup, sx)];
400                                                 double piece = lxk[Sbase + qloc] * area;
401                                                 Eis[totalPrimaryPoints * sgroup + qx] += piece;
402                                         }
403                                         Ei[qx] *= Eis[totalPrimaryPoints * sgroup + qx];
404                                 }
405                         }
406
407                         double *wh = wherePrep.data();
408                         for (long qx=0; qx < totalPrimaryPoints; qx++) {
409                                 double Ei1 = Ei[qx];
410                                 for (long sx=0; sx < specificPoints; sx++) {
411                                         long qloc = qx * specificPoints + sx; // change to ++
412                                         for (int sgroup=0; sgroup < numSpecific; sgroup++) {
413                                                 double area = areaProduct(state, qx, sx, sgroup);
414                                                 double lxk1 = lxk[totalQuadPoints * sgroup + qloc];
415                                                 double Eis1 = Eis[totalPrimaryPoints * sgroup + qx];
416                                                 double tmp = ((Ei1 / Eis1) * lxk1 * area);
417                                                 mapLatentSpace(state, sgroup, tmp, wh, wh + maxDims,
418                                                                thrLatentDist + px * numLatents);
419                                                 if (state->cacheLXK) {
420                                                         double *lxkCache = getLXKcache(state, qloc, sgroup);
421                                                         lxkCache[px] = lxk1;
422                                                 }
423                                         }
424                                         wh += whereChunk;
425                                 }
426                         }
427
428                         double patternLik1 = 0;
429                         for (long qx=0; qx < totalPrimaryPoints; qx++) {
430                                 double priArea = state->priQarea[qx];
431                                 patternLik1 += Ei[qx] * priArea; // combine in upper loop TODO
432                         }
433                         patternLik[px] = patternLik1;
434
435                         double weight = numIdentical[px] / patternLik1;
436
437                         outcomeBase = -itemOutcomes[0];
438                         for (size_t ix=0; ix < numItems; ix++) {
439                                 outcomeBase += itemOutcomes[ix];
440                                 int pick = omxIntDataElementUnsafe(data, rowMap[px], ix);
441                                 if (pick == NA_INTEGER) continue;
442                                 pick -= 1;
443
444                                 int Sgroup = state->Sgroup[ix];
445
446                                 //double *out = myExpected + outcomeBase;
447                                 for (long qx=0; qx < totalPrimaryPoints; qx++) {
448                                         double Ei1 = Ei[qx];
449                                         for (long sx=0; sx < specificPoints; sx++) {
450                                                 long qloc = qx * specificPoints + sx; // change to ++
451                                                 double *out = myExpected + qloc * totalOutcomes + outcomeBase;
452                                                 double lxk1 = lxk[totalQuadPoints * Sgroup + qloc];
453                                                 double Eis1 = Eis[totalPrimaryPoints * Sgroup + qx];
454                                                 out[pick] += weight * (Ei1 / Eis1) * lxk1;
455                                                 //out += totalOutcomes;
456                                         }
457                                 }
458                         }
459                 }
460         }
461
462         long expectedSize = totalQuadPoints * totalOutcomes;
463 #pragma omp parallel for num_threads(Global->numThreads) schedule(static,64)
464         for (long qx=0; qx < expectedSize; ++qx) {
465                 state->expected[qx] = 0;
466                 double *e1 = thrExpected.data() + qx;
467                 for (int tx=0; tx < Global->numThreads; ++tx) {
468                         state->expected[qx] += *e1;
469                         e1 += expectedSize;
470                 }
471         }
472
473         //mxLog("raw latent");
474         //pda(latentDist, numLatents, numUnique);
475
476 #pragma omp parallel for num_threads(Global->numThreads) schedule(dynamic)
477         for (int lx=0; lx < maxAbilities + triangleLoc1(primaryDims); ++lx) {
478                 for (int tx=1; tx < Global->numThreads; ++tx) {
479                         double *thrLatentDist = latentDist + tx * numLatentsPerThread;
480                         for (int px=0; px < numUnique; px++) {
481                                 int loc = px * numLatents + lx;
482                                 latentDist[loc] += thrLatentDist[loc];
483                         }
484                 }
485         }
486
487 #pragma omp parallel for num_threads(Global->numThreads)
488         for (int sdim=primaryDims; sdim < maxAbilities; sdim++) {
489                 for (int tx=1; tx < Global->numThreads; ++tx) {
490                         double *thrLatentDist = latentDist + tx * numLatentsPerThread;
491                         for (int px=0; px < numUnique; px++) {
492                                 int loc = px * numLatents + maxAbilities + triangleLoc0(sdim);
493                                 latentDist[loc] += thrLatentDist[loc];
494                         }
495                 }
496         }
497
498 #pragma omp parallel for num_threads(Global->numThreads)
499         for (int px=0; px < numUnique; px++) {
500                 if (!std::isfinite(patternLik[px])) {
501                         omxRaiseErrorf(globalState, "Likelihood of pattern %d is %.3g",
502                                        px, patternLik[px]);
503                 }
504
505                 double *latentDist1 = latentDist + px * numLatents;
506                 double weight = numIdentical[px] / patternLik[px];
507                 int cx = maxAbilities;
508                 for (int d1=0; d1 < primaryDims; d1++) {
509                         latentDist1[d1] *= weight;
510                         for (int d2=0; d2 <= d1; d2++) {
511                                 latentDist1[cx] *= weight;
512                                 ++cx;
513                         }
514                 }
515                 for (int sdim=primaryDims; sdim < maxAbilities; sdim++) {
516                         latentDist1[sdim] *= weight;
517                         int loc = maxAbilities + triangleLoc0(sdim);
518                         latentDist1[loc] *= weight;
519                 }
520 #if 0
521                 if (!isfinite(patternLik[px])) {
522                         error("Likelihood of row %d is %f", state->rowMap[px], patternLik[px]);
523                 }
524 #endif
525         }
526
527         //mxLog("raw latent after weighting");
528         //pda(latentDist, numLatents, numUnique);
529
530         std::vector<double> &ElatentMean = state->ElatentMean;
531         std::vector<double> &ElatentCov = state->ElatentCov;
532         
533         ElatentMean.assign(ElatentMean.size(), 0.0);
534         ElatentCov.assign(ElatentCov.size(), 0.0);
535
536 #pragma omp parallel for num_threads(Global->numThreads)
537         for (int d1=0; d1 < maxAbilities; d1++) {
538                 for (int px=0; px < numUnique; px++) {
539                         double *latentDist1 = latentDist + px * numLatents;
540                         int cx = maxAbilities + triangleLoc1(d1);
541                         if (d1 < primaryDims) {
542                                 ElatentMean[d1] += latentDist1[d1];
543                                 for (int d2=0; d2 <= d1; d2++) {
544                                         int cell = d2 * maxAbilities + d1;
545                                         ElatentCov[cell] += latentDist1[cx];
546                                         ++cx;
547                                 }
548                         } else {
549                                 ElatentMean[d1] += latentDist1[d1];
550                                 int cell = d1 * maxAbilities + d1;
551                                 int loc = maxAbilities + triangleLoc0(d1);
552                                 ElatentCov[cell] += latentDist1[loc];
553                         }
554                 }
555         }
556
557         //pda(ElatentMean.data(), 1, state->maxAbilities);
558         //pda(ElatentCov.data(), state->maxAbilities, state->maxAbilities);
559
560         for (int d1=0; d1 < maxAbilities; d1++) {
561                 ElatentMean[d1] /= data->rows;
562         }
563
564         for (int d1=0; d1 < primaryDims; d1++) {
565                 for (int d2=0; d2 <= d1; d2++) {
566                         int cell = d2 * maxAbilities + d1;
567                         int tcell = d1 * maxAbilities + d2;
568                         ElatentCov[tcell] = ElatentCov[cell] =
569                                 ElatentCov[cell] / data->rows - ElatentMean[d1] * ElatentMean[d2];
570                 }
571         }
572         for (int sdim=primaryDims; sdim < maxAbilities; sdim++) {
573                 int cell = sdim * maxAbilities + sdim;
574                 ElatentCov[cell] = ElatentCov[cell] / data->rows - ElatentMean[sdim] * ElatentMean[sdim];
575         }
576
577         if (state->cacheLXK) state->LXKcached = TRUE;
578
579         Free(latentDist);
580
581         //mxLog("E-step");
582         //pda(ElatentMean.data(), 1, state->maxAbilities);
583         //pda(ElatentCov.data(), state->maxAbilities, state->maxAbilities);
584 }
585
586 static int getLatentVersion(BA81Expect *state)
587 {
588         return omxGetMatrixVersion(state->latentMeanOut) + omxGetMatrixVersion(state->latentCovOut);
589 }
590
591 // Attempt G-H grid? http://dbarajassolano.wordpress.com/2012/01/26/on-sparse-grid-quadratures/
592 static void ba81SetupQuadrature(omxExpectation* oo, int gridsize)
593 {
594         BA81Expect *state = (BA81Expect *) oo->argStruct;
595         if (state->verbose) {
596                 mxLog("%s: quadrature(%d)", oo->name, getLatentVersion(state));
597         }
598         int numUnique = state->numUnique;
599         int numThreads = Global->numThreads;
600         int maxDims = state->maxDims;
601         double Qwidth = state->Qwidth;
602         int numSpecific = state->numSpecific;
603         int priDims = maxDims - (numSpecific? 1 : 0);
604
605         // try starting small and increasing to the cap TODO
606         state->quadGridSize = gridsize;
607
608         state->totalQuadPoints = 1;
609         for (int dx=0; dx < maxDims; dx++) {
610                 state->totalQuadPoints *= state->quadGridSize;
611         }
612
613         state->totalPrimaryPoints = state->totalQuadPoints;
614
615         if (numSpecific) {
616                 state->totalPrimaryPoints /= state->quadGridSize;
617                 state->speQarea.resize(gridsize * numSpecific);
618         }
619
620         state->Qpoint.resize(gridsize);
621         state->priQarea.resize(state->totalPrimaryPoints);
622
623         double qgs = state->quadGridSize-1;
624         for (int px=0; px < state->quadGridSize; px ++) {
625                 state->Qpoint[px] = Qwidth - px * 2 * Qwidth / qgs;
626         }
627
628         //pda(state->latentMeanOut->data, 1, state->maxAbilities);
629         //pda(state->latentCovOut->data, state->maxAbilities, state->maxAbilities);
630
631         double totalArea = 0;
632         for (int qx=0; qx < state->totalPrimaryPoints; qx++) {
633                 int quad[priDims];
634                 decodeLocation(qx, priDims, state->quadGridSize, quad);
635                 double where[priDims];
636                 pointToWhere(state, quad, where, priDims);
637                 state->priQarea[qx] = exp(dmvnorm(priDims, where,
638                                                   state->latentMeanOut->data,
639                                                   state->latentCovOut->data));
640                 totalArea += state->priQarea[qx];
641         }
642         for (int qx=0; qx < state->totalPrimaryPoints; qx++) {
643                 state->priQarea[qx] /= totalArea;
644                 //mxLog("%.5g,", state->priQarea[qx]);
645         }
646
647         for (int sx=0; sx < numSpecific; sx++) {
648                 totalArea = 0;
649                 int covCell = (priDims + sx) * state->maxAbilities + priDims + sx;
650                 double mean = state->latentMeanOut->data[priDims + sx];
651                 double var = state->latentCovOut->data[covCell];
652                 //mxLog("setup[%d] %.2f %.2f", sx, mean, var);
653                 for (int qx=0; qx < state->quadGridSize; qx++) {
654                         double den = dnorm(state->Qpoint[qx], mean, sqrt(var), FALSE);
655                         state->speQarea[sIndex(state, sx, qx)] = den;
656                         totalArea += den;
657                 }
658                 for (int qx=0; qx < state->quadGridSize; qx++) {
659                         state->speQarea[sIndex(state, sx, qx)] /= totalArea;
660                 }
661                 //pda(state->speQarea.data() + sIndex(state, sx, 0), 1, state->quadGridSize);
662         }
663
664         if (!state->cacheLXK) {
665                 state->lxk = Realloc(state->lxk, numUnique * numThreads, double);
666         } else {
667                 int ns = state->numSpecific;
668                 if (ns == 0) ns = 1;
669                 long numOrdinate = ns * state->totalQuadPoints;
670                 state->lxk = Realloc(state->lxk, numUnique * numOrdinate, double);
671         }
672
673         state->expected = Realloc(state->expected, state->totalOutcomes * state->totalQuadPoints, double);
674 }
675
676 static void ba81buildLXKcache(omxExpectation *oo)
677 {
678         BA81Expect *state = (BA81Expect *) oo->argStruct;
679         if (!state->cacheLXK || state->LXKcached) return;
680         
681         ba81Estep1(oo);
682 }
683
684 OMXINLINE static void
685 expectedUpdate(omxData *data, const int *rowMap, const int px, const int item,
686                const double observed, double *out)
687 {
688         int pick = omxIntDataElementUnsafe(data, rowMap[px], item);
689         if (pick != NA_INTEGER) {
690                 out[pick-1] += observed;
691         }
692 }
693
694 OMXINLINE static void
695 ba81Expected(omxExpectation* oo)
696 {
697         BA81Expect *state = (BA81Expect*) oo->argStruct;
698         if (state->verbose) mxLog("%s: EM.expected", oo->name);
699
700         omxData *data = state->data;
701         int numSpecific = state->numSpecific;
702         const int *rowMap = state->rowMap;
703         double *patternLik = state->patternLik;
704         int *numIdentical = state->numIdentical;
705         int numUnique = state->numUnique;
706         int numItems = state->EitemParam->cols;
707         int totalOutcomes = state->totalOutcomes;
708         std::vector<int> &itemOutcomes = state->itemOutcomes;
709
710         OMXZERO(state->expected, totalOutcomes * state->totalQuadPoints);
711
712         if (numSpecific == 0) {
713 #pragma omp parallel for num_threads(Global->numThreads)
714                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
715                         int thrId = omx_absolute_thread_num();
716                         double *lxk = ba81LikelihoodFast(oo, thrId, 0, qx);
717                         for (int px=0; px < numUnique; px++) {
718                                 double *out = state->expected + qx * totalOutcomes;
719                                 double observed = numIdentical[px] * lxk[px] / patternLik[px];
720                                 for (int ix=0; ix < numItems; ix++) {
721                                         const int outcomes = itemOutcomes[ix];
722                                         expectedUpdate(data, rowMap, px, ix, observed, out);
723                                         out += outcomes;
724                                 }
725                         }
726                 }
727         } else {
728                 long specificPoints = state->quadGridSize;
729
730 #pragma omp parallel for num_threads(Global->numThreads)
731                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
732                         int thrId = omx_absolute_thread_num();
733
734                         cai2010(oo, thrId, FALSE, qx);
735                         double *allElxk = eBase(state, thrId);
736                         double *Eslxk = esBase(state, thrId);
737
738                         for (long sx=0; sx < specificPoints; sx++) {
739                                 long qloc = qx * specificPoints + sx;
740
741                                 for (int sgroup=0; sgroup < numSpecific; sgroup++) {
742                                         double *lxk = ba81LikelihoodFast(oo, thrId, sgroup, qloc);
743                                         double *myEslxk = Eslxk + sgroup * numUnique;
744
745                                         for (int px=0; px < numUnique; px++) {
746                                                 double *out = state->expected + totalOutcomes * qloc;
747
748                                                 for (int ix=0; ix < numItems; ix++) {
749                                                         const int outcomes = itemOutcomes[ix];
750                                                         if (state->Sgroup[ix] == sgroup) {
751                                                                 double Ei = allElxk[px];
752                                                                 double Eis = myEslxk[px];
753                                                                 double observed = (numIdentical[px] * (Ei / Eis) *
754                                                                                    (lxk[px] / patternLik[px]));
755                                                                 expectedUpdate(data, rowMap, px, ix, observed, out);
756                                                         }
757                                                         out += outcomes;
758                                                 }
759                                         }
760                                 }
761                         }
762                 }
763         }
764
765         if (!state->checkedBadData) {
766                 std::vector<double> byOutcome(totalOutcomes, 0);
767                 for (int ox=0; ox < totalOutcomes; ++ox) {
768                         for (long qx=0; qx < state->totalQuadPoints; qx++) {
769                                 byOutcome[ox] += state->expected[totalOutcomes * qx + ox];
770                         }
771                         if (byOutcome[ox] == 0) {
772                                 int uptoItem = 0;
773                                 for (size_t cx = 0; cx < itemOutcomes.size(); cx++) {
774                                         if (ox < uptoItem + itemOutcomes[cx]) {
775                                                 int bad = ox - uptoItem;
776                                                 omxRaiseErrorf(globalState, "Item %lu outcome %d is never endorsed.\n"
777                                                                "You must collapse categories or omit this item to estimate item parameters.", 1+cx, 1+bad);
778                                                 break;
779                                         }
780                                         uptoItem += itemOutcomes[cx];
781                                 }
782                         }
783                 }
784                 state->checkedBadData = TRUE;
785         }
786         //pda(state->expected, state->totalOutcomes, state->totalQuadPoints);
787 }
788
789 OMXINLINE static void
790 accumulateScores(BA81Expect *state, int px, int sgroup, double piece, const double *where,
791                  int primaryDims, int covEntries, std::vector<double> *mean, std::vector<double> *cov)
792 {
793         int maxDims = state->maxDims;
794         int maxAbilities = state->maxAbilities;
795
796         if (sgroup == 0) {
797                 int cx=0;
798                 for (int d1=0; d1 < primaryDims; d1++) {
799                         double piece_w1 = piece * where[d1];
800                         double &dest1 = (*mean)[px * maxAbilities + d1];
801 #pragma omp atomic
802                         dest1 += piece_w1;
803                         for (int d2=0; d2 <= d1; d2++) {
804                                 double &dest2 = (*cov)[px * covEntries + cx];
805 #pragma omp atomic
806                                 dest2 += where[d2] * piece_w1;
807                                 ++cx;
808                         }
809                 }
810         }
811
812         if (state->numSpecific) {
813                 int sdim = maxDims + sgroup - 1;
814                 double piece_w1 = piece * where[primaryDims];
815                 double &dest3 = (*mean)[px * maxAbilities + sdim];
816 #pragma omp atomic
817                 dest3 += piece_w1;
818
819                 double &dest4 = (*cov)[px * covEntries + triangleLoc0(sdim)];
820 #pragma omp atomic
821                 dest4 += piece_w1 * where[primaryDims];
822         }
823 }
824
825 static void
826 EAPinternalFast(omxExpectation *oo, std::vector<double> *mean, std::vector<double> *cov)
827 {
828         BA81Expect *state = (BA81Expect*) oo->argStruct;
829         if (state->verbose) mxLog("%s: EAP", oo->name);
830
831         int numUnique = state->numUnique;
832         int numSpecific = state->numSpecific;
833         int maxDims = state->maxDims;
834         int maxAbilities = state->maxAbilities;
835         int primaryDims = maxDims;
836         int covEntries = triangleLoc1(maxAbilities);
837
838         mean->assign(numUnique * maxAbilities, 0);
839         cov->assign(numUnique * covEntries, 0);
840
841         if (numSpecific == 0) {
842 #pragma omp parallel for num_threads(Global->numThreads)
843                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
844                         const int thrId = omx_absolute_thread_num();
845                         int quad[maxDims];
846                         decodeLocation(qx, maxDims, state->quadGridSize, quad);
847                         double where[maxDims];
848                         pointToWhere(state, quad, where, maxDims);
849
850                         double *lxk = ba81LikelihoodFast(oo, thrId, 0, qx);
851
852                         double area = state->priQarea[qx];
853                         for (int px=0; px < numUnique; px++) {
854                                 double tmp = lxk[px] * area;
855                                 accumulateScores(state, px, 0, tmp, where, primaryDims, covEntries, mean, cov);
856                         }
857                 }
858         } else {
859                 primaryDims -= 1;
860                 int sDim = primaryDims;
861                 long specificPoints = state->quadGridSize;
862
863 #pragma omp parallel for num_threads(Global->numThreads)
864                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
865                         const int thrId = omx_absolute_thread_num();
866                         int quad[maxDims];
867                         decodeLocation(qx, primaryDims, state->quadGridSize, quad);
868
869                         cai2010(oo, thrId, FALSE, qx);
870                         double *allElxk = eBase(state, thrId);
871                         double *Eslxk = esBase(state, thrId);
872
873                         for (int sgroup=0; sgroup < numSpecific; sgroup++) {
874                                 for (long sx=0; sx < specificPoints; sx++) {
875                                         long qloc = qx * specificPoints + sx;
876                                         quad[sDim] = sx;
877                                         double where[maxDims];
878                                         pointToWhere(state, quad, where, maxDims);
879                                         double area = areaProduct(state, qx, sx, sgroup);
880                                         double *lxk = ba81LikelihoodFast(oo, thrId, sgroup, qloc);
881                                         for (int px=0; px < numUnique; px++) {
882                                                 double Ei = allElxk[px];
883                                                 double Eis = Eslxk[sgroup * numUnique + px];
884                                                 double tmp = ((Ei / Eis) * lxk[px] * area);
885                                                 accumulateScores(state, px, sgroup, tmp, where, primaryDims,
886                                                                  covEntries, mean, cov);
887                                         }
888                                 }
889                         }
890                 }
891         }
892
893         double *patternLik = state->patternLik;
894         for (int px=0; px < numUnique; px++) {
895                 double denom = patternLik[px];
896                 for (int ax=0; ax < maxAbilities; ax++) {
897                         (*mean)[px * maxAbilities + ax] /= denom;
898                 }
899                 for (int cx=0; cx < triangleLoc1(primaryDims); ++cx) {
900                         (*cov)[px * covEntries + cx] /= denom;
901                 }
902                 for (int sx=0; sx < numSpecific; sx++) {
903                         (*cov)[px * covEntries + triangleLoc0(primaryDims + sx)] /= denom;
904                 }
905                 int cx=0;
906                 for (int a1=0; a1 < primaryDims; ++a1) {
907                         for (int a2=0; a2 <= a1; ++a2) {
908                                 double ma1 = (*mean)[px * maxAbilities + a1];
909                                 double ma2 = (*mean)[px * maxAbilities + a2];
910                                 (*cov)[px * covEntries + cx] -= ma1 * ma2;
911                                 ++cx;
912                         }
913                 }
914                 for (int sx=0; sx < numSpecific; sx++) {
915                         int sdim = primaryDims + sx;
916                         double ma1 = (*mean)[px * maxAbilities + sdim];
917                         (*cov)[px * covEntries + triangleLoc0(sdim)] -= ma1 * ma1;
918                 }
919         }
920 }
921
922 static void recomputePatternLik(omxExpectation *oo)
923 {
924         BA81Expect *estate = (BA81Expect*) oo->argStruct;
925         if (estate->verbose) mxLog("%s: patternLik", oo->name);
926
927         int numUnique = estate->numUnique;
928         int numSpecific = estate->numSpecific;
929         int maxDims = estate->maxDims;
930         int primaryDims = maxDims;
931         double *patternLik = estate->patternLik;
932         OMXZERO(patternLik, numUnique);
933
934         if (numSpecific == 0) {
935 #pragma omp parallel for num_threads(Global->numThreads)
936                 for (long qx=0; qx < estate->totalQuadPoints; qx++) {
937                         const int thrId = omx_absolute_thread_num();
938                         double area = estate->priQarea[qx];
939                         double *lxk = ba81LikelihoodFast(oo, thrId, 0, qx);
940
941                         for (int px=0; px < numUnique; px++) {
942                                 double tmp = (lxk[px] * area);
943 #pragma omp atomic
944                                 patternLik[px] += tmp;
945                         }
946                 }
947         } else {
948                 primaryDims -= 1;
949
950 #pragma omp parallel for num_threads(Global->numThreads)
951                 for (long qx=0; qx < estate->totalPrimaryPoints; qx++) {
952                         const int thrId = omx_absolute_thread_num();
953
954                         cai2010(oo, thrId, FALSE, qx);
955                         double *allElxk = eBase(estate, thrId);
956
957                         double priArea = estate->priQarea[qx];
958                         for (int px=0; px < numUnique; px++) {
959                                 double Ei = allElxk[px];
960                                 double tmp = (Ei * priArea);
961 #pragma omp atomic
962                                 patternLik[px] += tmp;
963                         }
964                 }
965         }
966 }
967
968 static void
969 ba81compute(omxExpectation *oo, const char *context)
970 {
971         BA81Expect *state = (BA81Expect *) oo->argStruct;
972
973         if (context) {
974                 if (strcmp(context, "EM")==0) {
975                         state->type = EXPECTATION_AUGMENTED;
976                 } else if (context[0] == 0) {
977                         state->type = EXPECTATION_OBSERVED;
978                 } else {
979                         omxRaiseErrorf(globalState, "Unknown context '%s'", context);
980                         return;
981                 }
982         }
983
984         omxRecompute(state->EitemParam);
985
986         bool itemClean = state->itemParamVersion == omxGetMatrixVersion(state->EitemParam);
987         bool latentClean = state->latentParamVersion == getLatentVersion(state);
988
989         if (state->verbose) {
990                 mxLog("%s: Qinit %d itemClean %d latentClean %d (1=clean)",
991                       oo->name, state->Qpoint.size() != 0, itemClean, latentClean);
992         }
993
994         if (state->Qpoint.size() == 0 || !latentClean) {
995                 ba81SetupQuadrature(oo, state->targetQpoints);
996         }
997         if (itemClean) {
998                 ba81buildLXKcache(oo);
999                 if (!latentClean) recomputePatternLik(oo);
1000         } else {
1001                 ba81OutcomeProb(state);
1002                 ba81Estep1(oo);
1003         }
1004
1005         if (state->type == EXPECTATION_AUGMENTED) {
1006                 //ba81Expected(oo);
1007         }
1008
1009         state->itemParamVersion = omxGetMatrixVersion(state->EitemParam);
1010         state->latentParamVersion = getLatentVersion(state);
1011 }
1012
1013 static void
1014 copyScore(int rows, int maxAbilities, std::vector<double> &mean,
1015           std::vector<double> &cov, const int rx, double *scores, const int dest)
1016 {
1017         for (int ax=0; ax < maxAbilities; ++ax) {
1018                 scores[rows * ax + dest] = mean[maxAbilities * rx + ax];
1019         }
1020         for (int ax=0; ax < maxAbilities; ++ax) {
1021                 scores[rows * (maxAbilities + ax) + dest] =
1022                         sqrt(cov[triangleLoc1(maxAbilities) * rx + triangleLoc0(ax)]);
1023         }
1024         for (int ax=0; ax < triangleLoc1(maxAbilities); ++ax) {
1025                 scores[rows * (2*maxAbilities + ax) + dest] =
1026                         cov[triangleLoc1(maxAbilities) * rx + ax];
1027         }
1028 }
1029
1030 /**
1031  * MAP is not affected by the number of items. EAP is. Likelihood can
1032  * get concentrated in a single quadrature ordinate. For 3PL, response
1033  * patterns can have a bimodal likelihood. This will confuse MAP and
1034  * is a key advantage of EAP (Thissen & Orlando, 2001, p. 136).
1035  *
1036  * Thissen, D. & Orlando, M. (2001). IRT for items scored in two
1037  * categories. In D. Thissen & H. Wainer (Eds.), \emph{Test scoring}
1038  * (pp 73-140). Lawrence Erlbaum Associates, Inc.
1039  */
1040 static void
1041 ba81PopulateAttributes(omxExpectation *oo, SEXP robj)
1042 {
1043         BA81Expect *state = (BA81Expect *) oo->argStruct;
1044         int maxAbilities = state->maxAbilities;
1045
1046         SEXP Rmean, Rcov;
1047         PROTECT(Rmean = allocVector(REALSXP, maxAbilities));
1048         memcpy(REAL(Rmean), state->ElatentMean.data(), maxAbilities * sizeof(double));
1049
1050         PROTECT(Rcov = allocMatrix(REALSXP, maxAbilities, maxAbilities));
1051         memcpy(REAL(Rcov), state->ElatentCov.data(), maxAbilities * maxAbilities * sizeof(double));
1052
1053         setAttrib(robj, install("empirical.mean"), Rmean);
1054         setAttrib(robj, install("empirical.cov"), Rcov);
1055
1056         if (state->type == EXPECTATION_AUGMENTED) {
1057                 int numUnique = state->numUnique;
1058                 int totalOutcomes = state->totalOutcomes;
1059                 SEXP Rlik;
1060                 SEXP Rexpected;
1061
1062                 PROTECT(Rlik = allocVector(REALSXP, numUnique));
1063                 memcpy(REAL(Rlik), state->patternLik, sizeof(double) * numUnique);
1064
1065                 PROTECT(Rexpected = allocMatrix(REALSXP, totalOutcomes, state->totalQuadPoints));
1066                 memcpy(REAL(Rexpected), state->expected, sizeof(double) * totalOutcomes * state->totalQuadPoints);
1067
1068                 setAttrib(robj, install("patternLikelihood"), Rlik);
1069                 setAttrib(robj, install("em.expected"), Rexpected);
1070         }
1071
1072         if (state->scores == SCORES_OMIT || state->type == EXPECTATION_UNINITIALIZED) return;
1073
1074         // TODO Wainer & Thissen. (1987). Estimating ability with the wrong
1075         // model. Journal of Educational Statistics, 12, 339-368.
1076
1077         /*
1078         int numQpoints = state->targetQpoints * 2;  // make configurable TODO
1079
1080         if (numQpoints < 1 + 2.0 * sqrt(state->itemSpec->cols)) {
1081                 // Thissen & Orlando (2001, p. 136)
1082                 warning("EAP requires at least 2*sqrt(items) quadrature points");
1083         }
1084
1085         ba81SetupQuadrature(oo, numQpoints, 0);
1086         ba81Estep1(oo);
1087         */
1088
1089         std::vector<double> mean;
1090         std::vector<double> cov;
1091         EAPinternalFast(oo, &mean, &cov);
1092
1093         int numUnique = state->numUnique;
1094         omxData *data = state->data;
1095         int rows = state->scores == SCORES_FULL? data->rows : numUnique;
1096         int cols = 2 * maxAbilities + triangleLoc1(maxAbilities);
1097         SEXP Rscores;
1098         PROTECT(Rscores = allocMatrix(REALSXP, rows, cols));
1099         double *scores = REAL(Rscores);
1100
1101         const int SMALLBUF = 10;
1102         char buf[SMALLBUF];
1103         SEXP names;
1104         PROTECT(names = allocVector(STRSXP, cols));
1105         for (int nx=0; nx < maxAbilities; ++nx) {
1106                 snprintf(buf, SMALLBUF, "s%d", nx+1);
1107                 SET_STRING_ELT(names, nx, mkChar(buf));
1108                 snprintf(buf, SMALLBUF, "se%d", nx+1);
1109                 SET_STRING_ELT(names, maxAbilities + nx, mkChar(buf));
1110         }
1111         for (int nx=0; nx < triangleLoc1(maxAbilities); ++nx) {
1112                 snprintf(buf, SMALLBUF, "cov%d", nx+1);
1113                 SET_STRING_ELT(names, maxAbilities*2 + nx, mkChar(buf));
1114         }
1115         SEXP dimnames;
1116         PROTECT(dimnames = allocVector(VECSXP, 2));
1117         SET_VECTOR_ELT(dimnames, 1, names);
1118         setAttrib(Rscores, R_DimNamesSymbol, dimnames);
1119
1120         if (state->scores == SCORES_FULL) {
1121 #pragma omp parallel for num_threads(Global->numThreads)
1122                 for (int rx=0; rx < numUnique; rx++) {
1123                         int dups = omxDataNumIdenticalRows(state->data, state->rowMap[rx]);
1124                         for (int dup=0; dup < dups; dup++) {
1125                                 int dest = omxDataIndex(data, state->rowMap[rx]+dup);
1126                                 copyScore(rows, maxAbilities, mean, cov, rx, scores, dest);
1127                         }
1128                 }
1129         } else {
1130 #pragma omp parallel for num_threads(Global->numThreads)
1131                 for (int rx=0; rx < numUnique; rx++) {
1132                         copyScore(rows, maxAbilities, mean, cov, rx, scores, rx);
1133                 }
1134         }
1135
1136         setAttrib(robj, install("scores.out"), Rscores);
1137 }
1138
1139 static void ba81Destroy(omxExpectation *oo) {
1140         if(OMX_DEBUG) {
1141                 mxLog("Freeing %s function.", oo->name);
1142         }
1143         BA81Expect *state = (BA81Expect *) oo->argStruct;
1144         omxFreeAllMatrixData(state->EitemParam);
1145         omxFreeAllMatrixData(state->design);
1146         omxFreeAllMatrixData(state->latentMeanOut);
1147         omxFreeAllMatrixData(state->latentCovOut);
1148         omxFreeAllMatrixData(state->customPrior);
1149         omxFreeAllMatrixData(state->itemParam);
1150         Free(state->numIdentical);
1151         Free(state->rowMap);
1152         Free(state->patternLik);
1153         Free(state->lxk);
1154         Free(state->Eslxk);
1155         Free(state->allElxk);
1156         Free(state->Sgroup);
1157         Free(state->expected);
1158         Free(state->outcomeProb);
1159         delete state;
1160 }
1161
1162 void getMatrixDims(SEXP r_theta, int *rows, int *cols)
1163 {
1164     SEXP matrixDims;
1165     PROTECT(matrixDims = getAttrib(r_theta, R_DimSymbol));
1166     int *dimList = INTEGER(matrixDims);
1167     *rows = dimList[0];
1168     *cols = dimList[1];
1169     UNPROTECT(1);
1170 }
1171
1172 static void ignoreSetVarGroup(omxExpectation*, FreeVarGroup *)
1173 {}
1174
1175 void omxInitExpectationBA81(omxExpectation* oo) {
1176         omxState* currentState = oo->currentState;      
1177         SEXP rObj = oo->rObj;
1178         SEXP tmp;
1179         
1180         if(OMX_DEBUG) {
1181                 mxLog("Initializing %s.", oo->name);
1182         }
1183         if (!rpf_model) {
1184                 if (0) {
1185                         const int wantVersion = 3;
1186                         int version;
1187                         get_librpf_t get_librpf = (get_librpf_t) R_GetCCallable("rpf", "get_librpf_model_GPL");
1188                         (*get_librpf)(&version, &rpf_numModels, &rpf_model);
1189                         if (version < wantVersion) error("librpf binary API %d installed, at least %d is required",
1190                                                          version, wantVersion);
1191                 } else {
1192                         rpf_numModels = librpf_numModels;
1193                         rpf_model = librpf_model;
1194                 }
1195         }
1196         
1197         BA81Expect *state = new BA81Expect;
1198         state->checkedBadData = FALSE;
1199         state->numSpecific = 0;
1200         state->numIdentical = NULL;
1201         state->rowMap = NULL;
1202         state->design = NULL;
1203         state->lxk = NULL;
1204         state->patternLik = NULL;
1205         state->Eslxk = NULL;
1206         state->allElxk = NULL;
1207         state->outcomeProb = NULL;
1208         state->expected = NULL;
1209         state->type = EXPECTATION_UNINITIALIZED;
1210         state->scores = SCORES_OMIT;
1211         state->itemParam = NULL;
1212         state->customPrior = NULL;
1213         state->itemParamVersion = 0;
1214         state->latentParamVersion = 0;
1215         oo->argStruct = (void*) state;
1216
1217         PROTECT(tmp = GET_SLOT(rObj, install("data")));
1218         state->data = omxDataLookupFromState(tmp, currentState);
1219
1220         if (strcmp(omxDataType(state->data), "raw") != 0) {
1221                 omxRaiseErrorf(currentState, "%s unable to handle data type %s", oo->name, omxDataType(state->data));
1222                 return;
1223         }
1224
1225         PROTECT(tmp = GET_SLOT(rObj, install("ItemSpec")));
1226         for (int sx=0; sx < length(tmp); ++sx) {
1227                 SEXP model = VECTOR_ELT(tmp, sx);
1228                 if (!OBJECT(model)) {
1229                         error("Item models must inherit rpf.base");
1230                 }
1231                 SEXP spec;
1232                 PROTECT(spec = GET_SLOT(model, install("spec")));
1233                 state->itemSpec.push_back(REAL(spec));
1234         }
1235
1236         PROTECT(tmp = GET_SLOT(rObj, install("design")));
1237         if (!isNull(tmp)) {
1238                 // better to demand integers and not coerce to real TODO
1239                 state->design = omxNewMatrixFromRPrimitive(tmp, globalState, FALSE, 0);
1240         }
1241
1242         state->latentMeanOut = omxNewMatrixFromSlot(rObj, currentState, "mean");
1243         if (!state->latentMeanOut) error("Failed to retrieve mean matrix");
1244         state->latentCovOut  = omxNewMatrixFromSlot(rObj, currentState, "cov");
1245         if (!state->latentCovOut) error("Failed to retrieve cov matrix");
1246
1247         state->EitemParam =
1248                 omxNewMatrixFromSlot(rObj, currentState, "EItemParam");
1249         if (!state->EitemParam) error("Must supply EItemParam");
1250
1251         state->itemParam =
1252                 omxNewMatrixFromSlot(rObj, globalState, "ItemParam");
1253
1254         if (state->EitemParam->rows != state->itemParam->rows ||
1255             state->EitemParam->cols != state->itemParam->cols) {
1256                 error("ItemParam and EItemParam must be of the same dimension");
1257         }
1258
1259         oo->computeFun = ba81compute;
1260         oo->setVarGroup = ignoreSetVarGroup;
1261         oo->destructFun = ba81Destroy;
1262         oo->populateAttrFun = ba81PopulateAttributes;
1263         
1264         // TODO: Exactly identical rows do not contribute any information.
1265         // The sorting algorithm ought to remove them so we don't waste RAM.
1266         // The following summary stats would be cheaper to calculate too.
1267
1268         int numUnique = 0;
1269         omxData *data = state->data;
1270         if (omxDataNumFactor(data) != data->cols) {
1271                 // verify they are ordered factors TODO
1272                 omxRaiseErrorf(currentState, "%s: all columns must be factors", oo->name);
1273                 return;
1274         }
1275
1276         for (int rx=0; rx < data->rows;) {
1277                 rx += omxDataNumIdenticalRows(state->data, rx);
1278                 ++numUnique;
1279         }
1280         state->numUnique = numUnique;
1281
1282         state->rowMap = Realloc(NULL, numUnique, int);
1283         state->numIdentical = Realloc(NULL, numUnique, int);
1284
1285         state->customPrior =
1286                 omxNewMatrixFromSlot(rObj, globalState, "CustomPrior");
1287         
1288         int numItems = state->EitemParam->cols;
1289         if (data->cols != numItems) {
1290                 error("Data has %d columns for %d items", data->cols, numItems);
1291         }
1292
1293         for (int rx=0, ux=0; rx < data->rows; ux++) {
1294                 if (rx == 0) {
1295                         // all NA rows will sort to the top
1296                         int na=0;
1297                         for (int ix=0; ix < numItems; ix++) {
1298                                 if (omxIntDataElement(data, 0, ix) == NA_INTEGER) { ++na; }
1299                         }
1300                         if (na == numItems) {
1301                                 omxRaiseErrorf(currentState, "Remove rows with all NAs");
1302                                 return;
1303                         }
1304                 }
1305                 int dups = omxDataNumIdenticalRows(state->data, rx);
1306                 state->numIdentical[ux] = dups;
1307                 state->rowMap[ux] = rx;
1308                 rx += dups;
1309         }
1310
1311         int numThreads = Global->numThreads;
1312
1313         int maxSpec = 0;
1314         int maxParam = 0;
1315         state->maxDims = 0;
1316
1317         std::vector<int> &itemOutcomes = state->itemOutcomes;
1318         itemOutcomes.resize(numItems);
1319         int totalOutcomes = 0;
1320         for (int cx = 0; cx < data->cols; cx++) {
1321                 const double *spec = state->itemSpec[cx];
1322                 int id = spec[RPF_ISpecID];
1323                 int dims = spec[RPF_ISpecDims];
1324                 if (state->maxDims < dims)
1325                         state->maxDims = dims;
1326
1327                 int no = spec[RPF_ISpecOutcomes];
1328                 itemOutcomes[cx] = no;
1329                 totalOutcomes += no;
1330
1331                 // TODO this summary stat should be available from omxData
1332                 int dataMax=0;
1333                 for (int rx=0; rx < data->rows; rx++) {
1334                         int pick = omxIntDataElementUnsafe(data, rx, cx);
1335                         if (dataMax < pick)
1336                                 dataMax = pick;
1337                 }
1338                 if (dataMax > no) {
1339                         error("Data for item %d has %d outcomes, not %d", cx+1, dataMax, no);
1340                 } else if (dataMax < no) {
1341                         warning("Data for item %d has only %d outcomes, not %d", cx+1, dataMax, no);
1342                         // promote to error?
1343                         // should complain if an outcome is not represented in the data TODO
1344                 }
1345
1346                 int numSpec = (*rpf_model[id].numSpec)(spec);
1347                 if (maxSpec < numSpec)
1348                         maxSpec = numSpec;
1349
1350                 int numParam = (*rpf_model[id].numParam)(spec);
1351                 if (maxParam < numParam)
1352                         maxParam = numParam;
1353         }
1354
1355         state->totalOutcomes = totalOutcomes;
1356
1357         if (int(state->itemSpec.size()) != data->cols) {
1358                 omxRaiseErrorf(currentState, "ItemSpec must contain %d item model specifications",
1359                                data->cols);
1360                 return;
1361         }
1362         if (state->EitemParam->rows != maxParam) {
1363                 omxRaiseErrorf(currentState, "ItemParam should have %d rows", maxParam);
1364                 return;
1365         }
1366
1367         if (state->design == NULL) {
1368                 state->maxAbilities = state->maxDims;
1369                 state->design = omxInitTemporaryMatrix(NULL, state->maxDims, numItems,
1370                                        TRUE, currentState);
1371                 for (int ix=0; ix < numItems; ix++) {
1372                         const double *spec = state->itemSpec[ix];
1373                         int dims = spec[RPF_ISpecDims];
1374                         for (int dx=0; dx < state->maxDims; dx++) {
1375                                 omxSetMatrixElement(state->design, dx, ix, dx < dims? (double)dx+1 : nan(""));
1376                         }
1377                 }
1378         } else {
1379                 omxMatrix *design = state->design;
1380                 if (design->cols != numItems ||
1381                     design->rows != state->maxDims) {
1382                         omxRaiseErrorf(currentState, "Design matrix should have %d rows and %d columns",
1383                                        state->maxDims, numItems);
1384                         return;
1385                 }
1386
1387                 state->maxAbilities = 0;
1388                 for (int ix=0; ix < design->rows * design->cols; ix++) {
1389                         double got = design->data[ix];
1390                         if (!R_FINITE(got)) continue;
1391                         if (round(got) != (int)got) error("Design matrix can only contain integers"); // TODO better way?
1392                         if (state->maxAbilities < got)
1393                                 state->maxAbilities = got;
1394                 }
1395                 for (int ix=0; ix < design->cols; ix++) {
1396                         const double *idesign = omxMatrixColumn(design, ix);
1397                         int ddim = 0;
1398                         for (int rx=0; rx < design->rows; rx++) {
1399                                 if (std::isfinite(idesign[rx])) ddim += 1;
1400                         }
1401                         const double *spec = state->itemSpec[ix];
1402                         int dims = spec[RPF_ISpecDims];
1403                         if (ddim > dims) error("Item %d has %d dims but design assigns %d", ix, dims, ddim);
1404                 }
1405         }
1406         if (state->maxAbilities <= state->maxDims) {
1407                 state->Sgroup = Calloc(numItems, int);
1408         } else {
1409                 // Not sure if this is correct, revisit TODO
1410                 int Sgroup0 = -1;
1411                 state->Sgroup = Realloc(NULL, numItems, int);
1412                 for (int dx=0; dx < state->maxDims; dx++) {
1413                         for (int ix=0; ix < numItems; ix++) {
1414                                 int ability = omxMatrixElement(state->design, dx, ix);
1415                                 if (dx < state->maxDims - 1) {
1416                                         if (Sgroup0 <= ability)
1417                                                 Sgroup0 = ability+1;
1418                                         continue;
1419                                 }
1420                                 int ss=-1;
1421                                 if (ability >= Sgroup0) {
1422                                         if (ss == -1) {
1423                                                 ss = ability;
1424                                         } else {
1425                                                 omxRaiseErrorf(currentState, "Item %d cannot belong to more than "
1426                                                                "1 specific dimension (both %d and %d)",
1427                                                                ix, ss, ability);
1428                                                 return;
1429                                         }
1430                                 }
1431                                 if (ss == -1) ss = Sgroup0;
1432                                 state->Sgroup[ix] = ss - Sgroup0;
1433                         }
1434                 }
1435                 state->numSpecific = state->maxAbilities - state->maxDims + 1;
1436                 state->allElxk = Realloc(NULL, numUnique * numThreads, double);
1437                 state->Eslxk = Realloc(NULL, numUnique * state->numSpecific * numThreads, double);
1438         }
1439
1440         if (state->latentMeanOut->rows * state->latentMeanOut->cols != state->maxAbilities) {
1441                 error("The mean matrix '%s' must be 1x%d or %dx1", state->latentMeanOut->name,
1442                       state->maxAbilities, state->maxAbilities);
1443         }
1444         if (state->latentCovOut->rows != state->maxAbilities ||
1445             state->latentCovOut->cols != state->maxAbilities) {
1446                 error("The cov matrix '%s' must be %dx%d",
1447                       state->latentCovOut->name, state->maxAbilities, state->maxAbilities);
1448         }
1449
1450         PROTECT(tmp = GET_SLOT(rObj, install("verbose")));
1451         state->verbose = asLogical(tmp);
1452
1453         PROTECT(tmp = GET_SLOT(rObj, install("cache")));
1454         state->cacheLXK = asLogical(tmp);
1455         state->LXKcached = FALSE;
1456
1457         PROTECT(tmp = GET_SLOT(rObj, install("qpoints")));
1458         state->targetQpoints = asReal(tmp);
1459
1460         PROTECT(tmp = GET_SLOT(rObj, install("qwidth")));
1461         state->Qwidth = asReal(tmp);
1462
1463         PROTECT(tmp = GET_SLOT(rObj, install("scores")));
1464         const char *score_option = CHAR(asChar(tmp));
1465         if (strcmp(score_option, "omit")==0) state->scores = SCORES_OMIT;
1466         if (strcmp(score_option, "unique")==0) state->scores = SCORES_UNIQUE;
1467         if (strcmp(score_option, "full")==0) state->scores = SCORES_FULL;
1468
1469         state->ElatentMean.resize(state->maxAbilities);
1470         state->ElatentCov.resize(state->maxAbilities * state->maxAbilities);
1471
1472         // verify data bounded between 1 and numOutcomes TODO
1473         // hm, looks like something could be added to omxData for column summary stats?
1474 }