Reorg E step 9/9
[openmx:openmx.git] / src / omxExpectationBA81.cpp
1 /*
2   Copyright 2012-2013 Joshua Nathaniel Pritikin and contributors
3
4   This is free software: you can redistribute it and/or modify
5   it under the terms of the GNU General Public License as published by
6   the Free Software Foundation, either version 3 of the License, or
7   (at your option) any later version.
8
9   This program is distributed in the hope that it will be useful,
10   but WITHOUT ANY WARRANTY; without even the implied warranty of
11   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12   GNU General Public License for more details.
13
14   You should have received a copy of the GNU General Public License
15   along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #include <Rmath.h>
19
20 #include "omxExpectationBA81.h"
21 #include "glue.h"
22 #include "libifa-rpf.h"
23 #include "dmvnorm.h"
24
25 const struct rpf *rpf_model = NULL;
26 int rpf_numModels;
27
28 void pda(const double *ar, int rows, int cols)
29 {
30         std::string buf;
31         for (int rx=0; rx < rows; rx++) {   // column major order
32                 for (int cx=0; cx < cols; cx++) {
33                         buf += string_snprintf("%.6g, ", ar[cx * rows + rx]);
34                 }
35                 buf += "\n";
36         }
37         mxLogBig(buf);
38 }
39
40 void pia(const int *ar, int rows, int cols)
41 {
42         std::string buf;
43         for (int rx=0; rx < rows; rx++) {   // column major order
44                 for (int cx=0; cx < cols; cx++) {
45                         buf += string_snprintf("%d, ", ar[cx * rows + rx]);
46                 }
47                 buf += "\n";
48         }
49         mxLogBig(buf);
50 }
51
52 // state->speQarea[sIndex(state, sx, qx)]
53 OMXINLINE static
54 int sIndex(BA81Expect *state, int sx, int qx)
55 {
56         //if (sx < 0 || sx >= state->numSpecific) error("Out of domain");
57         //if (qx < 0 || qx >= state->quadGridSize) error("Out of domain");
58         return sx * state->quadGridSize + qx;
59 }
60
61 // Depends on item parameters, but not latent distribution
62 void computeRPF(BA81Expect *state, omxMatrix *itemParam, const int *quad,
63                 const bool wantlog, double *out)
64 {
65         omxMatrix *design = state->design;
66         int maxDims = state->maxDims;
67         size_t numItems = state->itemSpec.size();
68
69         double theta[maxDims];
70         pointToWhere(state, quad, theta, maxDims);
71
72         for (size_t ix=0; ix < numItems; ix++) {
73                 const double *spec = state->itemSpec[ix];
74                 int id = spec[RPF_ISpecID];
75                 int dims = spec[RPF_ISpecDims];
76                 double ptheta[dims];
77
78                 for (int dx=0; dx < dims; dx++) {
79                         int ability = (int)omxMatrixElement(design, dx, ix) - 1;
80                         if (ability >= maxDims) ability = maxDims-1;
81                         ptheta[dx] = theta[ability];
82                 }
83
84                 double *iparam = omxMatrixColumn(itemParam, ix);
85                 if (wantlog) {
86                         (*rpf_model[id].logprob)(spec, iparam, ptheta, out);
87                 } else {
88                         (*rpf_model[id].prob)(spec, iparam, ptheta, out);
89                 }
90 #if 0
91                 for (int ox=0; ox < state->itemOutcomes[ix]; ox++) {
92                         if (!isfinite(out[ox]) || out[ox] > 0) {
93                                 mxLog("item param");
94                                 pda(iparam, itemParam->rows, 1);
95                                 mxLog("where");
96                                 pda(ptheta, dims, 1);
97                                 error("RPF returned %20.20f", out[ox]);
98                         }
99                 }
100 #endif
101                 out += state->itemOutcomes[ix];
102         }
103 }
104
105 OMXINLINE static double *
106 getLXKcache(BA81Expect *state, const long qx, const int specific)
107 {
108         long ordinate;
109         if (state->numSpecific == 0) {
110                 ordinate = qx;
111         } else {
112                 ordinate = specific * state->totalQuadPoints + qx;
113         }
114         return state->lxk + state->numUnique * ordinate;
115 }
116
117 OMXINLINE static double *
118 ba81Likelihood1(omxExpectation *oo, const int thrId, int specific, const long qx)
119 {
120         BA81Expect *state = (BA81Expect*) oo->argStruct;
121         int numUnique = state->numUnique;
122         std::vector<int> &itemOutcomes = state->itemOutcomes;
123         omxData *data = state->data;
124         size_t numItems = state->itemSpec.size();
125         int *Sgroup = state->Sgroup;
126         double *lxk;
127
128         if (!state->cacheLXK) {
129                 lxk = state->lxk + numUnique * thrId;
130         } else {
131                 lxk = getLXKcache(state, qx, specific);
132         }
133
134         const int *rowMap = state->rowMap;
135         for (int px=0; px < numUnique; px++) {
136                 double lxk1 = 1;
137                 const double *oProb = state->outcomeProb + qx * state->totalOutcomes;
138                 for (size_t ix=0; ix < numItems; ix++) {
139                         int pick = omxIntDataElementUnsafe(data, rowMap[px], ix);
140                         if (specific == Sgroup[ix] && pick != NA_INTEGER) {
141                                 double piece = oProb[pick-1];  // move -1 elsewhere TODO
142                                 lxk1 *= piece;
143                                 //mxLog("%d pick %d piece %.7f", ix, pick-1, piece);
144                         }
145                         oProb += itemOutcomes[ix];
146                 }
147                 lxk[px] = lxk1;
148         }
149
150         return lxk;
151 }
152
153 double *ba81LikelihoodFast1(omxExpectation *oo, const int thrId, int specific, const long qx)
154 {
155         BA81Expect *state = (BA81Expect*) oo->argStruct;
156         if (!state->cacheLXK) {
157                 double *ret = ba81Likelihood1(oo, thrId, specific, qx);
158                 return ret;
159         } else {
160                 return getLXKcache(state, qx, specific);
161         }
162
163 }
164
165 static OMXINLINE void
166 ba81LikelihoodSlow2(BA81Expect *state, int px, double *out)
167 {
168         long totalQuadPoints = state->totalQuadPoints;
169         std::vector<int> &itemOutcomes = state->itemOutcomes;
170         size_t numItems = state->itemSpec.size();
171         omxData *data = state->data;
172         const int *rowMap = state->rowMap;
173         int totalOutcomes = state->totalOutcomes;
174         int numSpecific = state->numSpecific;
175         int outcomeBase = -itemOutcomes[0];
176
177         if (numSpecific == 0) {
178                 for (long qx=0; qx < totalQuadPoints; ++qx) {
179                         out[qx] = 1.0;
180                 }
181
182                 for (size_t ix=0; ix < numItems; ix++) {
183                         outcomeBase += itemOutcomes[ix];
184                         int pick = omxIntDataElementUnsafe(data, rowMap[px], ix);
185                         if (pick == NA_INTEGER) continue;
186                         pick -= 1;
187
188                         double *oProb = state->outcomeProb + outcomeBase;
189                         for (long qx=0; qx < totalQuadPoints; ++qx) {
190                                 out[qx] *= oProb[pick];
191                                 oProb += totalOutcomes;
192                         }
193                 }
194         } else {
195                 for (long qx=0; qx < totalQuadPoints * numSpecific; ++qx) {
196                         out[qx] = 1.0;
197                 }
198
199                 for (size_t ix=0; ix < numItems; ix++) {
200                         outcomeBase += itemOutcomes[ix];
201                         int pick = omxIntDataElementUnsafe(data, rowMap[px], ix);
202                         if (pick == NA_INTEGER) continue;
203                         pick -= 1;
204                         int Sbase = state->Sgroup[ix] * totalQuadPoints;
205                         double *oProb = state->outcomeProb + outcomeBase;
206                         for (long qx=0; qx < state->totalQuadPoints; qx++) {
207                                 out[Sbase + qx] *= oProb[pick];
208                                 oProb += totalOutcomes;
209                         }
210                 }
211         }
212 }
213
214 static OMXINLINE void
215 cai2010EiEis(BA81Expect *state, int px, double *lxk, double *Eis, double *Ei)
216 {
217         long totalQuadPoints = state->totalQuadPoints;
218         int numSpecific = state->numSpecific;
219         long totalPrimaryPoints = state->totalPrimaryPoints;
220         long specificPoints = state->quadGridSize;
221
222         for (int sgroup=0, Sbase=0; sgroup < numSpecific; ++sgroup, Sbase += totalQuadPoints) {
223                 long qloc = 0;
224                 for (long qx=0; qx < totalPrimaryPoints; qx++) {
225                         for (long sx=0; sx < specificPoints; sx++) {
226                                 double area = state->speQarea[sIndex(state, sgroup, sx)];
227                                 double piece = lxk[Sbase + qloc] * area;
228                                 Eis[totalPrimaryPoints * sgroup + qx] += piece;
229                                 ++qloc;
230                         }
231                         Ei[qx] *= Eis[totalPrimaryPoints * sgroup + qx];
232                 }
233         }
234 }
235
236 static OMXINLINE void
237 ba81LikelihoodFast2(BA81Expect *state, int px, double *out)
238 {
239         long totalQuadPoints = state->totalQuadPoints;
240         int numSpecific = state->numSpecific;
241
242         if (state->cacheLXK) {
243                 if (numSpecific == 0) {
244                         for (long qx=0; qx < totalQuadPoints; qx++) {
245                                 out[qx] = state->lxk[state->numUnique * qx + px]; // transpose cache to avoid copy TODO
246                         }
247                 } else {
248                         long qloc = 0;
249                         for (int sx=0; sx < numSpecific; ++sx) {
250                                 for (long qx=0; qx < totalQuadPoints; ++qx) {
251                                         long ordinate = sx * state->totalQuadPoints + qx;
252                                         out[qloc] = state->lxk[state->numUnique * ordinate + px];
253                                         ++qloc;
254                                 }
255                         }
256                 }
257         } else {
258                 ba81LikelihoodSlow2(state, px, out);
259         }
260 }
261
262 OMXINLINE static void
263 mapLatentSpace(BA81Expect *state, int sgroup, double piece, const double *where,
264                const double *whereGram, double *latentDist)
265 {
266         int maxDims = state->maxDims;
267         int maxAbilities = state->maxAbilities;
268         int pmax = maxDims;
269         if (state->numSpecific) pmax -= 1;
270
271         if (sgroup == 0) {
272                 int gx = 0;
273                 int cx = maxAbilities;
274                 for (int d1=0; d1 < pmax; d1++) {
275                         double piece_w1 = piece * where[d1];
276                         latentDist[d1] += piece_w1;
277                         for (int d2=0; d2 <= d1; d2++) {
278                                 double piece_cov = piece * whereGram[gx];
279                                 latentDist[cx] += piece_cov;
280                                 ++cx; ++gx;
281                         }
282                 }
283         }
284
285         if (state->numSpecific) {
286                 int sdim = pmax + sgroup;
287                 double piece_w1 = piece * where[pmax];
288                 latentDist[sdim] += piece_w1;
289
290                 double piece_var = piece * whereGram[triangleLoc0(pmax)];
291                 int to = maxAbilities + triangleLoc0(sdim);
292                 latentDist[to] += piece_var;
293         }
294 }
295
296 // Eslxk, allElxk (Ei, Eis) depend on the ordinate of the primary dimensions
297 void cai2010(omxExpectation* oo, const int thrId, int recompute, const long primaryQ)
298 {
299         BA81Expect *state = (BA81Expect*) oo->argStruct;
300         int numUnique = state->numUnique;
301         int numSpecific = state->numSpecific;
302         int quadGridSize = state->quadGridSize;
303         double *allElxk = eBase(state, thrId);
304         double *Eslxk = esBase(state, thrId);
305
306         for (int px=0; px < numUnique; px++) {
307                 allElxk[px] = 1;
308                 for (int sx=0; sx < numSpecific; sx++) {
309                         Eslxk[sx * numUnique + px] = 0;
310                 }
311         }
312
313         if (!state->cacheLXK) recompute = TRUE;
314
315         for (int sx=0; sx < quadGridSize; sx++) {
316                 long qloc = primaryQ * quadGridSize + sx;
317
318                 for (int sgroup=0; sgroup < numSpecific; sgroup++) {
319                         double *myEslxk = Eslxk + sgroup * numUnique;
320                         double *lxk;     // a.k.a. "L_is"
321                         if (recompute) {
322                                 lxk = ba81Likelihood1(oo, thrId, sgroup, qloc);
323                         } else {
324                                 lxk = getLXKcache(state, qloc, sgroup);
325                         }
326
327                         for (int ix=0; ix < numUnique; ix++) {
328                                 double area = state->speQarea[sIndex(state, sgroup, sx)];
329                                 double piece = lxk[ix] * area;
330                                 //mxLog("E.is(%d) (%ld,%d) %.6f + %.6f = %.6f",
331                                 //  sgroup, primaryQ, sx, lxk[ix], area, piece);
332                                 myEslxk[ix] += piece;
333                         }
334                 }
335         }
336
337         for (int sx=0; sx < numSpecific; sx++) {
338                 for (int px=0; px < numUnique; px++) {
339                         //mxLog("E.is(%d) at (%ld) %.6f", sx, primaryQ,
340                         //  Eslxk[sx * numUnique + px]);
341                         allElxk[px] *= Eslxk[sx * numUnique + px];  // allSlxk a.k.a. "E_i"
342                 }
343         }
344 }
345
346 static void ba81OutcomeProb(BA81Expect *state)
347 {
348         int maxDims = state->maxDims;
349         double *qProb = state->outcomeProb =
350                 Realloc(state->outcomeProb, state->totalOutcomes * state->totalQuadPoints, double);
351         for (long qx=0; qx < state->totalQuadPoints; qx++) {
352                 int quad[maxDims];
353                 decodeLocation(qx, maxDims, state->quadGridSize, quad);
354                 double where[maxDims];
355                 pointToWhere(state, quad, where, maxDims);
356                 computeRPF(state, state->EitemParam, quad, FALSE, qProb);
357                 qProb += state->totalOutcomes;
358         }
359 }
360
361 static void ba81Estep1(omxExpectation *oo)
362 {
363         if(OMX_DEBUG) {mxLog("Beginning %s Computation.", oo->name);}
364
365         BA81Expect *state = (BA81Expect*) oo->argStruct;
366         if (state->verbose) {
367                 mxLog("%s: lxk(%d) patternLik ElatentMean ElatentCov",
368                       oo->name, omxGetMatrixVersion(state->EitemParam));
369         }
370
371         int numUnique = state->numUnique;
372         int numSpecific = state->numSpecific;
373         int maxDims = state->maxDims;
374         int maxAbilities = state->maxAbilities;
375         int primaryDims = maxDims;
376         int totalOutcomes = state->totalOutcomes;
377         omxData *data = state->data;
378         int *numIdentical = state->numIdentical;
379         long totalQuadPoints = state->totalQuadPoints;
380
381         state->patternLik = Realloc(state->patternLik, numUnique, double);
382         double *patternLik = state->patternLik;
383         std::vector<double> thrExpected(totalOutcomes * totalQuadPoints * Global->numThreads);
384
385         int numLatents = maxAbilities + triangleLoc1(maxAbilities);
386         int numLatentsPerThread = numUnique * numLatents;
387         double *latentDist = Calloc(numUnique * numLatents * Global->numThreads, double);
388
389         int whereChunk = maxDims + triangleLoc1(maxDims);
390         std::vector<double> wherePrep(totalQuadPoints * whereChunk);
391         for (long qx=0; qx < totalQuadPoints; qx++) {
392                 double *wh = wherePrep.data() + qx * whereChunk;
393                 int quad[maxDims];
394                 decodeLocation(qx, maxDims, state->quadGridSize, quad);
395                 pointToWhere(state, quad, wh, maxDims);
396                 gramProduct(wh, maxDims, wh + maxDims);
397         }
398
399         if (numSpecific == 0) {
400 #pragma omp parallel for num_threads(Global->numThreads) schedule(static,32)
401                 for (int px=0; px < numUnique; px++) {
402                         int thrId = omx_absolute_thread_num();
403                         double *thrLatentDist = latentDist + thrId * numLatentsPerThread;
404
405                         std::vector<double> lxk(totalQuadPoints); // make thread local TODO
406                         ba81LikelihoodSlow2(state, px, lxk.data());
407
408                         double *lxkCache = state->cacheLXK? state->lxk + px : NULL;
409                         double patternLik1 = 0;
410                         double *wh = wherePrep.data();
411                         for (long qx=0; qx < totalQuadPoints; qx++) {
412                                 double area = state->priQarea[qx];
413                                 double tmp = lxk[qx] * area;
414                                 patternLik1 += tmp;
415                                 mapLatentSpace(state, 0, tmp, wh, wh + maxDims,
416                                                thrLatentDist + px * numLatents);
417
418                                 if (lxkCache) {
419                                         *lxkCache = lxk[qx];
420                                         lxkCache += numUnique;
421                                 }
422                                 wh += whereChunk;
423                         }
424
425                         patternLik[px] = patternLik1;
426                 }
427         } else {
428                 primaryDims -= 1;
429                 long totalPrimaryPoints = state->totalPrimaryPoints;
430                 long specificPoints = state->quadGridSize;
431                 double *EiCache = state->EiCache;
432
433 #pragma omp parallel for num_threads(Global->numThreads) schedule(static, totalPrimaryPoints*8)
434                 for (int ex=0; ex < totalPrimaryPoints * numUnique; ++ex) {
435                         EiCache[ex] = 1.0;
436                 }
437
438 #pragma omp parallel for num_threads(Global->numThreads) schedule(static,32)
439                 for (int px=0; px < numUnique; px++) {
440                         int thrId = omx_absolute_thread_num();
441                         double *thrLatentDist = latentDist + thrId * numLatentsPerThread;
442                         double *myEi = EiCache + px * totalPrimaryPoints;
443
444                         std::vector<double> lxk(totalQuadPoints * numSpecific);
445                         ba81LikelihoodSlow2(state, px, lxk.data());
446
447                         std::vector<double> Eis(totalPrimaryPoints * numSpecific, 0.0);
448                         cai2010EiEis(state, px, lxk.data(), Eis.data(), myEi);
449
450                         double patternLik1 = 0;
451                         double *wh = wherePrep.data();
452                         for (long qx=0, qloc=0; qx < totalPrimaryPoints; ++qx) {
453                                 double priArea = state->priQarea[qx];
454                                 double EiArea = myEi[qx] * priArea;
455                                 patternLik1 += EiArea;
456                                 for (long sx=0; sx < specificPoints; sx++) {
457                                         for (int sgroup=0; sgroup < numSpecific; sgroup++) {
458                                                 double area = state->speQarea[sgroup * specificPoints + sx];
459                                                 double lxk1 = lxk[totalQuadPoints * sgroup + qloc];
460                                                 double Eis1 = Eis[totalPrimaryPoints * sgroup + qx];
461                                                 double tmp = (EiArea / Eis1) * lxk1 * area;
462                                                 mapLatentSpace(state, sgroup, tmp, wh, wh + maxDims,
463                                                                thrLatentDist + px * numLatents);
464                                                 if (state->cacheLXK) {
465                                                         double *lxkCache = getLXKcache(state, qloc, sgroup);
466                                                         lxkCache[px] = lxk1;
467                                                 }
468                                         }
469                                         wh += whereChunk;
470                                         ++qloc;
471                                 }
472                         }
473
474                         patternLik[px] = patternLik1;
475                 }
476         }
477
478         long expectedSize = totalQuadPoints * totalOutcomes;
479 #pragma omp parallel for num_threads(Global->numThreads) schedule(static,64)
480         for (long qx=0; qx < expectedSize; ++qx) {
481                 state->expected[qx] = 0;
482                 double *e1 = thrExpected.data() + qx;
483                 for (int tx=0; tx < Global->numThreads; ++tx) {
484                         state->expected[qx] += *e1;
485                         e1 += expectedSize;
486                 }
487         }
488
489         //mxLog("raw latent");
490         //pda(latentDist, numLatents, numUnique);
491
492 #pragma omp parallel for num_threads(Global->numThreads) schedule(dynamic)
493         for (int lx=0; lx < maxAbilities + triangleLoc1(primaryDims); ++lx) {
494                 for (int tx=1; tx < Global->numThreads; ++tx) {
495                         double *thrLatentDist = latentDist + tx * numLatentsPerThread;
496                         for (int px=0; px < numUnique; px++) {
497                                 int loc = px * numLatents + lx;
498                                 latentDist[loc] += thrLatentDist[loc];
499                         }
500                 }
501         }
502
503 #pragma omp parallel for num_threads(Global->numThreads)
504         for (int sdim=primaryDims; sdim < maxAbilities; sdim++) {
505                 for (int tx=1; tx < Global->numThreads; ++tx) {
506                         double *thrLatentDist = latentDist + tx * numLatentsPerThread;
507                         for (int px=0; px < numUnique; px++) {
508                                 int loc = px * numLatents + maxAbilities + triangleLoc0(sdim);
509                                 latentDist[loc] += thrLatentDist[loc];
510                         }
511                 }
512         }
513
514 #pragma omp parallel for num_threads(Global->numThreads)
515         for (int px=0; px < numUnique; px++) {
516                 if (!isfinite(patternLik[px])) {
517                         omxRaiseErrorf(globalState, "Likelihood of pattern %d is %.3g",
518                                        px, patternLik[px]);
519                 }
520
521                 double *latentDist1 = latentDist + px * numLatents;
522                 double weight = numIdentical[px] / patternLik[px];
523                 int cx = maxAbilities;
524                 for (int d1=0; d1 < primaryDims; d1++) {
525                         latentDist1[d1] *= weight;
526                         for (int d2=0; d2 <= d1; d2++) {
527                                 latentDist1[cx] *= weight;
528                                 ++cx;
529                         }
530                 }
531                 for (int sdim=primaryDims; sdim < maxAbilities; sdim++) {
532                         latentDist1[sdim] *= weight;
533                         int loc = maxAbilities + triangleLoc0(sdim);
534                         latentDist1[loc] *= weight;
535                 }
536 #if 0
537                 if (!isfinite(patternLik[px])) {
538                         error("Likelihood of row %d is %f", state->rowMap[px], patternLik[px]);
539                 }
540 #endif
541         }
542
543         //mxLog("raw latent after weighting");
544         //pda(latentDist, numLatents, numUnique);
545
546         std::vector<double> &ElatentMean = state->ElatentMean;
547         std::vector<double> &ElatentCov = state->ElatentCov;
548         
549         ElatentMean.assign(ElatentMean.size(), 0.0);
550         ElatentCov.assign(ElatentCov.size(), 0.0);
551
552 #pragma omp parallel for num_threads(Global->numThreads)
553         for (int d1=0; d1 < maxAbilities; d1++) {
554                 for (int px=0; px < numUnique; px++) {
555                         double *latentDist1 = latentDist + px * numLatents;
556                         int cx = maxAbilities + triangleLoc1(d1);
557                         if (d1 < primaryDims) {
558                                 ElatentMean[d1] += latentDist1[d1];
559                                 for (int d2=0; d2 <= d1; d2++) {
560                                         int cell = d2 * maxAbilities + d1;
561                                         ElatentCov[cell] += latentDist1[cx];
562                                         ++cx;
563                                 }
564                         } else {
565                                 ElatentMean[d1] += latentDist1[d1];
566                                 int cell = d1 * maxAbilities + d1;
567                                 int loc = maxAbilities + triangleLoc0(d1);
568                                 ElatentCov[cell] += latentDist1[loc];
569                         }
570                 }
571         }
572
573         //pda(ElatentMean.data(), 1, state->maxAbilities);
574         //pda(ElatentCov.data(), state->maxAbilities, state->maxAbilities);
575
576         for (int d1=0; d1 < maxAbilities; d1++) {
577                 ElatentMean[d1] /= data->rows;
578         }
579
580         for (int d1=0; d1 < primaryDims; d1++) {
581                 for (int d2=0; d2 <= d1; d2++) {
582                         int cell = d2 * maxAbilities + d1;
583                         int tcell = d1 * maxAbilities + d2;
584                         ElatentCov[tcell] = ElatentCov[cell] =
585                                 ElatentCov[cell] / data->rows - ElatentMean[d1] * ElatentMean[d2];
586                 }
587         }
588         for (int sdim=primaryDims; sdim < maxAbilities; sdim++) {
589                 int cell = sdim * maxAbilities + sdim;
590                 ElatentCov[cell] = ElatentCov[cell] / data->rows - ElatentMean[sdim] * ElatentMean[sdim];
591         }
592
593         if (state->cacheLXK) state->LXKcached = TRUE;
594
595         Free(latentDist);
596
597         //mxLog("E-step");
598         //pda(ElatentMean.data(), 1, state->maxAbilities);
599         //pda(ElatentCov.data(), state->maxAbilities, state->maxAbilities);
600 }
601
602 static int getLatentVersion(BA81Expect *state)
603 {
604         return omxGetMatrixVersion(state->latentMeanOut) + omxGetMatrixVersion(state->latentCovOut);
605 }
606
607 // Attempt G-H grid? http://dbarajassolano.wordpress.com/2012/01/26/on-sparse-grid-quadratures/
608 static void ba81SetupQuadrature(omxExpectation* oo, int gridsize)
609 {
610         BA81Expect *state = (BA81Expect *) oo->argStruct;
611         if (state->verbose) {
612                 mxLog("%s: quadrature(%d)", oo->name, getLatentVersion(state));
613         }
614         int numUnique = state->numUnique;
615         int numThreads = Global->numThreads;
616         int maxDims = state->maxDims;
617         double Qwidth = state->Qwidth;
618         int numSpecific = state->numSpecific;
619         int priDims = maxDims - (numSpecific? 1 : 0);
620
621         // try starting small and increasing to the cap TODO
622         state->quadGridSize = gridsize;
623
624         state->totalQuadPoints = 1;
625         for (int dx=0; dx < maxDims; dx++) {
626                 state->totalQuadPoints *= state->quadGridSize;
627         }
628
629         state->totalPrimaryPoints = state->totalQuadPoints;
630
631         if (numSpecific) {
632                 state->totalPrimaryPoints /= state->quadGridSize;
633                 state->speQarea.resize(gridsize * numSpecific);
634         }
635
636         state->Qpoint.resize(gridsize);
637         state->priQarea.resize(state->totalPrimaryPoints);
638
639         double qgs = state->quadGridSize-1;
640         for (int px=0; px < state->quadGridSize; px ++) {
641                 state->Qpoint[px] = Qwidth - px * 2 * Qwidth / qgs;
642         }
643
644         //pda(state->latentMeanOut->data, 1, state->maxAbilities);
645         //pda(state->latentCovOut->data, state->maxAbilities, state->maxAbilities);
646
647         double totalArea = 0;
648         for (int qx=0; qx < state->totalPrimaryPoints; qx++) {
649                 int quad[priDims];
650                 decodeLocation(qx, priDims, state->quadGridSize, quad);
651                 double where[priDims];
652                 pointToWhere(state, quad, where, priDims);
653                 state->priQarea[qx] = exp(dmvnorm(priDims, where,
654                                                   state->latentMeanOut->data,
655                                                   state->latentCovOut->data));
656                 totalArea += state->priQarea[qx];
657         }
658         for (int qx=0; qx < state->totalPrimaryPoints; qx++) {
659                 state->priQarea[qx] /= totalArea;
660                 //mxLog("%.5g,", state->priQarea[qx]);
661         }
662
663         for (int sx=0; sx < numSpecific; sx++) {
664                 totalArea = 0;
665                 int covCell = (priDims + sx) * state->maxAbilities + priDims + sx;
666                 double mean = state->latentMeanOut->data[priDims + sx];
667                 double var = state->latentCovOut->data[covCell];
668                 //mxLog("setup[%d] %.2f %.2f", sx, mean, var);
669                 for (int qx=0; qx < state->quadGridSize; qx++) {
670                         double den = dnorm(state->Qpoint[qx], mean, sqrt(var), FALSE);
671                         state->speQarea[sIndex(state, sx, qx)] = den;
672                         totalArea += den;
673                 }
674                 for (int qx=0; qx < state->quadGridSize; qx++) {
675                         state->speQarea[sIndex(state, sx, qx)] /= totalArea;
676                 }
677                 //pda(state->speQarea.data() + sIndex(state, sx, 0), 1, state->quadGridSize);
678         }
679
680         if (!state->cacheLXK) {
681                 state->lxk = Realloc(state->lxk, numUnique * numThreads, double);
682         } else {
683                 int ns = state->numSpecific;
684                 if (ns == 0) ns = 1;
685                 long numOrdinate = ns * state->totalQuadPoints;
686                 state->lxk = Realloc(state->lxk, numUnique * numOrdinate, double);
687         }
688
689         state->expected = Realloc(state->expected, state->totalOutcomes * state->totalQuadPoints, double);
690         if (state->numSpecific) {
691                 state->EiCache = Realloc(state->EiCache, state->totalPrimaryPoints * numUnique, double);
692         }
693 }
694
695 static void ba81buildLXKcache(omxExpectation *oo)
696 {
697         BA81Expect *state = (BA81Expect *) oo->argStruct;
698         if (!state->cacheLXK || state->LXKcached) return;
699         
700         ba81Estep1(oo);
701 }
702
703 OMXINLINE static void
704 expectedUpdate(omxData *data, const int *rowMap, const int px, const int item,
705                const double observed, double *out)
706 {
707         int pick = omxIntDataElementUnsafe(data, rowMap[px], item);
708         if (pick != NA_INTEGER) {
709                 out[pick-1] += observed;
710         }
711 }
712
713 OMXINLINE static void
714 ba81Expected(omxExpectation* oo)
715 {
716         BA81Expect *state = (BA81Expect*) oo->argStruct;
717         if (state->verbose) mxLog("%s: EM.expected", oo->name);
718
719         omxData *data = state->data;
720         int numSpecific = state->numSpecific;
721         const int *rowMap = state->rowMap;
722         double *patternLik = state->patternLik;
723         int *numIdentical = state->numIdentical;
724         int numUnique = state->numUnique;
725         int numItems = state->EitemParam->cols;
726         int totalOutcomes = state->totalOutcomes;
727         std::vector<int> &itemOutcomes = state->itemOutcomes;
728         long totalQuadPoints = state->totalQuadPoints;
729         long totalPrimaryPoints = state->totalPrimaryPoints;
730         long specificPoints = state->quadGridSize;
731
732         OMXZERO(state->expected, totalOutcomes * totalQuadPoints);
733         std::vector<double> thrExpected(totalOutcomes * totalQuadPoints * Global->numThreads);
734
735         if (numSpecific == 0) {
736 #pragma omp parallel for num_threads(Global->numThreads) schedule(static,32)
737                 for (int px=0; px < numUnique; px++) {
738                         int thrId = omx_absolute_thread_num();
739                         double *myExpected = thrExpected.data() + thrId * totalOutcomes * totalQuadPoints;
740
741                         std::vector<double> lxk(totalQuadPoints);
742                         ba81LikelihoodFast2(state, px, lxk.data());
743
744                         double weight = numIdentical[px] / patternLik[px];
745
746                         int outcomeBase = -itemOutcomes[0];
747                         for (int ix=0; ix < numItems; ix++) {
748                                 outcomeBase += itemOutcomes[ix];
749                                 int pick = omxIntDataElementUnsafe(data, rowMap[px], ix);
750                                 if (pick == NA_INTEGER) continue;
751                                 pick -= 1;
752
753                                 double *out = myExpected + outcomeBase;
754                                 for (long qx=0; qx < totalQuadPoints; ++qx) {
755                                         out[pick] += weight * lxk[qx];
756                                         out += totalOutcomes;
757                                 }
758                         }
759                 }
760         } else {
761 #pragma omp parallel for num_threads(Global->numThreads) schedule(static,32)
762                 for (int px=0; px < numUnique; px++) {
763                         int thrId = omx_absolute_thread_num();
764                         double *myExpected = thrExpected.data() + thrId * totalOutcomes * totalQuadPoints;
765
766                         std::vector<double> lxk(totalQuadPoints * numSpecific);
767                         ba81LikelihoodFast2(state, px, lxk.data());
768
769                         std::vector<double> Eis(totalPrimaryPoints * numSpecific, 0.0);
770                         std::vector<double> Ei(totalPrimaryPoints, 1.0);
771                         cai2010EiEis(state, px, lxk.data(), Eis.data(), Ei.data());
772
773                         double weight = numIdentical[px] / patternLik[px];
774
775                         int outcomeBase = -itemOutcomes[0];
776                         for (int ix=0; ix < numItems; ix++) {
777                                 outcomeBase += itemOutcomes[ix];
778                                 int pick = omxIntDataElementUnsafe(data, rowMap[px], ix);
779                                 if (pick == NA_INTEGER) continue;
780                                 pick -= 1;
781
782                                 int Sgroup = state->Sgroup[ix];
783
784                                 double *out = myExpected + outcomeBase;
785                                 long qloc = 0;
786                                 for (long qx=0; qx < totalPrimaryPoints; qx++) {
787                                         double Ei1 = Ei[qx];
788                                         for (long sx=0; sx < specificPoints; sx++) {
789                                                 double lxk1 = lxk[totalQuadPoints * Sgroup + qloc];
790                                                 double Eis1 = Eis[totalPrimaryPoints * Sgroup + qx];
791                                                 out[pick] += weight * (Ei1 / Eis1) * lxk1;
792                                                 out += totalOutcomes;
793                                                 ++qloc;
794                                         }
795                                 }
796                         }
797                 }
798         }
799
800         long expectedSize = totalQuadPoints * totalOutcomes;
801 #pragma omp parallel for num_threads(Global->numThreads) schedule(static,64)
802         for (long ex=0; ex < expectedSize; ++ex) {
803                 state->expected[ex] = 0;
804                 double *e1 = thrExpected.data() + ex;
805                 for (int tx=0; tx < Global->numThreads; ++tx) {
806                         state->expected[ex] += *e1;
807                         e1 += expectedSize;
808                 }
809         }
810         //pda(state->expected, state->totalOutcomes, state->totalQuadPoints);
811 }
812
813 OMXINLINE static void
814 accumulateScores(BA81Expect *state, int px, int sgroup, double piece, const double *where,
815                  int primaryDims, int covEntries, std::vector<double> *mean, std::vector<double> *cov)
816 {
817         int maxDims = state->maxDims;
818         int maxAbilities = state->maxAbilities;
819
820         if (sgroup == 0) {
821                 int cx=0;
822                 for (int d1=0; d1 < primaryDims; d1++) {
823                         double piece_w1 = piece * where[d1];
824                         double &dest1 = (*mean)[px * maxAbilities + d1];
825 #pragma omp atomic
826                         dest1 += piece_w1;
827                         for (int d2=0; d2 <= d1; d2++) {
828                                 double &dest2 = (*cov)[px * covEntries + cx];
829 #pragma omp atomic
830                                 dest2 += where[d2] * piece_w1;
831                                 ++cx;
832                         }
833                 }
834         }
835
836         if (state->numSpecific) {
837                 int sdim = maxDims + sgroup - 1;
838                 double piece_w1 = piece * where[primaryDims];
839                 double &dest3 = (*mean)[px * maxAbilities + sdim];
840 #pragma omp atomic
841                 dest3 += piece_w1;
842
843                 double &dest4 = (*cov)[px * covEntries + triangleLoc0(sdim)];
844 #pragma omp atomic
845                 dest4 += piece_w1 * where[primaryDims];
846         }
847 }
848
849 static void
850 EAPinternalFast(omxExpectation *oo, std::vector<double> *mean, std::vector<double> *cov)
851 {
852         BA81Expect *state = (BA81Expect*) oo->argStruct;
853         if (state->verbose) mxLog("%s: EAP", oo->name);
854
855         int numUnique = state->numUnique;
856         int numSpecific = state->numSpecific;
857         int maxDims = state->maxDims;
858         int maxAbilities = state->maxAbilities;
859         int primaryDims = maxDims;
860         int covEntries = triangleLoc1(maxAbilities);
861
862         mean->assign(numUnique * maxAbilities, 0);
863         cov->assign(numUnique * covEntries, 0);
864
865         if (numSpecific == 0) {
866 #pragma omp parallel for num_threads(Global->numThreads)
867                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
868                         const int thrId = omx_absolute_thread_num();
869                         int quad[maxDims];
870                         decodeLocation(qx, maxDims, state->quadGridSize, quad);
871                         double where[maxDims];
872                         pointToWhere(state, quad, where, maxDims);
873
874                         double *lxk = ba81LikelihoodFast1(oo, thrId, 0, qx);
875
876                         double area = state->priQarea[qx];
877                         for (int px=0; px < numUnique; px++) {
878                                 double tmp = lxk[px] * area;
879                                 accumulateScores(state, px, 0, tmp, where, primaryDims, covEntries, mean, cov);
880                         }
881                 }
882         } else {
883                 primaryDims -= 1;
884                 int sDim = primaryDims;
885                 long specificPoints = state->quadGridSize;
886
887 #pragma omp parallel for num_threads(Global->numThreads)
888                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
889                         const int thrId = omx_absolute_thread_num();
890                         int quad[maxDims];
891                         decodeLocation(qx, primaryDims, state->quadGridSize, quad);
892
893                         cai2010(oo, thrId, FALSE, qx);
894                         double *allElxk = eBase(state, thrId);
895                         double *Eslxk = esBase(state, thrId);
896
897                         for (int sgroup=0; sgroup < numSpecific; sgroup++) {
898                                 for (long sx=0; sx < specificPoints; sx++) {
899                                         long qloc = qx * specificPoints + sx;
900                                         quad[sDim] = sx;
901                                         double where[maxDims];
902                                         pointToWhere(state, quad, where, maxDims);
903                                         double area = areaProduct(state, qx, sx, sgroup);
904                                         double *lxk = ba81LikelihoodFast1(oo, thrId, sgroup, qloc);
905                                         for (int px=0; px < numUnique; px++) {
906                                                 double Ei = allElxk[px];
907                                                 double Eis = Eslxk[sgroup * numUnique + px];
908                                                 double tmp = ((Ei / Eis) * lxk[px] * area);
909                                                 accumulateScores(state, px, sgroup, tmp, where, primaryDims,
910                                                                  covEntries, mean, cov);
911                                         }
912                                 }
913                         }
914                 }
915         }
916
917         double *patternLik = state->patternLik;
918         for (int px=0; px < numUnique; px++) {
919                 double denom = patternLik[px];
920                 for (int ax=0; ax < maxAbilities; ax++) {
921                         (*mean)[px * maxAbilities + ax] /= denom;
922                 }
923                 for (int cx=0; cx < triangleLoc1(primaryDims); ++cx) {
924                         (*cov)[px * covEntries + cx] /= denom;
925                 }
926                 for (int sx=0; sx < numSpecific; sx++) {
927                         (*cov)[px * covEntries + triangleLoc0(primaryDims + sx)] /= denom;
928                 }
929                 int cx=0;
930                 for (int a1=0; a1 < primaryDims; ++a1) {
931                         for (int a2=0; a2 <= a1; ++a2) {
932                                 double ma1 = (*mean)[px * maxAbilities + a1];
933                                 double ma2 = (*mean)[px * maxAbilities + a2];
934                                 (*cov)[px * covEntries + cx] -= ma1 * ma2;
935                                 ++cx;
936                         }
937                 }
938                 for (int sx=0; sx < numSpecific; sx++) {
939                         int sdim = primaryDims + sx;
940                         double ma1 = (*mean)[px * maxAbilities + sdim];
941                         (*cov)[px * covEntries + triangleLoc0(sdim)] -= ma1 * ma1;
942                 }
943         }
944 }
945
946 static void recomputePatternLik(omxExpectation *oo)
947 {
948         BA81Expect *state = (BA81Expect*) oo->argStruct;
949         if (state->verbose) mxLog("%s: patternLik", oo->name);
950
951         int numUnique = state->numUnique;
952         long totalQuadPoints = state->totalQuadPoints;
953         double *patternLik = state->patternLik;
954         OMXZERO(patternLik, numUnique);
955
956         if (state->numSpecific == 0) {
957 #pragma omp parallel for num_threads(Global->numThreads) schedule(static,32)
958                 for (int px=0; px < numUnique; px++) {
959                         std::vector<double> lxk(totalQuadPoints);
960                         ba81LikelihoodFast2(state, px, lxk.data());
961                         for (long qx=0; qx < totalQuadPoints; qx++) {
962                                 patternLik[px] += lxk[qx] * state->priQarea[qx];
963                         }
964                 }
965         } else {
966                 double *EiCache = state->EiCache;
967                 long totalPrimaryPoints = state->totalPrimaryPoints;
968
969 #pragma omp parallel for num_threads(Global->numThreads) schedule(static,32)
970                 for (int px=0; px < numUnique; px++) {
971                         double *Ei = EiCache + totalPrimaryPoints * px;
972                         for (long qx=0; qx < totalPrimaryPoints; qx++) {
973                                 patternLik[px] += Ei[qx] * state->priQarea[qx];
974                         }
975                 }
976         }
977 }
978
979 static void
980 ba81compute(omxExpectation *oo, const char *context)
981 {
982         BA81Expect *state = (BA81Expect *) oo->argStruct;
983
984         if (context) {
985                 if (strcmp(context, "EM")==0) {
986                         state->type = EXPECTATION_AUGMENTED;
987                 } else if (context[0] == 0) {
988                         state->type = EXPECTATION_OBSERVED;
989                 } else {
990                         omxRaiseErrorf(globalState, "Unknown context '%s'", context);
991                         return;
992                 }
993         }
994
995         omxRecompute(state->EitemParam);
996
997         bool itemClean = state->itemParamVersion == omxGetMatrixVersion(state->EitemParam);
998         bool latentClean = state->latentParamVersion == getLatentVersion(state);
999
1000         if (state->verbose) {
1001                 mxLog("%s: Qinit %d itemClean %d latentClean %d (1=clean)",
1002                       oo->name, state->Qpoint.size() != 0, itemClean, latentClean);
1003         }
1004
1005         if (state->Qpoint.size() == 0 || !latentClean) {
1006                 ba81SetupQuadrature(oo, state->targetQpoints);
1007         }
1008         if (itemClean) {
1009                 ba81buildLXKcache(oo);
1010                 if (!latentClean) recomputePatternLik(oo);
1011         } else {
1012                 ba81OutcomeProb(state);
1013                 ba81Estep1(oo);
1014         }
1015
1016         if (state->type == EXPECTATION_AUGMENTED) {
1017                 ba81Expected(oo);
1018         }
1019
1020         state->itemParamVersion = omxGetMatrixVersion(state->EitemParam);
1021         state->latentParamVersion = getLatentVersion(state);
1022 }
1023
1024 static void
1025 copyScore(int rows, int maxAbilities, std::vector<double> &mean,
1026           std::vector<double> &cov, const int rx, double *scores, const int dest)
1027 {
1028         for (int ax=0; ax < maxAbilities; ++ax) {
1029                 scores[rows * ax + dest] = mean[maxAbilities * rx + ax];
1030         }
1031         for (int ax=0; ax < maxAbilities; ++ax) {
1032                 scores[rows * (maxAbilities + ax) + dest] =
1033                         sqrt(cov[triangleLoc1(maxAbilities) * rx + triangleLoc0(ax)]);
1034         }
1035         for (int ax=0; ax < triangleLoc1(maxAbilities); ++ax) {
1036                 scores[rows * (2*maxAbilities + ax) + dest] =
1037                         cov[triangleLoc1(maxAbilities) * rx + ax];
1038         }
1039 }
1040
1041 /**
1042  * MAP is not affected by the number of items. EAP is. Likelihood can
1043  * get concentrated in a single quadrature ordinate. For 3PL, response
1044  * patterns can have a bimodal likelihood. This will confuse MAP and
1045  * is a key advantage of EAP (Thissen & Orlando, 2001, p. 136).
1046  *
1047  * Thissen, D. & Orlando, M. (2001). IRT for items scored in two
1048  * categories. In D. Thissen & H. Wainer (Eds.), \emph{Test scoring}
1049  * (pp 73-140). Lawrence Erlbaum Associates, Inc.
1050  */
1051 static void
1052 ba81PopulateAttributes(omxExpectation *oo, SEXP robj)
1053 {
1054         BA81Expect *state = (BA81Expect *) oo->argStruct;
1055         int maxAbilities = state->maxAbilities;
1056
1057         SEXP Rmean, Rcov;
1058         PROTECT(Rmean = allocVector(REALSXP, maxAbilities));
1059         memcpy(REAL(Rmean), state->ElatentMean.data(), maxAbilities * sizeof(double));
1060
1061         PROTECT(Rcov = allocMatrix(REALSXP, maxAbilities, maxAbilities));
1062         memcpy(REAL(Rcov), state->ElatentCov.data(), maxAbilities * maxAbilities * sizeof(double));
1063
1064         setAttrib(robj, install("empirical.mean"), Rmean);
1065         setAttrib(robj, install("empirical.cov"), Rcov);
1066
1067         if (state->type == EXPECTATION_AUGMENTED) {
1068                 int numUnique = state->numUnique;
1069                 int totalOutcomes = state->totalOutcomes;
1070                 SEXP Rlik;
1071                 SEXP Rexpected;
1072
1073                 PROTECT(Rlik = allocVector(REALSXP, numUnique));
1074                 memcpy(REAL(Rlik), state->patternLik, sizeof(double) * numUnique);
1075
1076                 PROTECT(Rexpected = allocMatrix(REALSXP, totalOutcomes, state->totalQuadPoints));
1077                 memcpy(REAL(Rexpected), state->expected, sizeof(double) * totalOutcomes * state->totalQuadPoints);
1078
1079                 setAttrib(robj, install("patternLikelihood"), Rlik);
1080                 setAttrib(robj, install("em.expected"), Rexpected);
1081         }
1082
1083         if (state->scores == SCORES_OMIT || state->type == EXPECTATION_UNINITIALIZED) return;
1084
1085         // TODO Wainer & Thissen. (1987). Estimating ability with the wrong
1086         // model. Journal of Educational Statistics, 12, 339-368.
1087
1088         /*
1089         int numQpoints = state->targetQpoints * 2;  // make configurable TODO
1090
1091         if (numQpoints < 1 + 2.0 * sqrt(state->itemSpec->cols)) {
1092                 // Thissen & Orlando (2001, p. 136)
1093                 warning("EAP requires at least 2*sqrt(items) quadrature points");
1094         }
1095
1096         ba81SetupQuadrature(oo, numQpoints, 0);
1097         ba81Estep1(oo);
1098         */
1099
1100         std::vector<double> mean;
1101         std::vector<double> cov;
1102         EAPinternalFast(oo, &mean, &cov);
1103
1104         int numUnique = state->numUnique;
1105         omxData *data = state->data;
1106         int rows = state->scores == SCORES_FULL? data->rows : numUnique;
1107         int cols = 2 * maxAbilities + triangleLoc1(maxAbilities);
1108         SEXP Rscores;
1109         PROTECT(Rscores = allocMatrix(REALSXP, rows, cols));
1110         double *scores = REAL(Rscores);
1111
1112         const int SMALLBUF = 10;
1113         char buf[SMALLBUF];
1114         SEXP names;
1115         PROTECT(names = allocVector(STRSXP, cols));
1116         for (int nx=0; nx < maxAbilities; ++nx) {
1117                 snprintf(buf, SMALLBUF, "s%d", nx+1);
1118                 SET_STRING_ELT(names, nx, mkChar(buf));
1119                 snprintf(buf, SMALLBUF, "se%d", nx+1);
1120                 SET_STRING_ELT(names, maxAbilities + nx, mkChar(buf));
1121         }
1122         for (int nx=0; nx < triangleLoc1(maxAbilities); ++nx) {
1123                 snprintf(buf, SMALLBUF, "cov%d", nx+1);
1124                 SET_STRING_ELT(names, maxAbilities*2 + nx, mkChar(buf));
1125         }
1126         SEXP dimnames;
1127         PROTECT(dimnames = allocVector(VECSXP, 2));
1128         SET_VECTOR_ELT(dimnames, 1, names);
1129         setAttrib(Rscores, R_DimNamesSymbol, dimnames);
1130
1131         if (state->scores == SCORES_FULL) {
1132 #pragma omp parallel for num_threads(Global->numThreads)
1133                 for (int rx=0; rx < numUnique; rx++) {
1134                         int dups = omxDataNumIdenticalRows(state->data, state->rowMap[rx]);
1135                         for (int dup=0; dup < dups; dup++) {
1136                                 int dest = omxDataIndex(data, state->rowMap[rx]+dup);
1137                                 copyScore(rows, maxAbilities, mean, cov, rx, scores, dest);
1138                         }
1139                 }
1140         } else {
1141 #pragma omp parallel for num_threads(Global->numThreads)
1142                 for (int rx=0; rx < numUnique; rx++) {
1143                         copyScore(rows, maxAbilities, mean, cov, rx, scores, rx);
1144                 }
1145         }
1146
1147         setAttrib(robj, install("scores.out"), Rscores);
1148 }
1149
1150 static void ba81Destroy(omxExpectation *oo) {
1151         if(OMX_DEBUG) {
1152                 mxLog("Freeing %s function.", oo->name);
1153         }
1154         BA81Expect *state = (BA81Expect *) oo->argStruct;
1155         omxFreeAllMatrixData(state->EitemParam);
1156         omxFreeAllMatrixData(state->design);
1157         omxFreeAllMatrixData(state->latentMeanOut);
1158         omxFreeAllMatrixData(state->latentCovOut);
1159         omxFreeAllMatrixData(state->customPrior);
1160         omxFreeAllMatrixData(state->itemParam);
1161         Free(state->numIdentical);
1162         Free(state->rowMap);
1163         Free(state->patternLik);
1164         Free(state->lxk);
1165         Free(state->Eslxk);
1166         Free(state->allElxk);
1167         Free(state->Sgroup);
1168         Free(state->expected);
1169         Free(state->outcomeProb);
1170         Free(state->EiCache);
1171         delete state;
1172 }
1173
1174 void getMatrixDims(SEXP r_theta, int *rows, int *cols)
1175 {
1176     SEXP matrixDims;
1177     PROTECT(matrixDims = getAttrib(r_theta, R_DimSymbol));
1178     int *dimList = INTEGER(matrixDims);
1179     *rows = dimList[0];
1180     *cols = dimList[1];
1181     UNPROTECT(1);
1182 }
1183
1184 static void ignoreSetVarGroup(omxExpectation*, FreeVarGroup *)
1185 {}
1186
1187 void omxInitExpectationBA81(omxExpectation* oo) {
1188         omxState* currentState = oo->currentState;      
1189         SEXP rObj = oo->rObj;
1190         SEXP tmp;
1191         
1192         if(OMX_DEBUG) {
1193                 mxLog("Initializing %s.", oo->name);
1194         }
1195         if (!rpf_model) {
1196                 if (0) {
1197                         const int wantVersion = 3;
1198                         int version;
1199                         get_librpf_t get_librpf = (get_librpf_t) R_GetCCallable("rpf", "get_librpf_model_GPL");
1200                         (*get_librpf)(&version, &rpf_numModels, &rpf_model);
1201                         if (version < wantVersion) error("librpf binary API %d installed, at least %d is required",
1202                                                          version, wantVersion);
1203                 } else {
1204                         rpf_numModels = librpf_numModels;
1205                         rpf_model = librpf_model;
1206                 }
1207         }
1208         
1209         BA81Expect *state = new BA81Expect;
1210         state->numSpecific = 0;
1211         state->numIdentical = NULL;
1212         state->rowMap = NULL;
1213         state->design = NULL;
1214         state->lxk = NULL;
1215         state->patternLik = NULL;
1216         state->Eslxk = NULL;
1217         state->allElxk = NULL;
1218         state->outcomeProb = NULL;
1219         state->expected = NULL;
1220         state->type = EXPECTATION_UNINITIALIZED;
1221         state->scores = SCORES_OMIT;
1222         state->itemParam = NULL;
1223         state->customPrior = NULL;
1224         state->itemParamVersion = 0;
1225         state->latentParamVersion = 0;
1226         state->EiCache = NULL;
1227         oo->argStruct = (void*) state;
1228
1229         PROTECT(tmp = GET_SLOT(rObj, install("data")));
1230         state->data = omxDataLookupFromState(tmp, currentState);
1231
1232         if (strcmp(omxDataType(state->data), "raw") != 0) {
1233                 omxRaiseErrorf(currentState, "%s unable to handle data type %s", oo->name, omxDataType(state->data));
1234                 return;
1235         }
1236
1237         PROTECT(tmp = GET_SLOT(rObj, install("ItemSpec")));
1238         for (int sx=0; sx < length(tmp); ++sx) {
1239                 SEXP model = VECTOR_ELT(tmp, sx);
1240                 if (!OBJECT(model)) {
1241                         error("Item models must inherit rpf.base");
1242                 }
1243                 SEXP spec;
1244                 PROTECT(spec = GET_SLOT(model, install("spec")));
1245                 state->itemSpec.push_back(REAL(spec));
1246         }
1247
1248         PROTECT(tmp = GET_SLOT(rObj, install("design")));
1249         if (!isNull(tmp)) {
1250                 // better to demand integers and not coerce to real TODO
1251                 state->design = omxNewMatrixFromRPrimitive(tmp, globalState, FALSE, 0);
1252         }
1253
1254         state->latentMeanOut = omxNewMatrixFromSlot(rObj, currentState, "mean");
1255         if (!state->latentMeanOut) error("Failed to retrieve mean matrix");
1256         state->latentCovOut  = omxNewMatrixFromSlot(rObj, currentState, "cov");
1257         if (!state->latentCovOut) error("Failed to retrieve cov matrix");
1258
1259         state->EitemParam =
1260                 omxNewMatrixFromSlot(rObj, currentState, "EItemParam");
1261         if (!state->EitemParam) error("Must supply EItemParam");
1262
1263         state->itemParam =
1264                 omxNewMatrixFromSlot(rObj, globalState, "ItemParam");
1265
1266         if (state->EitemParam->rows != state->itemParam->rows ||
1267             state->EitemParam->cols != state->itemParam->cols) {
1268                 error("ItemParam and EItemParam must be of the same dimension");
1269         }
1270
1271         oo->computeFun = ba81compute;
1272         oo->setVarGroup = ignoreSetVarGroup;
1273         oo->destructFun = ba81Destroy;
1274         oo->populateAttrFun = ba81PopulateAttributes;
1275         
1276         // TODO: Exactly identical rows do not contribute any information.
1277         // The sorting algorithm ought to remove them so we don't waste RAM.
1278         // The following summary stats would be cheaper to calculate too.
1279
1280         int numUnique = 0;
1281         omxData *data = state->data;
1282         if (omxDataNumFactor(data) != data->cols) {
1283                 // verify they are ordered factors TODO
1284                 omxRaiseErrorf(currentState, "%s: all columns must be factors", oo->name);
1285                 return;
1286         }
1287
1288         for (int rx=0; rx < data->rows;) {
1289                 rx += omxDataNumIdenticalRows(state->data, rx);
1290                 ++numUnique;
1291         }
1292         state->numUnique = numUnique;
1293
1294         state->rowMap = Realloc(NULL, numUnique, int);
1295         state->numIdentical = Realloc(NULL, numUnique, int);
1296
1297         state->customPrior =
1298                 omxNewMatrixFromSlot(rObj, globalState, "CustomPrior");
1299         
1300         int numItems = state->EitemParam->cols;
1301         if (data->cols != numItems) {
1302                 error("Data has %d columns for %d items", data->cols, numItems);
1303         }
1304
1305         int numThreads = Global->numThreads;
1306
1307         int maxSpec = 0;
1308         int maxParam = 0;
1309         state->maxDims = 0;
1310
1311         std::vector<int> &itemOutcomes = state->itemOutcomes;
1312         itemOutcomes.resize(numItems);
1313         int totalOutcomes = 0;
1314         for (int cx = 0; cx < data->cols; cx++) {
1315                 const double *spec = state->itemSpec[cx];
1316                 int id = spec[RPF_ISpecID];
1317                 int dims = spec[RPF_ISpecDims];
1318                 if (state->maxDims < dims)
1319                         state->maxDims = dims;
1320
1321                 int no = spec[RPF_ISpecOutcomes];
1322                 itemOutcomes[cx] = no;
1323                 totalOutcomes += no;
1324
1325                 // TODO this summary stat should be available from omxData
1326                 int dataMax=0;
1327                 for (int rx=0; rx < data->rows; rx++) {
1328                         int pick = omxIntDataElementUnsafe(data, rx, cx);
1329                         if (dataMax < pick)
1330                                 dataMax = pick;
1331                 }
1332                 if (dataMax > no) {
1333                         error("Data for item %d has %d outcomes, not %d", cx+1, dataMax, no);
1334                 } else if (dataMax < no) {
1335                         warning("Data for item %d has only %d outcomes, not %d", cx+1, dataMax, no);
1336                         // promote to error?
1337                         // should complain if an outcome is not represented in the data TODO
1338                 }
1339
1340                 int numSpec = (*rpf_model[id].numSpec)(spec);
1341                 if (maxSpec < numSpec)
1342                         maxSpec = numSpec;
1343
1344                 int numParam = (*rpf_model[id].numParam)(spec);
1345                 if (maxParam < numParam)
1346                         maxParam = numParam;
1347         }
1348
1349         state->totalOutcomes = totalOutcomes;
1350
1351         if (int(state->itemSpec.size()) != data->cols) {
1352                 omxRaiseErrorf(currentState, "ItemSpec must contain %d item model specifications",
1353                                data->cols);
1354                 return;
1355         }
1356         if (state->EitemParam->rows != maxParam) {
1357                 omxRaiseErrorf(currentState, "ItemParam should have %d rows", maxParam);
1358                 return;
1359         }
1360
1361         std::vector<bool> byOutcome(totalOutcomes, false);
1362         int outcomesSeen = 0;
1363         for (int rx=0, ux=0; rx < data->rows; ux++) {
1364                 if (rx == 0) {
1365                         // all NA rows will sort to the top
1366                         int na=0;
1367                         for (int ix=0; ix < numItems; ix++) {
1368                                 if (omxIntDataElement(data, 0, ix) == NA_INTEGER) { ++na; }
1369                         }
1370                         if (na == numItems) {
1371                                 omxRaiseErrorf(currentState, "Remove rows with all NAs");
1372                                 return;
1373                         }
1374                 }
1375                 int dups = omxDataNumIdenticalRows(state->data, rx);
1376                 state->numIdentical[ux] = dups;
1377                 state->rowMap[ux] = rx;
1378                 rx += dups;
1379
1380                 if (outcomesSeen < totalOutcomes) {
1381                         for (int ix=0, outcomeBase=0; ix < numItems; outcomeBase += itemOutcomes[ix], ++ix) {
1382                                 int pick = omxIntDataElementUnsafe(data, rx, ix);
1383                                 if (pick == NA_INTEGER) continue;
1384                                 --pick;
1385                                 if (!byOutcome[outcomeBase + pick]) {
1386                                         byOutcome[outcomeBase + pick] = true;
1387                                         if (++outcomesSeen == totalOutcomes) break;
1388                                 }
1389                         }
1390                 }
1391         }
1392
1393         if (outcomesSeen < totalOutcomes) {
1394                 std::string buf;
1395                 for (int ix=0, outcomeBase=0; ix < numItems; outcomeBase += itemOutcomes[ix], ++ix) {
1396                         for (int pick=0; pick < itemOutcomes[ix]; ++pick) {
1397                                 if (byOutcome[outcomeBase + pick]) continue;
1398                                 buf += string_snprintf(" item %d outcome %d", 1+ix, 1+pick);
1399                         }
1400                 }
1401                 omxRaiseErrorf(globalState, "Never endorsed:%s\n"
1402                                "You must collapse categories or omit items to estimate item parameters.",
1403                                buf.c_str());
1404         }
1405
1406         if (state->design == NULL) {
1407                 state->maxAbilities = state->maxDims;
1408                 state->design = omxInitTemporaryMatrix(NULL, state->maxDims, numItems,
1409                                        TRUE, currentState);
1410                 for (int ix=0; ix < numItems; ix++) {
1411                         const double *spec = state->itemSpec[ix];
1412                         int dims = spec[RPF_ISpecDims];
1413                         for (int dx=0; dx < state->maxDims; dx++) {
1414                                 omxSetMatrixElement(state->design, dx, ix, dx < dims? (double)dx+1 : nan(""));
1415                         }
1416                 }
1417         } else {
1418                 omxMatrix *design = state->design;
1419                 if (design->cols != numItems ||
1420                     design->rows != state->maxDims) {
1421                         omxRaiseErrorf(currentState, "Design matrix should have %d rows and %d columns",
1422                                        state->maxDims, numItems);
1423                         return;
1424                 }
1425
1426                 state->maxAbilities = 0;
1427                 for (int ix=0; ix < design->rows * design->cols; ix++) {
1428                         double got = design->data[ix];
1429                         if (!R_FINITE(got)) continue;
1430                         if (round(got) != (int)got) error("Design matrix can only contain integers"); // TODO better way?
1431                         if (state->maxAbilities < got)
1432                                 state->maxAbilities = got;
1433                 }
1434                 for (int ix=0; ix < design->cols; ix++) {
1435                         const double *idesign = omxMatrixColumn(design, ix);
1436                         int ddim = 0;
1437                         for (int rx=0; rx < design->rows; rx++) {
1438                                 if (isfinite(idesign[rx])) ddim += 1;
1439                         }
1440                         const double *spec = state->itemSpec[ix];
1441                         int dims = spec[RPF_ISpecDims];
1442                         if (ddim > dims) error("Item %d has %d dims but design assigns %d", ix, dims, ddim);
1443                 }
1444         }
1445         if (state->maxAbilities <= state->maxDims) {
1446                 state->Sgroup = Calloc(numItems, int);
1447         } else {
1448                 // Not sure if this is correct, revisit TODO
1449                 int Sgroup0 = -1;
1450                 state->Sgroup = Realloc(NULL, numItems, int);
1451                 for (int dx=0; dx < state->maxDims; dx++) {
1452                         for (int ix=0; ix < numItems; ix++) {
1453                                 int ability = omxMatrixElement(state->design, dx, ix);
1454                                 if (dx < state->maxDims - 1) {
1455                                         if (Sgroup0 <= ability)
1456                                                 Sgroup0 = ability+1;
1457                                         continue;
1458                                 }
1459                                 int ss=-1;
1460                                 if (ability >= Sgroup0) {
1461                                         if (ss == -1) {
1462                                                 ss = ability;
1463                                         } else {
1464                                                 omxRaiseErrorf(currentState, "Item %d cannot belong to more than "
1465                                                                "1 specific dimension (both %d and %d)",
1466                                                                ix, ss, ability);
1467                                                 return;
1468                                         }
1469                                 }
1470                                 if (ss == -1) ss = Sgroup0;
1471                                 state->Sgroup[ix] = ss - Sgroup0;
1472                         }
1473                 }
1474                 state->numSpecific = state->maxAbilities - state->maxDims + 1;
1475                 state->allElxk = Realloc(NULL, numUnique * numThreads, double);
1476                 state->Eslxk = Realloc(NULL, numUnique * state->numSpecific * numThreads, double);
1477         }
1478
1479         if (state->latentMeanOut->rows * state->latentMeanOut->cols != state->maxAbilities) {
1480                 error("The mean matrix '%s' must be 1x%d or %dx1", state->latentMeanOut->name,
1481                       state->maxAbilities, state->maxAbilities);
1482         }
1483         if (state->latentCovOut->rows != state->maxAbilities ||
1484             state->latentCovOut->cols != state->maxAbilities) {
1485                 error("The cov matrix '%s' must be %dx%d",
1486                       state->latentCovOut->name, state->maxAbilities, state->maxAbilities);
1487         }
1488
1489         PROTECT(tmp = GET_SLOT(rObj, install("verbose")));
1490         state->verbose = asLogical(tmp);
1491
1492         PROTECT(tmp = GET_SLOT(rObj, install("cache")));
1493         state->cacheLXK = asLogical(tmp);
1494         state->LXKcached = FALSE;
1495
1496         PROTECT(tmp = GET_SLOT(rObj, install("qpoints")));
1497         state->targetQpoints = asReal(tmp);
1498
1499         PROTECT(tmp = GET_SLOT(rObj, install("qwidth")));
1500         state->Qwidth = asReal(tmp);
1501
1502         PROTECT(tmp = GET_SLOT(rObj, install("scores")));
1503         const char *score_option = CHAR(asChar(tmp));
1504         if (strcmp(score_option, "omit")==0) state->scores = SCORES_OMIT;
1505         if (strcmp(score_option, "unique")==0) state->scores = SCORES_UNIQUE;
1506         if (strcmp(score_option, "full")==0) state->scores = SCORES_FULL;
1507
1508         state->ElatentMean.resize(state->maxAbilities);
1509         state->ElatentCov.resize(state->maxAbilities * state->maxAbilities);
1510
1511         // verify data bounded between 1 and numOutcomes TODO
1512         // hm, looks like something could be added to omxData for column summary stats?
1513 }