ifa: Add checks for invalid priors & gradients
[openmx:openmx.git] / src / omxExpectationBA81.c
1 /*
2   Copyright 2012 Joshua Nathaniel Pritikin and contributors
3
4   This is free software: you can redistribute it and/or modify
5   it under the terms of the GNU General Public License as published by
6   the Free Software Foundation, either version 3 of the License, or
7   (at your option) any later version.
8
9   This program is distributed in the hope that it will be useful,
10   but WITHOUT ANY WARRANTY; without even the implied warranty of
11   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12   GNU General Public License for more details.
13
14   You should have received a copy of the GNU General Public License
15   along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 // Consider replacing log() with log2() in some places? Not worth it?
19
20 #include "omxExpectation.h"
21 #include "omxOpenmpWrap.h"
22 #include "npsolWrap.h"
23 #include "libirt-rpf.h"
24 #include "merge.h"
25
26 static const char *NAME = "ExpectationBA81";
27
28 static const struct rpf *rpf_model = NULL;
29 static int rpf_numModels;
30
31 typedef double *(*rpf_fn_t)(omxExpectation *oo, omxMatrix *itemParam, const int *quad);
32
33 typedef struct {
34
35         // data characteristics
36         omxData *data;
37         int numUnique;
38         double *logNumIdentical;  // length numUnique
39         int *rowMap;              // length numUnique
40
41         // item description related
42         omxMatrix *itemSpec;
43         int maxOutcomes;
44         int maxDims;
45         int maxAbilities;
46         int numSpecific;
47         int *Sgroup;              // item's specific group 0..numSpecific-1
48         omxMatrix *design;        // items * maxDims
49
50         // quadrature related
51         int numQpoints;
52         double *Qpoint;
53         double *Qarea;
54         double *logQarea;
55         long *quadGridSize;       // maxDims
56         long totalPrimaryPoints;  // product of quadGridSize except specific dim
57         long totalQuadPoints;     // product of quadGridSize
58
59         // estimation related
60         omxMatrix *EitemParam;    // E step version
61         omxMatrix *itemParam;     // M step version
62         SEXP rpf;
63         rpf_fn_t computeRPF;
64         omxMatrix *customPrior;
65         int *paramMap;
66         int cacheLXK;             // w/cache,  numUnique * #specific quad points * totalQuadPoints
67         double *lxk;              // wo/cache, numUnique * thread
68         double *allSlxk;          // numUnique * thread
69         double *Slxk;             // numUnique * #specific dimensions * thread
70         double *patternLik;       // length numUnique
71         double ll;                // the most recent finite ll
72
73         int gradientCount;
74         int fitCount;
75 } omxBA81State;
76
77 static void
78 pda(const double *ar, int rows, int cols) {
79         for (int rx=0; rx < rows; rx++) {
80                 for (int cx=0; cx < cols; cx++) {
81                         Rprintf("%.6g ", ar[cx * rows + rx]);
82                 }
83                 Rprintf("\n");
84         }
85
86 }
87
88 static int
89 findFreeVarLocation(omxMatrix *itemParam, const omxFreeVar *fv)
90 {
91         for (int lx=0; lx < fv->numLocations; lx++) {
92                 if (~fv->matrices[lx] == itemParam->matrixNumber) return lx;
93         }
94         return -1;
95 }
96
97 static int
98 compareFV(const int *fv1x, const int *fv2x, omxExpectation* oo)
99 {
100         omxState* currentState = oo->currentState;
101         omxBA81State *state = (omxBA81State *) oo->argStruct;
102         omxMatrix *itemParam = state->itemParam;
103         omxFreeVar *fv1 = currentState->freeVarList + *fv1x;
104         omxFreeVar *fv2 = currentState->freeVarList + *fv2x;
105         int l1 = findFreeVarLocation(itemParam, fv1);
106         int l2 = findFreeVarLocation(itemParam, fv2);
107         if (l1 == -1 && l2 == -1) return 0;
108         if ((l1 == -1) ^ (l2 == -1)) return l1 == -1? 1:-1;  // TODO reversed?
109         // Columns are items. Sort columns together
110         return fv1->col[l1] - fv2->col[l1];
111 }
112
113 static void buildParamMap(omxExpectation* oo)
114 {
115         omxState* currentState = oo->currentState;
116         omxBA81State *state = (omxBA81State *) oo->argStruct;
117         int numFreeParams = currentState->numFreeParams;
118         state->paramMap = Realloc(NULL, numFreeParams, int);
119         for (int px=0; px < numFreeParams; px++) { state->paramMap[px] = px; }
120         freebsd_mergesort(state->paramMap, numFreeParams, sizeof(int),
121                           (mergesort_cmp_t)compareFV, oo);
122 }
123
124 OMXINLINE static void
125 pointToWhere(omxBA81State *state, const int *quad, double *where, int upto)
126 {
127         for (int dx=0; dx < upto; dx++) {
128                 where[dx] = state->Qpoint[quad[dx]];
129         }
130 }
131
132 OMXINLINE static void
133 assignDims(omxMatrix *itemSpec, omxMatrix *design, int dims, int maxDims, int ix,
134            const double *restrict theta, double *restrict ptheta)
135 {
136         for (int dx=0; dx < dims; dx++) {
137                 int ability = (int)omxMatrixElement(design, dx, ix) - 1;
138                 if (ability >= maxDims) ability = maxDims-1;
139                 ptheta[dx] = theta[ability];
140         }
141 }
142
143 /**
144  * This is the main function needed to generate simulated data from
145  * the model. It could be argued that the rest of the estimation
146  * machinery belongs in the fitfunction.
147  *
148  * \param theta Vector of ability parameters, one per ability
149  * \returns A numItems by maxOutcomes colMajor vector of doubles. Caller must Free it.
150  */
151 static double *
152 standardComputeRPF(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
153 {
154         omxBA81State *state = (omxBA81State*) oo->argStruct;
155         omxMatrix *itemSpec = state->itemSpec;
156         int numItems = itemSpec->cols;
157         omxMatrix *design = state->design;
158         int maxDims = state->maxDims;
159
160         double theta[maxDims];
161         pointToWhere(state, quad, theta, maxDims);
162
163         double *outcomeProb = Realloc(NULL, numItems * state->maxOutcomes, double);
164
165         for (int ix=0; ix < numItems; ix++) {
166                 const double *spec = omxMatrixColumn(itemSpec, ix);
167                 double *iparam = omxMatrixColumn(itemParam, ix);
168                 double *out = outcomeProb + ix * state->maxOutcomes;
169                 int id = spec[RPF_ISpecID];
170                 int dims = spec[RPF_ISpecDims];
171                 double ptheta[dims];
172                 assignDims(itemSpec, design, dims, maxDims, ix, theta, ptheta);
173                 (*rpf_model[id].logprob)(spec, iparam, ptheta, out);
174         }
175
176         return outcomeProb;
177 }
178
179 static double *
180 RComputeRPF1(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
181 {
182         omxBA81State *state = (omxBA81State*) oo->argStruct;
183         int maxOutcomes = state->maxOutcomes;
184         omxMatrix *design = state->design;
185         omxMatrix *itemSpec = state->itemSpec;
186         int maxDims = state->maxDims;
187
188         double theta[maxDims];
189         pointToWhere(state, quad, theta, maxDims);
190
191         SEXP invoke;
192         PROTECT(invoke = allocVector(LANGSXP, 4));
193         SETCAR(invoke, state->rpf);
194         SETCADR(invoke, omxExportMatrix(itemParam));
195         SETCADDR(invoke, omxExportMatrix(itemSpec));
196
197         SEXP where;
198         PROTECT(where = allocMatrix(REALSXP, maxDims, itemParam->cols));
199         double *ptheta = REAL(where);
200         for (int ix=0; ix < itemParam->cols; ix++) {
201                 int dims = omxMatrixElement(itemSpec, RPF_ISpecDims, ix);
202                 assignDims(itemSpec, design, dims, maxDims, ix, theta, ptheta + ix*maxDims);
203                 for (int dx=dims; dx < maxDims; dx++) {
204                         ptheta[ix*maxDims + dx] = NA_REAL;
205                 }
206         }
207         SETCADDDR(invoke, where);
208
209         SEXP matrix;
210         PROTECT(matrix = eval(invoke, R_GlobalEnv));
211
212         if (!isMatrix(matrix)) {
213                 omxRaiseError(oo->currentState, -1,
214                               "RPF must return an item by outcome matrix");
215                 return NULL;
216         }
217
218         SEXP matrixDims;
219         PROTECT(matrixDims = getAttrib(matrix, R_DimSymbol));
220         int *dimList = INTEGER(matrixDims);
221         int numItems = state->itemSpec->cols;
222         if (dimList[0] != maxOutcomes || dimList[1] != numItems) {
223                 const int errlen = 200;
224                 char errstr[errlen];
225                 snprintf(errstr, errlen, "RPF must return a %d outcomes by %d items matrix",
226                          maxOutcomes, numItems);
227                 omxRaiseError(oo->currentState, -1, errstr);
228                 return NULL;
229         }
230
231         // Unlikely to be of type INTSXP, but just to be safe
232         PROTECT(matrix = coerceVector(matrix, REALSXP));
233         double *restrict got = REAL(matrix);
234
235         // Need to copy because threads cannot share SEXP
236         double *restrict outcomeProb = Realloc(NULL, numItems * maxOutcomes, double);
237
238         // Double check there aren't NAs in the wrong place
239         for (int ix=0; ix < numItems; ix++) {
240                 int numOutcomes = omxMatrixElement(state->itemSpec, RPF_ISpecOutcomes, ix);
241                 for (int ox=0; ox < numOutcomes; ox++) {
242                         int vx = ix * maxOutcomes + ox;
243                         if (isnan(got[vx])) {
244                                 const int errlen = 200;
245                                 char errstr[errlen];
246                                 snprintf(errstr, errlen, "RPF returned NA in [%d,%d]", ox,ix);
247                                 omxRaiseError(oo->currentState, -1, errstr);
248                         }
249                         outcomeProb[vx] = got[vx];
250                 }
251         }
252
253         return outcomeProb;
254 }
255
256 static double *
257 RComputeRPF(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
258 {
259         omx_omp_set_lock(&GlobalRLock);
260         PROTECT_INDEX pi = omxProtectSave();
261         double *ret = RComputeRPF1(oo, itemParam, quad);
262         omxProtectRestore(pi);
263         omx_omp_unset_lock(&GlobalRLock);  // hope there was no exception!
264         return ret;
265 }
266
267 OMXINLINE static long
268 encodeLocation(const int dims, const long *restrict grid, const int *restrict quad)
269 {
270         long qx = 0;
271         for (int dx=dims-1; dx >= 0; dx--) {
272                 qx = qx * grid[dx];
273                 qx += quad[dx];
274         }
275         return qx;
276 }
277
278 #define CALC_LXK_CACHED(state, numUnique, quad, tqp, specific) \
279         ((state)->lxk + \
280          (numUnique) * encodeLocation((state)->maxDims, (state)->quadGridSize, quad) + \
281          (numUnique) * (tqp) * (specific))
282
283 OMXINLINE static double *
284 ba81Likelihood(omxExpectation *oo, int specific, const int *restrict quad)
285 {
286         omxBA81State *state = (omxBA81State*) oo->argStruct;
287         int numUnique = state->numUnique;
288         int maxOutcomes = state->maxOutcomes;
289         omxData *data = state->data;
290         int numItems = state->itemSpec->cols;
291         rpf_fn_t rpf_fn = state->computeRPF;
292         int *restrict Sgroup = state->Sgroup;
293         double *restrict lxk;
294
295         if (!state->cacheLXK) {
296                 lxk = state->lxk + numUnique * omx_absolute_thread_num();
297         } else {
298                 lxk = CALC_LXK_CACHED(state, numUnique, quad, state->totalQuadPoints, specific);
299         }
300
301         const double *outcomeProb = (*rpf_fn)(oo, state->EitemParam, quad);
302         if (!outcomeProb) {
303                 OMXZERO(lxk, numUnique);
304                 return lxk;
305         }
306
307         const int *rowMap = state->rowMap;
308         for (int px=0; px < numUnique; px++) {
309                 double lxk1 = 0;
310                 for (int ix=0; ix < numItems; ix++) {
311                         if (specific != Sgroup[ix]) continue;
312                         int pick = omxIntDataElementUnsafe(data, rowMap[px], ix);
313                         if (pick == NA_INTEGER) continue;
314                         lxk1 += outcomeProb[ix * maxOutcomes + pick-1];
315                 }
316                 lxk[px] = lxk1;
317         }
318
319         Free(outcomeProb);
320
321         return lxk;
322 }
323
324 OMXINLINE static double *
325 ba81LikelihoodFast(omxExpectation *oo, int specific, const int *restrict quad)
326 {
327         omxBA81State *state = (omxBA81State*) oo->argStruct;
328         if (!state->cacheLXK) {
329                 return ba81LikelihoodFast(oo, specific, quad);
330         } else {
331                 return CALC_LXK_CACHED(state, state->numUnique, quad, state->totalQuadPoints, specific);
332         }
333
334 }
335
336 #define CALC_ALLSLXK(state, numUnique) \
337         (state->allSlxk + omx_absolute_thread_num() * (numUnique))
338
339 #define CALC_SLXK(state, numUnique, numSpecific) \
340         (state->Slxk + omx_absolute_thread_num() * (numUnique) * (numSpecific))
341
342 OMXINLINE static void
343 cai2010(omxExpectation* oo, int recompute, const int *restrict primaryQuad,
344         double *restrict allSlxk, double *restrict Slxk)
345 {
346         omxBA81State *state = (omxBA81State*) oo->argStruct;
347         int numUnique = state->numUnique;
348         int numSpecific = state->numSpecific;
349         int maxDims = state->maxDims;
350         int sDim = maxDims-1;
351
352         int quad[maxDims];
353         memcpy(quad, primaryQuad, sizeof(int)*sDim);
354
355         OMXZERO(Slxk, numUnique * numSpecific);
356         OMXZERO(allSlxk, numUnique);
357
358         for (int sx=0; sx < numSpecific; sx++) {
359                 double *eis = Slxk + numUnique * sx;
360                 int quadGridSize = state->quadGridSize[sDim];
361
362                 for (int qx=0; qx < quadGridSize; qx++) {
363                         quad[sDim] = qx;
364                         double *lxk;
365                         if (recompute) {
366                                 lxk = ba81Likelihood(oo, sx, quad);
367                         } else {
368                                 lxk = CALC_LXK_CACHED(state, numUnique, quad, state->totalQuadPoints, sx);
369                         }
370
371                         for (int ix=0; ix < numUnique; ix++) {
372                                 eis[ix] += exp(lxk[ix] + state->logQarea[qx]);
373                         }
374                 }
375
376                 for (int px=0; px < numUnique; px++) {
377                         eis[px] = log(eis[px]);
378                         allSlxk[px] += eis[px];
379                 }
380         }
381 }
382
383 OMXINLINE static double
384 logAreaProduct(omxBA81State *state, const int *restrict quad, const int upto)
385 {
386         double logArea = 0;
387         for (int dx=0; dx < upto; dx++) {
388                 logArea += state->logQarea[quad[dx]];
389         }
390         return logArea;
391 }
392
393 // The idea of this API is to allow passing in a number larger than 1.
394 OMXINLINE static void
395 areaProduct(omxBA81State *state, const int *restrict quad, const int upto, double *restrict out)
396 {
397         for (int dx=0; dx < upto; dx++) {
398                 *out *= state->Qarea[quad[dx]];
399         }
400 }
401
402 OMXINLINE static void
403 decodeLocation(long qx, const int dims, const long *restrict grid,
404                int *restrict quad)
405 {
406         for (int dx=0; dx < dims; dx++) {
407                 quad[dx] = qx % grid[dx];
408                 qx = qx / grid[dx];
409         }
410 }
411
412 static void
413 ba81Estep(omxExpectation *oo) {
414         if(OMX_DEBUG_MML) {Rprintf("Beginning %s Computation.\n", NAME);}
415
416         omxBA81State *state = (omxBA81State*) oo->argStruct;
417         double *patternLik = state->patternLik;
418         int numUnique = state->numUnique;
419         int numSpecific = state->numSpecific;
420
421         omxCopyMatrix(state->EitemParam, state->itemParam);
422
423         OMXZERO(patternLik, numUnique);
424
425         // E-step, marginalize person ability
426         //
427         // Note: In the notation of Bock & Aitkin (1981) and
428         // Cai~(2010), these loops are reversed.  That is, the inner
429         // loop is over quadrature points and the outer loop is over
430         // all response patterns.
431         //
432         if (numSpecific == 0) {
433 #pragma omp parallel for num_threads(oo->currentState->numThreads)
434                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
435                         int quad[state->maxDims];
436                         decodeLocation(qx, state->maxDims, state->quadGridSize, quad);
437
438                         double *lxk = ba81Likelihood(oo, 0, quad);
439
440                         double logArea = logAreaProduct(state, quad, state->maxDims);
441 #pragma omp critical(EstepUpdate)
442                         for (int px=0; px < numUnique; px++) {
443                                 double tmp = exp(lxk[px] + logArea);
444                                 patternLik[px] += tmp;
445                         }
446                 }
447         } else {
448                 int sDim = state->maxDims-1;
449
450 #pragma omp parallel for num_threads(oo->currentState->numThreads)
451                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
452                         int quad[state->maxDims];
453                         decodeLocation(qx, sDim, state->quadGridSize, quad);
454
455                         double *allSlxk = CALC_ALLSLXK(state, numUnique);
456                         double *Slxk = CALC_SLXK(state, numUnique, numSpecific);
457                         cai2010(oo, TRUE, quad, allSlxk, Slxk);
458
459                         double logArea = logAreaProduct(state, quad, sDim);
460 #pragma omp critical(EstepUpdate)
461                         for (int px=0; px < numUnique; px++) {
462                                 double tmp = exp(allSlxk[px] + logArea);
463                                 patternLik[px] += tmp;
464                         }
465                 }
466         }
467
468         for (int px=0; px < numUnique; px++) {
469                 patternLik[px] = log(patternLik[px]);
470         }
471 }
472
473 OMXINLINE static void
474 expectedUpdate(omxData *restrict data, const int *rowMap, const int px, const int item,
475                const double observed, const int outcomes, double *out)
476 {
477         int pick = omxIntDataElementUnsafe(data, rowMap[px], item);
478         if (pick == NA_INTEGER) {
479                 double slice = exp(observed - log(outcomes));
480                 for (int ox=0; ox < outcomes; ox++) {
481                         out[ox] += slice;
482                 }
483         } else {
484                 out[pick-1] += exp(observed);
485         }
486 }
487
488 /** 
489  * \param quad a vector that indexes into a multidimensional quadrature
490  * \param out points to an array numOutcomes wide
491  */
492 OMXINLINE static void
493 ba81Weight(omxExpectation* oo, const int item, const int *quad, int outcomes, double *out)
494 {
495         omxBA81State *state = (omxBA81State*) oo->argStruct;
496         omxData *data = state->data;
497         const int *rowMap = state->rowMap;
498         int specific = state->Sgroup[item];
499         double *patternLik = state->patternLik;
500         double *logNumIdentical = state->logNumIdentical;
501         int numUnique = state->numUnique;
502         int numSpecific = state->numSpecific;
503         int sDim = state->maxDims-1;
504
505         OMXZERO(out, outcomes);
506
507         if (numSpecific == 0) {
508                 double *lxk = ba81LikelihoodFast(oo, specific, quad);
509                 for (int px=0; px < numUnique; px++) {
510                         double observed = logNumIdentical[px] + lxk[px] - patternLik[px];
511                         expectedUpdate(data, rowMap, px, item, observed, outcomes, out);
512                 }
513         } else {
514                 double *allSlxk = CALC_ALLSLXK(state, numUnique);
515                 double *Slxk = CALC_SLXK(state, numUnique, numSpecific);
516                 if (quad[sDim] == 0) {
517                         // allSlxk, Slxk only depend on the ordinate of the primary dimensions
518                         cai2010(oo, !state->cacheLXK, quad, allSlxk, Slxk);
519                 }
520                 double *eis = Slxk + numUnique * specific;
521
522                 // Avoid recalc when cache disabled with modest buffer? TODO
523                 double *lxk = ba81LikelihoodFast(oo, specific, quad);
524
525                 for (int px=0; px < numUnique; px++) {
526                         double observed = logNumIdentical[px] + (allSlxk[px] - eis[px]) +
527                                 (lxk[px] - patternLik[px]);
528                         expectedUpdate(data, rowMap, px, item, observed, outcomes, out);
529                 }
530         }
531 }
532
533 OMXINLINE static double
534 ba81Fit1Ordinate(omxExpectation* oo, const int *quad)
535 {
536         omxBA81State *state = (omxBA81State*) oo->argStruct;
537         omxMatrix *itemParam = state->itemParam;
538         int numItems = itemParam->cols;
539         rpf_fn_t rpf_fn = state->computeRPF;
540         int maxOutcomes = state->maxOutcomes;
541         int maxDims = state->maxDims;
542
543         double *outcomeProb = (*rpf_fn)(oo, itemParam, quad);
544         if (!outcomeProb) return 0;
545
546         double thr_ll = 0;
547         for (int ix=0; ix < numItems; ix++) {
548                 int outcomes = omxMatrixElement(state->itemSpec, RPF_ISpecOutcomes, ix);
549                 double out[outcomes];
550                 ba81Weight(oo, ix, quad, outcomes, out);
551                 for (int ox=0; ox < outcomes; ox++) {
552 #if 0
553                         if (!isnormal(outcomeProb[ix * maxOutcomes + ox])) {
554                                 pda(itemParam->data, itemParam->rows, itemParam->cols);
555                                 pda(outcomeProb, outcomes, numItems);
556                                 error("RPF produced NAs");
557                         }
558 #endif
559                         double got = out[ox] * outcomeProb[ix * maxOutcomes + ox];
560                         areaProduct(state, quad, maxDims, &got);
561                         thr_ll += got;
562                 }
563         }
564
565         Free(outcomeProb);
566         return thr_ll;
567 }
568
569 static double
570 ba81ComputeFit1(omxExpectation* oo)
571 {
572         omxBA81State *state = (omxBA81State*) oo->argStruct;
573         ++state->fitCount;
574         omxMatrix *customPrior = state->customPrior;
575         int numSpecific = state->numSpecific;
576         int maxDims = state->maxDims;
577
578         double ll = 0;
579         if (customPrior) {
580                 omxRecompute(customPrior);
581                 ll = customPrior->data[0];
582         } else {
583                 omxMatrix *itemSpec = state->itemSpec;
584                 omxMatrix *itemParam = state->itemParam;
585                 int numItems = itemSpec->cols;
586                 for (int ix=0; ix < numItems; ix++) {
587                         const double *spec = omxMatrixColumn(itemSpec, ix);
588                         int id = spec[RPF_ISpecID];
589                         double *iparam = omxMatrixColumn(itemParam, ix);
590                         ll += (*rpf_model[id].prior)(spec, iparam);
591                 }
592         }
593
594         if (!isnormal(ll)) {
595                 omxMatrix *itemParam = state->itemParam;
596                 omxPrint(itemParam, "item param");
597                 error("Bayesian prior returned NA; do you need to add a lbound/ubound?");
598         }
599
600         if (numSpecific == 0) {
601 #pragma omp parallel for num_threads(oo->currentState->numThreads)
602                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
603                         int quad[maxDims];
604                         decodeLocation(qx, maxDims, state->quadGridSize, quad);
605                         double thr_ll = ba81Fit1Ordinate(oo, quad);
606
607 #pragma omp atomic
608                         ll += thr_ll;
609                 }
610         } else {
611                 int sDim = state->maxDims-1;
612                 long *quadGridSize = state->quadGridSize;
613
614 #pragma omp parallel for num_threads(oo->currentState->numThreads)
615                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
616                         int quad[maxDims];
617                         decodeLocation(qx, maxDims, quadGridSize, quad);
618
619                         double thr_ll = 0;
620                         long specificPoints = quadGridSize[sDim];
621                         for (long sx=0; sx < specificPoints; sx++) {
622                                 quad[sDim] = sx;
623                                 thr_ll += ba81Fit1Ordinate(oo, quad);
624                         }
625 #pragma omp atomic
626                         ll += thr_ll;
627                 }
628         }
629
630         if (isinf(ll)) {
631                 return 2*state->ll;
632         } else {
633                 ll = -2 * ll;
634                 state->ll = ll;
635                 return ll;
636         }
637 }
638
639 double
640 ba81ComputeFit(omxExpectation* oo)
641 {
642         double got = ba81ComputeFit1(oo);
643         return got;
644 }
645
646 OMXINLINE static void
647 ba81ItemGradientOrdinate(omxExpectation* oo, omxBA81State *state,
648                          int maxDims, int *quad, int item, const double *spec, int id,
649                          int outcomes, double *iparam, int numParam, int *paramMask, double *gq)
650 {
651         double where[maxDims];
652         pointToWhere(state, quad, where, maxDims);
653         double weight[outcomes];
654         ba81Weight(oo, item, quad, outcomes, weight);
655
656         (*rpf_model[id].gradient)(spec, iparam, paramMask, where, weight, gq);
657
658         for (int ox=0; ox < numParam; ox++) {
659                 if (paramMask[ox] == -1) continue;
660
661 #if 0
662                 if (!isnormal(gq[ox])) {
663                         Rprintf("item spec:\n");
664                         pda(spec, (*rpf_model[id].numSpec)(spec), 1);
665                         Rprintf("item parameters:\n");
666                         pda(iparam, numParam, 1);
667                         Rprintf("where:\n");
668                         pda(where, maxDims, 1);
669                         Rprintf("weight:\n");
670                         pda(weight, outcomes, 1);
671                         error("Gradient for item %d param %d is %f; are you missing a lbound/ubound?",
672                               item, ox, gq[ox]);
673                 }
674 #endif
675
676                 areaProduct(state, quad, maxDims, gq+ox);
677         }
678 }
679
680 OMXINLINE static void
681 ba81ItemGradient(omxExpectation* oo, omxBA81State *state, const double *spec, omxMatrix *itemParam,
682                  int item, int id, int outcomes, int numParam, int *paramMask, double *out)
683 {
684         int maxDims = state->maxDims;
685         double *iparam = omxMatrixColumn(itemParam, item);
686         double gradient[numParam];
687         OMXZERO(gradient, numParam);
688
689         if (state->numSpecific == 0) {
690 #pragma omp parallel for num_threads(oo->currentState->numThreads)
691                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
692                         int quad[maxDims];
693                         decodeLocation(qx, maxDims, state->quadGridSize, quad);
694                         double gq[numParam];
695                         //                      for (int gx=0; gx < numParam; gx++) gq[gx] = 3.14159; // debugging TODO
696
697                         ba81ItemGradientOrdinate(oo, state, maxDims, quad, item, spec, id,
698                                                  outcomes, iparam, numParam, paramMask, gq);
699
700 #pragma omp critical(GradientUpdate)
701                         for (int ox=0; ox < numParam; ox++) {
702                                 gradient[ox] += gq[ox];
703                         }
704                 }
705         } else {
706                 int sDim = state->maxDims-1;
707                 long *quadGridSize = state->quadGridSize;
708 #pragma omp parallel for num_threads(oo->currentState->numThreads)
709                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
710                         int quad[maxDims];
711                         decodeLocation(qx, maxDims, quadGridSize, quad);
712                         double gsubtotal[numParam];
713                         OMXZERO(gsubtotal, numParam);
714
715                         long specificPoints = quadGridSize[sDim];
716                         for (long sx=0; sx < specificPoints; sx++) {
717                                 double gq[numParam];
718                                 quad[sDim] = sx;
719                                 ba81ItemGradientOrdinate(oo, state, maxDims, quad, item, spec, id,
720                                                          outcomes, iparam, numParam, paramMask, gq);
721                                 for (int gx=0; gx < numParam; gx++) {
722                                         gsubtotal[gx] += gq[gx];
723                                 }
724                         }
725 #pragma omp critical(GradientUpdate)
726                         for (int gx=0; gx < numParam; gx++) {
727                                 gradient[gx] += gsubtotal[gx];
728                         }
729                 }
730         }
731
732         (*rpf_model[id].gradient)(spec, iparam, paramMask, NULL, NULL, gradient);
733
734         for (int px=0; px < numParam; px++) {
735                 int loc = paramMask[px];
736                 if (loc == -1) continue;
737
738                 // Need to check because this can happen if
739                 // lbounds/ubounds are not set appropriately.
740                 if (!isnormal(gradient[px])) {
741                         Rprintf("item spec:\n");
742                         pda(spec, (*rpf_model[id].numSpec)(spec), 1);
743                         Rprintf("item parameters:\n");
744                         pda(iparam, numParam, 1);
745                         error("Gradient for item %d param %d is %f; are you missing a lbound/ubound?",
746                               item, px, gradient[px]);
747                 }
748
749                 out[loc] = -2 * gradient[px];
750         }
751 }
752
753 void ba81Gradient(omxExpectation* oo, double *out)
754 {
755         omxState* currentState = oo->currentState;
756         int numFreeParams = currentState->numFreeParams;
757         omxBA81State *state = (omxBA81State *) oo->argStruct;
758         if (!state->paramMap) buildParamMap(oo);
759         ++state->gradientCount;
760         omxMatrix *itemSpec = state->itemSpec;
761         omxMatrix *itemParam = state->itemParam;
762
763         int vx = 0;
764         while (vx < numFreeParams) {
765             omxFreeVar *fv = currentState->freeVarList + state->paramMap[vx];
766             int vloc = findFreeVarLocation(itemParam, fv);
767             if (vloc < 0) {
768                     ++vx;
769                     continue;
770             }
771
772             int item = fv->col[vloc];
773             const double *spec = omxMatrixColumn(itemSpec, item);
774             int id = spec[RPF_ISpecID];
775             int outcomes = spec[RPF_ISpecOutcomes];
776             int numParam = (*rpf_model[id].numParam)(spec);
777
778             int paramMask[numParam];
779             for (int px=0; px < numParam; px++) { paramMask[px] = -1; }
780
781             if (fv->row[vloc] >= numParam) {
782                     warning("Item %d has too many free parameters", item);
783                     continue;
784             }
785             paramMask[fv->row[vloc]] = vx;
786
787             while (++vx < numFreeParams) {
788                     omxFreeVar *fv = currentState->freeVarList + state->paramMap[vx];
789                     int vloc = findFreeVarLocation(itemParam, fv);
790                     if (fv->col[vloc] != item) break;
791                     if (fv->row[vloc] >= numParam) {
792                             warning("Item %d has too many free parameters", item);
793                             continue;
794                     }
795                     paramMask[fv->row[vloc]] = vx;
796             }
797
798             ba81ItemGradient(oo, state, spec, itemParam, item,
799                              id, outcomes, numParam, paramMask, out);
800         }
801 }
802
803 static int
804 getNumThreads(omxExpectation* oo)
805 {
806         int numThreads = oo->currentState->numThreads;
807         if (numThreads < 1) numThreads = 1;
808         return numThreads;
809 }
810
811 static void
812 ba81SetupQuadrature(omxExpectation* oo, int numPoints, double *points, double *area)
813 {
814         omxBA81State *state = (omxBA81State *) oo->argStruct;
815         int numUnique = state->numUnique;
816         int numThreads = getNumThreads(oo);
817
818         state->numQpoints = numPoints;
819
820         Free(state->Qpoint);
821         Free(state->Qarea);
822         state->Qpoint = Realloc(NULL, numPoints, double);
823         state->Qarea = Realloc(NULL, numPoints, double);
824         memcpy(state->Qpoint, points, sizeof(double)*numPoints);
825         memcpy(state->Qarea, area, sizeof(double)*numPoints);
826
827         Free(state->logQarea);
828
829         state->logQarea = Realloc(NULL, state->numQpoints, double);
830         for (int px=0; px < state->numQpoints; px++) {
831                 state->logQarea[px] = log(state->Qarea[px]);
832         }
833
834         state->totalQuadPoints = 1;
835         state->totalPrimaryPoints = 1;
836         state->quadGridSize = (long*) R_alloc(state->maxDims, sizeof(long));
837         for (int dx=0; dx < state->maxDims; dx++) {
838                 state->quadGridSize[dx] = state->numQpoints;
839                 state->totalQuadPoints *= state->quadGridSize[dx];
840                 if (dx < state->maxDims-1) {
841                         state->totalPrimaryPoints *= state->quadGridSize[dx];
842                 }
843         }
844
845         Free(state->lxk);
846
847         if (!state->cacheLXK) {
848                 state->lxk = Realloc(NULL, numUnique * numThreads, double);
849         } else {
850                 int ns = state->numSpecific;
851                 if (ns == 0) ns = 1;
852                 state->lxk = Realloc(NULL, numUnique * state->totalQuadPoints * ns, double);
853         }
854 }
855
856 static void
857 ba81EAP1(omxExpectation *oo, double *workspace, long qx, int maxDims, int numUnique,
858          double *ability, double *cov, double *spstats)
859 {
860         omxBA81State *state = (omxBA81State *) oo->argStruct;
861         double *patternLik = state->patternLik;
862         int quad[maxDims];
863         decodeLocation(qx, maxDims, state->quadGridSize, quad);
864         double where[maxDims];
865         pointToWhere(state, quad, where, maxDims);
866         double logArea = logAreaProduct(state, quad, maxDims);
867         double *lxk = ba81LikelihoodFast(oo, 0, quad);
868         double *myspace = workspace + 2 * maxDims * numUnique * omx_absolute_thread_num();
869
870         for (int px=0; px < numUnique; px++) {
871                 double *piece = myspace + px * 2 * maxDims;
872                 double plik = exp(lxk[px] + logArea - patternLik[px]);
873                 for (int dx=0; dx < maxDims; dx++) {
874                         piece[dx] = where[dx] * plik;
875                 }
876                 /*
877                 for (int d1=0; d1 < maxDims; d1++) {
878                         for (int d2=0; d2 <= d1; d2++) {
879                                 covPiece[d1 * maxDims + d2] += piece[d1] * piece[d2];
880                         }
881                 }
882                 */
883         }
884 #pragma omp critical(EAP1Update)
885         for (int px=0; px < numUnique; px++) {
886                 double *piece = myspace + px * 2 * maxDims;
887                 double *arow = ability + px * 2 * maxDims;
888                 for (int dx=0; dx < maxDims; dx++) {
889                         arow[dx*2] += piece[dx];
890                 }
891                 /*
892                 for (int d1=0; d1 < maxDims; d1++) {
893                         for (int d2=0; d2 <= d1; d2++) {
894                                 int loc = d1 * maxDims + d2;
895                                 cov[loc] += covPiece[loc];
896                         }
897                 }
898                 */
899         }
900 }
901
902 static void
903 ba81EAP2(omxExpectation *oo, double *workspace, long qx, int maxDims, int numUnique,
904          double *ability, double *spstats)
905 {
906         omxBA81State *state = (omxBA81State *) oo->argStruct;
907         double *patternLik = state->patternLik;
908         int quad[maxDims];
909         decodeLocation(qx, maxDims, state->quadGridSize, quad);
910         double where[maxDims];
911         pointToWhere(state, quad, where, maxDims);
912         double logArea = logAreaProduct(state, quad, maxDims);
913         double *lxk = ba81LikelihoodFast(oo, 0, quad);
914
915         for (int px=0; px < numUnique; px++) {
916                 double psd[maxDims];
917                 double *arow = ability + px * 2 * maxDims;
918                 for (int dx=0; dx < maxDims; dx++) {
919                         // is this just sqrt(variance) and redundant with the covariance matrix? TODO
920                         double ldiff = log(fabs(where[dx] - arow[dx*2]));
921                         psd[dx] = exp(2 * ldiff + lxk[px] + logArea - patternLik[px]);
922                 }
923 #pragma omp critical(EAP1Update)
924                 for (int dx=0; dx < maxDims; dx++) {
925                         arow[dx*2+1] += psd[dx];
926                 }
927         }
928 }
929
930 /**
931  * MAP is not affected by the number of items. EAP is. Likelihood can
932  * get concentrated in a single quadrature ordinate. For 3PL, response
933  * patterns can have a bimodal likelihood. This will confuse MAP and
934  * is a key advantage of EAP (Thissen & Orlando, 2001, p. 136).
935  *
936  * Thissen, D. & Orlando, M. (2001). IRT for items scored in two
937  * categories. In D. Thissen & H. Wainer (Eds.), \emph{Test scoring}
938  * (pp 73-140). Lawrence Erlbaum Associates, Inc.
939  */
940 omxRListElement *
941 ba81EAP(omxExpectation *oo, int *numReturns)
942 {
943         omxBA81State *state = (omxBA81State *) oo->argStruct;
944         int maxDims = state->maxDims;
945         //int numSpecific = state->numSpecific;
946
947         *numReturns = 2; // + (maxDims > 1) + (numSpecific > 1);
948         omxRListElement *out = (omxRListElement*) R_alloc(*numReturns, sizeof(omxRListElement));
949
950         out[0].numValues = 1;
951         out[0].values = (double*) R_alloc(1, sizeof(double));
952         strcpy(out[0].label, "Minus2LogLikelihood");
953         out[0].values[0] = state->ll;
954
955         omxData *data = state->data;
956         int numUnique = state->numUnique;
957
958         // TODO Wainer & Thissen. (1987). Estimating ability with the wrong
959         // model. Journal of Educational Statistics, 12, 339-368.
960
961         int numQpoints = state->numQpoints * 2;  // make configurable TODO
962
963         if (numQpoints < 1 + 2.0 * sqrt(state->itemSpec->cols)) {
964                 // Thissen & Orlando (2001, p. 136)
965                 warning("EAP requires at least 2*sqrt(items) quadrature points");
966         }
967
968         double Qpoint[numQpoints];
969         double Qarea[numQpoints];
970         const double Qwidth = 4;
971         for (int qx=0; qx < numQpoints; qx++) {
972                 Qpoint[qx] = Qwidth - qx * Qwidth*2 / (numQpoints-1);
973                 Qarea[qx] = 1.0/numQpoints;
974         }
975         ba81SetupQuadrature(oo, numQpoints, Qpoint, Qarea);
976         ba81Estep(oo);   // recalc patternLik with a flat prior
977
978         double *cov = NULL;
979         /*
980         if (maxDims > 1) {
981                 strcpy(out[2].label, "ability.cov");
982                 out[2].numValues = -1;
983                 out[2].rows = maxDims;
984                 out[2].cols = maxDims;
985                 out[2].values = (double*) R_alloc(out[2].rows * out[2].cols, sizeof(double));
986                 cov = out[2].values;
987                 OMXZERO(cov, out[2].rows * out[2].cols);
988         }
989         */
990         double *spstats = NULL;
991         /*
992         if (numSpecific) {
993                 strcpy(out[3].label, "specific");
994                 out[3].numValues = -1;
995                 out[3].rows = numSpecific;
996                 out[3].cols = 2;
997                 out[3].values = (double*) R_alloc(out[3].rows * out[3].cols, sizeof(double));
998                 spstats = out[3].values;
999         }
1000         */
1001
1002         // allocation of workspace could be optional
1003         int numThreads = getNumThreads(oo);
1004         double *workspace = Realloc(NULL, numUnique * maxDims * 2 * numThreads, double);
1005
1006         // Need a separate work space because the destination needs
1007         // to be in unsorted order with duplicated rows.
1008         double *ability = Calloc(numUnique * maxDims * 2, double);
1009
1010 #pragma omp parallel for num_threads(oo->currentState->numThreads)
1011         for (long qx=0; qx < state->totalQuadPoints; qx++) {
1012                 ba81EAP1(oo, workspace, qx, maxDims, numUnique, ability, cov, spstats);
1013         }
1014
1015         /*
1016         // make symmetric
1017         for (int d1=0; d1 < maxDims; d1++) {
1018                 for (int d2=0; d2 < d1; d2++) {
1019                         cov[d2 * maxDims + d1] = cov[d1 * maxDims + d2];
1020                 }
1021         }
1022         */
1023
1024 #pragma omp parallel for num_threads(oo->currentState->numThreads)
1025         for (long qx=0; qx < state->totalQuadPoints; qx++) {
1026                 ba81EAP2(oo, workspace, qx, maxDims, numUnique, ability, spstats);
1027         }
1028
1029         for (int px=0; px < numUnique; px++) {
1030                 double *arow = ability + px * 2 * maxDims;
1031                 for (int dx=0; dx < maxDims; dx++) {
1032                         arow[dx*2+1] = sqrt(arow[dx*2+1]);
1033                 }
1034         }
1035
1036         strcpy(out[1].label, "ability");
1037         out[1].numValues = -1;
1038         out[1].rows = data->rows;
1039         out[1].cols = 2 * maxDims;
1040         out[1].values = (double*) R_alloc(out[1].rows * out[1].cols, sizeof(double));
1041
1042         for (int rx=0; rx < numUnique; rx++) {
1043                 double *pa = ability + rx * 2 * maxDims;
1044
1045                 int dups = omxDataNumIdenticalRows(state->data, state->rowMap[rx]);
1046                 for (int dup=0; dup < dups; dup++) {
1047                         int dest = omxDataIndex(data, state->rowMap[rx]+dup);
1048                         int col=0;
1049                         for (int dx=0; dx < maxDims; dx++) {
1050                                 out[1].values[col * out[1].rows + dest] = pa[col]; ++col;
1051                                 out[1].values[col * out[1].rows + dest] = pa[col]; ++col;
1052                         }
1053                 }
1054         }
1055         Free(ability);
1056         Free(workspace);
1057         return out;
1058 }
1059
1060 static void ba81Destroy(omxExpectation *oo) {
1061         if(OMX_DEBUG) {
1062                 Rprintf("Freeing %s function.\n", NAME);
1063         }
1064         omxBA81State *state = (omxBA81State *) oo->argStruct;
1065         Rprintf("fit %d gradient %d\n", state->fitCount, state->gradientCount);
1066         omxFreeAllMatrixData(state->itemSpec);
1067         omxFreeAllMatrixData(state->itemParam);
1068         omxFreeAllMatrixData(state->EitemParam);
1069         omxFreeAllMatrixData(state->design);
1070         omxFreeAllMatrixData(state->customPrior);
1071         Free(state->logNumIdentical);
1072         Free(state->Qpoint);
1073         Free(state->Qarea);
1074         Free(state->logQarea);
1075         Free(state->rowMap);
1076         Free(state->patternLik);
1077         Free(state->lxk);
1078         Free(state->Slxk);
1079         Free(state->allSlxk);
1080         Free(state->Sgroup);
1081         Free(state->paramMap);
1082         Free(state);
1083 }
1084
1085 int ba81ExpectationHasGradients(omxExpectation* oo)
1086 {
1087         omxBA81State *state = (omxBA81State *) oo->argStruct;
1088         return state->computeRPF == standardComputeRPF;
1089 }
1090
1091 void omxInitExpectationBA81(omxExpectation* oo) {
1092         omxState* currentState = oo->currentState;      
1093         SEXP rObj = oo->rObj;
1094         SEXP tmp;
1095         
1096         if(OMX_DEBUG) {
1097                 Rprintf("Initializing %s.\n", NAME);
1098         }
1099         if (!rpf_model) {
1100                 get_librpf_t get_librpf = (get_librpf_t) R_GetCCallable("rpf", "get_librpf_model");
1101                 (*get_librpf)(&rpf_numModels, &rpf_model);
1102         }
1103         
1104         omxBA81State *state = Calloc(1, omxBA81State);
1105         oo->argStruct = (void*) state;
1106
1107         state->ll = 1e20;   // finite but big
1108         
1109         PROTECT(tmp = GET_SLOT(rObj, install("data")));
1110         state->data = omxNewDataFromMxDataPtr(tmp, currentState);
1111         UNPROTECT(1);
1112
1113         if (strcmp(omxDataType(state->data), "raw") != 0) {
1114                 omxRaiseErrorf(currentState, "%s unable to handle data type %s", NAME, omxDataType(state->data));
1115                 return;
1116         }
1117
1118         PROTECT(state->rpf = GET_SLOT(rObj, install("RPF")));
1119         if (state->rpf == R_NilValue) {
1120                 state->computeRPF = standardComputeRPF;
1121         } else {
1122                 state->computeRPF = RComputeRPF;
1123         }
1124
1125         state->itemSpec =
1126                 omxNewMatrixFromIndexSlot(rObj, currentState, "ItemSpec");
1127         state->design =
1128                 omxNewMatrixFromIndexSlot(rObj, currentState, "Design");
1129         state->itemParam =
1130                 omxNewMatrixFromIndexSlot(rObj, currentState, "ItemParam");
1131         state->EitemParam =
1132                 omxInitTemporaryMatrix(NULL, state->itemParam->rows, state->itemParam->cols,
1133                                        TRUE, currentState);
1134         state->customPrior =
1135                 omxNewMatrixFromIndexSlot(rObj, currentState, "CustomPrior");
1136         
1137         oo->computeFun = ba81Estep;
1138         oo->destructFun = ba81Destroy;
1139         
1140         // TODO: Exactly identical rows do not contribute any information.
1141         // The sorting algorithm ought to remove them so we don't waste RAM.
1142         // The following summary stats would be cheaper to calculate too.
1143
1144         int numUnique = 0;
1145         omxData *data = state->data;
1146         if (omxDataNumFactor(data) != data->cols) {
1147                 // verify they are ordered factors TODO
1148                 omxRaiseErrorf(currentState, "%s: all columns must be factors", NAME);
1149                 return;
1150         }
1151
1152         for (int rx=0; rx < data->rows;) {
1153                 rx += omxDataNumIdenticalRows(state->data, rx);
1154                 ++numUnique;
1155         }
1156         state->numUnique = numUnique;
1157
1158         state->rowMap = Realloc(NULL, numUnique, int);
1159         state->logNumIdentical = Realloc(NULL, numUnique, double);
1160
1161         int numItems = state->itemParam->cols;
1162
1163         for (int rx=0, ux=0; rx < data->rows; ux++) {
1164                 if (rx == 0) {
1165                         // all NA rows will sort to the top
1166                         int na=0;
1167                         for (int ix=0; ix < numItems; ix++) {
1168                                 if (omxIntDataElement(data, 0, ix) == NA_INTEGER) { ++na; }
1169                         }
1170                         if (na == numItems) {
1171                                 omxRaiseErrorf(currentState, "Remove rows with all NAs");
1172                                 return;
1173                         }
1174                 }
1175                 int dups = omxDataNumIdenticalRows(state->data, rx);
1176                 state->logNumIdentical[ux] = log(dups);
1177                 state->rowMap[ux] = rx;
1178                 rx += dups;
1179         }
1180
1181         state->patternLik = Realloc(NULL, numUnique, double);
1182
1183         int numThreads = getNumThreads(oo);
1184
1185         int maxSpec = 0;
1186         int maxParam = 0;
1187         state->maxDims = 0;
1188         state->maxOutcomes = 0;
1189
1190         for (int cx = 0; cx < data->cols; cx++) {
1191                 const double *spec = omxMatrixColumn(state->itemSpec, cx);
1192                 int id = spec[RPF_ISpecID];
1193                 if (id < 0 || id >= rpf_numModels) {
1194                         omxRaiseErrorf(currentState, "ItemSpec column %d has unknown item model %d", cx, id);
1195                         return;
1196                 }
1197
1198                 int dims = spec[RPF_ISpecDims];
1199                 if (state->maxDims < dims)
1200                         state->maxDims = dims;
1201
1202                 // TODO verify that item model can have requested number of outcomes
1203                 int no = spec[RPF_ISpecOutcomes];
1204                 if (state->maxOutcomes < no)
1205                         state->maxOutcomes = no;
1206
1207                 int numSpec = (*rpf_model[id].numSpec)(spec);
1208                 if (maxSpec < numSpec)
1209                         maxSpec = numSpec;
1210
1211                 int numParam = (*rpf_model[id].numParam)(spec);
1212                 if (maxParam < numParam)
1213                         maxParam = numParam;
1214         }
1215
1216         if (state->itemSpec->cols != data->cols || state->itemSpec->rows != maxSpec) {
1217                 omxRaiseErrorf(currentState, "ItemSpec must have %d item columns and %d rows",
1218                                data->cols, maxSpec);
1219                 return;
1220         }
1221         if (state->itemParam->rows != maxParam) {
1222                 omxRaiseErrorf(currentState, "ItemParam should have %d rows", maxParam);
1223                 return;
1224         }
1225
1226         if (state->design == NULL) {
1227                 state->maxAbilities = state->maxDims;
1228                 state->design = omxInitTemporaryMatrix(NULL, state->maxDims, numItems,
1229                                        TRUE, currentState);
1230                 for (int ix=0; ix < numItems; ix++) {
1231                         for (int dx=0; dx < state->maxDims; dx++) {
1232                                 omxSetMatrixElement(state->design, dx, ix, (double)dx+1);
1233                         }
1234                 }
1235         } else {
1236                 omxMatrix *design = state->design;
1237                 if (design->cols != numItems ||
1238                     design->rows != state->maxDims) {
1239                         omxRaiseErrorf(currentState, "Design matrix should have %d rows and %d columns",
1240                                        state->maxDims, numItems);
1241                         return;
1242                 }
1243
1244                 state->maxAbilities = 0;
1245                 for (int ix=0; ix < design->rows * design->cols; ix++) {
1246                         double got = design->data[ix];
1247                         if (!R_FINITE(got)) continue;
1248                         if (round(got) != got) error("Design matrix can only contain integers"); // TODO better way?
1249                         if (state->maxAbilities < got)
1250                                 state->maxAbilities = got;
1251                 }
1252         }
1253         if (state->maxAbilities <= state->maxDims) {
1254                 state->Sgroup = Calloc(numItems, int);
1255         } else {
1256                 // Not sure if this is correct, revisit TODO
1257                 int Sgroup0 = -1;
1258                 state->Sgroup = Realloc(NULL, numItems, int);
1259                 for (int dx=0; dx < state->maxDims; dx++) {
1260                         for (int ix=0; ix < numItems; ix++) {
1261                                 int ability = omxMatrixElement(state->design, dx, ix);
1262                                 if (dx < state->maxDims - 1) {
1263                                         if (Sgroup0 <= ability)
1264                                                 Sgroup0 = ability+1;
1265                                         continue;
1266                                 }
1267                                 int ss=-1;
1268                                 if (ability >= Sgroup0) {
1269                                         if (ss == -1) {
1270                                                 ss = ability;
1271                                         } else {
1272                                                 omxRaiseErrorf(currentState, "Item %d cannot belong to more than "
1273                                                                "1 specific dimension (both %d and %d)",
1274                                                                ix, ss, ability);
1275                                                 return;
1276                                         }
1277                                 }
1278                                 if (ss == -1) ss = Sgroup0;
1279                                 state->Sgroup[ix] = ss - Sgroup0;
1280                         }
1281                 }
1282                 state->numSpecific = state->maxAbilities - state->maxDims + 1;
1283                 state->allSlxk = Realloc(NULL, numUnique * numThreads, double);
1284                 state->Slxk = Realloc(NULL, numUnique * state->numSpecific * numThreads, double);
1285         }
1286
1287         PROTECT(tmp = GET_SLOT(rObj, install("cache")));
1288         state->cacheLXK = asLogical(tmp);
1289
1290         PROTECT(tmp = GET_SLOT(rObj, install("GHpoints")));
1291         double *qpoints = REAL(tmp);
1292         int numQPoints = length(tmp);
1293
1294         PROTECT(tmp = GET_SLOT(rObj, install("GHarea")));
1295         double *qarea = REAL(tmp);
1296         if (numQPoints != length(tmp)) error("length(GHpoints) != length(GHarea)");
1297
1298         ba81SetupQuadrature(oo, numQPoints, qpoints, qarea);
1299
1300         // verify data bounded between 1 and numOutcomes TODO
1301         // hm, looks like something could be added to omxData for column summary stats?
1302 }