ifa: Gradients
[openmx:openmx.git] / src / omxExpectationBA81.c
1 /*
2   Copyright 2012 Joshua Nathaniel Pritikin and contributors
3
4   This is free software: you can redistribute it and/or modify
5   it under the terms of the GNU General Public License as published by
6   the Free Software Foundation, either version 3 of the License, or
7   (at your option) any later version.
8
9   This program is distributed in the hope that it will be useful,
10   but WITHOUT ANY WARRANTY; without even the implied warranty of
11   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12   GNU General Public License for more details.
13
14   You should have received a copy of the GNU General Public License
15   along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 // Consider replacing log() with log2() in some places? Not worth it?
19
20 #include "omxExpectation.h"
21 #include "omxOpenmpWrap.h"
22 #include "npsolWrap.h"
23 #include "libirt-rpf.h"
24
25 static const char *NAME = "ExpectationBA81";
26
27 typedef double *(*rpf_fn_t)(omxExpectation *oo, omxMatrix *itemParam, const int *quad);
28
29 typedef int (*rpf_numParam_t)(const int numDims, const int numOutcomes);
30 // TODO arguments ought to be in the same order
31 typedef void (*rpf_logprob_t)(const int numDims, const double *restrict param,
32                               const double *restrict th,
33                               const int numOutcomes, double *restrict out);
34 typedef double (*rpf_gradient_t)(const int numDims, const int numOutcomes,
35                                  const double *restrict param, const int which,
36                                  const int outcome, const double *where);
37
38 struct rpf {
39         const char name[8];
40         rpf_numParam_t numParam;
41         rpf_logprob_t logprob;
42         rpf_gradient_t gradient;
43 };
44
45 static const struct rpf rpf_table[] = {
46         { "drm1",
47           irt_rpf_1dim_drm_numParam,
48           irt_rpf_1dim_drm_logprob,
49           irt_rpf_1dim_drm_gradient,
50         },
51         { "drm",
52           irt_rpf_mdim_drm_numParam,
53           irt_rpf_mdim_drm_logprob,
54           NULL //irt_rpf_mdim_drm_gradient,
55         },
56         { "gpcm1",
57           irt_rpf_1dim_gpcm_numParam,
58           irt_rpf_1dim_gpcm_logprob,
59           irt_rpf_1dim_gpcm_gradient,
60         }
61 };
62 static const int numStandardRPF = (sizeof(rpf_table) / sizeof(struct rpf));
63
64 typedef struct {
65
66         omxData *data;
67         int numUnique;
68         omxMatrix *itemSpec;
69         int *Sgroup;              // item's specific group 0..numSpecific-1
70         int maxOutcomes;
71         int maxDims;
72         int numGHpoints;
73         double *GHpoint;
74         double *logGHarea;
75         double *GHarea;
76         long *quadGridSize;       // maxDims
77         long totalPrimaryPoints;  // product of quadGridSize except specific dim
78         long totalQuadPoints;     // product of quadGridSize
79         int maxAbilities;
80         int numSpecific;
81         omxMatrix *design;        // items * maxDims
82         omxMatrix *itemPrior;
83         omxMatrix *itemParam;     // M step version
84         omxMatrix *EitemParam;    // E step version
85         SEXP rpf;
86         rpf_fn_t computeRPF;
87
88         int cacheLXK;             // w/cache,  numUnique * #specific quad points * totalQuadPoints
89         double *lxk;              // wo/cache, numUnique * thread
90         double *allSlxk;          // numUnique * thread
91         double *Slxk;             // numUnique * #specific dimensions * thread
92
93         double *patternLik;       // length numUnique
94         double *logNumIdentical;  // length numUnique
95         int *rowMap;              // length numUnique
96         double ll;                // the most recent finite ll
97
98         int gradientCount;
99         int fitCount;
100 } omxBA81State;
101
102 enum ISpecRow {
103         ISpecID,
104         ISpecOutcomes,
105         ISpecDims,
106         ISpecRowCount
107 };
108
109 /*
110 static void
111 pda(const double *ar, int rows, int cols) {
112         for (int rx=0; rx < rows; rx++) {
113                 for (int cx=0; cx < cols; cx++) {
114                         Rprintf("%.6g ", ar[cx * rows + rx]);
115                 }
116                 Rprintf("\n");
117         }
118
119 }
120 */
121
122 OMXINLINE static void
123 pointToWhere(omxBA81State *state, const int *quad, double *where, int upto)
124 {
125         for (int dx=0; dx < upto; dx++) {
126                 where[dx] = state->GHpoint[quad[dx]];
127         }
128 }
129
130 OMXINLINE static void
131 assignDims(omxMatrix *itemSpec, omxMatrix *design, int dims, int maxDims, int ix,
132            const double *restrict theta, double *restrict ptheta)
133 {
134         for (int dx=0; dx < dims; dx++) {
135                 int ability = (int)omxMatrixElement(design, dx, ix) - 1;
136                 if (ability >= maxDims) ability = maxDims-1;
137                 ptheta[dx] = theta[ability];
138         }
139 }
140
141 /**
142  * This is the main function needed to generate simulated data from
143  * the model. It could be argued that the rest of the estimation
144  * machinery belongs in the fitfunction.
145  *
146  * \param theta Vector of ability parameters, one per ability
147  * \returns A numItems by maxOutcomes colMajor vector of doubles. Caller must Free it.
148  */
149 static double *
150 standardComputeRPF(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
151 {
152         omxBA81State *state = (omxBA81State*) oo->argStruct;
153         omxMatrix *itemSpec = state->itemSpec;
154         int numItems = itemSpec->cols;
155         omxMatrix *design = state->design;
156         int maxDims = state->maxDims;
157
158         double theta[maxDims];
159         pointToWhere(state, quad, theta, maxDims);
160
161         double *outcomeProb = Realloc(NULL, numItems * state->maxOutcomes, double);
162
163         for (int ix=0; ix < numItems; ix++) {
164                 int outcomes = omxMatrixElement(itemSpec, ISpecOutcomes, ix);
165                 double *iparam = omxMatrixColumn(itemParam, ix);
166                 double *out = outcomeProb + ix * state->maxOutcomes;
167                 int id = omxMatrixElement(itemSpec, ISpecID, ix);
168                 int dims = omxMatrixElement(itemSpec, ISpecDims, ix);
169                 double ptheta[dims];
170                 assignDims(itemSpec, design, dims, maxDims, ix, theta, ptheta);
171                 (*rpf_table[id].logprob)(dims, iparam, ptheta, outcomes, out);
172         }
173
174         return outcomeProb;
175 }
176
177 static double *
178 RComputeRPF1(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
179 {
180         omxBA81State *state = (omxBA81State*) oo->argStruct;
181         int maxOutcomes = state->maxOutcomes;
182         omxMatrix *design = state->design;
183         omxMatrix *itemSpec = state->itemSpec;
184         int maxDims = state->maxDims;
185
186         double theta[maxDims];
187         pointToWhere(state, quad, theta, maxDims);
188
189         SEXP invoke;
190         PROTECT(invoke = allocVector(LANGSXP, 4));
191         SETCAR(invoke, state->rpf);
192         SETCADR(invoke, omxExportMatrix(itemParam));
193         SETCADDR(invoke, omxExportMatrix(itemSpec));
194
195         SEXP where;
196         PROTECT(where = allocMatrix(REALSXP, maxDims, itemParam->cols));
197         double *ptheta = REAL(where);
198         for (int ix=0; ix < itemParam->cols; ix++) {
199                 int dims = omxMatrixElement(itemSpec, ISpecDims, ix);
200                 assignDims(itemSpec, design, dims, maxDims, ix, theta, ptheta + ix*maxDims);
201                 for (int dx=dims; dx < maxDims; dx++) {
202                         ptheta[ix*maxDims + dx] = NA_REAL;
203                 }
204         }
205         SETCADDDR(invoke, where);
206
207         SEXP matrix;
208         PROTECT(matrix = eval(invoke, R_GlobalEnv));
209
210         if (!isMatrix(matrix)) {
211                 omxRaiseError(oo->currentState, -1,
212                               "RPF must return an item by outcome matrix");
213                 return NULL;
214         }
215
216         SEXP matrixDims;
217         PROTECT(matrixDims = getAttrib(matrix, R_DimSymbol));
218         int *dimList = INTEGER(matrixDims);
219         int numItems = state->itemSpec->cols;
220         if (dimList[0] != maxOutcomes || dimList[1] != numItems) {
221                 const int errlen = 200;
222                 char errstr[errlen];
223                 snprintf(errstr, errlen, "RPF must return a %d outcomes by %d items matrix",
224                          maxOutcomes, numItems);
225                 omxRaiseError(oo->currentState, -1, errstr);
226                 return NULL;
227         }
228
229         // Unlikely to be of type INTSXP, but just to be safe
230         PROTECT(matrix = coerceVector(matrix, REALSXP));
231         double *restrict got = REAL(matrix);
232
233         // Need to copy because threads cannot share SEXP
234         double *restrict outcomeProb = Realloc(NULL, numItems * maxOutcomes, double);
235
236         // Double check there aren't NAs in the wrong place
237         for (int ix=0; ix < numItems; ix++) {
238                 int numOutcomes = omxMatrixElement(state->itemSpec, ISpecOutcomes, ix);
239                 for (int ox=0; ox < numOutcomes; ox++) {
240                         int vx = ix * maxOutcomes + ox;
241                         if (isnan(got[vx])) {
242                                 const int errlen = 200;
243                                 char errstr[errlen];
244                                 snprintf(errstr, errlen, "RPF returned NA in [%d,%d]", ox,ix);
245                                 omxRaiseError(oo->currentState, -1, errstr);
246                         }
247                         outcomeProb[vx] = got[vx];
248                 }
249         }
250
251         return outcomeProb;
252 }
253
254 static double *
255 RComputeRPF(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
256 {
257         omx_omp_set_lock(&GlobalRLock);
258         PROTECT_INDEX pi = omxProtectSave();
259         double *ret = RComputeRPF1(oo, itemParam, quad);
260         omxProtectRestore(pi);
261         omx_omp_unset_lock(&GlobalRLock);  // hope there was no exception!
262         return ret;
263 }
264
265 OMXINLINE static long
266 encodeLocation(const int dims, const long *restrict grid, const int *restrict quad)
267 {
268         long qx = 0;
269         for (int dx=dims-1; dx >= 0; dx--) {
270                 qx = qx * grid[dx];
271                 qx += quad[dx];
272         }
273         return qx;
274 }
275
276 #define CALC_LXK_CACHED(state, numUnique, quad, tqp, specific) \
277         ((state)->lxk + \
278          (numUnique) * encodeLocation((state)->maxDims, (state)->quadGridSize, quad) + \
279          (numUnique) * (tqp) * (specific))
280
281 OMXINLINE static double *
282 ba81Likelihood(omxExpectation *oo, int specific, const int *restrict quad)
283 {
284         omxBA81State *state = (omxBA81State*) oo->argStruct;
285         int numUnique = state->numUnique;
286         int maxOutcomes = state->maxOutcomes;
287         omxData *data = state->data;
288         int numItems = state->itemSpec->cols;
289         rpf_fn_t rpf_fn = state->computeRPF;
290         int *restrict Sgroup = state->Sgroup;
291         double *restrict lxk;
292
293         if (!state->cacheLXK) {
294                 lxk = state->lxk + numUnique * omx_absolute_thread_num();
295         } else {
296                 lxk = CALC_LXK_CACHED(state, numUnique, quad, state->totalQuadPoints, specific);
297         }
298
299         const double *outcomeProb = (*rpf_fn)(oo, state->EitemParam, quad);
300         if (!outcomeProb) {
301                 OMXZERO(lxk, numUnique);
302                 return lxk;
303         }
304
305         const int *rowMap = state->rowMap;
306         for (int px=0; px < numUnique; px++) {
307                 double lxk1 = 0;
308                 for (int ix=0; ix < numItems; ix++) {
309                         if (specific != Sgroup[ix]) continue;
310                         int pick = omxIntDataElementUnsafe(data, rowMap[px], ix);
311                         if (pick == NA_INTEGER) continue;
312                         lxk1 += outcomeProb[ix * maxOutcomes + pick-1];
313                 }
314                 lxk[px] = lxk1;
315         }
316
317         Free(outcomeProb);
318
319         return lxk;
320 }
321
322 OMXINLINE static double *
323 ba81LikelihoodFast(omxExpectation *oo, int specific, const int *restrict quad)
324 {
325         omxBA81State *state = (omxBA81State*) oo->argStruct;
326         if (!state->cacheLXK) {
327                 return ba81LikelihoodFast(oo, specific, quad);
328         } else {
329                 return CALC_LXK_CACHED(state, state->numUnique, quad, state->totalQuadPoints, specific);
330         }
331
332 }
333
334 #define CALC_ALLSLXK(state, numUnique) \
335         (state->allSlxk + omx_absolute_thread_num() * (numUnique))
336
337 #define CALC_SLXK(state, numUnique, numSpecific) \
338         (state->Slxk + omx_absolute_thread_num() * (numUnique) * (numSpecific))
339
340 OMXINLINE static void
341 cai2010(omxExpectation* oo, int recompute, const int *restrict primaryQuad,
342         double *restrict allSlxk, double *restrict Slxk)
343 {
344         omxBA81State *state = (omxBA81State*) oo->argStruct;
345         int numUnique = state->numUnique;
346         int numSpecific = state->numSpecific;
347         int maxDims = state->maxDims;
348         int sDim = maxDims-1;
349
350         int quad[maxDims];
351         memcpy(quad, primaryQuad, sizeof(int)*sDim);
352
353         OMXZERO(Slxk, numUnique * numSpecific);
354         OMXZERO(allSlxk, numUnique);
355
356         for (int sx=0; sx < numSpecific; sx++) {
357                 double *eis = Slxk + numUnique * sx;
358                 int quadGridSize = state->quadGridSize[sDim];
359
360                 for (int qx=0; qx < quadGridSize; qx++) {
361                         quad[sDim] = qx;
362                         double *lxk;
363                         if (recompute) {
364                                 lxk = ba81Likelihood(oo, sx, quad);
365                         } else {
366                                 lxk = CALC_LXK_CACHED(state, numUnique, quad, state->totalQuadPoints, sx);
367                         }
368
369                         for (int ix=0; ix < numUnique; ix++) {
370                                 eis[ix] += exp(lxk[ix] + state->logGHarea[qx]);
371                         }
372                 }
373
374                 for (int px=0; px < numUnique; px++) {
375                         eis[px] = log(eis[px]);
376                         allSlxk[px] += eis[px];
377                 }
378         }
379 }
380
381 OMXINLINE static double
382 logAreaProduct(omxBA81State *state, const int *restrict quad, const int upto)
383 {
384         double logArea = 0;
385         for (int dx=0; dx < upto; dx++) {
386                 logArea += state->logGHarea[quad[dx]];
387         }
388         return logArea;
389 }
390
391 // The idea of this API is to allow passing in a number larger than 1.
392 OMXINLINE static void
393 areaProduct(omxBA81State *state, const int *restrict quad, const int upto, double *restrict out)
394 {
395         for (int dx=0; dx < upto; dx++) {
396                 *out *= state->GHarea[quad[dx]];
397         }
398 }
399
400 OMXINLINE static void
401 decodeLocation(long qx, const int dims, const long *restrict grid,
402                int *restrict quad)
403 {
404         for (int dx=0; dx < dims; dx++) {
405                 quad[dx] = qx % grid[dx];
406                 qx = qx / grid[dx];
407         }
408 }
409
410 static void
411 ba81Estep(omxExpectation *oo) {
412         if(OMX_DEBUG_MML) {Rprintf("Beginning %s Computation.\n", NAME);}
413
414         omxBA81State *state = (omxBA81State*) oo->argStruct;
415         double *patternLik = state->patternLik;
416         int numUnique = state->numUnique;
417         int numSpecific = state->numSpecific;
418
419         omxCopyMatrix(state->EitemParam, state->itemParam);
420
421         OMXZERO(patternLik, numUnique);
422
423         // E-step, marginalize person ability
424         //
425         // Note: In the notation of Bock & Aitkin (1981) and
426         // Cai~(2010), these loops are reversed.  That is, the inner
427         // loop is over quadrature points and the outer loop is over
428         // all response patterns.
429         //
430         if (numSpecific == 0) {
431 #pragma omp parallel for num_threads(oo->currentState->numThreads)
432                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
433                         int quad[state->maxDims];
434                         decodeLocation(qx, state->maxDims, state->quadGridSize, quad);
435
436                         double *lxk = ba81Likelihood(oo, 0, quad);
437
438                         double logArea = logAreaProduct(state, quad, state->maxDims);
439 #pragma omp critical(EstepUpdate)
440                         for (int px=0; px < numUnique; px++) {
441                                 double tmp = exp(lxk[px] + logArea);
442                                 patternLik[px] += tmp;
443                         }
444                 }
445         } else {
446                 int sDim = state->maxDims-1;
447
448 #pragma omp parallel for num_threads(oo->currentState->numThreads)
449                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
450                         int quad[state->maxDims];
451                         decodeLocation(qx, sDim, state->quadGridSize, quad);
452
453                         double *allSlxk = CALC_ALLSLXK(state, numUnique);
454                         double *Slxk = CALC_SLXK(state, numUnique, numSpecific);
455                         cai2010(oo, TRUE, quad, allSlxk, Slxk);
456
457                         double logArea = logAreaProduct(state, quad, sDim);
458 #pragma omp critical(EstepUpdate)
459                         for (int px=0; px < numUnique; px++) {
460                                 double tmp = exp(allSlxk[px] + logArea);
461                                 patternLik[px] += tmp;
462                         }
463                 }
464         }
465
466         for (int px=0; px < numUnique; px++) {
467                 patternLik[px] = log(patternLik[px]);
468         }
469 }
470
471 OMXINLINE static void
472 expectedUpdate(omxData *restrict data, const int *rowMap, const int px, const int item,
473                const double observed, const int outcomes, double *out)
474 {
475         int pick = omxIntDataElementUnsafe(data, rowMap[px], item);
476         if (pick == NA_INTEGER) {
477                 double slice = exp(observed - log(outcomes));
478                 for (int ox=0; ox < outcomes; ox++) {
479                         out[ox] += slice;
480                 }
481         } else {
482                 out[pick-1] += exp(observed);
483         }
484 }
485
486 /** 
487  * \param quad a vector that indexes into a multidimensional quadrature
488  * \param out points to an array numOutcomes wide
489  */
490 OMXINLINE static void
491 ba81Weight(omxExpectation* oo, const int item, const int *quad, int outcomes, double *out)
492 {
493         omxBA81State *state = (omxBA81State*) oo->argStruct;
494         omxData *data = state->data;
495         const int *rowMap = state->rowMap;
496         int specific = state->Sgroup[item];
497         double *patternLik = state->patternLik;
498         double *logNumIdentical = state->logNumIdentical;
499         int numUnique = state->numUnique;
500         int numSpecific = state->numSpecific;
501         int sDim = state->maxDims-1;
502
503         OMXZERO(out, outcomes);
504
505         if (numSpecific == 0) {
506                 double *lxk = ba81LikelihoodFast(oo, specific, quad);
507                 for (int px=0; px < numUnique; px++) {
508                         double observed = logNumIdentical[px] + lxk[px] - patternLik[px];
509                         expectedUpdate(data, rowMap, px, item, observed, outcomes, out);
510                 }
511         } else {
512                 double *allSlxk = CALC_ALLSLXK(state, numUnique);
513                 double *Slxk = CALC_SLXK(state, numUnique, numSpecific);
514                 if (quad[sDim] == 0) {
515                         // allSlxk, Slxk only depend on the ordinate of the primary dimensions
516                         cai2010(oo, !state->cacheLXK, quad, allSlxk, Slxk);
517                 }
518                 double *eis = Slxk + numUnique * specific;
519                 double *lxk = ba81LikelihoodFast(oo, specific, quad);
520
521                 for (int px=0; px < numUnique; px++) {
522                         double observed = logNumIdentical[px] + (allSlxk[px] - eis[px]) +
523                                 (lxk[px] - patternLik[px]);
524                         expectedUpdate(data, rowMap, px, item, observed, outcomes, out);
525                 }
526         }
527 }
528
529 OMXINLINE static double
530 ba81Fit1Ordinate(omxExpectation* oo, const int *quad)
531 {
532         omxBA81State *state = (omxBA81State*) oo->argStruct;
533         omxMatrix *itemParam = state->itemParam;
534         int numItems = itemParam->cols;
535         rpf_fn_t rpf_fn = state->computeRPF;
536         int maxOutcomes = state->maxOutcomes;
537         int maxDims = state->maxDims;
538
539         double *outcomeProb = (*rpf_fn)(oo, itemParam, quad);
540         if (!outcomeProb) return 0;
541
542         double thr_ll = 0;
543         for (int ix=0; ix < numItems; ix++) {
544                 int outcomes = omxMatrixElement(state->itemSpec, ISpecOutcomes, ix);
545                 double out[outcomes];
546                 ba81Weight(oo, ix, quad, outcomes, out);
547                 for (int ox=0; ox < outcomes; ox++) {
548                         double got = out[ox] * outcomeProb[ix * maxOutcomes + ox];
549                         areaProduct(state, quad, maxDims, &got);
550                         thr_ll += got;
551                 }
552         }
553
554         Free(outcomeProb);
555         return thr_ll;
556 }
557
558 static double
559 ba81ComputeFit1(omxExpectation* oo)
560 {
561         omxBA81State *state = (omxBA81State*) oo->argStruct;
562         ++state->fitCount;
563         omxMatrix *itemPrior = state->itemPrior;
564         int numSpecific = state->numSpecific;
565         int maxDims = state->maxDims;
566
567         omxRecompute(itemPrior);
568         double ll = itemPrior->data[0];
569
570         if (numSpecific == 0) {
571 #pragma omp parallel for num_threads(oo->currentState->numThreads)
572                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
573                         int quad[maxDims];
574                         decodeLocation(qx, maxDims, state->quadGridSize, quad);
575                         double thr_ll = ba81Fit1Ordinate(oo, quad);
576
577 #pragma omp atomic
578                         ll += thr_ll;
579                 }
580         } else {
581                 int sDim = state->maxDims-1;
582                 long *quadGridSize = state->quadGridSize;
583
584 #pragma omp parallel for num_threads(oo->currentState->numThreads)
585                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
586                         int quad[maxDims];
587                         decodeLocation(qx, maxDims, state->quadGridSize, quad);
588
589                         // copy loop structure to gradient TODO
590                         double thr_ll = 0;
591                         long specificPoints = quadGridSize[sDim];
592                         for (long sx=0; sx < specificPoints; sx++) {
593                                 quad[sDim] = sx;
594                                 thr_ll += ba81Fit1Ordinate(oo, quad);
595                         }
596 #pragma omp atomic
597                         ll += thr_ll;
598                 }
599         }
600
601         if (isinf(ll)) {
602                 return 2*state->ll;
603         } else {
604                 ll = -2 * ll;
605                 state->ll = ll;
606                 return ll;
607         }
608 }
609
610 double
611 ba81ComputeFit(omxExpectation* oo)
612 {
613         double got = ba81ComputeFit1(oo);
614         return got;
615 }
616
617 void ba81Gradient(omxExpectation* oo, double *out)
618 {
619         omxState* currentState = oo->currentState;
620         int numFreeParams = currentState->numFreeParams;
621         omxBA81State *state = (omxBA81State *) oo->argStruct;
622         ++state->gradientCount;
623         omxMatrix *itemSpec = state->itemSpec;
624         omxMatrix *itemParam = state->itemParam;
625         int maxDims = state->maxDims;
626
627         // sort & do all of an item's parameters together TODO
628         for (int vx = 0; vx < numFreeParams; vx++) {
629             omxFreeVar *fv = currentState->freeVarList + vx;
630             // TODO match the matrix number
631             int item = fv->col[0];
632             int which = fv->row[0];
633             int id = omxMatrixElement(itemSpec, ISpecID, item);
634             int dims = omxMatrixElement(itemSpec, ISpecDims, item);
635             int outcomes = omxMatrixElement(itemSpec, ISpecOutcomes, item);
636             double *iparam = omxMatrixColumn(itemParam, item);
637             double grad=0;
638
639 #pragma omp parallel for num_threads(oo->currentState->numThreads)
640             for (long qx=0; qx < state->totalQuadPoints; qx++) {
641                     int quad[maxDims];
642                     decodeLocation(qx, maxDims, state->quadGridSize, quad);
643                     double where[maxDims];
644                     pointToWhere(state, quad, where, maxDims);
645
646                     double out[outcomes];
647                     ba81Weight(oo, item, quad, outcomes, out); // thread safe with cache off? TODO
648
649                     for (int ox=0; ox < outcomes; ox++) {
650                             double got = (*rpf_table[id].gradient)(dims, outcomes, iparam, which, ox, where);
651                             double piece = out[ox] * got;
652                             areaProduct(state, quad, maxDims, &piece);
653 #pragma omp atomic
654                             grad += piece;
655                     }
656             }
657             double got = (*rpf_table[id].gradient)(dims, outcomes, iparam, which, -1, NULL);
658             out[vx] = -2 * grad * got;
659         }
660 }
661
662 static void ba81Destroy(omxExpectation *oo) {
663         if(OMX_DEBUG) {
664                 Rprintf("Freeing %s function.\n", NAME);
665         }
666         omxBA81State *state = (omxBA81State *) oo->argStruct;
667         Rprintf("fit %d gradient %d\n", state->fitCount, state->gradientCount);
668         omxFreeAllMatrixData(state->itemSpec);
669         omxFreeAllMatrixData(state->itemParam);
670         omxFreeAllMatrixData(state->EitemParam);
671         omxFreeAllMatrixData(state->design);
672         omxFreeAllMatrixData(state->itemPrior);
673         Free(state->logNumIdentical);
674         Free(state->logGHarea);
675         Free(state->rowMap);
676         Free(state->patternLik);
677         Free(state->lxk);
678         Free(state->Slxk);
679         Free(state->allSlxk);
680         Free(state->Sgroup);
681         Free(state);
682 }
683
684 int ba81ExpectationHasGradients(omxExpectation* oo)
685 {
686         omxBA81State *state = (omxBA81State *) oo->argStruct;
687         return state->computeRPF == standardComputeRPF;
688 }
689
690 void omxInitExpectationBA81(omxExpectation* oo) {
691         omxState* currentState = oo->currentState;      
692         SEXP rObj = oo->rObj;
693         SEXP tmp;
694         
695         if(OMX_DEBUG) {
696                 Rprintf("Initializing %s.\n", NAME);
697         }
698         
699         omxBA81State *state = Calloc(1, omxBA81State);
700         state->ll = 10^9;   // finite but big
701         
702         PROTECT(tmp = GET_SLOT(rObj, install("GHpoints")));
703         state->numGHpoints = length(tmp);
704         state->GHpoint = REAL(tmp);
705
706         PROTECT(tmp = GET_SLOT(rObj, install("GHarea")));
707         if (state->numGHpoints != length(tmp)) error("length(GHpoints) != length(GHarea)");
708         state->GHarea = REAL(tmp);
709
710         state->logGHarea = Realloc(NULL, state->numGHpoints, double);
711         for (int px=0; px < state->numGHpoints; px++) {
712                 state->logGHarea[px] = log(state->GHarea[px]);
713         }
714
715         PROTECT(tmp = GET_SLOT(rObj, install("data")));
716         state->data = omxNewDataFromMxDataPtr(tmp, currentState);
717         UNPROTECT(1);
718
719         if (strcmp(omxDataType(state->data), "raw") != 0) {
720                 omxRaiseErrorf(currentState, "%s unable to handle data type %s", NAME, omxDataType(state->data));
721                 return;
722         }
723
724         PROTECT(state->rpf = GET_SLOT(rObj, install("RPF")));
725         if (state->rpf == R_NilValue) {
726                 state->computeRPF = standardComputeRPF;
727         } else {
728                 state->computeRPF = RComputeRPF;
729         }
730
731         state->itemSpec =
732                 omxNewMatrixFromIndexSlot(rObj, currentState, "ItemSpec");
733         state->design =
734                 omxNewMatrixFromIndexSlot(rObj, currentState, "Design");
735         state->itemParam =
736                 omxNewMatrixFromIndexSlot(rObj, currentState, "ItemParam");
737         state->EitemParam =
738                 omxInitTemporaryMatrix(NULL, state->itemParam->rows, state->itemParam->cols,
739                                        TRUE, currentState);
740         state->itemPrior =
741                 omxNewMatrixFromIndexSlot(rObj, currentState, "ItemPrior");
742         
743         oo->computeFun = ba81Estep;
744         oo->destructFun = ba81Destroy;
745         
746         oo->argStruct = (void*) state;
747
748         // TODO: Exactly identical rows do not contribute any information.
749         // The sorting algorithm ought to remove them so we don't waste RAM.
750         // The following summary stats would be cheaper to calculate too.
751
752         int numUnique = 0;
753         omxData *data = state->data;
754         if (omxDataNumFactor(data) != data->cols) {
755                 omxRaiseErrorf(currentState, "%s: all columns must be factors", NAME);
756                 return;
757         }
758
759         for (int rx=0; rx < data->rows;) {
760                 rx += omxDataNumIdenticalRows(state->data, rx);
761                 ++numUnique;
762         }
763         state->numUnique = numUnique;
764
765         state->rowMap = Realloc(NULL, numUnique, int);
766         state->logNumIdentical = Realloc(NULL, numUnique, double);
767
768         int numItems = state->itemParam->cols;
769
770         for (int rx=0, ux=0; rx < data->rows; ux++) {
771                 if (rx == 0) {
772                         // all NA rows will sort to the top
773                         int na=0;
774                         for (int ix=0; ix < numItems; ix++) {
775                                 if (omxIntDataElement(data, 0, ix) == NA_INTEGER) { ++na; }
776                         }
777                         if (na == numItems) {
778                                 omxRaiseErrorf(currentState, "Remove rows with all NAs");
779                                 return;
780                         }
781                 }
782                 int dups = omxDataNumIdenticalRows(state->data, rx);
783                 state->logNumIdentical[ux] = log(dups);
784                 state->rowMap[ux] = rx;
785                 rx += dups;
786         }
787
788         state->patternLik = Realloc(NULL, numUnique, double);
789
790         int numThreads = oo->currentState->numThreads;
791         if (numThreads < 1) numThreads = 1;
792
793         if (state->itemSpec->cols != data->cols || state->itemSpec->rows != ISpecRowCount) {
794                 omxRaiseErrorf(currentState, "ItemSpec must have %d item columns and %d rows",
795                                data->cols, ISpecRowCount);
796                 return;
797         }
798
799         int maxParam = 0;
800         state->maxDims = 0;
801         state->maxOutcomes = 0;
802
803         for (int cx = 0; cx < data->cols; cx++) {
804                 int id = omxMatrixElement(state->itemSpec, ISpecID, cx);
805                 if (id < 0 || id >= numStandardRPF) {
806                         omxRaiseErrorf(currentState, "ItemSpec column %d has unknown item model %d", cx, id);
807                         return;
808                 }
809
810                 int dims = omxMatrixElement(state->itemSpec, ISpecDims, cx);
811                 if (state->maxDims < dims)
812                         state->maxDims = dims;
813
814                 // TODO verify that item model can have requested number of outcomes
815                 int no = omxMatrixElement(state->itemSpec, ISpecOutcomes, cx);
816                 if (state->maxOutcomes < no)
817                         state->maxOutcomes = no;
818
819                 int numParam = (*rpf_table[id].numParam)(dims, no);
820                 if (maxParam < numParam)
821                         maxParam = numParam;
822         }
823
824         if (state->itemParam->rows != maxParam) {
825                 omxRaiseErrorf(currentState, "ItemParam should have %d rows", maxParam);
826                 return;
827         }
828
829         if (state->design == NULL) {
830                 state->maxAbilities = state->maxDims;
831                 state->design = omxInitTemporaryMatrix(NULL, state->maxDims, numItems,
832                                        TRUE, currentState);
833                 for (int ix=0; ix < numItems; ix++) {
834                         for (int dx=0; dx < state->maxDims; dx++) {
835                                 omxSetMatrixElement(state->design, dx, ix, (double)dx+1);
836                         }
837                 }
838         } else {
839                 omxMatrix *design = state->design;
840                 if (design->cols != numItems ||
841                     design->rows != state->maxDims) {
842                         omxRaiseErrorf(currentState, "Design matrix should have %d rows and %d columns",
843                                        state->maxDims, numItems);
844                         return;
845                 }
846
847                 state->maxAbilities = 0;
848                 for (int ix=0; ix < design->rows * design->cols; ix++) {
849                         double got = design->data[ix];
850                         if (!R_FINITE(got)) continue;
851                         if (round(got) != got) error("Design matrix can only contain integers"); // TODO better way?
852                         if (state->maxAbilities < got)
853                                 state->maxAbilities = got;
854                 }
855         }
856         if (state->maxAbilities <= state->maxDims) {
857                 state->Sgroup = Calloc(numItems, int);
858         } else {
859                 int Sgroup0 = state->maxDims;
860                 state->Sgroup = Realloc(NULL, numItems, int);
861                 for (int ix=0; ix < numItems; ix++) {
862                         int ss=-1;
863                         for (int dx=0; dx < state->maxDims; dx++) {
864                                 int ability = omxMatrixElement(state->design, dx, ix);
865                                 if (ability >= Sgroup0) {
866                                         if (ss == -1) {
867                                                 ss = ability;
868                                         } else {
869                                                 omxRaiseErrorf(currentState, "Item %d cannot belong to more than "
870                                                                "1 specific dimension (both %d and %d)",
871                                                                ix, ss, ability);
872                                                 return;
873                                         }
874                                 }
875                         }
876                         if (ss == -1) ss = 0;
877                         state->Sgroup[ix] = ss - Sgroup0;
878                 }
879                 state->numSpecific = state->maxAbilities - state->maxDims + 1;
880                 state->allSlxk = Realloc(NULL, numUnique * numThreads, double);
881                 state->Slxk = Realloc(NULL, numUnique * state->numSpecific * numThreads, double);
882         }
883
884         state->totalQuadPoints = 1;
885         state->totalPrimaryPoints = 1;
886         state->quadGridSize = (long*) R_alloc(state->maxDims, sizeof(long));
887         for (int dx=0; dx < state->maxDims; dx++) {
888                 state->quadGridSize[dx] = state->numGHpoints;
889                 state->totalQuadPoints *= state->quadGridSize[dx];
890                 if (dx < state->maxDims-1) {
891                         state->totalPrimaryPoints *= state->quadGridSize[dx];
892                 }
893         }
894
895         PROTECT(tmp = GET_SLOT(rObj, install("cache")));
896         state->cacheLXK = asLogical(tmp);
897
898         if (!state->cacheLXK) {
899                 state->lxk = Realloc(NULL, numUnique * numThreads, double);
900         } else {
901                 int ns = state->numSpecific;
902                 if (ns == 0) ns = 1;
903                 state->lxk = Realloc(NULL, numUnique * state->totalQuadPoints * ns, double);
904         }
905
906         // verify data bounded between 1 and numOutcomes TODO
907         // hm, looks like something could be added to omxData for column summary stats?
908 }
909
910 SEXP omx_get_rpf_names()
911 {
912         SEXP outsxp;
913         PROTECT(outsxp = allocVector(STRSXP, numStandardRPF));
914         for (int sx=0; sx < numStandardRPF; sx++) {
915                 SET_STRING_ELT(outsxp, sx, mkChar(rpf_table[sx].name));
916         }
917         UNPROTECT(1);
918         return outsxp;
919 }