ifa: Fixup
[openmx:openmx.git] / src / omxExpectationBA81.c
1 /*
2   Copyright 2012 Joshua Nathaniel Pritikin and contributors
3
4   This is free software: you can redistribute it and/or modify
5   it under the terms of the GNU General Public License as published by
6   the Free Software Foundation, either version 3 of the License, or
7   (at your option) any later version.
8
9   This program is distributed in the hope that it will be useful,
10   but WITHOUT ANY WARRANTY; without even the implied warranty of
11   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12   GNU General Public License for more details.
13
14   You should have received a copy of the GNU General Public License
15   along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 // Consider replacing log() with log2() in some places? Not worth it?
19
20 #include "omxExpectation.h"
21 #include "omxOpenmpWrap.h"
22 #include "npsolWrap.h"
23 #include "libirt-rpf.h"
24 #include "merge.h"
25
26 static const char *NAME = "ExpectationBA81";
27
28 static const struct rpf *rpf_model = NULL;
29 static int rpf_numModels;
30
31 typedef double *(*rpf_fn_t)(omxExpectation *oo, omxMatrix *itemParam, const int *quad);
32
33 typedef struct {
34
35         // data characteristics
36         omxData *data;
37         int numUnique;
38         double *logNumIdentical;  // length numUnique
39         int *rowMap;              // length numUnique
40
41         // item description related
42         omxMatrix *itemSpec;
43         int maxOutcomes;
44         int maxDims;
45         int maxAbilities;
46         int numSpecific;
47         int *Sgroup;              // item's specific group 0..numSpecific-1
48         omxMatrix *design;        // items * maxDims
49
50         // quadrature related
51         int numQpoints;
52         double *Qpoint;
53         double *Qarea;
54         double *logQarea;
55         long *quadGridSize;       // maxDims
56         long totalPrimaryPoints;  // product of quadGridSize except specific dim
57         long totalQuadPoints;     // product of quadGridSize
58
59         // estimation related
60         omxMatrix *EitemParam;    // E step version
61         omxMatrix *itemParam;     // M step version
62         SEXP rpf;
63         rpf_fn_t computeRPF;
64         omxMatrix *customPrior;
65         int *paramMap;
66         int cacheLXK;             // w/cache,  numUnique * #specific quad points * totalQuadPoints
67         double *lxk;              // wo/cache, numUnique * thread
68         double *allSlxk;          // numUnique * thread
69         double *Slxk;             // numUnique * #specific dimensions * thread
70         double *patternLik;       // length numUnique
71         double ll;                // the most recent finite ll
72
73         int gradientCount;
74         int fitCount;
75 } omxBA81State;
76
77 static void
78 pda(const double *ar, int rows, int cols) {
79         for (int rx=0; rx < rows; rx++) {
80                 for (int cx=0; cx < cols; cx++) {
81                         Rprintf("%.6g ", ar[cx * rows + rx]);
82                 }
83                 Rprintf("\n");
84         }
85
86 }
87
88 static int
89 findFreeVarLocation(omxMatrix *itemParam, const omxFreeVar *fv)
90 {
91         for (int lx=0; lx < fv->numLocations; lx++) {
92                 if (~fv->matrices[lx] == itemParam->matrixNumber) return lx;
93         }
94         return -1;
95 }
96
97 static int
98 compareFV(const int *fv1x, const int *fv2x, omxExpectation* oo)
99 {
100         omxState* currentState = oo->currentState;
101         omxBA81State *state = (omxBA81State *) oo->argStruct;
102         omxMatrix *itemParam = state->itemParam;
103         omxFreeVar *fv1 = currentState->freeVarList + *fv1x;
104         omxFreeVar *fv2 = currentState->freeVarList + *fv2x;
105         int l1 = findFreeVarLocation(itemParam, fv1);
106         int l2 = findFreeVarLocation(itemParam, fv2);
107         if (l1 == -1 && l2 == -1) return 0;
108         if ((l1 == -1) ^ (l2 == -1)) return l1 == -1? 1:-1;  // TODO reversed?
109         // Columns are items. Sort columns together
110         return fv1->col[l1] - fv2->col[l1];
111 }
112
113 static void buildParamMap(omxExpectation* oo)
114 {
115         omxState* currentState = oo->currentState;
116         omxBA81State *state = (omxBA81State *) oo->argStruct;
117         int numFreeParams = currentState->numFreeParams;
118         state->paramMap = Realloc(NULL, numFreeParams, int);
119         for (int px=0; px < numFreeParams; px++) { state->paramMap[px] = px; }
120         freebsd_mergesort(state->paramMap, numFreeParams, sizeof(int),
121                           (mergesort_cmp_t)compareFV, oo);
122 }
123
124 OMXINLINE static void
125 pointToWhere(omxBA81State *state, const int *quad, double *where, int upto)
126 {
127         for (int dx=0; dx < upto; dx++) {
128                 where[dx] = state->Qpoint[quad[dx]];
129         }
130 }
131
132 OMXINLINE static void
133 assignDims(omxMatrix *itemSpec, omxMatrix *design, int dims, int maxDims, int ix,
134            const double *restrict theta, double *restrict ptheta)
135 {
136         for (int dx=0; dx < dims; dx++) {
137                 int ability = (int)omxMatrixElement(design, dx, ix) - 1;
138                 if (ability >= maxDims) ability = maxDims-1;
139                 ptheta[dx] = theta[ability];
140         }
141 }
142
143 /**
144  * This is the main function needed to generate simulated data from
145  * the model. It could be argued that the rest of the estimation
146  * machinery belongs in the fitfunction.
147  *
148  * \param theta Vector of ability parameters, one per ability
149  * \returns A numItems by maxOutcomes colMajor vector of doubles. Caller must Free it.
150  */
151 static double *
152 standardComputeRPF(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
153 {
154         omxBA81State *state = (omxBA81State*) oo->argStruct;
155         omxMatrix *itemSpec = state->itemSpec;
156         int numItems = itemSpec->cols;
157         omxMatrix *design = state->design;
158         int maxDims = state->maxDims;
159
160         double theta[maxDims];
161         pointToWhere(state, quad, theta, maxDims);
162
163         double *outcomeProb = Realloc(NULL, numItems * state->maxOutcomes, double);
164
165         for (int ix=0; ix < numItems; ix++) {
166                 const double *spec = omxMatrixColumn(itemSpec, ix);
167                 double *iparam = omxMatrixColumn(itemParam, ix);
168                 double *out = outcomeProb + ix * state->maxOutcomes;
169                 int id = spec[RPF_ISpecID];
170                 int dims = spec[RPF_ISpecDims];
171                 double ptheta[dims];
172                 assignDims(itemSpec, design, dims, maxDims, ix, theta, ptheta);
173                 (*rpf_model[id].logprob)(spec, iparam, ptheta, out);
174         }
175
176         return outcomeProb;
177 }
178
179 static double *
180 RComputeRPF1(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
181 {
182         omxBA81State *state = (omxBA81State*) oo->argStruct;
183         int maxOutcomes = state->maxOutcomes;
184         omxMatrix *design = state->design;
185         omxMatrix *itemSpec = state->itemSpec;
186         int maxDims = state->maxDims;
187
188         double theta[maxDims];
189         pointToWhere(state, quad, theta, maxDims);
190
191         SEXP invoke;
192         PROTECT(invoke = allocVector(LANGSXP, 4));
193         SETCAR(invoke, state->rpf);
194         SETCADR(invoke, omxExportMatrix(itemParam));
195         SETCADDR(invoke, omxExportMatrix(itemSpec));
196
197         SEXP where;
198         PROTECT(where = allocMatrix(REALSXP, maxDims, itemParam->cols));
199         double *ptheta = REAL(where);
200         for (int ix=0; ix < itemParam->cols; ix++) {
201                 int dims = omxMatrixElement(itemSpec, RPF_ISpecDims, ix);
202                 assignDims(itemSpec, design, dims, maxDims, ix, theta, ptheta + ix*maxDims);
203                 for (int dx=dims; dx < maxDims; dx++) {
204                         ptheta[ix*maxDims + dx] = NA_REAL;
205                 }
206         }
207         SETCADDDR(invoke, where);
208
209         SEXP matrix;
210         PROTECT(matrix = eval(invoke, R_GlobalEnv));
211
212         if (!isMatrix(matrix)) {
213                 omxRaiseError(oo->currentState, -1,
214                               "RPF must return an item by outcome matrix");
215                 return NULL;
216         }
217
218         SEXP matrixDims;
219         PROTECT(matrixDims = getAttrib(matrix, R_DimSymbol));
220         int *dimList = INTEGER(matrixDims);
221         int numItems = state->itemSpec->cols;
222         if (dimList[0] != maxOutcomes || dimList[1] != numItems) {
223                 const int errlen = 200;
224                 char errstr[errlen];
225                 snprintf(errstr, errlen, "RPF must return a %d outcomes by %d items matrix",
226                          maxOutcomes, numItems);
227                 omxRaiseError(oo->currentState, -1, errstr);
228                 return NULL;
229         }
230
231         // Unlikely to be of type INTSXP, but just to be safe
232         PROTECT(matrix = coerceVector(matrix, REALSXP));
233         double *restrict got = REAL(matrix);
234
235         // Need to copy because threads cannot share SEXP
236         double *restrict outcomeProb = Realloc(NULL, numItems * maxOutcomes, double);
237
238         // Double check there aren't NAs in the wrong place
239         for (int ix=0; ix < numItems; ix++) {
240                 int numOutcomes = omxMatrixElement(state->itemSpec, RPF_ISpecOutcomes, ix);
241                 for (int ox=0; ox < numOutcomes; ox++) {
242                         int vx = ix * maxOutcomes + ox;
243                         if (isnan(got[vx])) {
244                                 const int errlen = 200;
245                                 char errstr[errlen];
246                                 snprintf(errstr, errlen, "RPF returned NA in [%d,%d]", ox,ix);
247                                 omxRaiseError(oo->currentState, -1, errstr);
248                         }
249                         outcomeProb[vx] = got[vx];
250                 }
251         }
252
253         return outcomeProb;
254 }
255
256 static double *
257 RComputeRPF(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
258 {
259         omx_omp_set_lock(&GlobalRLock);
260         PROTECT_INDEX pi = omxProtectSave();
261         double *ret = RComputeRPF1(oo, itemParam, quad);
262         omxProtectRestore(pi);
263         omx_omp_unset_lock(&GlobalRLock);  // hope there was no exception!
264         return ret;
265 }
266
267 OMXINLINE static long
268 encodeLocation(const int dims, const long *restrict grid, const int *restrict quad)
269 {
270         long qx = 0;
271         for (int dx=dims-1; dx >= 0; dx--) {
272                 qx = qx * grid[dx];
273                 qx += quad[dx];
274         }
275         return qx;
276 }
277
278 #define CALC_LXK_CACHED(state, numUnique, quad, tqp, specific) \
279         ((state)->lxk + \
280          (numUnique) * encodeLocation((state)->maxDims, (state)->quadGridSize, quad) + \
281          (numUnique) * (tqp) * (specific))
282
283 OMXINLINE static double *
284 ba81Likelihood(omxExpectation *oo, int specific, const int *restrict quad)
285 {
286         omxBA81State *state = (omxBA81State*) oo->argStruct;
287         int numUnique = state->numUnique;
288         int maxOutcomes = state->maxOutcomes;
289         omxData *data = state->data;
290         int numItems = state->itemSpec->cols;
291         rpf_fn_t rpf_fn = state->computeRPF;
292         int *restrict Sgroup = state->Sgroup;
293         double *restrict lxk;
294
295         if (!state->cacheLXK) {
296                 lxk = state->lxk + numUnique * omx_absolute_thread_num();
297         } else {
298                 lxk = CALC_LXK_CACHED(state, numUnique, quad, state->totalQuadPoints, specific);
299         }
300
301         const double *outcomeProb = (*rpf_fn)(oo, state->EitemParam, quad);
302         if (!outcomeProb) {
303                 OMXZERO(lxk, numUnique);
304                 return lxk;
305         }
306
307         const int *rowMap = state->rowMap;
308         for (int px=0; px < numUnique; px++) {
309                 double lxk1 = 0;
310                 for (int ix=0; ix < numItems; ix++) {
311                         if (specific != Sgroup[ix]) continue;
312                         int pick = omxIntDataElementUnsafe(data, rowMap[px], ix);
313                         if (pick == NA_INTEGER) continue;
314                         lxk1 += outcomeProb[ix * maxOutcomes + pick-1];
315                 }
316                 lxk[px] = lxk1;
317         }
318
319         Free(outcomeProb);
320
321         return lxk;
322 }
323
324 OMXINLINE static double *
325 ba81LikelihoodFast(omxExpectation *oo, int specific, const int *restrict quad)
326 {
327         omxBA81State *state = (omxBA81State*) oo->argStruct;
328         if (!state->cacheLXK) {
329                 return ba81LikelihoodFast(oo, specific, quad);
330         } else {
331                 return CALC_LXK_CACHED(state, state->numUnique, quad, state->totalQuadPoints, specific);
332         }
333
334 }
335
336 #define CALC_ALLSLXK(state, numUnique) \
337         (state->allSlxk + omx_absolute_thread_num() * (numUnique))
338
339 #define CALC_SLXK(state, numUnique, numSpecific) \
340         (state->Slxk + omx_absolute_thread_num() * (numUnique) * (numSpecific))
341
342 OMXINLINE static void
343 cai2010(omxExpectation* oo, int recompute, const int *restrict primaryQuad,
344         double *restrict allSlxk, double *restrict Slxk)
345 {
346         omxBA81State *state = (omxBA81State*) oo->argStruct;
347         int numUnique = state->numUnique;
348         int numSpecific = state->numSpecific;
349         int maxDims = state->maxDims;
350         int sDim = maxDims-1;
351
352         int quad[maxDims];
353         memcpy(quad, primaryQuad, sizeof(int)*sDim);
354
355         OMXZERO(Slxk, numUnique * numSpecific);
356         OMXZERO(allSlxk, numUnique);
357
358         for (int sx=0; sx < numSpecific; sx++) {
359                 double *eis = Slxk + numUnique * sx;
360                 int quadGridSize = state->quadGridSize[sDim];
361
362                 for (int qx=0; qx < quadGridSize; qx++) {
363                         quad[sDim] = qx;
364                         double *lxk;
365                         if (recompute) {
366                                 lxk = ba81Likelihood(oo, sx, quad);
367                         } else {
368                                 lxk = CALC_LXK_CACHED(state, numUnique, quad, state->totalQuadPoints, sx);
369                         }
370
371                         for (int ix=0; ix < numUnique; ix++) {
372                                 eis[ix] += exp(lxk[ix] + state->logQarea[qx]);
373                         }
374                 }
375
376                 for (int px=0; px < numUnique; px++) {
377                         eis[px] = log(eis[px]);
378                         allSlxk[px] += eis[px];
379                 }
380         }
381 }
382
383 OMXINLINE static double
384 logAreaProduct(omxBA81State *state, const int *restrict quad, const int upto)
385 {
386         double logArea = 0;
387         for (int dx=0; dx < upto; dx++) {
388                 logArea += state->logQarea[quad[dx]];
389         }
390         return logArea;
391 }
392
393 // The idea of this API is to allow passing in a number larger than 1.
394 OMXINLINE static void
395 areaProduct(omxBA81State *state, const int *restrict quad, const int upto, double *restrict out)
396 {
397         for (int dx=0; dx < upto; dx++) {
398                 *out *= state->Qarea[quad[dx]];
399         }
400 }
401
402 OMXINLINE static void
403 decodeLocation(long qx, const int dims, const long *restrict grid,
404                int *restrict quad)
405 {
406         for (int dx=0; dx < dims; dx++) {
407                 quad[dx] = qx % grid[dx];
408                 qx = qx / grid[dx];
409         }
410 }
411
412 static void
413 ba81Estep(omxExpectation *oo) {
414         if(OMX_DEBUG_MML) {Rprintf("Beginning %s Computation.\n", NAME);}
415
416         omxBA81State *state = (omxBA81State*) oo->argStruct;
417         double *patternLik = state->patternLik;
418         int numUnique = state->numUnique;
419         int numSpecific = state->numSpecific;
420
421         omxCopyMatrix(state->EitemParam, state->itemParam);
422
423         OMXZERO(patternLik, numUnique);
424
425         // E-step, marginalize person ability
426         //
427         // Note: In the notation of Bock & Aitkin (1981) and
428         // Cai~(2010), these loops are reversed.  That is, the inner
429         // loop is over quadrature points and the outer loop is over
430         // all response patterns.
431         //
432         if (numSpecific == 0) {
433 #pragma omp parallel for num_threads(oo->currentState->numThreads)
434                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
435                         int quad[state->maxDims];
436                         decodeLocation(qx, state->maxDims, state->quadGridSize, quad);
437
438                         double *lxk = ba81Likelihood(oo, 0, quad);
439
440                         double logArea = logAreaProduct(state, quad, state->maxDims);
441 #pragma omp critical(EstepUpdate)
442                         for (int px=0; px < numUnique; px++) {
443                                 double tmp = exp(lxk[px] + logArea);
444                                 patternLik[px] += tmp;
445                         }
446                 }
447         } else {
448                 int sDim = state->maxDims-1;
449
450 #pragma omp parallel for num_threads(oo->currentState->numThreads)
451                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
452                         int quad[state->maxDims];
453                         decodeLocation(qx, sDim, state->quadGridSize, quad);
454
455                         double *allSlxk = CALC_ALLSLXK(state, numUnique);
456                         double *Slxk = CALC_SLXK(state, numUnique, numSpecific);
457                         cai2010(oo, TRUE, quad, allSlxk, Slxk);
458
459                         double logArea = logAreaProduct(state, quad, sDim);
460 #pragma omp critical(EstepUpdate)
461                         for (int px=0; px < numUnique; px++) {
462                                 double tmp = exp(allSlxk[px] + logArea);
463                                 patternLik[px] += tmp;
464                         }
465                 }
466         }
467
468         for (int px=0; px < numUnique; px++) {
469                 patternLik[px] = log(patternLik[px]);
470         }
471 }
472
473 OMXINLINE static void
474 expectedUpdate(omxData *restrict data, const int *rowMap, const int px, const int item,
475                const double observed, const int outcomes, double *out)
476 {
477         int pick = omxIntDataElementUnsafe(data, rowMap[px], item);
478         if (pick == NA_INTEGER) {
479                 double slice = exp(observed - log(outcomes));
480                 for (int ox=0; ox < outcomes; ox++) {
481                         out[ox] += slice;
482                 }
483         } else {
484                 out[pick-1] += exp(observed);
485         }
486 }
487
488 /** 
489  * \param quad a vector that indexes into a multidimensional quadrature
490  * \param out points to an array numOutcomes wide
491  */
492 OMXINLINE static void
493 ba81Weight(omxExpectation* oo, const int item, const int *quad, int outcomes, double *out)
494 {
495         omxBA81State *state = (omxBA81State*) oo->argStruct;
496         omxData *data = state->data;
497         const int *rowMap = state->rowMap;
498         int specific = state->Sgroup[item];
499         double *patternLik = state->patternLik;
500         double *logNumIdentical = state->logNumIdentical;
501         int numUnique = state->numUnique;
502         int numSpecific = state->numSpecific;
503         int sDim = state->maxDims-1;
504
505         OMXZERO(out, outcomes);
506
507         if (numSpecific == 0) {
508                 double *lxk = ba81LikelihoodFast(oo, specific, quad);
509                 for (int px=0; px < numUnique; px++) {
510                         double observed = logNumIdentical[px] + lxk[px] - patternLik[px];
511                         expectedUpdate(data, rowMap, px, item, observed, outcomes, out);
512                 }
513         } else {
514                 double *allSlxk = CALC_ALLSLXK(state, numUnique);
515                 double *Slxk = CALC_SLXK(state, numUnique, numSpecific);
516                 if (quad[sDim] == 0) {
517                         // allSlxk, Slxk only depend on the ordinate of the primary dimensions
518                         cai2010(oo, !state->cacheLXK, quad, allSlxk, Slxk);
519                 }
520                 double *eis = Slxk + numUnique * specific;
521
522                 // Avoid recalc when cache disabled with modest buffer? TODO
523                 double *lxk = ba81LikelihoodFast(oo, specific, quad);
524
525                 for (int px=0; px < numUnique; px++) {
526                         double observed = logNumIdentical[px] + (allSlxk[px] - eis[px]) +
527                                 (lxk[px] - patternLik[px]);
528                         expectedUpdate(data, rowMap, px, item, observed, outcomes, out);
529                 }
530         }
531 }
532
533 OMXINLINE static double
534 ba81Fit1Ordinate(omxExpectation* oo, const int *quad)
535 {
536         omxBA81State *state = (omxBA81State*) oo->argStruct;
537         omxMatrix *itemParam = state->itemParam;
538         int numItems = itemParam->cols;
539         rpf_fn_t rpf_fn = state->computeRPF;
540         int maxOutcomes = state->maxOutcomes;
541         int maxDims = state->maxDims;
542
543         double *outcomeProb = (*rpf_fn)(oo, itemParam, quad);
544         if (!outcomeProb) return 0;
545
546         double thr_ll = 0;
547         for (int ix=0; ix < numItems; ix++) {
548                 int outcomes = omxMatrixElement(state->itemSpec, RPF_ISpecOutcomes, ix);
549                 double out[outcomes];
550                 ba81Weight(oo, ix, quad, outcomes, out);
551                 for (int ox=0; ox < outcomes; ox++) {
552 #if 0
553                         if (!isnormal(outcomeProb[ix * maxOutcomes + ox])) {
554                                 pda(itemParam->data, itemParam->rows, itemParam->cols);
555                                 pda(outcomeProb, outcomes, numItems);
556                                 error("RPF produced NAs");
557                         }
558 #endif
559                         double got = out[ox] * outcomeProb[ix * maxOutcomes + ox];
560                         areaProduct(state, quad, maxDims, &got);
561                         thr_ll += got;
562                 }
563         }
564
565         Free(outcomeProb);
566         return thr_ll;
567 }
568
569 static double
570 ba81ComputeFit1(omxExpectation* oo, int want, double *gradient)
571 {
572         omxBA81State *state = (omxBA81State*) oo->argStruct;
573         ++state->fitCount;
574         omxMatrix *customPrior = state->customPrior;
575         int numSpecific = state->numSpecific;
576         int maxDims = state->maxDims;
577
578         double ll = 0;
579         if (customPrior) {
580                 omxRecompute(customPrior);
581                 ll = customPrior->data[0];
582         } else {
583                 omxMatrix *itemSpec = state->itemSpec;
584                 omxMatrix *itemParam = state->itemParam;
585                 int numItems = itemSpec->cols;
586                 for (int ix=0; ix < numItems; ix++) {
587                         const double *spec = omxMatrixColumn(itemSpec, ix);
588                         int id = spec[RPF_ISpecID];
589                         double *iparam = omxMatrixColumn(itemParam, ix);
590                         ll += (*rpf_model[id].prior)(spec, iparam);
591                 }
592         }
593
594         if (!isnormal(ll)) {
595                 omxMatrix *itemParam = state->itemParam;
596                 omxPrint(itemParam, "item param");
597                 error("Bayesian prior returned NA; do you need to add a lbound/ubound?");
598         }
599
600         if (numSpecific == 0) {
601 #pragma omp parallel for num_threads(oo->currentState->numThreads)
602                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
603                         int quad[maxDims];
604                         decodeLocation(qx, maxDims, state->quadGridSize, quad);
605                         double thr_ll = ba81Fit1Ordinate(oo, quad);
606
607 #pragma omp atomic
608                         ll += thr_ll;
609                 }
610         } else {
611                 int sDim = state->maxDims-1;
612                 long *quadGridSize = state->quadGridSize;
613
614 #pragma omp parallel for num_threads(oo->currentState->numThreads)
615                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
616                         int quad[maxDims];
617                         decodeLocation(qx, maxDims, quadGridSize, quad);
618
619                         double thr_ll = 0;
620                         long specificPoints = quadGridSize[sDim];
621                         for (long sx=0; sx < specificPoints; sx++) {
622                                 quad[sDim] = sx;
623                                 thr_ll += ba81Fit1Ordinate(oo, quad);
624                         }
625 #pragma omp atomic
626                         ll += thr_ll;
627                 }
628         }
629
630         if (isinf(ll)) {
631                 return 2*state->ll;
632         } else {
633                 ll = -2 * ll;
634                 state->ll = ll;
635                 return ll;
636         }
637 }
638
639 double
640 ba81ComputeFit(omxExpectation* oo, int want, double *gradient)
641 {
642         // TODO ignore call from omxInitialMatrixAlgebraCompute
643         double got = ba81ComputeFit1(oo, want, gradient);
644         return got;
645 }
646
647 OMXINLINE static void
648 ba81ItemGradientOrdinate(omxExpectation* oo, omxBA81State *state,
649                          int maxDims, int *quad, int item, const double *spec, int id,
650                          int outcomes, double *iparam, int numParam, int *paramMask, double *gq)
651 {
652         double where[maxDims];
653         pointToWhere(state, quad, where, maxDims);
654         double weight[outcomes];
655         ba81Weight(oo, item, quad, outcomes, weight);
656
657         (*rpf_model[id].gradient)(spec, iparam, paramMask, where, weight, gq);
658
659         for (int ox=0; ox < numParam; ox++) {
660                 if (paramMask[ox] == -1) continue;
661
662 #if 0
663                 if (!isnormal(gq[ox])) {
664                         Rprintf("item spec:\n");
665                         pda(spec, (*rpf_model[id].numSpec)(spec), 1);
666                         Rprintf("item parameters:\n");
667                         pda(iparam, numParam, 1);
668                         Rprintf("where:\n");
669                         pda(where, maxDims, 1);
670                         Rprintf("weight:\n");
671                         pda(weight, outcomes, 1);
672                         error("Gradient for item %d param %d is %f; are you missing a lbound/ubound?",
673                               item, ox, gq[ox]);
674                 }
675 #endif
676
677                 areaProduct(state, quad, maxDims, gq+ox);
678         }
679 }
680
681 OMXINLINE static void
682 ba81ItemGradient(omxExpectation* oo, omxBA81State *state, const double *spec, omxMatrix *itemParam,
683                  int item, int id, int outcomes, int numParam, int *paramMask, double *out)
684 {
685         int maxDims = state->maxDims;
686         double *iparam = omxMatrixColumn(itemParam, item);
687         double gradient[numParam];
688         OMXZERO(gradient, numParam);
689
690         if (state->numSpecific == 0) {
691 #pragma omp parallel for num_threads(oo->currentState->numThreads)
692                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
693                         int quad[maxDims];
694                         decodeLocation(qx, maxDims, state->quadGridSize, quad);
695                         double gq[numParam];
696                         //                      for (int gx=0; gx < numParam; gx++) gq[gx] = 3.14159; // debugging TODO
697
698                         ba81ItemGradientOrdinate(oo, state, maxDims, quad, item, spec, id,
699                                                  outcomes, iparam, numParam, paramMask, gq);
700
701 #pragma omp critical(GradientUpdate)
702                         for (int ox=0; ox < numParam; ox++) {
703                                 gradient[ox] += gq[ox];
704                         }
705                 }
706         } else {
707                 int sDim = state->maxDims-1;
708                 long *quadGridSize = state->quadGridSize;
709 #pragma omp parallel for num_threads(oo->currentState->numThreads)
710                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
711                         int quad[maxDims];
712                         decodeLocation(qx, maxDims, quadGridSize, quad);
713                         double gsubtotal[numParam];
714                         OMXZERO(gsubtotal, numParam);
715
716                         long specificPoints = quadGridSize[sDim];
717                         for (long sx=0; sx < specificPoints; sx++) {
718                                 double gq[numParam];
719                                 quad[sDim] = sx;
720                                 ba81ItemGradientOrdinate(oo, state, maxDims, quad, item, spec, id,
721                                                          outcomes, iparam, numParam, paramMask, gq);
722                                 for (int gx=0; gx < numParam; gx++) {
723                                         gsubtotal[gx] += gq[gx];
724                                 }
725                         }
726 #pragma omp critical(GradientUpdate)
727                         for (int gx=0; gx < numParam; gx++) {
728                                 gradient[gx] += gsubtotal[gx];
729                         }
730                 }
731         }
732
733         (*rpf_model[id].gradient)(spec, iparam, paramMask, NULL, NULL, gradient);
734
735         for (int px=0; px < numParam; px++) {
736                 int loc = paramMask[px];
737                 if (loc == -1) continue;
738
739                 // Need to check because this can happen if
740                 // lbounds/ubounds are not set appropriately.
741                 if (!isnormal(gradient[px])) {
742                         Rprintf("item spec:\n");
743                         pda(spec, (*rpf_model[id].numSpec)(spec), 1);
744                         Rprintf("item parameters:\n");
745                         pda(iparam, numParam, 1);
746                         error("Gradient for item %d param %d is %f; are you missing a lbound/ubound?",
747                               item, px, gradient[px]);
748                 }
749
750                 out[loc] = -2 * gradient[px];
751         }
752 }
753
754 void ba81Gradient(omxExpectation* oo, double *out)
755 {
756         omxState* currentState = oo->currentState;
757         int numFreeParams = currentState->numFreeParams;
758         omxBA81State *state = (omxBA81State *) oo->argStruct;
759         if (!state->paramMap) buildParamMap(oo);
760         ++state->gradientCount;
761         omxMatrix *itemSpec = state->itemSpec;
762         omxMatrix *itemParam = state->itemParam;
763
764         int vx = 0;
765         while (vx < numFreeParams) {
766             omxFreeVar *fv = currentState->freeVarList + state->paramMap[vx];
767             int vloc = findFreeVarLocation(itemParam, fv);
768             if (vloc < 0) {
769                     ++vx;
770                     continue;
771             }
772
773             int item = fv->col[vloc];
774             const double *spec = omxMatrixColumn(itemSpec, item);
775             int id = spec[RPF_ISpecID];
776             int outcomes = spec[RPF_ISpecOutcomes];
777             int numParam = (*rpf_model[id].numParam)(spec);
778
779             int paramMask[numParam];
780             for (int px=0; px < numParam; px++) { paramMask[px] = -1; }
781
782             if (fv->row[vloc] >= numParam) {
783                     warning("Item %d has too many free parameters", item);
784                     continue;
785             }
786             paramMask[fv->row[vloc]] = vx;
787
788             while (++vx < numFreeParams) {
789                     omxFreeVar *fv = currentState->freeVarList + state->paramMap[vx];
790                     int vloc = findFreeVarLocation(itemParam, fv);
791                     if (fv->col[vloc] != item) break;
792                     if (fv->row[vloc] >= numParam) {
793                             warning("Item %d has too many free parameters", item);
794                             continue;
795                     }
796                     paramMask[fv->row[vloc]] = vx;
797             }
798
799             ba81ItemGradient(oo, state, spec, itemParam, item,
800                              id, outcomes, numParam, paramMask, out);
801         }
802 }
803
804 static int
805 getNumThreads(omxExpectation* oo)
806 {
807         int numThreads = oo->currentState->numThreads;
808         if (numThreads < 1) numThreads = 1;
809         return numThreads;
810 }
811
812 static void
813 ba81SetupQuadrature(omxExpectation* oo, int numPoints, double *points, double *area)
814 {
815         omxBA81State *state = (omxBA81State *) oo->argStruct;
816         int numUnique = state->numUnique;
817         int numThreads = getNumThreads(oo);
818
819         state->numQpoints = numPoints;
820
821         Free(state->Qpoint);
822         Free(state->Qarea);
823         state->Qpoint = Realloc(NULL, numPoints, double);
824         state->Qarea = Realloc(NULL, numPoints, double);
825         memcpy(state->Qpoint, points, sizeof(double)*numPoints);
826         memcpy(state->Qarea, area, sizeof(double)*numPoints);
827
828         Free(state->logQarea);
829
830         state->logQarea = Realloc(NULL, state->numQpoints, double);
831         for (int px=0; px < state->numQpoints; px++) {
832                 state->logQarea[px] = log(state->Qarea[px]);
833         }
834
835         state->totalQuadPoints = 1;
836         state->totalPrimaryPoints = 1;
837         state->quadGridSize = (long*) R_alloc(state->maxDims, sizeof(long));
838         for (int dx=0; dx < state->maxDims; dx++) {
839                 state->quadGridSize[dx] = state->numQpoints;
840                 state->totalQuadPoints *= state->quadGridSize[dx];
841                 if (dx < state->maxDims-1) {
842                         state->totalPrimaryPoints *= state->quadGridSize[dx];
843                 }
844         }
845
846         Free(state->lxk);
847
848         if (!state->cacheLXK) {
849                 state->lxk = Realloc(NULL, numUnique * numThreads, double);
850         } else {
851                 int ns = state->numSpecific;
852                 if (ns == 0) ns = 1;
853                 state->lxk = Realloc(NULL, numUnique * state->totalQuadPoints * ns, double);
854         }
855 }
856
857 static void
858 ba81EAP1(omxExpectation *oo, double *workspace, long qx, int maxDims, int numUnique,
859          double *ability, double *cov, double *spstats)
860 {
861         omxBA81State *state = (omxBA81State *) oo->argStruct;
862         double *patternLik = state->patternLik;
863         int quad[maxDims];
864         decodeLocation(qx, maxDims, state->quadGridSize, quad);
865         double where[maxDims];
866         pointToWhere(state, quad, where, maxDims);
867         double logArea = logAreaProduct(state, quad, maxDims);
868         double *lxk = ba81LikelihoodFast(oo, 0, quad);
869         double *myspace = workspace + 2 * maxDims * numUnique * omx_absolute_thread_num();
870
871         for (int px=0; px < numUnique; px++) {
872                 double *piece = myspace + px * 2 * maxDims;
873                 double plik = exp(lxk[px] + logArea - patternLik[px]);
874                 for (int dx=0; dx < maxDims; dx++) {
875                         piece[dx] = where[dx] * plik;
876                 }
877                 /*
878                 for (int d1=0; d1 < maxDims; d1++) {
879                         for (int d2=0; d2 <= d1; d2++) {
880                                 covPiece[d1 * maxDims + d2] += piece[d1] * piece[d2];
881                         }
882                 }
883                 */
884         }
885 #pragma omp critical(EAP1Update)
886         for (int px=0; px < numUnique; px++) {
887                 double *piece = myspace + px * 2 * maxDims;
888                 double *arow = ability + px * 2 * maxDims;
889                 for (int dx=0; dx < maxDims; dx++) {
890                         arow[dx*2] += piece[dx];
891                 }
892                 /*
893                 for (int d1=0; d1 < maxDims; d1++) {
894                         for (int d2=0; d2 <= d1; d2++) {
895                                 int loc = d1 * maxDims + d2;
896                                 cov[loc] += covPiece[loc];
897                         }
898                 }
899                 */
900         }
901 }
902
903 static void
904 ba81EAP2(omxExpectation *oo, double *workspace, long qx, int maxDims, int numUnique,
905          double *ability, double *spstats)
906 {
907         omxBA81State *state = (omxBA81State *) oo->argStruct;
908         double *patternLik = state->patternLik;
909         int quad[maxDims];
910         decodeLocation(qx, maxDims, state->quadGridSize, quad);
911         double where[maxDims];
912         pointToWhere(state, quad, where, maxDims);
913         double logArea = logAreaProduct(state, quad, maxDims);
914         double *lxk = ba81LikelihoodFast(oo, 0, quad);
915
916         for (int px=0; px < numUnique; px++) {
917                 double psd[maxDims];
918                 double *arow = ability + px * 2 * maxDims;
919                 for (int dx=0; dx < maxDims; dx++) {
920                         // is this just sqrt(variance) and redundant with the covariance matrix? TODO
921                         double ldiff = log(fabs(where[dx] - arow[dx*2]));
922                         psd[dx] = exp(2 * ldiff + lxk[px] + logArea - patternLik[px]);
923                 }
924 #pragma omp critical(EAP1Update)
925                 for (int dx=0; dx < maxDims; dx++) {
926                         arow[dx*2+1] += psd[dx];
927                 }
928         }
929 }
930
931 /**
932  * MAP is not affected by the number of items. EAP is. Likelihood can
933  * get concentrated in a single quadrature ordinate. For 3PL, response
934  * patterns can have a bimodal likelihood. This will confuse MAP and
935  * is a key advantage of EAP (Thissen & Orlando, 2001, p. 136).
936  *
937  * Thissen, D. & Orlando, M. (2001). IRT for items scored in two
938  * categories. In D. Thissen & H. Wainer (Eds.), \emph{Test scoring}
939  * (pp 73-140). Lawrence Erlbaum Associates, Inc.
940  */
941 omxRListElement *
942 ba81EAP(omxExpectation *oo, int *numReturns)
943 {
944         omxBA81State *state = (omxBA81State *) oo->argStruct;
945         int maxDims = state->maxDims;
946         //int numSpecific = state->numSpecific;
947
948         *numReturns = 2; // + (maxDims > 1) + (numSpecific > 1);
949         omxRListElement *out = (omxRListElement*) R_alloc(*numReturns, sizeof(omxRListElement));
950
951         out[0].numValues = 1;
952         out[0].values = (double*) R_alloc(1, sizeof(double));
953         strcpy(out[0].label, "Minus2LogLikelihood");
954         out[0].values[0] = state->ll;
955
956         omxData *data = state->data;
957         int numUnique = state->numUnique;
958
959         // TODO Wainer & Thissen. (1987). Estimating ability with the wrong
960         // model. Journal of Educational Statistics, 12, 339-368.
961
962         int numQpoints = state->numQpoints * 2;  // make configurable TODO
963
964         if (numQpoints < 1 + 2.0 * sqrt(state->itemSpec->cols)) {
965                 // Thissen & Orlando (2001, p. 136)
966                 warning("EAP requires at least 2*sqrt(items) quadrature points");
967         }
968
969         double Qpoint[numQpoints];
970         double Qarea[numQpoints];
971         const double Qwidth = 4;
972         for (int qx=0; qx < numQpoints; qx++) {
973                 Qpoint[qx] = Qwidth - qx * Qwidth*2 / (numQpoints-1);
974                 Qarea[qx] = 1.0/numQpoints;
975         }
976         ba81SetupQuadrature(oo, numQpoints, Qpoint, Qarea);
977         ba81Estep(oo);   // recalc patternLik with a flat prior
978
979         double *cov = NULL;
980         /*
981         if (maxDims > 1) {
982                 strcpy(out[2].label, "ability.cov");
983                 out[2].numValues = -1;
984                 out[2].rows = maxDims;
985                 out[2].cols = maxDims;
986                 out[2].values = (double*) R_alloc(out[2].rows * out[2].cols, sizeof(double));
987                 cov = out[2].values;
988                 OMXZERO(cov, out[2].rows * out[2].cols);
989         }
990         */
991         double *spstats = NULL;
992         /*
993         if (numSpecific) {
994                 strcpy(out[3].label, "specific");
995                 out[3].numValues = -1;
996                 out[3].rows = numSpecific;
997                 out[3].cols = 2;
998                 out[3].values = (double*) R_alloc(out[3].rows * out[3].cols, sizeof(double));
999                 spstats = out[3].values;
1000         }
1001         */
1002
1003         // allocation of workspace could be optional
1004         int numThreads = getNumThreads(oo);
1005         double *workspace = Realloc(NULL, numUnique * maxDims * 2 * numThreads, double);
1006
1007         // Need a separate work space because the destination needs
1008         // to be in unsorted order with duplicated rows.
1009         double *ability = Calloc(numUnique * maxDims * 2, double);
1010
1011 #pragma omp parallel for num_threads(oo->currentState->numThreads)
1012         for (long qx=0; qx < state->totalQuadPoints; qx++) {
1013                 ba81EAP1(oo, workspace, qx, maxDims, numUnique, ability, cov, spstats);
1014         }
1015
1016         /*
1017         // make symmetric
1018         for (int d1=0; d1 < maxDims; d1++) {
1019                 for (int d2=0; d2 < d1; d2++) {
1020                         cov[d2 * maxDims + d1] = cov[d1 * maxDims + d2];
1021                 }
1022         }
1023         */
1024
1025 #pragma omp parallel for num_threads(oo->currentState->numThreads)
1026         for (long qx=0; qx < state->totalQuadPoints; qx++) {
1027                 ba81EAP2(oo, workspace, qx, maxDims, numUnique, ability, spstats);
1028         }
1029
1030         for (int px=0; px < numUnique; px++) {
1031                 double *arow = ability + px * 2 * maxDims;
1032                 for (int dx=0; dx < maxDims; dx++) {
1033                         arow[dx*2+1] = sqrt(arow[dx*2+1]);
1034                 }
1035         }
1036
1037         strcpy(out[1].label, "ability");
1038         out[1].numValues = -1;
1039         out[1].rows = data->rows;
1040         out[1].cols = 2 * maxDims;
1041         out[1].values = (double*) R_alloc(out[1].rows * out[1].cols, sizeof(double));
1042
1043         for (int rx=0; rx < numUnique; rx++) {
1044                 double *pa = ability + rx * 2 * maxDims;
1045
1046                 int dups = omxDataNumIdenticalRows(state->data, state->rowMap[rx]);
1047                 for (int dup=0; dup < dups; dup++) {
1048                         int dest = omxDataIndex(data, state->rowMap[rx]+dup);
1049                         int col=0;
1050                         for (int dx=0; dx < maxDims; dx++) {
1051                                 out[1].values[col * out[1].rows + dest] = pa[col]; ++col;
1052                                 out[1].values[col * out[1].rows + dest] = pa[col]; ++col;
1053                         }
1054                 }
1055         }
1056         Free(ability);
1057         Free(workspace);
1058         return out;
1059 }
1060
1061 static void ba81Destroy(omxExpectation *oo) {
1062         if(OMX_DEBUG) {
1063                 Rprintf("Freeing %s function.\n", NAME);
1064         }
1065         omxBA81State *state = (omxBA81State *) oo->argStruct;
1066         Rprintf("fit %d gradient %d\n", state->fitCount, state->gradientCount);
1067         omxFreeAllMatrixData(state->itemSpec);
1068         omxFreeAllMatrixData(state->itemParam);
1069         omxFreeAllMatrixData(state->EitemParam);
1070         omxFreeAllMatrixData(state->design);
1071         omxFreeAllMatrixData(state->customPrior);
1072         Free(state->logNumIdentical);
1073         Free(state->Qpoint);
1074         Free(state->Qarea);
1075         Free(state->logQarea);
1076         Free(state->rowMap);
1077         Free(state->patternLik);
1078         Free(state->lxk);
1079         Free(state->Slxk);
1080         Free(state->allSlxk);
1081         Free(state->Sgroup);
1082         Free(state->paramMap);
1083         Free(state);
1084 }
1085
1086 void omxInitExpectationBA81(omxExpectation* oo) {
1087         omxState* currentState = oo->currentState;      
1088         SEXP rObj = oo->rObj;
1089         SEXP tmp;
1090         
1091         if(OMX_DEBUG) {
1092                 Rprintf("Initializing %s.\n", NAME);
1093         }
1094         if (!rpf_model) {
1095                 get_librpf_t get_librpf = (get_librpf_t) R_GetCCallable("rpf", "get_librpf_model");
1096                 (*get_librpf)(&rpf_numModels, &rpf_model);
1097         }
1098         
1099         omxBA81State *state = Calloc(1, omxBA81State);
1100         oo->argStruct = (void*) state;
1101
1102         state->ll = 1e20;   // finite but big
1103         
1104         PROTECT(tmp = GET_SLOT(rObj, install("data")));
1105         state->data = omxNewDataFromMxDataPtr(tmp, currentState);
1106         UNPROTECT(1);
1107
1108         if (strcmp(omxDataType(state->data), "raw") != 0) {
1109                 omxRaiseErrorf(currentState, "%s unable to handle data type %s", NAME, omxDataType(state->data));
1110                 return;
1111         }
1112
1113         PROTECT(state->rpf = GET_SLOT(rObj, install("RPF")));
1114         if (state->rpf == R_NilValue) {
1115                 state->computeRPF = standardComputeRPF;
1116         } else {
1117                 state->computeRPF = RComputeRPF;
1118         }
1119
1120         state->itemSpec =
1121                 omxNewMatrixFromIndexSlot(rObj, currentState, "ItemSpec");
1122         state->design =
1123                 omxNewMatrixFromIndexSlot(rObj, currentState, "Design");
1124         state->itemParam =
1125                 omxNewMatrixFromIndexSlot(rObj, currentState, "ItemParam");
1126         state->EitemParam =
1127                 omxInitTemporaryMatrix(NULL, state->itemParam->rows, state->itemParam->cols,
1128                                        TRUE, currentState);
1129         state->customPrior =
1130                 omxNewMatrixFromIndexSlot(rObj, currentState, "CustomPrior");
1131         
1132         oo->computeFun = ba81Estep;
1133         oo->destructFun = ba81Destroy;
1134         
1135         // TODO: Exactly identical rows do not contribute any information.
1136         // The sorting algorithm ought to remove them so we don't waste RAM.
1137         // The following summary stats would be cheaper to calculate too.
1138
1139         int numUnique = 0;
1140         omxData *data = state->data;
1141         if (omxDataNumFactor(data) != data->cols) {
1142                 // verify they are ordered factors TODO
1143                 omxRaiseErrorf(currentState, "%s: all columns must be factors", NAME);
1144                 return;
1145         }
1146
1147         for (int rx=0; rx < data->rows;) {
1148                 rx += omxDataNumIdenticalRows(state->data, rx);
1149                 ++numUnique;
1150         }
1151         state->numUnique = numUnique;
1152
1153         state->rowMap = Realloc(NULL, numUnique, int);
1154         state->logNumIdentical = Realloc(NULL, numUnique, double);
1155
1156         int numItems = state->itemParam->cols;
1157
1158         for (int rx=0, ux=0; rx < data->rows; ux++) {
1159                 if (rx == 0) {
1160                         // all NA rows will sort to the top
1161                         int na=0;
1162                         for (int ix=0; ix < numItems; ix++) {
1163                                 if (omxIntDataElement(data, 0, ix) == NA_INTEGER) { ++na; }
1164                         }
1165                         if (na == numItems) {
1166                                 omxRaiseErrorf(currentState, "Remove rows with all NAs");
1167                                 return;
1168                         }
1169                 }
1170                 int dups = omxDataNumIdenticalRows(state->data, rx);
1171                 state->logNumIdentical[ux] = log(dups);
1172                 state->rowMap[ux] = rx;
1173                 rx += dups;
1174         }
1175
1176         state->patternLik = Realloc(NULL, numUnique, double);
1177
1178         int numThreads = getNumThreads(oo);
1179
1180         int maxSpec = 0;
1181         int maxParam = 0;
1182         state->maxDims = 0;
1183         state->maxOutcomes = 0;
1184
1185         for (int cx = 0; cx < data->cols; cx++) {
1186                 const double *spec = omxMatrixColumn(state->itemSpec, cx);
1187                 int id = spec[RPF_ISpecID];
1188                 if (id < 0 || id >= rpf_numModels) {
1189                         omxRaiseErrorf(currentState, "ItemSpec column %d has unknown item model %d", cx, id);
1190                         return;
1191                 }
1192
1193                 int dims = spec[RPF_ISpecDims];
1194                 if (state->maxDims < dims)
1195                         state->maxDims = dims;
1196
1197                 // TODO verify that item model can have requested number of outcomes
1198                 int no = spec[RPF_ISpecOutcomes];
1199                 if (state->maxOutcomes < no)
1200                         state->maxOutcomes = no;
1201
1202                 int numSpec = (*rpf_model[id].numSpec)(spec);
1203                 if (maxSpec < numSpec)
1204                         maxSpec = numSpec;
1205
1206                 int numParam = (*rpf_model[id].numParam)(spec);
1207                 if (maxParam < numParam)
1208                         maxParam = numParam;
1209         }
1210
1211         if (state->itemSpec->cols != data->cols || state->itemSpec->rows != maxSpec) {
1212                 omxRaiseErrorf(currentState, "ItemSpec must have %d item columns and %d rows",
1213                                data->cols, maxSpec);
1214                 return;
1215         }
1216         if (state->itemParam->rows != maxParam) {
1217                 omxRaiseErrorf(currentState, "ItemParam should have %d rows", maxParam);
1218                 return;
1219         }
1220
1221         if (state->design == NULL) {
1222                 state->maxAbilities = state->maxDims;
1223                 state->design = omxInitTemporaryMatrix(NULL, state->maxDims, numItems,
1224                                        TRUE, currentState);
1225                 for (int ix=0; ix < numItems; ix++) {
1226                         for (int dx=0; dx < state->maxDims; dx++) {
1227                                 omxSetMatrixElement(state->design, dx, ix, (double)dx+1);
1228                         }
1229                 }
1230         } else {
1231                 omxMatrix *design = state->design;
1232                 if (design->cols != numItems ||
1233                     design->rows != state->maxDims) {
1234                         omxRaiseErrorf(currentState, "Design matrix should have %d rows and %d columns",
1235                                        state->maxDims, numItems);
1236                         return;
1237                 }
1238
1239                 state->maxAbilities = 0;
1240                 for (int ix=0; ix < design->rows * design->cols; ix++) {
1241                         double got = design->data[ix];
1242                         if (!R_FINITE(got)) continue;
1243                         if (round(got) != got) error("Design matrix can only contain integers"); // TODO better way?
1244                         if (state->maxAbilities < got)
1245                                 state->maxAbilities = got;
1246                 }
1247         }
1248         if (state->maxAbilities <= state->maxDims) {
1249                 state->Sgroup = Calloc(numItems, int);
1250         } else {
1251                 // Not sure if this is correct, revisit TODO
1252                 int Sgroup0 = -1;
1253                 state->Sgroup = Realloc(NULL, numItems, int);
1254                 for (int dx=0; dx < state->maxDims; dx++) {
1255                         for (int ix=0; ix < numItems; ix++) {
1256                                 int ability = omxMatrixElement(state->design, dx, ix);
1257                                 if (dx < state->maxDims - 1) {
1258                                         if (Sgroup0 <= ability)
1259                                                 Sgroup0 = ability+1;
1260                                         continue;
1261                                 }
1262                                 int ss=-1;
1263                                 if (ability >= Sgroup0) {
1264                                         if (ss == -1) {
1265                                                 ss = ability;
1266                                         } else {
1267                                                 omxRaiseErrorf(currentState, "Item %d cannot belong to more than "
1268                                                                "1 specific dimension (both %d and %d)",
1269                                                                ix, ss, ability);
1270                                                 return;
1271                                         }
1272                                 }
1273                                 if (ss == -1) ss = Sgroup0;
1274                                 state->Sgroup[ix] = ss - Sgroup0;
1275                         }
1276                 }
1277                 state->numSpecific = state->maxAbilities - state->maxDims + 1;
1278                 state->allSlxk = Realloc(NULL, numUnique * numThreads, double);
1279                 state->Slxk = Realloc(NULL, numUnique * state->numSpecific * numThreads, double);
1280         }
1281
1282         PROTECT(tmp = GET_SLOT(rObj, install("cache")));
1283         state->cacheLXK = asLogical(tmp);
1284
1285         PROTECT(tmp = GET_SLOT(rObj, install("GHpoints")));
1286         double *qpoints = REAL(tmp);
1287         int numQPoints = length(tmp);
1288
1289         PROTECT(tmp = GET_SLOT(rObj, install("GHarea")));
1290         double *qarea = REAL(tmp);
1291         if (numQPoints != length(tmp)) error("length(GHpoints) != length(GHarea)");
1292
1293         ba81SetupQuadrature(oo, numQPoints, qpoints, qarea);
1294
1295         // verify data bounded between 1 and numOutcomes TODO
1296         // hm, looks like something could be added to omxData for column summary stats?
1297 }