ifa: fastGHQuad
[openmx:openmx.git] / src / omxExpectationBA81.c
1 /*
2   Copyright 2012 Joshua Nathaniel Pritikin and contributors
3
4   This is free software: you can redistribute it and/or modify
5   it under the terms of the GNU General Public License as published by
6   the Free Software Foundation, either version 3 of the License, or
7   (at your option) any later version.
8
9   This program is distributed in the hope that it will be useful,
10   but WITHOUT ANY WARRANTY; without even the implied warranty of
11   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12   GNU General Public License for more details.
13
14   You should have received a copy of the GNU General Public License
15   along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #include "omxExpectation.h"
19 #include "omxOpenmpWrap.h"
20 #include "npsolWrap.h"
21 #include "libirt-rpf.h"
22
23 static const char *NAME = "ExpectationBA81";
24
25 typedef double *(*rpf_fn_t)(omxExpectation *oo, omxMatrix *itemParam, const int *quad);
26
27 typedef int (*rpf_numParam_t)(const int numDims, const int numOutcomes);
28 typedef void (*rpf_logprob_t)(const int numDims, const double *restrict param,
29                               const double *restrict th,
30                               const int numOutcomes, double *restrict out);
31 struct rpf {
32         const char name[8];
33         rpf_numParam_t numParam;
34         rpf_logprob_t logprob;
35 };
36
37 static const struct rpf rpf_table[] = {
38         { "drm1",  irt_rpf_1dim_drm_numParam,  irt_rpf_1dim_drm_logprob },
39         { "drm",   irt_rpf_mdim_drm_numParam,  irt_rpf_mdim_drm_logprob },
40         { "gpcm1", irt_rpf_1dim_gpcm_numParam, irt_rpf_1dim_gpcm_logprob }
41 };
42 static const int numStandardRPF = (sizeof(rpf_table) / sizeof(struct rpf));
43
44 typedef struct {
45
46         omxData *data;
47         int numUnique;
48         omxMatrix *itemSpec;
49         int *Sgroup;              // item's specific group 0..numSpecific-1
50         int maxOutcomes;
51         int maxDims;
52         int numGHpoints;
53         double *GHpoint;
54         double *GHarea;
55         long *quadGridSize;       // maxDims
56         long totalPrimaryPoints;  // product of quadGridSize except specific dim
57         long totalQuadPoints;     // product of quadGridSize
58         int maxAbilities;
59         int numSpecific;
60         omxMatrix *design;        // items * maxDims
61         omxMatrix *itemPrior;
62         omxMatrix *itemParam;     // M step version
63         omxMatrix *EitemParam;    // E step version
64         SEXP rpf;
65         rpf_fn_t computeRPF;
66
67         int cacheLXK;             // w/cache,  numUnique * #specific quad points * totalQuadPoints
68         double *lxk;              // wo/cache, numUnique * thread
69         double *allSlxk;          // numUnique * thread
70         double *Slxk;             // numUnique * #specific dimensions * thread
71
72         double *patternLik;       // length numUnique
73         double *logNumIdentical;  // length numUnique
74         double ll;                // the most recent finite ll
75
76 } omxBA81State;
77
78 enum ISpecRow {
79         ISpecID,
80         ISpecOutcomes,
81         ISpecDims,
82         ISpecRowCount
83 };
84
85 /*
86 static void
87 pda(const double *ar, int rows, int cols) {
88         for (int rx=0; rx < rows; rx++) {
89                 for (int cx=0; cx < cols; cx++) {
90                         Rprintf("%.6g ", ar[cx * rows + rx]);
91                 }
92                 Rprintf("\n");
93         }
94
95 }
96 */
97
98 OMXINLINE static void
99 pointToWhere(omxBA81State *state, const int *quad, double *where, int upto)
100 {
101         for (int dx=0; dx < upto; dx++) {
102                 where[dx] = state->GHpoint[quad[dx]];
103         }
104 }
105
106 OMXINLINE static void
107 assignDims(omxMatrix *itemSpec, omxMatrix *design, int dims, int maxDims, int ix,
108            const double *restrict theta, double *restrict ptheta)
109 {
110         for (int dx=0; dx < dims; dx++) {
111                 int ability = (int)omxMatrixElement(design, dx, ix) - 1;
112                 if (ability >= maxDims) ability = maxDims-1;
113                 ptheta[dx] = theta[ability];
114         }
115 }
116
117 /**
118  * This is the main function needed to generate simulated data from
119  * the model. It could be argued that the rest of the estimation
120  * machinery belongs in the fitfunction.
121  *
122  * \param theta Vector of ability parameters, one per ability
123  * \returns A numItems by maxOutcomes colMajor vector of doubles. Caller must Free it.
124  */
125 static double *
126 standardComputeRPF(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
127 {
128         omxBA81State *state = (omxBA81State*) oo->argStruct;
129         omxMatrix *itemSpec = state->itemSpec;
130         int numItems = itemSpec->cols;
131         omxMatrix *design = state->design;
132         int maxDims = state->maxDims;
133
134         double theta[maxDims];
135         pointToWhere(state, quad, theta, maxDims);
136
137         double *outcomeProb = Realloc(NULL, numItems * state->maxOutcomes, double);
138
139         for (int ix=0; ix < numItems; ix++) {
140                 int outcomes = omxMatrixElement(itemSpec, ISpecOutcomes, ix);
141                 double *iparam = omxMatrixColumn(itemParam, ix);
142                 double *out = outcomeProb + ix * state->maxOutcomes;
143                 int id = omxMatrixElement(itemSpec, ISpecID, ix);
144                 int dims = omxMatrixElement(itemSpec, ISpecDims, ix);
145                 double ptheta[dims];
146                 assignDims(itemSpec, design, dims, maxDims, ix, theta, ptheta);
147                 (*rpf_table[id].logprob)(dims, iparam, ptheta, outcomes, out);
148         }
149
150         return outcomeProb;
151 }
152
153 static double *
154 RComputeRPF1(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
155 {
156         omxBA81State *state = (omxBA81State*) oo->argStruct;
157         int maxOutcomes = state->maxOutcomes;
158         omxMatrix *design = state->design;
159         omxMatrix *itemSpec = state->itemSpec;
160         int maxDims = state->maxDims;
161
162         double theta[maxDims];
163         pointToWhere(state, quad, theta, maxDims);
164
165         SEXP invoke;
166         PROTECT(invoke = allocVector(LANGSXP, 4));
167         SETCAR(invoke, state->rpf);
168         SETCADR(invoke, omxExportMatrix(itemParam));
169         SETCADDR(invoke, omxExportMatrix(itemSpec));
170
171         SEXP where;
172         PROTECT(where = allocMatrix(REALSXP, maxDims, itemParam->cols));
173         double *ptheta = REAL(where);
174         for (int ix=0; ix < itemParam->cols; ix++) {
175                 int dims = omxMatrixElement(itemSpec, ISpecDims, ix);
176                 assignDims(itemSpec, design, dims, maxDims, ix, theta, ptheta + ix*maxDims);
177                 for (int dx=dims; dx < maxDims; dx++) {
178                         ptheta[ix*maxDims + dx] = NA_REAL;
179                 }
180         }
181         SETCADDDR(invoke, where);
182
183         SEXP matrix;
184         PROTECT(matrix = eval(invoke, R_GlobalEnv));
185
186         if (!isMatrix(matrix)) {
187                 omxRaiseError(oo->currentState, -1,
188                               "RPF must return an item by outcome matrix");
189                 return NULL;
190         }
191
192         SEXP matrixDims;
193         PROTECT(matrixDims = getAttrib(matrix, R_DimSymbol));
194         int *dimList = INTEGER(matrixDims);
195         int numItems = state->itemSpec->cols;
196         if (dimList[0] != maxOutcomes || dimList[1] != numItems) {
197                 const int errlen = 200;
198                 char errstr[errlen];
199                 snprintf(errstr, errlen, "RPF must return a %d outcomes by %d items matrix",
200                          maxOutcomes, numItems);
201                 omxRaiseError(oo->currentState, -1, errstr);
202                 return NULL;
203         }
204
205         // Unlikely to be of type INTSXP, but just to be safe
206         PROTECT(matrix = coerceVector(matrix, REALSXP));
207         double *restrict got = REAL(matrix);
208
209         // Need to copy because threads cannot share SEXP
210         double *restrict outcomeProb = Realloc(NULL, numItems * maxOutcomes, double);
211
212         // Double check there aren't NAs in the wrong place
213         for (int ix=0; ix < numItems; ix++) {
214                 int numOutcomes = omxMatrixElement(state->itemSpec, ISpecOutcomes, ix);
215                 for (int ox=0; ox < numOutcomes; ox++) {
216                         int vx = ix * maxOutcomes + ox;
217                         if (isnan(got[vx])) {
218                                 const int errlen = 200;
219                                 char errstr[errlen];
220                                 snprintf(errstr, errlen, "RPF returned NA in [%d,%d]", ox,ix);
221                                 omxRaiseError(oo->currentState, -1, errstr);
222                         }
223                         outcomeProb[vx] = got[vx];
224                 }
225         }
226
227         return outcomeProb;
228 }
229
230 static double *
231 RComputeRPF(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
232 {
233         omx_omp_set_lock(&GlobalRLock);
234         PROTECT_INDEX pi = omxProtectSave();
235         double *ret = RComputeRPF1(oo, itemParam, quad);
236         omxProtectRestore(pi);
237         omx_omp_unset_lock(&GlobalRLock);  // hope there was no exception!
238         return ret;
239 }
240
241 OMXINLINE static long
242 encodeLocation(const int dims, const long *restrict grid, const int *restrict quad)
243 {
244         long qx = 0;
245         for (int dx=dims-1; dx >= 0; dx--) {
246                 qx = qx * grid[dx];
247                 qx += quad[dx];
248         }
249         return qx;
250 }
251
252 #define CALC_LXK_CACHED(state, numUnique, quad, tqp, specific) \
253         ((state)->lxk + \
254          (numUnique) * encodeLocation((state)->maxDims, (state)->quadGridSize, quad) + \
255          (numUnique) * (tqp) * (specific))
256
257 OMXINLINE static double *
258 ba81Likelihood(omxExpectation *oo, int specific, const int *restrict quad)
259 {
260         omxBA81State *state = (omxBA81State*) oo->argStruct;
261         int numUnique = state->numUnique;
262         int maxOutcomes = state->maxOutcomes;
263         omxData *data = state->data;
264         int numItems = state->itemSpec->cols;
265         rpf_fn_t rpf_fn = state->computeRPF;
266         int *restrict Sgroup = state->Sgroup;
267         double *restrict lxk;
268
269         if (!state->cacheLXK) {
270                 lxk = state->lxk + numUnique * omx_absolute_thread_num();
271         } else {
272                 lxk = CALC_LXK_CACHED(state, numUnique, quad, state->totalQuadPoints, specific);
273         }
274
275         const double *outcomeProb = (*rpf_fn)(oo, state->EitemParam, quad);
276         if (!outcomeProb) {
277                 OMXZERO(lxk, numUnique);
278                 return lxk;
279         }
280
281         for (int px=0, row=0; px < numUnique; px++) {
282                 double lxk1 = 0;
283                 for (int ix=0; ix < numItems; ix++) {
284                         if (specific != Sgroup[ix]) continue;
285                         int pick = omxIntDataElementUnsafe(data, row, ix);
286                         if (pick == NA_INTEGER) continue;
287                         lxk1 += outcomeProb[ix * maxOutcomes + pick-1];
288                 }
289                 lxk[px] = lxk1;
290                 row += omxDataNumIdenticalRows(data, row);
291         }
292
293         Free(outcomeProb);
294
295         return lxk;
296 }
297
298 OMXINLINE static double *
299 ba81LikelihoodFast(omxExpectation *oo, int specific, const int *restrict quad)
300 {
301         omxBA81State *state = (omxBA81State*) oo->argStruct;
302         if (!state->cacheLXK) {
303                 return ba81LikelihoodFast(oo, specific, quad);
304         } else {
305                 return CALC_LXK_CACHED(state, state->numUnique, quad, state->totalQuadPoints, specific);
306         }
307
308 }
309
310 #define CALC_ALLSLXK(state, numUnique) \
311         (state->allSlxk + omx_absolute_thread_num() * (numUnique))
312
313 #define CALC_SLXK(state, numUnique, numSpecific) \
314         (state->Slxk + omx_absolute_thread_num() * (numUnique) * (numSpecific))
315
316 OMXINLINE static void
317 cai2010(omxExpectation* oo, int recompute, const int *restrict primaryQuad,
318         double *restrict allSlxk, double *restrict Slxk)
319 {
320         omxBA81State *state = (omxBA81State*) oo->argStruct;
321         int numUnique = state->numUnique;
322         int numSpecific = state->numSpecific;
323         int maxDims = state->maxDims;
324         int sDim = maxDims-1;
325
326         int quad[maxDims];
327         memcpy(quad, primaryQuad, sizeof(int)*sDim);
328
329         OMXZERO(Slxk, numUnique * numSpecific);
330         OMXZERO(allSlxk, numUnique);
331
332         for (int sx=0; sx < numSpecific; sx++) {
333                 double *eis = Slxk + numUnique * sx;
334                 int quadGridSize = state->quadGridSize[sDim];
335
336                 for (int qx=0; qx < quadGridSize; qx++) {
337                         quad[sDim] = qx;
338                         double *lxk;
339                         if (recompute) {
340                                 lxk = ba81Likelihood(oo, sx, quad);
341                         } else {
342                                 lxk = CALC_LXK_CACHED(state, numUnique, quad, state->totalQuadPoints, sx);
343                         }
344
345                         for (int ix=0; ix < numUnique; ix++) {
346                                 eis[ix] += exp(lxk[ix] + state->GHarea[qx]);
347                         }
348                 }
349
350                 for (int px=0; px < numUnique; px++) {
351                         eis[px] = log(eis[px]);
352                         allSlxk[px] += eis[px];
353                 }
354         }
355 }
356
357 OMXINLINE static double
358 areaProduct(omxBA81State *state, const int *restrict quad, const int upto)
359 {
360         double logArea = 0;
361         for (int dx=0; dx < upto; dx++) {
362                 logArea += state->GHarea[quad[dx]];
363         }
364         return logArea;
365 }
366
367 OMXINLINE static void
368 decodeLocation(long qx, const int dims, const long *restrict grid,
369                int *restrict quad)
370 {
371         for (int dx=0; dx < dims; dx++) {
372                 quad[dx] = qx % grid[dx];
373                 qx = qx / grid[dx];
374         }
375 }
376
377 static void
378 ba81Estep(omxExpectation *oo) {
379         if(OMX_DEBUG_MML) {Rprintf("Beginning %s Computation.\n", NAME);}
380
381         omxBA81State *state = (omxBA81State*) oo->argStruct;
382         double *patternLik = state->patternLik;
383         int numUnique = state->numUnique;
384         int numSpecific = state->numSpecific;
385
386         omxCopyMatrix(state->EitemParam, state->itemParam);
387
388         OMXZERO(patternLik, numUnique);
389
390         // E-step, marginalize person ability
391         //
392         // Note: In the notation of Bock & Aitkin (1981) and
393         // Cai~(2010), these loops are reversed.  That is, the inner
394         // loop is over quadrature points and the outer loop is over
395         // all response patterns.
396         //
397         if (numSpecific == 0) {
398 #pragma omp parallel for num_threads(oo->currentState->numThreads)
399                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
400                         int quad[state->maxDims];
401                         decodeLocation(qx, state->maxDims, state->quadGridSize, quad);
402
403                         double *lxk = ba81Likelihood(oo, 0, quad);
404
405                         double logArea = areaProduct(state, quad, state->maxDims);
406 #pragma omp critical(EstepUpdate)
407                         for (int px=0; px < numUnique; px++) {
408                                 double tmp = exp(lxk[px] + logArea);
409                                 patternLik[px] += tmp;
410                         }
411                 }
412         } else {
413                 int sDim = state->maxDims-1;
414
415 #pragma omp parallel for num_threads(oo->currentState->numThreads)
416                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
417                         int quad[state->maxDims];
418                         decodeLocation(qx, sDim, state->quadGridSize, quad);
419
420                         double *allSlxk = CALC_ALLSLXK(state, numUnique);
421                         double *Slxk = CALC_SLXK(state, numUnique, numSpecific);
422                         cai2010(oo, TRUE, quad, allSlxk, Slxk);
423
424                         double logArea = areaProduct(state, quad, sDim);
425 #pragma omp critical(EstepUpdate)
426                         for (int px=0; px < numUnique; px++) {
427                                 double tmp = exp(allSlxk[px] + logArea);
428                                 patternLik[px] += tmp;
429                         }
430                 }
431         }
432
433         for (int px=0; px < numUnique; px++) {
434                 patternLik[px] = log(patternLik[px]);
435         }
436 }
437
438 OMXINLINE static void
439 expectedUpdate(omxData *restrict data, int *restrict row, const int item,
440                const double observed, const int outcomes, double *out)
441 {
442         int pick = omxIntDataElementUnsafe(data, *row, item);
443         if (pick == NA_INTEGER) {
444                 double slice = exp(observed - log(outcomes));
445                 for (int ox=0; ox < outcomes; ox++) {
446                         out[ox] += slice;
447                 }
448         } else {
449                 out[pick-1] += exp(observed);
450         }
451         *row += omxDataNumIdenticalRows(data, *row);
452 }
453
454 /** 
455  * \param quad a vector that indexes into a multidimensional quadrature
456  * \param out points to an array numOutcomes wide
457  */
458 OMXINLINE static void
459 ba81Weight(omxExpectation* oo, const int item, const int *quad, int outcomes, double *out)
460 {
461         omxBA81State *state = (omxBA81State*) oo->argStruct;
462         omxData *data = state->data;
463         int specific = state->Sgroup[item];
464         double *patternLik = state->patternLik;
465         double *logNumIdentical = state->logNumIdentical;
466         int numUnique = state->numUnique;
467         int numSpecific = state->numSpecific;
468         int maxDims = state->maxDims;
469         int sDim = state->maxDims-1;
470
471         OMXZERO(out, outcomes);
472
473         if (numSpecific == 0) {
474                 double *lxk = ba81LikelihoodFast(oo, specific, quad);
475                 for (int px=0, row=0; px < numUnique; px++) {
476                         double observed = logNumIdentical[px] + lxk[px] - patternLik[px];
477                         expectedUpdate(data, &row, item, observed, outcomes, out);
478                 }
479         } else {
480                 double *allSlxk = CALC_ALLSLXK(state, numUnique);
481                 double *Slxk = CALC_SLXK(state, numUnique, numSpecific);
482                 if (quad[sDim] == 0) {
483                         // allSlxk, Slxk only depend on the ordinate of the primary dimensions
484                         cai2010(oo, !state->cacheLXK, quad, allSlxk, Slxk);
485                 }
486                 double *eis = Slxk + numUnique * specific;
487                 double *lxk = ba81LikelihoodFast(oo, specific, quad);
488
489                 for (int px=0, row=0; px < numUnique; px++) {
490                         double observed = logNumIdentical[px] + (allSlxk[px] - eis[px]) +
491                                 (lxk[px] - patternLik[px]);
492                         expectedUpdate(data, &row, item, observed, outcomes, out);
493                 }
494         }
495
496         double logArea = areaProduct(state, quad, maxDims);
497
498         for (int ox=0; ox < outcomes; ox++) {
499                 out[ox] = log(out[ox]) + logArea;
500         }
501 }
502
503 OMXINLINE static double
504 ba81Fit1Ordinate(omxExpectation* oo, const int *quad)
505 {
506         omxBA81State *state = (omxBA81State*) oo->argStruct;
507         omxMatrix *itemParam = state->itemParam;
508         int numItems = itemParam->cols;
509         rpf_fn_t rpf_fn = state->computeRPF;
510         int maxOutcomes = state->maxOutcomes;
511
512         double *outcomeProb = (*rpf_fn)(oo, itemParam, quad);
513         if (!outcomeProb) return 0;
514
515         double thr_ll = 0;
516         for (int ix=0; ix < numItems; ix++) {
517                 int outcomes = omxMatrixElement(state->itemSpec, ISpecOutcomes, ix);
518                 double out[outcomes];
519                 ba81Weight(oo, ix, quad, outcomes, out);
520                 for (int ox=0; ox < outcomes; ox++) {
521                         double got = exp(out[ox]) * outcomeProb[ix * maxOutcomes + ox];
522                         thr_ll += got;
523                 }
524         }
525
526         Free(outcomeProb);
527         return thr_ll;
528 }
529
530 static double
531 ba81ComputeFit1(omxExpectation* oo)
532 {
533         omxBA81State *state = (omxBA81State*) oo->argStruct;
534         omxMatrix *itemPrior = state->itemPrior;
535         int numSpecific = state->numSpecific;
536         int maxDims = state->maxDims;
537
538         omxRecompute(itemPrior);
539         double ll = itemPrior->data[0];
540
541         if (numSpecific == 0) {
542 #pragma omp parallel for num_threads(oo->currentState->numThreads)
543                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
544                         int quad[maxDims];
545                         decodeLocation(qx, maxDims, state->quadGridSize, quad);
546
547                         double thr_ll = ba81Fit1Ordinate(oo, quad);
548
549 #pragma omp atomic
550                         ll += thr_ll;
551                 }
552         } else {
553                 int sDim = state->maxDims-1;
554                 long *quadGridSize = state->quadGridSize;
555
556 #pragma omp parallel for num_threads(oo->currentState->numThreads)
557                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
558                         int quad[maxDims];
559                         decodeLocation(qx, maxDims, state->quadGridSize, quad);
560
561                         double thr_ll = 0;
562                         long specificPoints = quadGridSize[sDim];
563                         for (long sx=0; sx < specificPoints; sx++) {
564                                 quad[sDim] = sx;
565                                 thr_ll += ba81Fit1Ordinate(oo, quad);
566                         }
567 #pragma omp atomic
568                         ll += thr_ll;
569                 }
570         }
571
572         if (isinf(ll)) {
573                 return 2*state->ll;
574         } else {
575                 // TODO need to *2 also?
576                 state->ll = -ll;
577                 return -ll;
578         }
579 }
580
581 double
582 ba81ComputeFit(omxExpectation* oo)
583 {
584         double got = ba81ComputeFit1(oo);
585         return got;
586 }
587
588 static void ba81Destroy(omxExpectation *oo) {
589         if(OMX_DEBUG) {
590                 Rprintf("Freeing %s function.\n", NAME);
591         }
592         omxBA81State *state = (omxBA81State *) oo->argStruct;
593         omxFreeAllMatrixData(state->itemSpec);
594         omxFreeAllMatrixData(state->itemParam);
595         omxFreeAllMatrixData(state->EitemParam);
596         omxFreeAllMatrixData(state->design);
597         omxFreeAllMatrixData(state->itemPrior);
598         Free(state->logNumIdentical);
599         Free(state->patternLik);
600         Free(state->lxk);
601         Free(state->Slxk);
602         Free(state->allSlxk);
603         Free(state->Sgroup);
604         Free(state);
605 }
606
607 void omxInitExpectationBA81(omxExpectation* oo) {
608         omxState* currentState = oo->currentState;      
609         SEXP rObj = oo->rObj;
610         SEXP tmp;
611         
612         if(OMX_DEBUG) {
613                 Rprintf("Initializing %s.\n", NAME);
614         }
615         
616         omxBA81State *state = Calloc(1, omxBA81State);
617         state->ll = 10^9;   // finite but big
618         
619         PROTECT(tmp = GET_SLOT(rObj, install("GHpoints")));
620         state->numGHpoints = length(tmp);
621         state->GHpoint = REAL(tmp);
622
623         PROTECT(tmp = GET_SLOT(rObj, install("GHarea")));
624         if (state->numGHpoints != length(tmp)) error("length(GHpoints) != length(GHarea)");
625         state->GHarea = REAL(tmp);
626
627         PROTECT(tmp = GET_SLOT(rObj, install("data")));
628         state->data = omxNewDataFromMxDataPtr(tmp, currentState);
629         UNPROTECT(1);
630
631         if (strcmp(omxDataType(state->data), "raw") != 0) {
632                 omxRaiseErrorf(currentState, "%s unable to handle data type %s", NAME, omxDataType(state->data));
633                 return;
634         }
635
636         PROTECT(state->rpf = GET_SLOT(rObj, install("RPF")));
637         if (state->rpf == R_NilValue) {
638                 state->computeRPF = standardComputeRPF;
639                 // and analytic gradient TODO
640         } else {
641                 state->computeRPF = RComputeRPF;
642         }
643
644         state->itemSpec =
645                 omxNewMatrixFromIndexSlot(rObj, currentState, "ItemSpec");
646         state->design =
647                 omxNewMatrixFromIndexSlot(rObj, currentState, "Design");
648         state->itemParam =
649                 omxNewMatrixFromIndexSlot(rObj, currentState, "ItemParam");
650         state->EitemParam =
651                 omxInitTemporaryMatrix(NULL, state->itemParam->rows, state->itemParam->cols,
652                                        TRUE, currentState);
653         state->itemPrior =
654                 omxNewMatrixFromIndexSlot(rObj, currentState, "ItemPrior");
655         
656         oo->computeFun = ba81Estep;
657         //      oo->gradientFun = ba81Gradient;
658         oo->destructFun = ba81Destroy;
659         
660         oo->argStruct = (void*) state;
661
662         // TODO: Exactly identical rows do not contribute any information.
663         // The sorting algorithm ought to remove them so we don't waste RAM.
664         // The following summary stats would be cheaper to calculate too.
665
666         int numUnique = 0;
667         omxData *data = state->data;
668         if (omxDataNumFactor(data) != data->cols) {
669                 omxRaiseErrorf(currentState, "%s: all columns must be factors", NAME);
670                 return;
671         }
672
673         for (int rx=0; rx < data->rows;) {
674                 rx += omxDataNumIdenticalRows(state->data, rx);
675                 ++numUnique;
676         }
677         state->numUnique = numUnique;
678
679         state->logNumIdentical = Realloc(NULL, numUnique, double);
680
681         int numItems = state->itemParam->cols;
682
683         for (int rx=0, ux=0; rx < data->rows;) {
684                 if (rx == 0) {
685                         // all NA rows will sort to the top
686                         int na=0;
687                         for (int ix=0; ix < numItems; ix++) {
688                                 if (omxIntDataElement(data, 0, ix) == NA_INTEGER) { ++na; }
689                         }
690                         if (na == numItems) {
691                                 omxRaiseErrorf(currentState, "Remove rows with all NAs");
692                                 return;
693                         }
694                 }
695                 int dups = omxDataNumIdenticalRows(state->data, rx);
696                 state->logNumIdentical[ux++] = log(dups);
697                 rx += dups;
698         }
699
700         state->patternLik = Realloc(NULL, numUnique, double);
701
702         int numThreads = oo->currentState->numThreads;
703         if (numThreads < 1) numThreads = 1;
704
705         if (state->itemSpec->cols != data->cols || state->itemSpec->rows != ISpecRowCount) {
706                 omxRaiseErrorf(currentState, "ItemSpec must have %d item columns and %d rows",
707                                data->cols, ISpecRowCount);
708                 return;
709         }
710
711         int maxParam = 0;
712         state->maxDims = 0;
713         state->maxOutcomes = 0;
714
715         for (int cx = 0; cx < data->cols; cx++) {
716                 int id = omxMatrixElement(state->itemSpec, ISpecID, cx);
717                 if (id < 0 || id >= numStandardRPF) {
718                         omxRaiseErrorf(currentState, "ItemSpec column %d has unknown item model %d", cx, id);
719                         return;
720                 }
721
722                 int dims = omxMatrixElement(state->itemSpec, ISpecDims, cx);
723                 if (state->maxDims < dims)
724                         state->maxDims = dims;
725
726                 // TODO verify that item model can have requested number of outcomes
727                 int no = omxMatrixElement(state->itemSpec, ISpecOutcomes, cx);
728                 if (state->maxOutcomes < no)
729                         state->maxOutcomes = no;
730
731                 int numParam = (*rpf_table[id].numParam)(dims, no);
732                 if (maxParam < numParam)
733                         maxParam = numParam;
734         }
735
736         if (state->itemParam->rows != maxParam) {
737                 omxRaiseErrorf(currentState, "ItemParam should have %d rows", maxParam);
738                 return;
739         }
740
741         if (state->design == NULL) {
742                 state->maxAbilities = state->maxDims;
743                 state->design = omxInitTemporaryMatrix(NULL, state->maxDims, numItems,
744                                        TRUE, currentState);
745                 for (int ix=0; ix < numItems; ix++) {
746                         for (int dx=0; dx < state->maxDims; dx++) {
747                                 omxSetMatrixElement(state->design, dx, ix, (double)dx+1);
748                         }
749                 }
750         } else {
751                 omxMatrix *design = state->design;
752                 if (design->cols != numItems ||
753                     design->rows != state->maxDims) {
754                         omxRaiseErrorf(currentState, "Design matrix should have %d rows and %d columns",
755                                        state->maxDims, numItems);
756                         return;
757                 }
758
759                 state->maxAbilities = 0;
760                 for (int ix=0; ix < design->rows * design->cols; ix++) {
761                         double got = design->data[ix];
762                         if (!R_FINITE(got)) continue;
763                         if (round(got) != got) error("Design matrix can only contain integers"); // TODO better way?
764                         if (state->maxAbilities < got)
765                                 state->maxAbilities = got;
766                 }
767         }
768         if (state->maxAbilities <= state->maxDims) {
769                 state->Sgroup = Calloc(numItems, int);
770         } else {
771                 int Sgroup0 = state->maxDims;
772                 state->Sgroup = Realloc(NULL, numItems, int);
773                 for (int ix=0; ix < numItems; ix++) {
774                         int ss=-1;
775                         for (int dx=0; dx < state->maxDims; dx++) {
776                                 int ability = omxMatrixElement(state->design, dx, ix);
777                                 if (ability >= Sgroup0) {
778                                         if (ss == -1) {
779                                                 ss = ability;
780                                         } else {
781                                                 omxRaiseErrorf(currentState, "Item %d cannot belong to more than "
782                                                                "1 specific dimension (both %d and %d)",
783                                                                ix, ss, ability);
784                                                 return;
785                                         }
786                                 }
787                         }
788                         if (ss == -1) ss = 0;
789                         state->Sgroup[ix] = ss - Sgroup0;
790                 }
791                 state->numSpecific = state->maxAbilities - state->maxDims + 1;
792                 state->allSlxk = Realloc(NULL, numUnique * numThreads, double);
793                 state->Slxk = Realloc(NULL, numUnique * state->numSpecific * numThreads, double);
794         }
795
796         state->totalQuadPoints = 1;
797         state->totalPrimaryPoints = 1;
798         state->quadGridSize = (long*) R_alloc(state->maxDims, sizeof(long));
799         for (int dx=0; dx < state->maxDims; dx++) {
800                 state->quadGridSize[dx] = state->numGHpoints;
801                 state->totalQuadPoints *= state->quadGridSize[dx];
802                 if (dx < state->maxDims-1) {
803                         state->totalPrimaryPoints *= state->quadGridSize[dx];
804                 }
805         }
806
807         PROTECT(tmp = GET_SLOT(rObj, install("cache")));
808         state->cacheLXK = asLogical(tmp);
809
810         if (!state->cacheLXK) {
811                 state->lxk = Realloc(NULL, numUnique * numThreads, double);
812         } else {
813                 int ns = state->numSpecific;
814                 if (ns == 0) ns = 1;
815                 state->lxk = Realloc(NULL, numUnique * state->totalQuadPoints * ns, double);
816         }
817
818         // verify data bounded between 1 and numOutcomes TODO
819         // hm, looks like something could be added to omxData?
820 }
821
822 SEXP omx_get_rpf_names()
823 {
824         SEXP outsxp;
825         PROTECT(outsxp = allocVector(STRSXP, numStandardRPF));
826         for (int sx=0; sx < numStandardRPF; sx++) {
827                 SET_STRING_ELT(outsxp, sx, mkChar(rpf_table[sx].name));
828         }
829         UNPROTECT(1);
830         return outsxp;
831 }