ifa: EAP reorg
[openmx:openmx.git] / src / omxExpectationBA81.c
1 /*
2   Copyright 2012 Joshua Nathaniel Pritikin and contributors
3
4   This is free software: you can redistribute it and/or modify
5   it under the terms of the GNU General Public License as published by
6   the Free Software Foundation, either version 3 of the License, or
7   (at your option) any later version.
8
9   This program is distributed in the hope that it will be useful,
10   but WITHOUT ANY WARRANTY; without even the implied warranty of
11   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12   GNU General Public License for more details.
13
14   You should have received a copy of the GNU General Public License
15   along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 // Consider replacing log() with log2() in some places? Not worth it?
19
20 #include "omxExpectation.h"
21 #include "omxOpenmpWrap.h"
22 #include "npsolWrap.h"
23 #include "libirt-rpf.h"
24 #include "merge.h"
25
26 static const char *NAME = "ExpectationBA81";
27
28 typedef double *(*rpf_fn_t)(omxExpectation *oo, omxMatrix *itemParam, const int *quad);
29
30 typedef int (*rpf_numParam_t)(const int numDims, const int numOutcomes);
31 // TODO arguments ought to be in the same order
32 typedef void (*rpf_logprob_t)(const int numDims, const double *restrict param,
33                               const double *restrict th,
34                               const int numOutcomes, double *restrict out);
35 typedef double (*rpf_prior_t)(const int numDims, const int numOutcomes,
36                               const double *restrict param);
37
38 typedef void (*rpf_gradient_t)(const int numDims, const int numOutcomes,
39                                const double *restrict param, const int *paramMask,
40                                const double *where, const double *weight, double *out);
41
42 struct rpf {
43         const char name[8];
44         rpf_numParam_t numParam;
45         rpf_logprob_t logprob;
46         rpf_prior_t prior;
47         rpf_gradient_t gradient;
48 };
49
50 // configuration of priors, probably via itemSpec TODO
51
52 static const struct rpf rpf_table[] = {
53         { "drm1",
54           irt_rpf_1dim_drm_numParam,
55           irt_rpf_1dim_drm_logprob,
56           irt_rpf_1dim_drm_prior,
57           irt_rpf_1dim_drm_gradient,
58         },
59         { "drm",
60           irt_rpf_mdim_drm_numParam,
61           irt_rpf_mdim_drm_logprob,
62           irt_rpf_mdim_drm_prior,
63           irt_rpf_mdim_drm_gradient,
64         },
65         { "gpcm1",
66           irt_rpf_1dim_gpcm_numParam,
67           irt_rpf_1dim_gpcm_logprob,
68           irt_rpf_1dim_gpcm_prior,
69           irt_rpf_1dim_gpcm_gradient,
70         }
71 };
72 static const int numStandardRPF = (sizeof(rpf_table) / sizeof(struct rpf));
73
74 typedef struct {
75
76         // data characteristics
77         omxData *data;
78         int numUnique;
79         double *logNumIdentical;  // length numUnique
80         int *rowMap;              // length numUnique
81
82         // item description related
83         omxMatrix *itemSpec;
84         int maxOutcomes;
85         int maxDims;
86         int maxAbilities;
87         int numSpecific;
88         int *Sgroup;              // item's specific group 0..numSpecific-1
89         omxMatrix *design;        // items * maxDims
90
91         // quadrature related
92         int numQpoints;
93         double *Qpoint;
94         double *Qarea;
95         double *logQarea;
96         long *quadGridSize;       // maxDims
97         long totalPrimaryPoints;  // product of quadGridSize except specific dim
98         long totalQuadPoints;     // product of quadGridSize
99
100         // estimation related
101         omxMatrix *EitemParam;    // E step version
102         omxMatrix *itemParam;     // M step version
103         SEXP rpf;
104         rpf_fn_t computeRPF;
105         omxMatrix *customPrior;
106         int *paramMap;
107         int cacheLXK;             // w/cache,  numUnique * #specific quad points * totalQuadPoints
108         double *lxk;              // wo/cache, numUnique * thread
109         double *allSlxk;          // numUnique * thread
110         double *Slxk;             // numUnique * #specific dimensions * thread
111         double *patternLik;       // length numUnique
112         double ll;                // the most recent finite ll
113
114         int gradientCount;
115         int fitCount;
116 } omxBA81State;
117
118 enum ISpecRow {
119         ISpecID,
120         ISpecOutcomes,
121         ISpecDims,
122         ISpecRowCount
123 };
124
125 /*
126 static void
127 pda(const double *ar, int rows, int cols) {
128         for (int rx=0; rx < rows; rx++) {
129                 for (int cx=0; cx < cols; cx++) {
130                         Rprintf("%.6g ", ar[cx * rows + rx]);
131                 }
132                 Rprintf("\n");
133         }
134
135 }
136 */
137
138 static int
139 findFreeVarLocation(omxMatrix *itemParam, const omxFreeVar *fv)
140 {
141         for (int lx=0; lx < fv->numLocations; lx++) {
142                 if (~fv->matrices[lx] == itemParam->matrixNumber) return lx;
143         }
144         return -1;
145 }
146
147 static int
148 compareFV(const int *fv1x, const int *fv2x, omxExpectation* oo)
149 {
150         omxState* currentState = oo->currentState;
151         omxBA81State *state = (omxBA81State *) oo->argStruct;
152         omxMatrix *itemParam = state->itemParam;
153         omxFreeVar *fv1 = currentState->freeVarList + *fv1x;
154         omxFreeVar *fv2 = currentState->freeVarList + *fv2x;
155         int l1 = findFreeVarLocation(itemParam, fv1);
156         int l2 = findFreeVarLocation(itemParam, fv2);
157         if (l1 == -1 && l2 == -1) return 0;
158         if ((l1 == -1) ^ (l2 == -1)) return l1 == -1? 1:-1;  // TODO reversed?
159         // Columns are items. Sort columns together
160         return fv1->col[l1] - fv2->col[l1];
161 }
162
163 static void buildParamMap(omxExpectation* oo)
164 {
165         omxState* currentState = oo->currentState;
166         omxBA81State *state = (omxBA81State *) oo->argStruct;
167         int numFreeParams = currentState->numFreeParams;
168         state->paramMap = Realloc(NULL, numFreeParams, int);
169         for (int px=0; px < numFreeParams; px++) { state->paramMap[px] = px; }
170         freebsd_mergesort(state->paramMap, numFreeParams, sizeof(int),
171                           (mergesort_cmp_t)compareFV, oo);
172 }
173
174 OMXINLINE static void
175 pointToWhere(omxBA81State *state, const int *quad, double *where, int upto)
176 {
177         for (int dx=0; dx < upto; dx++) {
178                 where[dx] = state->Qpoint[quad[dx]];
179         }
180 }
181
182 OMXINLINE static void
183 assignDims(omxMatrix *itemSpec, omxMatrix *design, int dims, int maxDims, int ix,
184            const double *restrict theta, double *restrict ptheta)
185 {
186         for (int dx=0; dx < dims; dx++) {
187                 int ability = (int)omxMatrixElement(design, dx, ix) - 1;
188                 if (ability >= maxDims) ability = maxDims-1;
189                 ptheta[dx] = theta[ability];
190         }
191 }
192
193 /**
194  * This is the main function needed to generate simulated data from
195  * the model. It could be argued that the rest of the estimation
196  * machinery belongs in the fitfunction.
197  *
198  * \param theta Vector of ability parameters, one per ability
199  * \returns A numItems by maxOutcomes colMajor vector of doubles. Caller must Free it.
200  */
201 static double *
202 standardComputeRPF(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
203 {
204         omxBA81State *state = (omxBA81State*) oo->argStruct;
205         omxMatrix *itemSpec = state->itemSpec;
206         int numItems = itemSpec->cols;
207         omxMatrix *design = state->design;
208         int maxDims = state->maxDims;
209
210         double theta[maxDims];
211         pointToWhere(state, quad, theta, maxDims);
212
213         double *outcomeProb = Realloc(NULL, numItems * state->maxOutcomes, double);
214
215         for (int ix=0; ix < numItems; ix++) {
216                 int outcomes = omxMatrixElement(itemSpec, ISpecOutcomes, ix);
217                 double *iparam = omxMatrixColumn(itemParam, ix);
218                 double *out = outcomeProb + ix * state->maxOutcomes;
219                 int id = omxMatrixElement(itemSpec, ISpecID, ix);
220                 int dims = omxMatrixElement(itemSpec, ISpecDims, ix);
221                 double ptheta[dims];
222                 assignDims(itemSpec, design, dims, maxDims, ix, theta, ptheta);
223                 (*rpf_table[id].logprob)(dims, iparam, ptheta, outcomes, out);
224         }
225
226         return outcomeProb;
227 }
228
229 static double *
230 RComputeRPF1(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
231 {
232         omxBA81State *state = (omxBA81State*) oo->argStruct;
233         int maxOutcomes = state->maxOutcomes;
234         omxMatrix *design = state->design;
235         omxMatrix *itemSpec = state->itemSpec;
236         int maxDims = state->maxDims;
237
238         double theta[maxDims];
239         pointToWhere(state, quad, theta, maxDims);
240
241         SEXP invoke;
242         PROTECT(invoke = allocVector(LANGSXP, 4));
243         SETCAR(invoke, state->rpf);
244         SETCADR(invoke, omxExportMatrix(itemParam));
245         SETCADDR(invoke, omxExportMatrix(itemSpec));
246
247         SEXP where;
248         PROTECT(where = allocMatrix(REALSXP, maxDims, itemParam->cols));
249         double *ptheta = REAL(where);
250         for (int ix=0; ix < itemParam->cols; ix++) {
251                 int dims = omxMatrixElement(itemSpec, ISpecDims, ix);
252                 assignDims(itemSpec, design, dims, maxDims, ix, theta, ptheta + ix*maxDims);
253                 for (int dx=dims; dx < maxDims; dx++) {
254                         ptheta[ix*maxDims + dx] = NA_REAL;
255                 }
256         }
257         SETCADDDR(invoke, where);
258
259         SEXP matrix;
260         PROTECT(matrix = eval(invoke, R_GlobalEnv));
261
262         if (!isMatrix(matrix)) {
263                 omxRaiseError(oo->currentState, -1,
264                               "RPF must return an item by outcome matrix");
265                 return NULL;
266         }
267
268         SEXP matrixDims;
269         PROTECT(matrixDims = getAttrib(matrix, R_DimSymbol));
270         int *dimList = INTEGER(matrixDims);
271         int numItems = state->itemSpec->cols;
272         if (dimList[0] != maxOutcomes || dimList[1] != numItems) {
273                 const int errlen = 200;
274                 char errstr[errlen];
275                 snprintf(errstr, errlen, "RPF must return a %d outcomes by %d items matrix",
276                          maxOutcomes, numItems);
277                 omxRaiseError(oo->currentState, -1, errstr);
278                 return NULL;
279         }
280
281         // Unlikely to be of type INTSXP, but just to be safe
282         PROTECT(matrix = coerceVector(matrix, REALSXP));
283         double *restrict got = REAL(matrix);
284
285         // Need to copy because threads cannot share SEXP
286         double *restrict outcomeProb = Realloc(NULL, numItems * maxOutcomes, double);
287
288         // Double check there aren't NAs in the wrong place
289         for (int ix=0; ix < numItems; ix++) {
290                 int numOutcomes = omxMatrixElement(state->itemSpec, ISpecOutcomes, ix);
291                 for (int ox=0; ox < numOutcomes; ox++) {
292                         int vx = ix * maxOutcomes + ox;
293                         if (isnan(got[vx])) {
294                                 const int errlen = 200;
295                                 char errstr[errlen];
296                                 snprintf(errstr, errlen, "RPF returned NA in [%d,%d]", ox,ix);
297                                 omxRaiseError(oo->currentState, -1, errstr);
298                         }
299                         outcomeProb[vx] = got[vx];
300                 }
301         }
302
303         return outcomeProb;
304 }
305
306 static double *
307 RComputeRPF(omxExpectation *oo, omxMatrix *itemParam, const int *quad)
308 {
309         omx_omp_set_lock(&GlobalRLock);
310         PROTECT_INDEX pi = omxProtectSave();
311         double *ret = RComputeRPF1(oo, itemParam, quad);
312         omxProtectRestore(pi);
313         omx_omp_unset_lock(&GlobalRLock);  // hope there was no exception!
314         return ret;
315 }
316
317 OMXINLINE static long
318 encodeLocation(const int dims, const long *restrict grid, const int *restrict quad)
319 {
320         long qx = 0;
321         for (int dx=dims-1; dx >= 0; dx--) {
322                 qx = qx * grid[dx];
323                 qx += quad[dx];
324         }
325         return qx;
326 }
327
328 #define CALC_LXK_CACHED(state, numUnique, quad, tqp, specific) \
329         ((state)->lxk + \
330          (numUnique) * encodeLocation((state)->maxDims, (state)->quadGridSize, quad) + \
331          (numUnique) * (tqp) * (specific))
332
333 OMXINLINE static double *
334 ba81Likelihood(omxExpectation *oo, int specific, const int *restrict quad)
335 {
336         omxBA81State *state = (omxBA81State*) oo->argStruct;
337         int numUnique = state->numUnique;
338         int maxOutcomes = state->maxOutcomes;
339         omxData *data = state->data;
340         int numItems = state->itemSpec->cols;
341         rpf_fn_t rpf_fn = state->computeRPF;
342         int *restrict Sgroup = state->Sgroup;
343         double *restrict lxk;
344
345         if (!state->cacheLXK) {
346                 lxk = state->lxk + numUnique * omx_absolute_thread_num();
347         } else {
348                 lxk = CALC_LXK_CACHED(state, numUnique, quad, state->totalQuadPoints, specific);
349         }
350
351         const double *outcomeProb = (*rpf_fn)(oo, state->EitemParam, quad);
352         if (!outcomeProb) {
353                 OMXZERO(lxk, numUnique);
354                 return lxk;
355         }
356
357         const int *rowMap = state->rowMap;
358         for (int px=0; px < numUnique; px++) {
359                 double lxk1 = 0;
360                 for (int ix=0; ix < numItems; ix++) {
361                         if (specific != Sgroup[ix]) continue;
362                         int pick = omxIntDataElementUnsafe(data, rowMap[px], ix);
363                         if (pick == NA_INTEGER) continue;
364                         lxk1 += outcomeProb[ix * maxOutcomes + pick-1];
365                 }
366                 lxk[px] = lxk1;
367         }
368
369         Free(outcomeProb);
370
371         return lxk;
372 }
373
374 OMXINLINE static double *
375 ba81LikelihoodFast(omxExpectation *oo, int specific, const int *restrict quad)
376 {
377         omxBA81State *state = (omxBA81State*) oo->argStruct;
378         if (!state->cacheLXK) {
379                 return ba81LikelihoodFast(oo, specific, quad);
380         } else {
381                 return CALC_LXK_CACHED(state, state->numUnique, quad, state->totalQuadPoints, specific);
382         }
383
384 }
385
386 #define CALC_ALLSLXK(state, numUnique) \
387         (state->allSlxk + omx_absolute_thread_num() * (numUnique))
388
389 #define CALC_SLXK(state, numUnique, numSpecific) \
390         (state->Slxk + omx_absolute_thread_num() * (numUnique) * (numSpecific))
391
392 OMXINLINE static void
393 cai2010(omxExpectation* oo, int recompute, const int *restrict primaryQuad,
394         double *restrict allSlxk, double *restrict Slxk)
395 {
396         omxBA81State *state = (omxBA81State*) oo->argStruct;
397         int numUnique = state->numUnique;
398         int numSpecific = state->numSpecific;
399         int maxDims = state->maxDims;
400         int sDim = maxDims-1;
401
402         int quad[maxDims];
403         memcpy(quad, primaryQuad, sizeof(int)*sDim);
404
405         OMXZERO(Slxk, numUnique * numSpecific);
406         OMXZERO(allSlxk, numUnique);
407
408         for (int sx=0; sx < numSpecific; sx++) {
409                 double *eis = Slxk + numUnique * sx;
410                 int quadGridSize = state->quadGridSize[sDim];
411
412                 for (int qx=0; qx < quadGridSize; qx++) {
413                         quad[sDim] = qx;
414                         double *lxk;
415                         if (recompute) {
416                                 lxk = ba81Likelihood(oo, sx, quad);
417                         } else {
418                                 lxk = CALC_LXK_CACHED(state, numUnique, quad, state->totalQuadPoints, sx);
419                         }
420
421                         for (int ix=0; ix < numUnique; ix++) {
422                                 eis[ix] += exp(lxk[ix] + state->logQarea[qx]);
423                         }
424                 }
425
426                 for (int px=0; px < numUnique; px++) {
427                         eis[px] = log(eis[px]);
428                         allSlxk[px] += eis[px];
429                 }
430         }
431 }
432
433 OMXINLINE static double
434 logAreaProduct(omxBA81State *state, const int *restrict quad, const int upto)
435 {
436         double logArea = 0;
437         for (int dx=0; dx < upto; dx++) {
438                 logArea += state->logQarea[quad[dx]];
439         }
440         return logArea;
441 }
442
443 // The idea of this API is to allow passing in a number larger than 1.
444 OMXINLINE static void
445 areaProduct(omxBA81State *state, const int *restrict quad, const int upto, double *restrict out)
446 {
447         for (int dx=0; dx < upto; dx++) {
448                 *out *= state->Qarea[quad[dx]];
449         }
450 }
451
452 OMXINLINE static void
453 decodeLocation(long qx, const int dims, const long *restrict grid,
454                int *restrict quad)
455 {
456         for (int dx=0; dx < dims; dx++) {
457                 quad[dx] = qx % grid[dx];
458                 qx = qx / grid[dx];
459         }
460 }
461
462 static void
463 ba81Estep(omxExpectation *oo) {
464         if(OMX_DEBUG_MML) {Rprintf("Beginning %s Computation.\n", NAME);}
465
466         omxBA81State *state = (omxBA81State*) oo->argStruct;
467         double *patternLik = state->patternLik;
468         int numUnique = state->numUnique;
469         int numSpecific = state->numSpecific;
470
471         omxCopyMatrix(state->EitemParam, state->itemParam);
472
473         OMXZERO(patternLik, numUnique);
474
475         // E-step, marginalize person ability
476         //
477         // Note: In the notation of Bock & Aitkin (1981) and
478         // Cai~(2010), these loops are reversed.  That is, the inner
479         // loop is over quadrature points and the outer loop is over
480         // all response patterns.
481         //
482         if (numSpecific == 0) {
483 #pragma omp parallel for num_threads(oo->currentState->numThreads)
484                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
485                         int quad[state->maxDims];
486                         decodeLocation(qx, state->maxDims, state->quadGridSize, quad);
487
488                         double *lxk = ba81Likelihood(oo, 0, quad);
489
490                         double logArea = logAreaProduct(state, quad, state->maxDims);
491 #pragma omp critical(EstepUpdate)
492                         for (int px=0; px < numUnique; px++) {
493                                 double tmp = exp(lxk[px] + logArea);
494                                 patternLik[px] += tmp;
495                         }
496                 }
497         } else {
498                 int sDim = state->maxDims-1;
499
500 #pragma omp parallel for num_threads(oo->currentState->numThreads)
501                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
502                         int quad[state->maxDims];
503                         decodeLocation(qx, sDim, state->quadGridSize, quad);
504
505                         double *allSlxk = CALC_ALLSLXK(state, numUnique);
506                         double *Slxk = CALC_SLXK(state, numUnique, numSpecific);
507                         cai2010(oo, TRUE, quad, allSlxk, Slxk);
508
509                         double logArea = logAreaProduct(state, quad, sDim);
510 #pragma omp critical(EstepUpdate)
511                         for (int px=0; px < numUnique; px++) {
512                                 double tmp = exp(allSlxk[px] + logArea);
513                                 patternLik[px] += tmp;
514                         }
515                 }
516         }
517
518         for (int px=0; px < numUnique; px++) {
519                 patternLik[px] = log(patternLik[px]);
520         }
521 }
522
523 OMXINLINE static void
524 expectedUpdate(omxData *restrict data, const int *rowMap, const int px, const int item,
525                const double observed, const int outcomes, double *out)
526 {
527         int pick = omxIntDataElementUnsafe(data, rowMap[px], item);
528         if (pick == NA_INTEGER) {
529                 double slice = exp(observed - log(outcomes));
530                 for (int ox=0; ox < outcomes; ox++) {
531                         out[ox] += slice;
532                 }
533         } else {
534                 out[pick-1] += exp(observed);
535         }
536 }
537
538 /** 
539  * \param quad a vector that indexes into a multidimensional quadrature
540  * \param out points to an array numOutcomes wide
541  */
542 OMXINLINE static void
543 ba81Weight(omxExpectation* oo, const int item, const int *quad, int outcomes, double *out)
544 {
545         omxBA81State *state = (omxBA81State*) oo->argStruct;
546         omxData *data = state->data;
547         const int *rowMap = state->rowMap;
548         int specific = state->Sgroup[item];
549         double *patternLik = state->patternLik;
550         double *logNumIdentical = state->logNumIdentical;
551         int numUnique = state->numUnique;
552         int numSpecific = state->numSpecific;
553         int sDim = state->maxDims-1;
554
555         OMXZERO(out, outcomes);
556
557         if (numSpecific == 0) {
558                 double *lxk = ba81LikelihoodFast(oo, specific, quad);
559                 for (int px=0; px < numUnique; px++) {
560                         double observed = logNumIdentical[px] + lxk[px] - patternLik[px];
561                         expectedUpdate(data, rowMap, px, item, observed, outcomes, out);
562                 }
563         } else {
564                 double *allSlxk = CALC_ALLSLXK(state, numUnique);
565                 double *Slxk = CALC_SLXK(state, numUnique, numSpecific);
566                 if (quad[sDim] == 0) {
567                         // allSlxk, Slxk only depend on the ordinate of the primary dimensions
568                         cai2010(oo, !state->cacheLXK, quad, allSlxk, Slxk);
569                 }
570                 double *eis = Slxk + numUnique * specific;
571
572                 // Avoid recalc when cache disabled with modest buffer? TODO
573                 double *lxk = ba81LikelihoodFast(oo, specific, quad);
574
575                 for (int px=0; px < numUnique; px++) {
576                         double observed = logNumIdentical[px] + (allSlxk[px] - eis[px]) +
577                                 (lxk[px] - patternLik[px]);
578                         expectedUpdate(data, rowMap, px, item, observed, outcomes, out);
579                 }
580         }
581 }
582
583 OMXINLINE static double
584 ba81Fit1Ordinate(omxExpectation* oo, const int *quad)
585 {
586         omxBA81State *state = (omxBA81State*) oo->argStruct;
587         omxMatrix *itemParam = state->itemParam;
588         int numItems = itemParam->cols;
589         rpf_fn_t rpf_fn = state->computeRPF;
590         int maxOutcomes = state->maxOutcomes;
591         int maxDims = state->maxDims;
592
593         double *outcomeProb = (*rpf_fn)(oo, itemParam, quad);
594         if (!outcomeProb) return 0;
595
596         double thr_ll = 0;
597         for (int ix=0; ix < numItems; ix++) {
598                 int outcomes = omxMatrixElement(state->itemSpec, ISpecOutcomes, ix);
599                 double out[outcomes];
600                 ba81Weight(oo, ix, quad, outcomes, out);
601                 for (int ox=0; ox < outcomes; ox++) {
602                         double got = out[ox] * outcomeProb[ix * maxOutcomes + ox];
603                         areaProduct(state, quad, maxDims, &got);
604                         thr_ll += got;
605                 }
606         }
607
608         Free(outcomeProb);
609         return thr_ll;
610 }
611
612 static double
613 ba81ComputeFit1(omxExpectation* oo)
614 {
615         omxBA81State *state = (omxBA81State*) oo->argStruct;
616         ++state->fitCount;
617         omxMatrix *customPrior = state->customPrior;
618         int numSpecific = state->numSpecific;
619         int maxDims = state->maxDims;
620
621         double ll = 0;
622         if (customPrior) {
623                 omxRecompute(customPrior);
624                 ll = customPrior->data[0];
625         } else {
626                 omxMatrix *itemSpec = state->itemSpec;
627                 omxMatrix *itemParam = state->itemParam;
628                 int numItems = itemSpec->cols;
629                 for (int ix=0; ix < numItems; ix++) {
630                         int id = omxMatrixElement(itemSpec, ISpecID, ix);
631                         int dims = omxMatrixElement(itemSpec, ISpecDims, ix);
632                         int outcomes = omxMatrixElement(itemSpec, ISpecOutcomes, ix);
633                         double *iparam = omxMatrixColumn(itemParam, ix);
634                         ll += (*rpf_table[id].prior)(dims, outcomes, iparam);
635                 }
636         }
637
638         if (numSpecific == 0) {
639 #pragma omp parallel for num_threads(oo->currentState->numThreads)
640                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
641                         int quad[maxDims];
642                         decodeLocation(qx, maxDims, state->quadGridSize, quad);
643                         double thr_ll = ba81Fit1Ordinate(oo, quad);
644
645 #pragma omp atomic
646                         ll += thr_ll;
647                 }
648         } else {
649                 int sDim = state->maxDims-1;
650                 long *quadGridSize = state->quadGridSize;
651
652 #pragma omp parallel for num_threads(oo->currentState->numThreads)
653                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
654                         int quad[maxDims];
655                         decodeLocation(qx, maxDims, quadGridSize, quad);
656
657                         double thr_ll = 0;
658                         long specificPoints = quadGridSize[sDim];
659                         for (long sx=0; sx < specificPoints; sx++) {
660                                 quad[sDim] = sx;
661                                 thr_ll += ba81Fit1Ordinate(oo, quad);
662                         }
663 #pragma omp atomic
664                         ll += thr_ll;
665                 }
666         }
667
668         if (isinf(ll)) {
669                 return 2*state->ll;
670         } else {
671                 ll = -2 * ll;
672                 state->ll = ll;
673                 return ll;
674         }
675 }
676
677 double
678 ba81ComputeFit(omxExpectation* oo)
679 {
680         double got = ba81ComputeFit1(oo);
681         return got;
682 }
683
684 OMXINLINE static void
685 ba81ItemGradientOrdinate(omxExpectation* oo, omxBA81State *state,
686                          int maxDims, int *quad, int item, int id,
687                          int dims, int outcomes,
688                          double *iparam, int *paramMask, double *gq)
689 {
690         double where[maxDims];
691         pointToWhere(state, quad, where, maxDims);
692         double weight[outcomes];
693         ba81Weight(oo, item, quad, outcomes, weight);
694
695         (*rpf_table[id].gradient)(dims, outcomes, iparam, paramMask, where, weight, gq);
696
697         for (int ox=0; ox < outcomes; ox++) {
698                 areaProduct(state, quad, maxDims, gq+ox);
699         }
700 }
701
702 OMXINLINE static void
703 ba81ItemGradient(omxExpectation* oo, omxBA81State *state, omxMatrix *itemParam,
704                  int item, int id, int dims, int outcomes, int numParam, int *paramMask, double *out)
705 {
706         int maxDims = state->maxDims;
707         double *iparam = omxMatrixColumn(itemParam, item);
708         double gradient[numParam];
709         OMXZERO(gradient, numParam);
710
711         if (state->numSpecific == 0) {
712 #pragma omp parallel for num_threads(oo->currentState->numThreads)
713                 for (long qx=0; qx < state->totalQuadPoints; qx++) {
714                         int quad[maxDims];
715                         decodeLocation(qx, maxDims, state->quadGridSize, quad);
716                         double gq[numParam];
717                         OMXZERO(gq, numParam);
718
719                         ba81ItemGradientOrdinate(oo, state, maxDims, quad, item, id, dims,
720                                                  outcomes, iparam, paramMask, gq);
721
722 #pragma omp critical(GradientUpdate)
723                         for (int ox=0; ox < outcomes; ox++) {
724                                 gradient[ox] += gq[ox];
725                         }
726                 }
727         } else {
728                 int sDim = state->maxDims-1;
729                 long *quadGridSize = state->quadGridSize;
730 #pragma omp parallel for num_threads(oo->currentState->numThreads)
731                 for (long qx=0; qx < state->totalPrimaryPoints; qx++) {
732                         int quad[maxDims];
733                         decodeLocation(qx, maxDims, quadGridSize, quad);
734                         double gq[numParam];
735                         OMXZERO(gq, numParam);
736
737                         long specificPoints = quadGridSize[sDim];
738                         for (long sx=0; sx < specificPoints; sx++) {
739                                 quad[sDim] = sx;
740                                 ba81ItemGradientOrdinate(oo, state, maxDims, quad, item, id, dims,
741                                                          outcomes, iparam, paramMask, gq);
742                         }
743 #pragma omp critical(GradientUpdate)
744                         for (int ox=0; ox < outcomes; ox++) {
745                                 gradient[ox] += gq[ox];
746                         }
747                 }
748         }
749
750         (*rpf_table[id].gradient)(dims, outcomes, iparam, paramMask, NULL, NULL, gradient);
751
752         for (int px=0; px < numParam; px++) {
753                 if (paramMask[px] == -1) continue;
754                 out[paramMask[px]] = -2 * gradient[px];
755         }
756 }
757
758 void ba81Gradient(omxExpectation* oo, double *out)
759 {
760         omxState* currentState = oo->currentState;
761         int numFreeParams = currentState->numFreeParams;
762         omxBA81State *state = (omxBA81State *) oo->argStruct;
763         if (!state->paramMap) buildParamMap(oo);
764         ++state->gradientCount;
765         omxMatrix *itemSpec = state->itemSpec;
766         omxMatrix *itemParam = state->itemParam;
767
768         int vx = 0;
769         while (vx < numFreeParams) {
770             omxFreeVar *fv = currentState->freeVarList + state->paramMap[vx];
771             int vloc = findFreeVarLocation(itemParam, fv);
772             if (vloc < 0) {
773                     ++vx;
774                     continue;
775             }
776
777             int item = fv->col[vloc];
778             int id = omxMatrixElement(itemSpec, ISpecID, item);
779             int dims = omxMatrixElement(itemSpec, ISpecDims, item);
780             int outcomes = omxMatrixElement(itemSpec, ISpecOutcomes, item);
781             int numParam = (*rpf_table[id].numParam)(dims, outcomes);
782
783             int paramMask[numParam];
784             for (int px=0; px < numParam; px++) { paramMask[px] = -1; }
785
786             paramMask[fv->row[vloc]] = vx;
787
788             while (++vx < numFreeParams) {
789                     omxFreeVar *fv = currentState->freeVarList + state->paramMap[vx];
790                     int vloc = findFreeVarLocation(itemParam, fv);
791                     if (fv->col[vloc] != item) break;
792                     paramMask[fv->row[vloc]] = vx;
793             }
794
795             ba81ItemGradient(oo, state, itemParam, item,
796                              id, dims, outcomes, numParam, paramMask, out);
797         }
798 }
799
800 static int
801 getNumThreads(omxExpectation* oo)
802 {
803         int numThreads = oo->currentState->numThreads;
804         if (numThreads < 1) numThreads = 1;
805         return numThreads;
806 }
807
808 static void
809 ba81SetupQuadrature(omxExpectation* oo, int numPoints, double *points, double *area)
810 {
811         omxBA81State *state = (omxBA81State *) oo->argStruct;
812         int numUnique = state->numUnique;
813         int numThreads = getNumThreads(oo);
814
815         state->numQpoints = numPoints;
816
817         Free(state->Qpoint);
818         Free(state->Qarea);
819         state->Qpoint = Realloc(NULL, numPoints, double);
820         state->Qarea = Realloc(NULL, numPoints, double);
821         memcpy(state->Qpoint, points, sizeof(double)*numPoints);
822         memcpy(state->Qarea, area, sizeof(double)*numPoints);
823
824         Free(state->logQarea);
825
826         state->logQarea = Realloc(NULL, state->numQpoints, double);
827         for (int px=0; px < state->numQpoints; px++) {
828                 state->logQarea[px] = log(state->Qarea[px]);
829         }
830
831         state->totalQuadPoints = 1;
832         state->totalPrimaryPoints = 1;
833         state->quadGridSize = (long*) R_alloc(state->maxDims, sizeof(long));
834         for (int dx=0; dx < state->maxDims; dx++) {
835                 state->quadGridSize[dx] = state->numQpoints;
836                 state->totalQuadPoints *= state->quadGridSize[dx];
837                 if (dx < state->maxDims-1) {
838                         state->totalPrimaryPoints *= state->quadGridSize[dx];
839                 }
840         }
841
842         Free(state->lxk);
843
844         if (!state->cacheLXK) {
845                 state->lxk = Realloc(NULL, numUnique * numThreads, double);
846         } else {
847                 int ns = state->numSpecific;
848                 if (ns == 0) ns = 1;
849                 state->lxk = Realloc(NULL, numUnique * state->totalQuadPoints * ns, double);
850         }
851 }
852
853 static void
854 ba81EAP1(omxExpectation *oo, long qx, int maxDims, int numUnique, double *ability)
855 {
856         omxBA81State *state = (omxBA81State *) oo->argStruct;
857         double *patternLik = state->patternLik;
858         int quad[maxDims];
859         decodeLocation(qx, maxDims, state->quadGridSize, quad);
860         double where[maxDims];
861         pointToWhere(state, quad, where, maxDims);
862         double logArea = logAreaProduct(state, quad, maxDims);
863         double *lxk = ba81LikelihoodFast(oo, 0, quad);
864
865         for (int px=0; px < numUnique; px++) {
866                 double piece[maxDims];
867                 double plik = exp(lxk[px] + logArea - patternLik[px]);
868                 for (int dx=0; dx < maxDims; dx++) {
869                         piece[dx] = where[dx] * plik;
870                 }
871                 double *arow = ability + px * 2 * maxDims;
872 #pragma omp critical(EAP1Update)
873                 for (int dx=0; dx < maxDims; dx++) {
874                         arow[dx*2] += piece[dx];
875                 }
876         }
877 }
878
879 static void
880 ba81EAP2(omxExpectation *oo, long qx, int maxDims, int numUnique, double *ability)
881 {
882         omxBA81State *state = (omxBA81State *) oo->argStruct;
883         double *patternLik = state->patternLik;
884         int quad[maxDims];
885         decodeLocation(qx, maxDims, state->quadGridSize, quad);
886         double where[maxDims];
887         pointToWhere(state, quad, where, maxDims);
888         double logArea = logAreaProduct(state, quad, maxDims);
889         double *lxk = ba81LikelihoodFast(oo, 0, quad);
890
891         for (int px=0; px < numUnique; px++) {
892                 double psd[maxDims];
893                 double *arow = ability + px * 2 * maxDims;
894                 for (int dx=0; dx < maxDims; dx++) {
895                         double ldiff = log(fabs(where[dx] - arow[dx*2]));
896                         psd[dx] = exp(2 * ldiff + lxk[px] + logArea - patternLik[px]);
897                 }
898 #pragma omp critical(EAP1Update)
899                 for (int dx=0; dx < maxDims; dx++) {
900                         arow[dx*2+1] += psd[dx];
901                 }
902         }
903 }
904
905 void ba81EAP(omxExpectation *oo, omxRListElement *out)
906 {
907         omxBA81State *state = (omxBA81State *) oo->argStruct;
908         int maxDims = state->maxDims;
909         omxData *data = state->data;
910         int numUnique = state->numUnique;
911
912         // TODO Wainer & Thissen. (1987). Estimating ability with the wrong
913         // model. Journal of Educational Statistics, 12, 339-368.
914
915         int numQpoints = state->numQpoints * 2;  // make configurable TODO
916         double Qpoint[numQpoints];
917         double Qarea[numQpoints];
918         const double Qwidth = 4;
919         for (int qx=0; qx < numQpoints; qx++) {
920                 Qpoint[qx] = Qwidth - qx * Qwidth*2 / (numQpoints-1);
921                 Qarea[qx] = 1.0/numQpoints;
922         }
923         ba81SetupQuadrature(oo, numQpoints, Qpoint, Qarea);
924         ba81Estep(oo);   // recalc patternLik with a flat prior
925
926         // Need a separate work space because the destination needs
927         // to be in unsorted order with duplicated rows.
928         double *ability = Calloc(numUnique * maxDims * 2, double);
929
930 #pragma omp parallel for num_threads(oo->currentState->numThreads)
931         for (long qx=0; qx < state->totalQuadPoints; qx++) {
932                 ba81EAP1(oo, qx, maxDims, numUnique, ability);
933         }
934
935 #pragma omp parallel for num_threads(oo->currentState->numThreads)
936         for (long qx=0; qx < state->totalQuadPoints; qx++) {
937                 ba81EAP2(oo, qx, maxDims, numUnique, ability);
938         }
939
940         for (int px=0; px < numUnique; px++) {
941                 double *arow = ability + px * 2 * maxDims;
942                 for (int dx=0; dx < maxDims; dx++) {
943                         arow[dx*2+1] = sqrt(arow[dx*2+1]);
944                 }
945         }
946
947         strcpy(out->label, "ability");
948         out->numValues = -1;
949         out->rows = data->rows;
950         out->cols = 2 * maxDims;
951         out->values = (double*) R_alloc(out->rows * out->cols, sizeof(double));
952
953         for (int rx=0; rx < numUnique; rx++) {
954                 double *pa = ability + rx * 2 * maxDims;
955
956                 int dups = omxDataNumIdenticalRows(state->data, state->rowMap[rx]);
957                 for (int dup=0; dup < dups; dup++) {
958                         int dest = omxDataIndex(data, state->rowMap[rx]+dup);
959                         int col=-1;
960                         for (int dx=0; dx < maxDims; dx++) {
961                                 out->values[++col * out->rows + dest] = pa[col];
962                                 out->values[++col * out->rows + dest] = pa[col];
963                         }
964                 }
965         }
966         Free(ability);
967 }
968
969 static void ba81Destroy(omxExpectation *oo) {
970         if(OMX_DEBUG) {
971                 Rprintf("Freeing %s function.\n", NAME);
972         }
973         omxBA81State *state = (omxBA81State *) oo->argStruct;
974         Rprintf("fit %d gradient %d\n", state->fitCount, state->gradientCount);
975         omxFreeAllMatrixData(state->itemSpec);
976         omxFreeAllMatrixData(state->itemParam);
977         omxFreeAllMatrixData(state->EitemParam);
978         omxFreeAllMatrixData(state->design);
979         omxFreeAllMatrixData(state->customPrior);
980         Free(state->logNumIdentical);
981         Free(state->Qpoint);
982         Free(state->Qarea);
983         Free(state->logQarea);
984         Free(state->rowMap);
985         Free(state->patternLik);
986         Free(state->lxk);
987         Free(state->Slxk);
988         Free(state->allSlxk);
989         Free(state->Sgroup);
990         Free(state->paramMap);
991         Free(state);
992 }
993
994 int ba81ExpectationHasGradients(omxExpectation* oo)
995 {
996         omxBA81State *state = (omxBA81State *) oo->argStruct;
997         return state->computeRPF == standardComputeRPF;
998 }
999
1000 void omxInitExpectationBA81(omxExpectation* oo) {
1001         omxState* currentState = oo->currentState;      
1002         SEXP rObj = oo->rObj;
1003         SEXP tmp;
1004         
1005         if(OMX_DEBUG) {
1006                 Rprintf("Initializing %s.\n", NAME);
1007         }
1008         
1009         omxBA81State *state = Calloc(1, omxBA81State);
1010         oo->argStruct = (void*) state;
1011
1012         state->ll = 1e20;   // finite but big
1013         
1014         PROTECT(tmp = GET_SLOT(rObj, install("data")));
1015         state->data = omxNewDataFromMxDataPtr(tmp, currentState);
1016         UNPROTECT(1);
1017
1018         if (strcmp(omxDataType(state->data), "raw") != 0) {
1019                 omxRaiseErrorf(currentState, "%s unable to handle data type %s", NAME, omxDataType(state->data));
1020                 return;
1021         }
1022
1023         PROTECT(state->rpf = GET_SLOT(rObj, install("RPF")));
1024         if (state->rpf == R_NilValue) {
1025                 state->computeRPF = standardComputeRPF;
1026         } else {
1027                 state->computeRPF = RComputeRPF;
1028         }
1029
1030         state->itemSpec =
1031                 omxNewMatrixFromIndexSlot(rObj, currentState, "ItemSpec");
1032         state->design =
1033                 omxNewMatrixFromIndexSlot(rObj, currentState, "Design");
1034         state->itemParam =
1035                 omxNewMatrixFromIndexSlot(rObj, currentState, "ItemParam");
1036         state->EitemParam =
1037                 omxInitTemporaryMatrix(NULL, state->itemParam->rows, state->itemParam->cols,
1038                                        TRUE, currentState);
1039         state->customPrior =
1040                 omxNewMatrixFromIndexSlot(rObj, currentState, "CustomPrior");
1041         
1042         oo->computeFun = ba81Estep;
1043         oo->destructFun = ba81Destroy;
1044         
1045         // TODO: Exactly identical rows do not contribute any information.
1046         // The sorting algorithm ought to remove them so we don't waste RAM.
1047         // The following summary stats would be cheaper to calculate too.
1048
1049         int numUnique = 0;
1050         omxData *data = state->data;
1051         if (omxDataNumFactor(data) != data->cols) {
1052                 // verify they are ordered factors TODO
1053                 omxRaiseErrorf(currentState, "%s: all columns must be factors", NAME);
1054                 return;
1055         }
1056
1057         for (int rx=0; rx < data->rows;) {
1058                 rx += omxDataNumIdenticalRows(state->data, rx);
1059                 ++numUnique;
1060         }
1061         state->numUnique = numUnique;
1062
1063         state->rowMap = Realloc(NULL, numUnique, int);
1064         state->logNumIdentical = Realloc(NULL, numUnique, double);
1065
1066         int numItems = state->itemParam->cols;
1067
1068         for (int rx=0, ux=0; rx < data->rows; ux++) {
1069                 if (rx == 0) {
1070                         // all NA rows will sort to the top
1071                         int na=0;
1072                         for (int ix=0; ix < numItems; ix++) {
1073                                 if (omxIntDataElement(data, 0, ix) == NA_INTEGER) { ++na; }
1074                         }
1075                         if (na == numItems) {
1076                                 omxRaiseErrorf(currentState, "Remove rows with all NAs");
1077                                 return;
1078                         }
1079                 }
1080                 int dups = omxDataNumIdenticalRows(state->data, rx);
1081                 state->logNumIdentical[ux] = log(dups);
1082                 state->rowMap[ux] = rx;
1083                 rx += dups;
1084         }
1085
1086         state->patternLik = Realloc(NULL, numUnique, double);
1087
1088         int numThreads = getNumThreads(oo);
1089
1090         if (state->itemSpec->cols != data->cols || state->itemSpec->rows != ISpecRowCount) {
1091                 omxRaiseErrorf(currentState, "ItemSpec must have %d item columns and %d rows",
1092                                data->cols, ISpecRowCount);
1093                 return;
1094         }
1095
1096         int maxParam = 0;
1097         state->maxDims = 0;
1098         state->maxOutcomes = 0;
1099
1100         for (int cx = 0; cx < data->cols; cx++) {
1101                 int id = omxMatrixElement(state->itemSpec, ISpecID, cx);
1102                 if (id < 0 || id >= numStandardRPF) {
1103                         omxRaiseErrorf(currentState, "ItemSpec column %d has unknown item model %d", cx, id);
1104                         return;
1105                 }
1106
1107                 int dims = omxMatrixElement(state->itemSpec, ISpecDims, cx);
1108                 if (state->maxDims < dims)
1109                         state->maxDims = dims;
1110
1111                 // TODO verify that item model can have requested number of outcomes
1112                 int no = omxMatrixElement(state->itemSpec, ISpecOutcomes, cx);
1113                 if (state->maxOutcomes < no)
1114                         state->maxOutcomes = no;
1115
1116                 int numParam = (*rpf_table[id].numParam)(dims, no);
1117                 if (maxParam < numParam)
1118                         maxParam = numParam;
1119         }
1120
1121         if (state->itemParam->rows != maxParam) {
1122                 omxRaiseErrorf(currentState, "ItemParam should have %d rows", maxParam);
1123                 return;
1124         }
1125
1126         if (state->design == NULL) {
1127                 state->maxAbilities = state->maxDims;
1128                 state->design = omxInitTemporaryMatrix(NULL, state->maxDims, numItems,
1129                                        TRUE, currentState);
1130                 for (int ix=0; ix < numItems; ix++) {
1131                         for (int dx=0; dx < state->maxDims; dx++) {
1132                                 omxSetMatrixElement(state->design, dx, ix, (double)dx+1);
1133                         }
1134                 }
1135         } else {
1136                 omxMatrix *design = state->design;
1137                 if (design->cols != numItems ||
1138                     design->rows != state->maxDims) {
1139                         omxRaiseErrorf(currentState, "Design matrix should have %d rows and %d columns",
1140                                        state->maxDims, numItems);
1141                         return;
1142                 }
1143
1144                 state->maxAbilities = 0;
1145                 for (int ix=0; ix < design->rows * design->cols; ix++) {
1146                         double got = design->data[ix];
1147                         if (!R_FINITE(got)) continue;
1148                         if (round(got) != got) error("Design matrix can only contain integers"); // TODO better way?
1149                         if (state->maxAbilities < got)
1150                                 state->maxAbilities = got;
1151                 }
1152         }
1153         if (state->maxAbilities <= state->maxDims) {
1154                 state->Sgroup = Calloc(numItems, int);
1155         } else {
1156                 int Sgroup0 = state->maxDims;
1157                 state->Sgroup = Realloc(NULL, numItems, int);
1158                 for (int ix=0; ix < numItems; ix++) {
1159                         int ss=-1;
1160                         for (int dx=0; dx < state->maxDims; dx++) {
1161                                 int ability = omxMatrixElement(state->design, dx, ix);
1162                                 if (ability >= Sgroup0) {
1163                                         if (ss == -1) {
1164                                                 ss = ability;
1165                                         } else {
1166                                                 omxRaiseErrorf(currentState, "Item %d cannot belong to more than "
1167                                                                "1 specific dimension (both %d and %d)",
1168                                                                ix, ss, ability);
1169                                                 return;
1170                                         }
1171                                 }
1172                         }
1173                         if (ss == -1) ss = 0;
1174                         state->Sgroup[ix] = ss - Sgroup0;
1175                 }
1176                 state->numSpecific = state->maxAbilities - state->maxDims + 1;
1177                 state->allSlxk = Realloc(NULL, numUnique * numThreads, double);
1178                 state->Slxk = Realloc(NULL, numUnique * state->numSpecific * numThreads, double);
1179         }
1180
1181         PROTECT(tmp = GET_SLOT(rObj, install("cache")));
1182         state->cacheLXK = asLogical(tmp);
1183
1184         PROTECT(tmp = GET_SLOT(rObj, install("GHpoints")));
1185         double *qpoints = REAL(tmp);
1186         int numQPoints = length(tmp);
1187
1188         PROTECT(tmp = GET_SLOT(rObj, install("GHarea")));
1189         double *qarea = REAL(tmp);
1190         if (numQPoints != length(tmp)) error("length(GHpoints) != length(GHarea)");
1191
1192         ba81SetupQuadrature(oo, numQPoints, qpoints, qarea);
1193
1194         // verify data bounded between 1 and numOutcomes TODO
1195         // hm, looks like something could be added to omxData for column summary stats?
1196 }
1197
1198 SEXP omx_get_rpf_names()
1199 {
1200         SEXP outsxp;
1201         PROTECT(outsxp = allocVector(STRSXP, numStandardRPF));
1202         for (int sx=0; sx < numStandardRPF; sx++) {
1203                 SET_STRING_ELT(outsxp, sx, mkChar(rpf_table[sx].name));
1204         }
1205         UNPROTECT(1);
1206         return outsxp;
1207 }