Multigroup IFA
[openmx:openmx.git] / src / omxFitFunctionBA81.cpp
1 /*
2   Copyright 2012-2013 Joshua Nathaniel Pritikin and contributors
3
4   This is free software: you can redistribute it and/or modify
5   it under the terms of the GNU General Public License as published by
6   the Free Software Foundation, either version 3 of the License, or
7   (at your option) any later version.
8
9   This program is distributed in the hope that it will be useful,
10   but WITHOUT ANY WARRANTY; without even the implied warranty of
11   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12   GNU General Public License for more details.
13
14   You should have received a copy of the GNU General Public License
15   along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #include "omxFitFunction.h"
19 #include "omxExpectationBA81.h"
20 #include "omxOpenmpWrap.h"
21 #include "libifa-rpf.h"
22
23 static const char *NAME = "FitFunctionBA81";
24
25 struct BA81FitState {
26
27         omxMatrix *itemParam;     // M step version
28         int derivPadSize;         // maxParam + maxParam*(1+maxParam)/2
29         double *thrDeriv;         // itemParam->cols * derivPadSize * thread
30         int *paramMap;            // itemParam->cols * derivPadSize -> index of free parameter
31         std::vector<int> NAtriangle;
32         bool rescale;
33         omxMatrix *customPrior;
34         int choleskyError;
35         double *tmpLatentMean;    // maxDims
36         double *tmpLatentCov;     // maxDims * maxDims ; only lower triangle is used
37         int fitCount;
38         int gradientCount;
39
40         std::vector< FreeVarGroup* > varGroups;
41         FreeVarGroup *latentFVG;
42
43         BA81FitState();
44 };
45
46 BA81FitState::BA81FitState()
47 {
48         paramMap = NULL;
49         latentFVG = NULL;
50         customPrior = NULL;
51         fitCount = 0;
52         gradientCount = 0;
53 }
54
55 static void buildParamMap(omxFitFunction* oo)
56 {
57         BA81FitState *state = (BA81FitState *) oo->argStruct;
58         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
59         omxMatrix *itemParam = state->itemParam;
60         int size = itemParam->cols * state->derivPadSize;
61
62         state->paramMap = Realloc(NULL, size, int);  // matrix location to free param index
63         for (int px=0; px < size; px++) {
64                 state->paramMap[px] = -1;
65         }
66
67         size_t numFreeParams = oo->freeVarGroup->vars.size();
68         int *pRow = Realloc(NULL, numFreeParams, int);
69         int *pCol = Realloc(NULL, numFreeParams, int);
70
71         for (size_t px=0; px < numFreeParams; px++) {
72                 pRow[px] = -1;
73                 pCol[px] = -1;
74                 omxFreeVar *fv = oo->freeVarGroup->vars[px];
75                 for (size_t lx=0; lx < fv->locations.size(); lx++) {
76                         omxFreeVarLocation *loc = &fv->locations[lx];
77                         if (~loc->matrix == itemParam->matrixNumber) {
78                                 pRow[px] = loc->row;
79                                 pCol[px] = loc->col;
80                                 int at = pCol[px] * state->derivPadSize + pRow[px];
81                                 state->paramMap[at] = px;
82                         }
83                 }
84         }
85
86         for (size_t p1=0; p1 < numFreeParams; p1++) {
87                 for (size_t p2=p1; p2 < numFreeParams; p2++) {
88                         if (pCol[p1] == -1 || pCol[p1] != pCol[p2]) continue;
89                         const double *spec = omxMatrixColumn(estate->itemSpec, pCol[p1]);
90                         int id = spec[RPF_ISpecID];
91                         int numParam = (*rpf_model[id].numParam)(spec);
92                         int r1 = pRow[p1];
93                         int r2 = pRow[p2];
94                         if (r1 > r2) { int tmp=r1; r1=r2; r2=tmp; }
95                         int rowOffset = 0;
96                         for (int rx=1; rx <= r2; rx++) rowOffset += rx;
97                         int at = pCol[p1] * state->derivPadSize + numParam + rowOffset + r1;
98                         state->paramMap[at] = numFreeParams + p1 * numFreeParams + p2;
99                         if (p2 != p1) state->NAtriangle.push_back(p2 * numFreeParams + p1);
100                 }
101         }
102
103         Free(pRow);
104         Free(pCol);
105
106         state->thrDeriv = Realloc(NULL, itemParam->cols * state->derivPadSize * Global->numThreads, double);
107 }
108
109 OMXINLINE static double
110 ba81Fit1Ordinate(omxFitFunction* oo, const int *quad, const double *weight, int want)
111 {
112         BA81FitState *state = (BA81FitState*) oo->argStruct;
113         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
114         omxMatrix *itemSpec = estate->itemSpec;
115         omxMatrix *itemParam = state->itemParam;
116         int numItems = itemParam->cols;
117         int maxOutcomes = estate->maxOutcomes;
118         int maxDims = estate->maxDims;
119         double *myDeriv = state->thrDeriv + itemParam->cols * state->derivPadSize * omx_absolute_thread_num();
120         int do_deriv = want & (FF_COMPUTE_GRADIENT | FF_COMPUTE_HESSIAN);
121
122         double where[maxDims];
123         pointToWhere(estate->Qpoint, quad, where, maxDims);
124
125         double *outcomeProb = computeRPF(estate->itemSpec, estate->design, itemParam, estate->maxDims,
126                                          estate->maxOutcomes, quad, estate->Qpoint); // avoid malloc/free? TODO
127         if (!outcomeProb) return 0;
128
129         double thr_ll = 0;
130         for (int ix=0; ix < numItems; ix++) {
131                 const double *spec = omxMatrixColumn(itemSpec, ix);
132                 int id = spec[RPF_ISpecID];
133                 int iOutcomes = spec[RPF_ISpecOutcomes];
134
135                 double area = exp(logAreaProduct(estate, quad, estate->Sgroup[ix]));   // avoid exp() here? TODO
136                 for (int ox=0; ox < iOutcomes; ox++) {
137 #if 0
138 #pragma omp critical(ba81Fit1OrdinateDebug1)
139                         if (!isfinite(outcomeProb[ix * maxOutcomes + ox])) {
140                                 pda(itemParam->data, itemParam->rows, itemParam->cols);
141                                 pda(outcomeProb, outcomes, numItems);
142                                 error("RPF produced NAs");
143                         }
144 #endif
145                         double got = weight[ox] * outcomeProb[ix * maxOutcomes + ox];
146                         thr_ll += got * area;
147                 }
148
149                 if (do_deriv) {
150                         double *iparam = omxMatrixColumn(itemParam, ix);
151                         double *pad = myDeriv + ix * state->derivPadSize;
152                         (*rpf_model[id].dLL1)(spec, iparam, where, area, weight, pad);
153                 }
154                 weight += iOutcomes;
155         }
156
157         Free(outcomeProb);
158
159         return thr_ll;
160 }
161
162 static double
163 ba81ComputeMFit1(omxFitFunction* oo, int want, double *gradient, double *hessian)
164 {
165         BA81FitState *state = (BA81FitState*) oo->argStruct;
166         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
167         omxMatrix *customPrior = state->customPrior;
168         omxMatrix *itemParam = state->itemParam;
169         omxMatrix *itemSpec = estate->itemSpec;
170         int maxDims = estate->maxDims;
171         const int totalOutcomes = estate->totalOutcomes;
172
173         double ll = 0;
174         if (customPrior) {
175                 omxRecompute(customPrior);
176                 ll = customPrior->data[0];
177                 // need deriv adjustment TODO
178         }
179
180         if (!isfinite(ll)) {
181                 omxPrint(itemParam, "item param");
182                 error("Bayesian prior returned %g; do you need to add a lbound/ubound?", ll);
183         }
184
185 #pragma omp parallel for num_threads(Global->numThreads)
186         for (long qx=0; qx < estate->totalQuadPoints; qx++) {
187                 //double area = exp(state->priLogQarea[qx]);  // avoid exp() here? TODO
188                 int quad[maxDims];
189                 decodeLocation(qx, maxDims, estate->quadGridSize, quad);
190                 double *weight = estate->expected + qx * totalOutcomes;
191                 double thr_ll = ba81Fit1Ordinate(oo, quad, weight, want);
192                 
193 #pragma omp atomic
194                 ll += thr_ll;
195         }
196
197         if (gradient) {
198                 double *deriv0 = state->thrDeriv;
199
200                 int perThread = itemParam->cols * state->derivPadSize;
201                 for (int th=1; th < Global->numThreads; th++) {
202                         double *thrD = state->thrDeriv + th * perThread;
203                         for (int ox=0; ox < perThread; ox++) deriv0[ox] += thrD[ox];
204                 }
205
206                 int numItems = itemParam->cols;
207                 for (int ix=0; ix < numItems; ix++) {
208                         const double *spec = omxMatrixColumn(itemSpec, ix);
209                         int id = spec[RPF_ISpecID];
210                         double *iparam = omxMatrixColumn(itemParam, ix);
211                         double *pad = deriv0 + ix * state->derivPadSize;
212                         (*rpf_model[id].dLL2)(spec, iparam, pad);
213                 }
214
215                 int numFreeParams = int(oo->freeVarGroup->vars.size());
216                 int numParams = itemParam->cols * state->derivPadSize;
217                 for (int ox=0; ox < numParams; ox++) {
218                         int to = state->paramMap[ox];
219                         if (to == -1) continue;
220
221                         // Need to check because this can happen if
222                         // lbounds/ubounds are not set appropriately.
223                         if (0 && !isfinite(deriv0[ox])) {
224                                 int item = ox / itemParam->rows;
225                                 mxLog("item parameters:\n");
226                                 const double *spec = omxMatrixColumn(itemSpec, item);
227                                 int id = spec[RPF_ISpecID];
228                                 int numParam = (*rpf_model[id].numParam)(spec);
229                                 double *iparam = omxMatrixColumn(itemParam, item);
230                                 pda(iparam, numParam, 1);
231                                 // Perhaps bounds can be pulled in from librpf? TODO
232                                 error("Deriv %d for item %d is %f; are you missing a lbound/ubound?",
233                                       ox, item, deriv0[ox]);
234                         }
235
236                         if (to < numFreeParams) {
237                                 gradient[to] -= deriv0[ox];
238                         } else {
239                                 hessian[to - numFreeParams] -= deriv0[ox];
240                         }
241                 }
242         }
243
244         return -ll;
245 }
246
247 static void
248 schilling_bock_2005_rescale(omxFitFunction *oo, FitContext *fc)
249 {
250         BA81FitState *state = (BA81FitState*) oo->argStruct;
251         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
252         omxMatrix *itemSpec = estate->itemSpec;
253         omxMatrix *itemParam = state->itemParam;
254         omxMatrix *design = estate->design;
255         double *ElatentMean = estate->ElatentMean;
256         double *ElatentCov = estate->ElatentCov;
257         double *tmpLatentMean = state->tmpLatentMean;
258         double *tmpLatentCov = state->tmpLatentCov;
259         int maxAbilities = estate->maxAbilities;
260         int maxDims = estate->maxDims;
261
262         //mxLog("schilling bock\n");
263         //pda(ElatentMean, maxAbilities, 1);
264         //pda(ElatentCov, maxAbilities, maxAbilities);
265         //omxPrint(design, "design");
266
267         // use omxDPOTRF instead? TODO
268         const char triangle = 'L';
269         F77_CALL(dpotrf)(&triangle, &maxAbilities, ElatentCov, &maxAbilities, &state->choleskyError);
270         if (state->choleskyError != 0) {
271                 warning("Cholesky failed with %d; rescaling disabled", state->choleskyError); // make error TODO?
272                 return;
273         }
274
275         //fc->log(FF_COMPUTE_ESTIMATE);
276
277         int numItems = itemParam->cols;
278         for (int ix=0; ix < numItems; ix++) {
279                 const double *spec = omxMatrixColumn(itemSpec, ix);
280                 int id = spec[RPF_ISpecID];
281                 const double *rawDesign = omxMatrixColumn(design, ix);
282                 int idesign[design->rows];
283                 int idx = 0;
284                 for (int dx=0; dx < design->rows; dx++) {
285                         if (isfinite(rawDesign[dx])) {
286                                 idesign[idx++] = rawDesign[dx]-1;
287                         } else {
288                                 idesign[idx++] = -1;
289                         }
290                 }
291                 for (int d1=0; d1 < idx; d1++) {
292                         if (idesign[d1] == -1) {
293                                 tmpLatentMean[d1] = 0;
294                         } else {
295                                 tmpLatentMean[d1] = ElatentMean[idesign[d1]];
296                         }
297                         for (int d2=0; d2 <= d1; d2++) {
298                                 int cell = idesign[d2] * maxAbilities + idesign[d1];
299                                 if (idesign[d1] == -1 || idesign[d2] == -1) {
300                                         tmpLatentCov[d2 * maxDims + d1] = d1==d2? 1 : 0;
301                                 } else {
302                                         tmpLatentCov[d2 * maxDims + d1] = ElatentCov[cell];
303                                 }
304                         }
305                 }
306                 if (1) {  // ease debugging, make optional TODO
307                         for (int d1=idx; d1 < maxDims; d1++) tmpLatentMean[d1] = nan("");
308                         for (int d1=0; d1 < maxDims; d1++) {
309                                 for (int d2=0; d2 < maxDims; d2++) {
310                                         if (d1 < idx && d2 < idx) continue;
311                                         tmpLatentCov[d2 * maxDims + d1] = nan("");
312                                 }
313                         }
314                 }
315                 double *iparam = omxMatrixColumn(itemParam, ix);
316                 int *mask = state->paramMap + state->derivPadSize * ix;
317                 rpf_model[id].rescale(spec, iparam, mask, tmpLatentMean, tmpLatentCov);
318         }
319
320         int numFreeParams = int(oo->freeVarGroup->vars.size());
321         for (int rx=0; rx < itemParam->rows; rx++) {
322                 for (int cx=0; cx < itemParam->cols; cx++) {
323                         int vx = state->paramMap[cx * state->derivPadSize + rx];
324                         if (vx >= 0 && vx < numFreeParams) {
325                                 fc->est[vx] = omxMatrixElement(itemParam, rx, cx);
326                         }
327                 }
328         }
329         fc->copyParamToModel(globalState);
330         //fc->log(FF_COMPUTE_ESTIMATE);
331 }
332
333 OMXINLINE static void
334 updateLatentParam(omxFitFunction* oo, FitContext *fc)
335 {
336         BA81FitState *state = (BA81FitState*) oo->argStruct;
337         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
338         int maxAbilities = estate->maxAbilities;
339         int meanNum = estate->latentMeanOut->matrixNumber;
340         int covNum = estate->latentCovOut->matrixNumber;
341         FreeVarGroup *latentFVG = state->latentFVG;
342
343         // TODO need denom for multigroup
344         size_t numFreeParams = latentFVG->vars.size();
345         for (size_t px=0; px < numFreeParams; px++) {
346                 omxFreeVar *fv = latentFVG->vars[px];
347                 for (size_t lx=0; lx < fv->locations.size(); ++lx) {
348                         omxFreeVarLocation *loc = &fv->locations[lx];
349                         int matNum = ~loc->matrix;
350                         if (matNum == meanNum) {
351                                 int dx = loc->row * loc->col;
352                                 fc->est[px] = estate->ElatentMean[dx];
353                         } else if (matNum == covNum) {
354                                 int cell = loc->col * maxAbilities + loc->row;
355                                 fc->est[px] = estate->ElatentCov[cell];
356                         }
357                 }
358         }
359         fc->copyParamToModel(globalState);
360 }
361
362 void ba81SetFreeVarGroup(omxFitFunction *oo, FreeVarGroup *fvg) // too ad hoc? TODO
363 {
364         if (!oo->argStruct) { // ugh!
365                 BA81FitState *state = new BA81FitState;
366                 oo->argStruct = state;
367         }
368
369         BA81FitState *state = (BA81FitState*) oo->argStruct;
370
371         state->varGroups.push_back(fvg);
372         if (state->varGroups.size() == 2) {
373                 int small = 0;
374                 if (state->varGroups[0] == state->varGroups[1])
375                         warning("Cannot recognize correct free parameter groups");
376                 if (state->varGroups[0]->vars.size() > state->varGroups[1]->vars.size())
377                         small = 1;
378                 oo->freeVarGroup = state->varGroups[small];
379                 state->latentFVG = state->varGroups[!small];
380         } else if (state->varGroups.size() > 2) {
381                 // ignore
382         }
383 }
384
385 static double
386 ba81ComputeFit(omxFitFunction* oo, int want, FitContext *fc)
387 {
388         if (!want) return 0;
389
390         BA81FitState *state = (BA81FitState*) oo->argStruct;
391
392         if (!state->paramMap) buildParamMap(oo);
393
394         if (want & FF_COMPUTE_PREOPTIMIZE) {
395                 if (state->rescale) schilling_bock_2005_rescale(oo, fc); // how does this work in multigroup? TODO
396                 return 0;
397         }
398
399         if (want & FF_COMPUTE_FIT) {
400                 ++state->fitCount;
401         }
402
403         if (want & (FF_COMPUTE_GRADIENT|FF_COMPUTE_HESSIAN)) {
404                 // M-step
405
406                 ++state->gradientCount;
407
408                 omxMatrix *itemParam = state->itemParam;
409                 OMXZERO(state->thrDeriv, state->derivPadSize * itemParam->cols * Global->numThreads);
410
411                 for (size_t nx=0; nx < state->NAtriangle.size(); ++nx) {
412                         fc->hess[ state->NAtriangle[nx] ] = nan("symmetric");
413                 }
414
415                 double got = ba81ComputeMFit1(oo, want, fc->grad, fc->hess);
416                 return got;
417         } else {
418                 // Major EM iteration, note completely different LL calculation
419
420                 updateLatentParam(oo, fc);
421
422                 BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
423                 double *patternLik = estate->patternLik;
424                 int *numIdentical = estate->numIdentical;
425                 int numUnique = estate->numUnique;
426                 double got = 0;
427                 for (int ux=0; ux < numUnique; ux++) {
428                         got += numIdentical[ux] * patternLik[ux];
429                 }
430                 return -2 * got;
431         }
432 }
433
434 static void ba81Compute(omxFitFunction *oo, int want, FitContext *fc)
435 {
436         oo->matrix->data[0] = ba81ComputeFit(oo, want, fc);
437 }
438
439 static void ba81Destroy(omxFitFunction *oo) {
440         BA81FitState *state = (BA81FitState *) oo->argStruct;
441
442         omxFreeAllMatrixData(state->customPrior);
443         Free(state->paramMap);
444         Free(state->thrDeriv);
445         Free(state->tmpLatentMean);
446         Free(state->tmpLatentCov);
447         omxFreeAllMatrixData(state->itemParam);
448         delete state;
449 }
450
451 static omxRListElement *ba81SetFinalReturns(omxFitFunction *oo, int *numReturns)
452 {
453         omxRListElement *ret = ba81EAP(oo->expectation, numReturns);
454
455         return ret;
456 }
457
458 void omxInitFitFunctionBA81(omxFitFunction* oo)
459 {
460         BA81FitState *state = (BA81FitState*) oo->argStruct;
461         SEXP rObj = oo->rObj;
462
463         omxExpectation *expectation = oo->expectation;
464         BA81Expect *estate = (BA81Expect*) expectation->argStruct;
465
466         //newObj->data = oo->expectation->data;
467
468         oo->computeFun = ba81Compute;
469         oo->setVarGroup = ba81SetFreeVarGroup;
470         oo->setFinalReturns = ba81SetFinalReturns;
471         oo->destructFun = ba81Destroy;
472         oo->gradientAvailable = TRUE;
473         oo->hessianAvailable = TRUE;
474
475         SEXP tmp;
476         PROTECT(tmp = GET_SLOT(rObj, install("rescale")));
477         state->rescale = asLogical(tmp);
478
479         state->itemParam =
480                 omxNewMatrixFromSlot(rObj, globalState, "ItemParam");
481
482         if (estate->EitemParam->rows != state->itemParam->rows ||
483             estate->EitemParam->cols != state->itemParam->cols) {
484                 error("ItemParam and EItemParam matrices must be the same dimension");
485         }
486
487         state->customPrior =
488                 omxNewMatrixFromSlot(rObj, globalState, "CustomPrior");
489         
490         int maxParam = state->itemParam->rows;
491         state->derivPadSize = maxParam + maxParam*(1+maxParam)/2;
492
493         state->tmpLatentMean = Realloc(NULL, estate->maxDims, double);
494         state->tmpLatentCov = Realloc(NULL, estate->maxDims * estate->maxDims, double);
495
496         int numItems = state->itemParam->cols;
497         for (int ix=0; ix < numItems; ix++) {
498                 double *spec = omxMatrixColumn(estate->itemSpec, ix);
499                 int id = spec[RPF_ISpecID];
500                 if (id < 0 || id >= rpf_numModels) {
501                         error("ItemSpec column %d has unknown item model %d", ix, id);
502                 }
503         }
504 }