ifa: Add omxIntDataElementUnsafe
[openmx:openmx.git] / src / omxExpectationBA81.c
1 /*
2   Copyright 2012 Joshua Nathaniel Pritikin and contributors
3
4   This is free software: you can redistribute it and/or modify
5   it under the terms of the GNU General Public License as published by
6   the Free Software Foundation, either version 3 of the License, or
7   (at your option) any later version.
8
9   This program is distributed in the hope that it will be useful,
10   but WITHOUT ANY WARRANTY; without even the implied warranty of
11   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12   GNU General Public License for more details.
13
14   You should have received a copy of the GNU General Public License
15   along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #include "omxExpectation.h"
19 #include "omxOpenmpWrap.h"
20 #include "npsolWrap.h"
21 #include "libirt-rpf.h"
22
23 static const char *NAME = "ExpectationBA81";
24
25 typedef double *(*rpf_fn_t)(omxExpectation *oo, omxMatrix *itemParam, const int *quad);
26
27 typedef int (*rpf_numParam_t)(const int numDims, const int numOutcomes);
28 typedef void (*rpf_logprob_t)(const int numDims, const double *restrict param,
29                               const double *restrict th,
30                               const int numOutcomes, double *restrict out);
31 struct rpf {
32         const char name[8];
33         rpf_numParam_t numParam;
34         rpf_logprob_t logprob;
35 };
36
37 static const struct rpf rpf_table[] = {
38         { "drm1",  irt_rpf_1dim_drm_numParam,  irt_rpf_1dim_drm_logprob },
39         { "drm",   irt_rpf_mdim_drm_numParam,  irt_rpf_mdim_drm_logprob },
40         { "gpcm1", irt_rpf_1dim_gpcm_numParam, irt_rpf_1dim_gpcm_logprob }
41 };
42 static const int numStandardRPF = (sizeof(rpf_table) / sizeof(struct rpf));
43
44 typedef struct {
45
46         omxData *data;
47         int numUnique;
48         omxMatrix *itemSpec;
49         int *Sgroup;              // item's specific group 0..numSpecific-1
50         int maxOutcomes;
51         int maxDims;
52         long *numQuadPoints;      // maxDims
53         long totalPrimaryPoints;  // product of numQuadPoints except specific dim
54         long totalQuadPoints;     // product of numQuadPoints
55         int maxAbilities;
56         int numSpecific;
57         omxMatrix *design;        // items * maxDims
58         omxMatrix *itemPrior;
59         omxMatrix *itemParam;     // M step version
60         omxMatrix *EitemParam;    // E step version
61         SEXP rpf;
62         rpf_fn_t computeRPF;
63
64         double *lxk;              // numUnique * #specific quad points * thread
65         double *allSlxk;          // numUnique * thread
66         double *Slxk;             // numUnique * #specific dimensions * thread
67
68         double *patternLik;       // length numUnique
69         double *logNumIdentical;  // length numUnique
70         double ll;                // the most recent finite ll
71
72 } omxBA81State;
73
74 enum ISpecRow {
75         ISpecID,
76         ISpecOutcomes,
77         ISpecDims,
78         ISpecRowCount
79 };
80
81 // This acts like a prior distribution for difficulty.
82 // PC-BILOG does not impose any additional prior
83 // on difficulty estimates (Baker & Kim, 2004, p. 196).
84 //
85 // quadrature points and weights via BILOG
86 static const int numQuadratures = 10;
87 static const double quadPoint[] = {
88         -4.000, -3.111,-2.222,-1.333,-0.4444,
89         0.4444, 1.333, 2.222, 3.111, 4.000
90 };
91 static const double quadArea[] = {
92         0.000119, 0.002805, 0.03002, 0.1458, 0.3213,
93         0.3213, 0.1458, 0.03002, 0.002805, 0.000119
94 };
95 static const double quadLogArea[] = {
96         -9.036387, -5.876352, -3.505891, -1.925519, -1.135380,
97         -1.135380, -1.925519, -3.505891, -5.876352, -9.036387
98 };
99
100 /*
101 static void
102 pda(const double *ar, int rows, int cols) {
103         for (int rx=0; rx < rows; rx++) {
104                 for (int cx=0; cx < cols; cx++) {
105                         Rprintf("%.6g ", ar[cx * rows + rx]);
106                 }
107                 Rprintf("\n");
108         }
109
110 }
111 */
112
113 OMXINLINE static void
114 pointToWhere(const int *quad, double *where, int upto)
115 {
116         for (int dx=0; dx < upto; dx++) {
117                 where[dx] = quadPoint[quad[dx]];
118         }
119 }
120
121 OMXINLINE static void
122 assignDims(omxMatrix *itemSpec, omxMatrix *design, int dims, int maxDims, int ix,
123            const double *restrict theta, double *restrict ptheta)
124 {
125         for (int dx=0; dx < dims; dx++) {
126                 int ability = (int)omxMatrixElement(design, dx, ix) - 1;
127                 if (ability >= maxDims) ability = maxDims-1;
128                 ptheta[dx] = theta[ability];
129         }
130 }
131
132 /**
133  * This is the main function needed to generate simulated data from
134  * the model. It could be argued that the rest of the estimation
135  * machinery belongs in the fitfunction.
136  *
137  * \param theta Vector of ability parameters, one per ability
138  * \returns A numItems by maxOutcomes colMajor vector of doubles. Caller must Free it.
139  */
140 static double *
141 standardComputeRPF(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
142 {
143         omxBA81State *state = (omxBA81State*) oo->argStruct;
144         omxMatrix *itemSpec = state->itemSpec;
145         int numItems = itemSpec->cols;
146         omxMatrix *design = state->design;
147         int maxDims = state->maxDims;
148
149         double theta[maxDims];
150         pointToWhere(quad, theta, maxDims);
151
152         double *outcomeProb = Realloc(NULL, numItems * state->maxOutcomes, double);
153
154         for (int ix=0; ix < numItems; ix++) {
155                 int outcomes = omxMatrixElement(itemSpec, ISpecOutcomes, ix);
156                 double *iparam = omxMatrixColumn(itemParam, ix);
157                 double *out = outcomeProb + ix * state->maxOutcomes;
158                 int id = omxMatrixElement(itemSpec, ISpecID, ix);
159                 int dims = omxMatrixElement(itemSpec, ISpecDims, ix);
160                 double ptheta[dims];
161                 assignDims(itemSpec, design, dims, maxDims, ix, theta, ptheta);
162                 (*rpf_table[id].logprob)(dims, iparam, ptheta, outcomes, out);
163         }
164
165         return outcomeProb;
166 }
167
168 static double *
169 RComputeRPF1(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
170 {
171         omxBA81State *state = (omxBA81State*) oo->argStruct;
172         int maxOutcomes = state->maxOutcomes;
173         omxMatrix *design = state->design;
174         omxMatrix *itemSpec = state->itemSpec;
175         int maxDims = state->maxDims;
176
177         double theta[maxDims];
178         pointToWhere(quad, theta, maxDims);
179
180         SEXP invoke;
181         PROTECT(invoke = allocVector(LANGSXP, 4));
182         SETCAR(invoke, state->rpf);
183         SETCADR(invoke, omxExportMatrix(itemParam));
184         SETCADDR(invoke, omxExportMatrix(itemSpec));
185
186         SEXP where;
187         PROTECT(where = allocMatrix(REALSXP, maxDims, itemParam->cols));
188         double *ptheta = REAL(where);
189         for (int ix=0; ix < itemParam->cols; ix++) {
190                 int dims = omxMatrixElement(itemSpec, ISpecDims, ix);
191                 assignDims(itemSpec, design, dims, maxDims, ix, theta, ptheta + ix*maxDims);
192                 for (int dx=dims; dx < maxDims; dx++) {
193                         ptheta[ix*maxDims + dx] = NA_REAL;
194                 }
195         }
196         SETCADDDR(invoke, where);
197
198         SEXP matrix;
199         PROTECT(matrix = eval(invoke, R_GlobalEnv));
200
201         if (!isMatrix(matrix)) {
202                 omxRaiseError(oo->currentState, -1,
203                               "RPF must return an item by outcome matrix");
204                 return NULL;
205         }
206
207         SEXP matrixDims;
208         PROTECT(matrixDims = getAttrib(matrix, R_DimSymbol));
209         int *dimList = INTEGER(matrixDims);
210         int numItems = state->itemSpec->cols;
211         if (dimList[0] != maxOutcomes || dimList[1] != numItems) {
212                 const int errlen = 200;
213                 char errstr[errlen];
214                 snprintf(errstr, errlen, "RPF must return a %d outcomes by %d items matrix",
215                          maxOutcomes, numItems);
216                 omxRaiseError(oo->currentState, -1, errstr);
217                 return NULL;
218         }
219
220         // Unlikely to be of type INTSXP, but just to be safe
221         PROTECT(matrix = coerceVector(matrix, REALSXP));
222         double *restrict got = REAL(matrix);
223
224         // Need to copy because threads cannot share SEXP
225         double *restrict outcomeProb = Realloc(NULL, numItems * maxOutcomes, double);
226
227         // Double check there aren't NAs in the wrong place
228         for (int ix=0; ix < numItems; ix++) {
229                 int numOutcomes = omxMatrixElement(state->itemSpec, ISpecOutcomes, ix);
230                 for (int ox=0; ox < numOutcomes; ox++) {
231                         int vx = ix * maxOutcomes + ox;
232                         if (isnan(got[vx])) {
233                                 const int errlen = 200;
234                                 char errstr[errlen];
235                                 snprintf(errstr, errlen, "RPF returned NA in [%d,%d]", ox,ix);
236                                 omxRaiseError(oo->currentState, -1, errstr);
237                         }
238                         outcomeProb[vx] = got[vx];
239                 }
240         }
241
242         return outcomeProb;
243 }
244
245 static double *
246 RComputeRPF(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
247 {
248         omx_omp_set_lock(&GlobalRLock);
249         PROTECT_INDEX pi = omxProtectSave();
250         double *ret = RComputeRPF1(oo, itemParam, quad);
251         omxProtectRestore(pi);
252         omx_omp_unset_lock(&GlobalRLock);  // hope there was no exception!
253         return ret;
254 }
255
256 /** If lxk was cached then it would save 1 evaluation per quadrature
257  * per fit.  However, this is only a good tradeoff if there is enough
258  * RAM. This is nice potential optimization for small models, but
259  * unfortunately, larger models would not benefit.
260  */
261 OMXINLINE static double *
262 ba81Likelihood(omxExpectation *oo, int specific, const int *restrict quad)   // TODO optimize more
263 {
264         omxBA81State *state = (omxBA81State*) oo->argStruct;
265         int numUnique = state->numUnique;
266         int maxOutcomes = state->maxOutcomes;
267         omxData *data = state->data;
268         int numItems = state->itemSpec->cols;
269         rpf_fn_t rpf_fn = state->computeRPF;
270         int *restrict Sgroup = state->Sgroup;
271         double *restrict lxk = state->lxk + numUnique * omx_absolute_thread_num();
272
273         const double *outcomeProb = (*rpf_fn)(oo, state->EitemParam, quad);
274         if (!outcomeProb) {
275                 OMXZERO(lxk, numUnique);
276                 return lxk;
277         }
278
279         for (int px=0, row=0; px < numUnique; px++) {
280                 double lxk1 = 0;
281                 for (int ix=0; ix < numItems; ix++) {
282                         if (specific != Sgroup[ix]) continue;
283                         int pick = omxIntDataElementUnsafe(data, row, ix);
284                         if (pick == NA_INTEGER) continue;
285                         lxk1 += outcomeProb[ix * maxOutcomes + pick-1];
286                 }
287                 lxk[px] = lxk1;
288                 row += omxDataNumIdenticalRows(data, row);
289         }
290
291         Free(outcomeProb);
292
293         return lxk;
294 }
295
296 #define CALC_ALLSLXK(state, numUnique) (state->allSlxk + omx_absolute_thread_num() * (numUnique))
297 #define CALC_SLXK(state, numUnique, numSpecific) (state->Slxk + omx_absolute_thread_num() * (numUnique) * (numSpecific))
298
299 OMXINLINE static void
300 cai2010(omxExpectation* oo, const int *restrict primaryQuad,
301         double *restrict allSlxk, double *restrict Slxk)
302 {
303         omxBA81State *state = (omxBA81State*) oo->argStruct;
304         int numUnique = state->numUnique;
305         int numSpecific = state->numSpecific;
306         int maxDims = state->maxDims;
307         int sDim = maxDims-1;
308
309         int quad[maxDims];
310         memcpy(quad, primaryQuad, sizeof(int)*sDim);
311
312         OMXZERO(Slxk, numUnique * numSpecific);
313         OMXZERO(allSlxk, numUnique);
314
315         for (int sx=0; sx < numSpecific; sx++) {
316                 double *eis = Slxk + numUnique * sx;
317                 int numQuadPoints = state->numQuadPoints[sDim];
318
319                 for (int qx=0; qx < numQuadPoints; qx++) {
320                         quad[sDim] = qx;
321                         double *lxk = ba81Likelihood(oo, sx, quad);
322
323                         for (int ix=0; ix < numUnique; ix++) {
324                                 eis[ix] += exp(lxk[ix] + quadLogArea[qx]);
325                         }
326                 }
327
328                 for (int px=0; px < numUnique; px++) {
329                         eis[px] = log(eis[px]);
330                         allSlxk[px] += eis[px];
331                 }
332         }
333 }
334
335 OMXINLINE static double
336 areaProduct(const int *restrict quad, const int upto)
337 {
338         double logArea = 0;
339         for (int dx=0; dx < upto; dx++) {
340                 logArea += quadLogArea[quad[dx]];
341         }
342         return logArea;
343 }
344
345 OMXINLINE static void
346 decodeLocation(long qx, const int dims, const long *restrict grid,
347                int *restrict quad)
348 {
349         for (int dx=0; dx < dims; dx++) {
350                 quad[dx] = qx % grid[dx];
351                 qx = qx / grid[dx];
352         }
353 }
354
355 static void
356 ba81Estep(omxExpectation *oo) {
357         if(OMX_DEBUG_MML) {Rprintf("Beginning %s Computation.\n", NAME);}
358
359         omxBA81State *state = (omxBA81State*) oo->argStruct;
360         double *patternLik = state->patternLik;
361         int numUnique = state->numUnique;
362         int numSpecific = state->numSpecific;
363
364         omxCopyMatrix(state->EitemParam, state->itemParam);
365
366         OMXZERO(patternLik, numUnique);
367
368         // E-step, marginalize person ability
369         //
370         // Note: In the notation of Bock & Aitkin (1981) and
371         // Cai~(2010), these loops are reversed.  That is, the inner
372         // loop is over quadrature points and the outer loop is over
373         // all response patterns.
374         //
375         if (numSpecific == 0) {
376 #pragma omp parallel for num_threads(oo->currentState->numThreads)
377                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
378                         int quad[state->maxDims];
379                         decodeLocation(qx, state->maxDims, state->numQuadPoints, quad);
380
381                         double *lxk = ba81Likelihood(oo, 0, quad);
382
383                         double logArea = areaProduct(quad, state->maxDims);
384 #pragma omp critical(EstepUpdate)
385                         for (int px=0; px < numUnique; px++) {
386                                 double tmp = exp(lxk[px] + logArea);
387                                 patternLik[px] += tmp;
388                         }
389                 }
390         } else {
391                 int sDim = state->maxDims-1;
392
393 #pragma omp parallel for num_threads(oo->currentState->numThreads)
394                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
395                         int quad[state->maxDims];
396                         decodeLocation(qx, sDim, state->numQuadPoints, quad);
397
398                         double *allSlxk = CALC_ALLSLXK(state, numUnique);
399                         double *Slxk = CALC_SLXK(state, numUnique, numSpecific);
400                         cai2010(oo, quad, allSlxk, Slxk);
401
402                         double logArea = areaProduct(quad, sDim);
403 #pragma omp critical(EstepUpdate)
404                         for (int px=0; px < numUnique; px++) {
405                                 double tmp = exp(allSlxk[px] + logArea);
406                                 patternLik[px] += tmp;
407                         }
408                 }
409         }
410
411         for (int px=0; px < numUnique; px++) {
412                 patternLik[px] = log(patternLik[px]);
413         }
414 }
415
416 OMXINLINE static void
417 expectedUpdate(omxData *restrict data, int *restrict row, const int item,
418                const double observed, const int outcomes, double *out)
419 {
420         int pick = omxIntDataElementUnsafe(data, *row, item);
421         if (pick == NA_INTEGER) {
422                 double slice = exp(observed - log(outcomes));
423                 for (int ox=0; ox < outcomes; ox++) {
424                         out[ox] += slice;
425                 }
426         } else {
427                 out[pick-1] += exp(observed);
428         }
429         *row += omxDataNumIdenticalRows(data, *row);
430 }
431
432 /** 
433  * \param quad a vector that indexes into a multidimensional quadrature
434  * \param out points to an array numOutcomes wide
435  */
436 static void
437 ba81Weight(omxExpectation* oo, const int item, const int *quad, int outcomes, double *out)
438 {
439         omxBA81State *state = (omxBA81State*) oo->argStruct;
440         omxData *data = state->data;
441         int specific = state->Sgroup[item];
442         int numItems = state->itemSpec->cols;
443         if (item < 0 || item >= numItems) error("ba81Weight: item %d out of range", item);
444
445         double *patternLik = state->patternLik;
446         double *logNumIdentical = state->logNumIdentical;
447         int numUnique = state->numUnique;
448         int numSpecific = state->numSpecific;
449         int maxDims = state->maxDims;
450         int sDim = state->maxDims-1;
451
452         OMXZERO(out, outcomes);
453
454         if (numSpecific == 0) {
455                 double *lxk = ba81Likelihood(oo, specific, quad);
456                 for (int px=0, row=0; px < numUnique; px++) {
457                         double observed = logNumIdentical[px] + lxk[px] - patternLik[px];
458                         expectedUpdate(data, &row, item, observed, outcomes, out);
459                 }
460         } else {
461                 double *allSlxk = CALC_ALLSLXK(state, numUnique);
462                 double *Slxk = CALC_SLXK(state, numUnique, numSpecific);
463                 if (quad[sDim] == 0) {
464                         // allSlxk, Slxk only depend on the ordinate of the primary dimensions
465                         cai2010(oo, quad, allSlxk, Slxk);
466                 }
467                 double *eis = Slxk + numUnique * specific;
468                 double *lxk = ba81Likelihood(oo, specific, quad);
469
470                 for (int px=0, row=0; px < numUnique; px++) {
471                         double observed = logNumIdentical[px] + (allSlxk[px] - eis[px]) +
472                                 (lxk[px] - patternLik[px]);
473                         expectedUpdate(data, &row, item, observed, outcomes, out);
474                 }
475         }
476
477         double logArea = areaProduct(quad, maxDims);
478
479         for (int ox=0; ox < outcomes; ox++) {
480                 out[ox] = log(out[ox]) + logArea;
481         }
482 }
483
484 OMXINLINE static double
485 ba81Fit1Ordinate(omxExpectation* oo, const int *quad)
486 {
487         omxBA81State *state = (omxBA81State*) oo->argStruct;
488         omxMatrix *itemParam = state->itemParam;
489         int numItems = itemParam->cols;
490         rpf_fn_t rpf_fn = state->computeRPF;
491         int maxOutcomes = state->maxOutcomes;
492
493         double *outcomeProb = (*rpf_fn)(oo, itemParam, quad);
494         if (!outcomeProb) return 0;
495
496         double thr_ll = 0;
497         for (int ix=0; ix < numItems; ix++) {
498                 int outcomes = omxMatrixElement(state->itemSpec, ISpecOutcomes, ix);
499                 double out[outcomes];
500                 ba81Weight(oo, ix, quad, outcomes, out);
501                 for (int ox=0; ox < outcomes; ox++) {
502                         double got = exp(out[ox]) * outcomeProb[ix * maxOutcomes + ox];
503                         thr_ll += got;
504                 }
505         }
506
507         Free(outcomeProb);
508         return thr_ll;
509 }
510
511 static double
512 ba81ComputeFit1(omxExpectation* oo)
513 {
514         omxBA81State *state = (omxBA81State*) oo->argStruct;
515         omxMatrix *itemPrior = state->itemPrior;
516         int numSpecific = state->numSpecific;
517         int maxDims = state->maxDims;
518
519         omxRecompute(itemPrior);
520         double ll = itemPrior->data[0];
521
522         if (numSpecific == 0) {
523 #pragma omp parallel for num_threads(oo->currentState->numThreads)
524                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
525                         int quad[maxDims];
526                         decodeLocation(qx, maxDims, state->numQuadPoints, quad);
527
528                         double thr_ll = ba81Fit1Ordinate(oo, quad);
529
530 #pragma omp atomic
531                         ll += thr_ll;
532                 }
533         } else {
534                 int sDim = state->maxDims-1;
535                 long *numQuadPoints = state->numQuadPoints;
536
537 #pragma omp parallel for num_threads(oo->currentState->numThreads)
538                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
539                         int quad[maxDims];
540                         decodeLocation(qx, maxDims, state->numQuadPoints, quad);
541
542                         double thr_ll = 0;
543                         long specificPoints = numQuadPoints[sDim];
544                         for (long sx=0; sx < specificPoints; sx++) {
545                                 quad[sDim] = sx;
546                                 thr_ll += ba81Fit1Ordinate(oo, quad);
547                         }
548 #pragma omp atomic
549                         ll += thr_ll;
550                 }
551         }
552
553         if (isinf(ll)) {
554                 return 2*state->ll;
555         } else {
556                 // TODO need to *2 also?
557                 state->ll = -ll;
558                 return -ll;
559         }
560 }
561
562 double
563 ba81ComputeFit(omxExpectation* oo)
564 {
565         double got = ba81ComputeFit1(oo);
566         return got;
567 }
568
569 static void ba81Destroy(omxExpectation *oo) {
570         if(OMX_DEBUG) {
571                 Rprintf("Freeing %s function.\n", NAME);
572         }
573         omxBA81State *state = (omxBA81State *) oo->argStruct;
574         omxFreeAllMatrixData(state->itemSpec);
575         omxFreeAllMatrixData(state->itemParam);
576         omxFreeAllMatrixData(state->EitemParam);
577         omxFreeAllMatrixData(state->design);
578         omxFreeAllMatrixData(state->itemPrior);
579         if (state->patternLik) Free(state->patternLik);
580         if (state->lxk) Free(state->lxk);
581         if (state->Slxk) Free(state->Slxk);
582         if (state->allSlxk) Free(state->allSlxk);
583         if (state->Sgroup) Free(state->Sgroup);
584         Free(state);
585 }
586
587 void omxInitExpectationBA81(omxExpectation* oo) {
588         omxState* currentState = oo->currentState;      
589         SEXP nextMatrix;
590         SEXP rObj = oo->rObj;
591         
592         if(OMX_DEBUG) {
593                 Rprintf("Initializing %s.\n", NAME);
594         }
595         
596         omxBA81State *state = Calloc(1, omxBA81State);
597         state->ll = 10^9;   // finite but big
598         
599         PROTECT(nextMatrix = GET_SLOT(rObj, install("data")));
600         state->data =
601                 omxNewDataFromMxDataPtr(nextMatrix, currentState);
602         UNPROTECT(1);
603
604         if (strcmp(omxDataType(state->data), "raw") != 0) {
605                 omxRaiseErrorf(currentState, "%s unable to handle data type %s", NAME, omxDataType(state->data));
606                 return;
607         }
608
609         PROTECT(state->rpf = GET_SLOT(rObj, install("RPF")));
610         if (state->rpf == R_NilValue) {
611                 state->computeRPF = standardComputeRPF;
612                 // and analytic gradient TODO
613         } else {
614                 state->computeRPF = RComputeRPF;
615         }
616
617         state->itemSpec =
618                 omxNewMatrixFromIndexSlot(rObj, currentState, "ItemSpec");
619         state->design =
620                 omxNewMatrixFromIndexSlot(rObj, currentState, "Design");
621         state->itemParam =
622                 omxNewMatrixFromIndexSlot(rObj, currentState, "ItemParam");
623         state->EitemParam =
624                 omxInitTemporaryMatrix(NULL, state->itemParam->rows, state->itemParam->cols,
625                                        TRUE, currentState);
626         state->itemPrior =
627                 omxNewMatrixFromIndexSlot(rObj, currentState, "ItemPrior");
628         
629         oo->computeFun = ba81Estep;
630         oo->destructFun = ba81Destroy;
631         
632         oo->argStruct = (void*) state;
633
634         // TODO: Exactly identical rows do not contribute any information.
635         // The sorting algorithm ought to remove them so we don't waste RAM.
636         // The following summary stats would be cheaper to calculate too.
637
638         int numUnique = 0;
639         omxData *data = state->data;
640         if (omxDataNumFactor(data) != data->cols) {
641                 omxRaiseErrorf(currentState, "%s: all columns must be factors", NAME);
642                 return;
643         }
644
645         for (int rx=0; rx < data->rows;) {
646                 rx += omxDataNumIdenticalRows(state->data, rx);
647                 ++numUnique;
648         }
649         state->numUnique = numUnique;
650
651         state->logNumIdentical =
652                 (double*) R_alloc(numUnique, sizeof(double));
653
654         int numItems = state->itemParam->cols;
655
656         for (int rx=0, ux=0; rx < data->rows;) {
657                 if (rx == 0) {
658                         // all NA rows will sort to the top
659                         int na=0;
660                         for (int ix=0; ix < numItems; ix++) {
661                                 if (omxIntDataElement(data, 0, ix) == NA_INTEGER) { ++na; }
662                         }
663                         if (na == numItems) {
664                                 omxRaiseErrorf(currentState, "Remove rows with all NAs");
665                                 return;
666                         }
667                 }
668                 int dups = omxDataNumIdenticalRows(state->data, rx);
669                 state->logNumIdentical[ux++] = log(dups);
670                 rx += dups;
671         }
672
673         state->patternLik = Realloc(NULL, numUnique, double);
674
675         int numThreads = oo->currentState->numThreads;
676         if (numThreads < 1) numThreads = 1;
677
678         if (state->itemSpec->cols != data->cols || state->itemSpec->rows != ISpecRowCount) {
679                 omxRaiseErrorf(currentState, "ItemSpec must have %d item columns and %d rows",
680                                data->cols, ISpecRowCount);
681                 return;
682         }
683
684         int maxParam = 0;
685         state->maxDims = 0;
686         state->maxOutcomes = 0;
687
688         for (int cx = 0; cx < data->cols; cx++) {
689                 int id = omxMatrixElement(state->itemSpec, ISpecID, cx);
690                 if (id < 0 || id >= numStandardRPF) {
691                         omxRaiseErrorf(currentState, "ItemSpec column %d has unknown item model %d", cx, id);
692                         return;
693                 }
694
695                 int dims = omxMatrixElement(state->itemSpec, ISpecDims, cx);
696                 if (state->maxDims < dims)
697                         state->maxDims = dims;
698
699                 // TODO verify that item model can have requested number of outcomes
700                 int no = omxMatrixElement(state->itemSpec, ISpecOutcomes, cx);
701                 if (state->maxOutcomes < no)
702                         state->maxOutcomes = no;
703
704                 int numParam = (*rpf_table[id].numParam)(dims, no);
705                 if (maxParam < numParam)
706                         maxParam = numParam;
707         }
708
709         if (state->itemParam->rows != maxParam) {
710                 omxRaiseErrorf(currentState, "ItemParam should have %d rows", maxParam);
711                 return;
712         }
713
714         if (state->design == NULL) {
715                 state->maxAbilities = state->maxDims;
716                 state->design = omxInitTemporaryMatrix(NULL, state->maxDims, numItems,
717                                        TRUE, currentState);
718                 for (int ix=0; ix < numItems; ix++) {
719                         for (int dx=0; dx < state->maxDims; dx++) {
720                                 omxSetMatrixElement(state->design, dx, ix, (double)dx+1);
721                         }
722                 }
723         } else {
724                 omxMatrix *design = state->design;
725                 if (design->cols != numItems ||
726                     design->rows != state->maxDims) {
727                         omxRaiseErrorf(currentState, "Design matrix should have %d rows and %d columns",
728                                        state->maxDims, numItems);
729                         return;
730                 }
731
732                 state->maxAbilities = 0;
733                 for (int ix=0; ix < design->rows * design->cols; ix++) {
734                         double got = design->data[ix];
735                         if (!R_FINITE(got)) continue;
736                         if (round(got) != got) error("Design matrix can only contain integers"); // TODO better way?
737                         if (state->maxAbilities < got)
738                                 state->maxAbilities = got;
739                 }
740         }
741         if (state->maxAbilities <= state->maxDims) {
742                 state->Sgroup = Calloc(numItems, int);
743         } else {
744                 int Sgroup0 = state->maxDims;
745                 state->Sgroup = Realloc(NULL, numItems, int);
746                 for (int ix=0; ix < numItems; ix++) {
747                         int ss=-1;
748                         for (int dx=0; dx < state->maxDims; dx++) {
749                                 int ability = omxMatrixElement(state->design, dx, ix);
750                                 if (ability >= Sgroup0) {
751                                         if (ss == -1) {
752                                                 ss = ability;
753                                         } else {
754                                                 omxRaiseErrorf(currentState, "Item %d cannot belong to more than "
755                                                                "1 specific dimension (both %d and %d)",
756                                                                ix, ss, ability);
757                                                 return;
758                                         }
759                                 }
760                         }
761                         if (ss == -1) ss = 0;
762                         state->Sgroup[ix] = ss - Sgroup0;
763                 }
764                 state->numSpecific = state->maxAbilities - state->maxDims + 1;
765                 state->allSlxk = Realloc(NULL, numUnique * numThreads, double);
766                 state->Slxk = Realloc(NULL, numUnique * state->numSpecific * numThreads, double);
767         }
768
769         state->totalQuadPoints = 1;
770         state->totalPrimaryPoints = 1;
771         state->numQuadPoints = (long*) R_alloc(state->maxDims, sizeof(long));
772         for (int dx=0; dx < state->maxDims; dx++) {
773                 state->numQuadPoints[dx] = numQuadratures;   // make configurable TODO
774                 state->totalQuadPoints *= state->numQuadPoints[dx];
775                 if (dx < state->maxDims-1) {
776                         state->totalPrimaryPoints *= state->numQuadPoints[dx];
777                 }
778         }
779
780         state->lxk = Realloc(NULL, numUnique * numThreads, double);
781
782         // verify data bounded between 1 and numOutcomes TODO
783         // hm, looks like something could be added to omxData?
784 }
785
786 SEXP omx_get_rpf_names()
787 {
788         SEXP outsxp;
789         PROTECT(outsxp = allocVector(STRSXP, numStandardRPF));
790         for (int sx=0; sx < numStandardRPF; sx++) {
791                 SET_STRING_ELT(outsxp, sx, mkChar(rpf_table[sx].name));
792         }
793         UNPROTECT(1);
794         return outsxp;
795 }