Look at expectation to determine which fit to calculate
[openmx:openmx.git] / src / omxFitFunctionBA81.cpp
1 /*
2   Copyright 2012-2013 Joshua Nathaniel Pritikin and contributors
3
4   This is free software: you can redistribute it and/or modify
5   it under the terms of the GNU General Public License as published by
6   the Free Software Foundation, either version 3 of the License, or
7   (at your option) any later version.
8
9   This program is distributed in the hope that it will be useful,
10   but WITHOUT ANY WARRANTY; without even the implied warranty of
11   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12   GNU General Public License for more details.
13
14   You should have received a copy of the GNU General Public License
15   along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #include "omxFitFunction.h"
19 #include "omxExpectationBA81.h"
20 #include "omxOpenmpWrap.h"
21 #include "libifa-rpf.h"
22
23 static const char *NAME = "FitFunctionBA81";
24
25 struct BA81FitState {
26
27         omxMatrix *itemParam;     // M step version
28         int derivPadSize;         // maxParam + maxParam*(1+maxParam)/2
29         double *thrDeriv;         // itemParam->cols * derivPadSize * thread
30         int *paramMap;            // itemParam->cols * derivPadSize -> index of free parameter
31         std::vector<int> latentMeanMap;
32         std::vector<int> latentCovMap;
33         std::vector<int> NAtriangle;
34         bool rescale;
35         omxMatrix *customPrior;
36         int choleskyError;
37         double *tmpLatentMean;    // maxDims
38         double *tmpLatentCov;     // maxDims * maxDims ; only lower triangle is used
39         int fitCount;
40         int gradientCount;
41
42         std::vector< FreeVarGroup* > varGroups;
43         FreeVarGroup *latentFVG;
44
45         BA81FitState();
46 };
47
48 BA81FitState::BA81FitState()
49 {
50         itemParam = NULL;
51         thrDeriv = NULL;
52         paramMap = NULL;
53         latentFVG = NULL;
54         customPrior = NULL;
55         fitCount = 0;
56         gradientCount = 0;
57         tmpLatentMean = NULL;
58         tmpLatentCov = NULL;
59 }
60
61 static void buildParamMap(omxFitFunction* oo)
62 {
63         BA81FitState *state = (BA81FitState *) oo->argStruct;
64         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
65         omxMatrix *itemParam = state->itemParam;
66         int size = itemParam->cols * state->derivPadSize;
67         int maxAbilities = estate->maxAbilities;
68         int meanNum = estate->latentMeanOut->matrixNumber;
69         int covNum = estate->latentCovOut->matrixNumber;
70         FreeVarGroup *fvg = state->latentFVG;
71
72         state->latentMeanMap.assign(maxAbilities, -1);
73         state->latentCovMap.assign(maxAbilities * maxAbilities, -1);
74
75         for (size_t px=0; px < fvg->vars.size(); px++) {
76                 omxFreeVar *fv = fvg->vars[px];
77                 for (size_t lx=0; lx < fv->locations.size(); lx++) {
78                         omxFreeVarLocation *loc = &fv->locations[lx];
79                         int matNum = ~loc->matrix;
80                         if (matNum == meanNum) {
81                                 state->latentMeanMap[loc->row + loc->col] = px;
82                         } else if (matNum == covNum) {
83                                 state->latentCovMap[loc->col * maxAbilities + loc->row] = px;
84                         }
85                 }
86         }
87
88         state->paramMap = Realloc(NULL, size, int);  // matrix location to free param index
89         for (int px=0; px < size; px++) {
90                 state->paramMap[px] = -1;
91         }
92
93         size_t numFreeParams = oo->freeVarGroup->vars.size();
94         int *pRow = Realloc(NULL, numFreeParams, int);
95         int *pCol = Realloc(NULL, numFreeParams, int);
96
97         for (size_t px=0; px < numFreeParams; px++) {
98                 pRow[px] = -1;
99                 pCol[px] = -1;
100                 omxFreeVar *fv = oo->freeVarGroup->vars[px];
101                 for (size_t lx=0; lx < fv->locations.size(); lx++) {
102                         omxFreeVarLocation *loc = &fv->locations[lx];
103                         int matNum = ~loc->matrix;
104                         if (matNum == itemParam->matrixNumber) {
105                                 pRow[px] = loc->row;
106                                 pCol[px] = loc->col;
107                                 int at = pCol[px] * state->derivPadSize + pRow[px];
108                                 state->paramMap[at] = px;
109
110                                 const double *spec = omxMatrixColumn(estate->itemSpec, loc->col);
111                                 int id = spec[RPF_ISpecID];
112                                 double upper, lower;
113                                 (*rpf_model[id].paramBound)(spec, loc->row, &upper, &lower);
114                                 if (fv->lbound == NEG_INF && isfinite(lower)) fv->lbound = lower;
115                                 if (fv->ubound == INF && isfinite(upper)) fv->ubound = upper;
116                         }
117                 }
118         }
119
120         for (size_t p1=0; p1 < numFreeParams; p1++) {
121                 for (size_t p2=p1; p2 < numFreeParams; p2++) {
122                         if (pCol[p1] == -1 || pCol[p1] != pCol[p2]) continue;
123                         const double *spec = omxMatrixColumn(estate->itemSpec, pCol[p1]);
124                         int id = spec[RPF_ISpecID];
125                         int numParam = (*rpf_model[id].numParam)(spec);
126                         int r1 = pRow[p1];
127                         int r2 = pRow[p2];
128                         if (r1 > r2) { int tmp=r1; r1=r2; r2=tmp; }
129                         int rowOffset = 0;
130                         for (int rx=1; rx <= r2; rx++) rowOffset += rx;
131                         int at = pCol[p1] * state->derivPadSize + numParam + rowOffset + r1;
132                         state->paramMap[at] = numFreeParams + p1 * numFreeParams + p2;
133                         if (p2 != p1) state->NAtriangle.push_back(p2 * numFreeParams + p1);
134                 }
135         }
136
137         Free(pRow);
138         Free(pCol);
139
140         state->thrDeriv = Realloc(NULL, itemParam->cols * state->derivPadSize * Global->numThreads, double);
141 }
142
143 OMXINLINE static double
144 ba81Fit1Ordinate(omxFitFunction* oo, const int *quad, const double *weight, int want)
145 {
146         BA81FitState *state = (BA81FitState*) oo->argStruct;
147         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
148         omxMatrix *itemSpec = estate->itemSpec;
149         omxMatrix *itemParam = state->itemParam;
150         int numItems = itemParam->cols;
151         int maxOutcomes = estate->maxOutcomes;
152         int maxDims = estate->maxDims;
153         double *myDeriv = state->thrDeriv + itemParam->cols * state->derivPadSize * omx_absolute_thread_num();
154         int do_deriv = want & (FF_COMPUTE_GRADIENT | FF_COMPUTE_HESSIAN);
155
156         double where[maxDims];
157         pointToWhere(estate, quad, where, maxDims);
158
159         double *outcomeProb = computeRPF(estate, estate->itemSpec, estate->design, itemParam, estate->maxDims,
160                                          estate->maxOutcomes, quad); // avoid malloc/free? TODO
161         if (!outcomeProb) return 0;
162
163         double thr_ll = 0;
164         for (int ix=0; ix < numItems; ix++) {
165                 const double *spec = omxMatrixColumn(itemSpec, ix);
166                 int id = spec[RPF_ISpecID];
167                 int iOutcomes = spec[RPF_ISpecOutcomes];
168
169                 double area = exp(logAreaProduct(estate, quad, estate->Sgroup[ix]));   // avoid exp() here? TODO
170                 for (int ox=0; ox < iOutcomes; ox++) {
171 #if 0
172 #pragma omp critical(ba81Fit1OrdinateDebug1)
173                         if (!std::isfinite(outcomeProb[ix * maxOutcomes + ox])) {
174                                 pda(itemParam->data, itemParam->rows, itemParam->cols);
175                                 pda(outcomeProb, outcomes, numItems);
176                                 error("RPF produced NAs");
177                         }
178 #endif
179                         double got = weight[ox] * outcomeProb[ix * maxOutcomes + ox];
180                         thr_ll += got * area;
181                 }
182
183                 if (do_deriv) {
184                         double *iparam = omxMatrixColumn(itemParam, ix);
185                         double *pad = myDeriv + ix * state->derivPadSize;
186                         (*rpf_model[id].dLL1)(spec, iparam, where, area, weight, pad);
187                 }
188                 weight += iOutcomes;
189         }
190
191         Free(outcomeProb);
192
193         return thr_ll;
194 }
195
196 static double
197 ba81ComputeMFit1(omxFitFunction* oo, int want, double *gradient, double *hessian)
198 {
199         BA81FitState *state = (BA81FitState*) oo->argStruct;
200         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
201         omxMatrix *customPrior = state->customPrior;
202         omxMatrix *itemParam = state->itemParam;
203         omxMatrix *itemSpec = estate->itemSpec;
204         int maxDims = estate->maxDims;
205         const int totalOutcomes = estate->totalOutcomes;
206
207         double ll = 0;
208         if (customPrior) {
209                 omxRecompute(customPrior);
210                 ll = customPrior->data[0];
211                 // need deriv adjustment TODO
212         }
213
214         if (!isfinite(ll)) {
215                 omxPrint(itemParam, "item param");
216                 error("Bayesian prior returned %g; do you need to add a lbound/ubound?", ll);
217         }
218
219 #pragma omp parallel for num_threads(Global->numThreads)
220         for (long qx=0; qx < estate->totalQuadPoints; qx++) {
221                 //double area = exp(state->priLogQarea[qx]);  // avoid exp() here? TODO
222                 int quad[maxDims];
223                 decodeLocation(qx, maxDims, estate->quadGridSize, quad);
224                 double *weight = estate->expected + qx * totalOutcomes;
225                 double thr_ll = ba81Fit1Ordinate(oo, quad, weight, want);
226                 
227 #pragma omp atomic
228                 ll += thr_ll;
229         }
230
231         if (gradient) {
232                 double *deriv0 = state->thrDeriv;
233
234                 int perThread = itemParam->cols * state->derivPadSize;
235                 for (int th=1; th < Global->numThreads; th++) {
236                         double *thrD = state->thrDeriv + th * perThread;
237                         for (int ox=0; ox < perThread; ox++) deriv0[ox] += thrD[ox];
238                 }
239
240                 int numItems = itemParam->cols;
241                 for (int ix=0; ix < numItems; ix++) {
242                         const double *spec = omxMatrixColumn(itemSpec, ix);
243                         int id = spec[RPF_ISpecID];
244                         double *iparam = omxMatrixColumn(itemParam, ix);
245                         double *pad = deriv0 + ix * state->derivPadSize;
246                         (*rpf_model[id].dLL2)(spec, iparam, pad);
247                 }
248
249                 int numFreeParams = int(oo->freeVarGroup->vars.size());
250                 int numParams = itemParam->cols * state->derivPadSize;
251                 for (int ox=0; ox < numParams; ox++) {
252                         int to = state->paramMap[ox];
253                         if (to == -1) continue;
254
255                         // Need to check because this can happen if
256                         // lbounds/ubounds are not set appropriately.
257                         if (0 && !isfinite(deriv0[ox])) {
258                                 int item = ox / itemParam->rows;
259                                 mxLog("item parameters:\n");
260                                 const double *spec = omxMatrixColumn(itemSpec, item);
261                                 int id = spec[RPF_ISpecID];
262                                 int numParam = (*rpf_model[id].numParam)(spec);
263                                 double *iparam = omxMatrixColumn(itemParam, item);
264                                 pda(iparam, numParam, 1);
265                                 // Perhaps bounds can be pulled in from librpf? TODO
266                                 error("Deriv %d for item %d is %f; are you missing a lbound/ubound?",
267                                       ox, item, deriv0[ox]);
268                         }
269
270                         if (to < numFreeParams) {
271                                 gradient[to] -= deriv0[ox];
272                         } else {
273                                 hessian[to - numFreeParams] -= deriv0[ox];
274                         }
275                 }
276         }
277
278         return -ll;
279 }
280
281 static void
282 moveLatentDistribution(omxFitFunction *oo, FitContext *fc,
283                        double *ElatentMean, double *ElatentCov)
284 {
285         BA81FitState *state = (BA81FitState*) oo->argStruct;
286         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
287         omxMatrix *itemSpec = estate->itemSpec;
288         omxMatrix *itemParam = state->itemParam;
289         omxMatrix *design = estate->design;
290         double *tmpLatentMean = state->tmpLatentMean;
291         double *tmpLatentCov = state->tmpLatentCov;
292         int maxDims = estate->maxDims;
293         int maxAbilities = estate->maxAbilities;
294
295         int numItems = itemParam->cols;
296         for (int ix=0; ix < numItems; ix++) {
297                 const double *spec = omxMatrixColumn(itemSpec, ix);
298                 int id = spec[RPF_ISpecID];
299                 const double *rawDesign = omxMatrixColumn(design, ix);
300                 int idesign[design->rows];
301                 int idx = 0;
302                 for (int dx=0; dx < design->rows; dx++) {
303                         if (isfinite(rawDesign[dx])) {
304                                 idesign[idx++] = rawDesign[dx]-1;
305                         } else {
306                                 idesign[idx++] = -1;
307                         }
308                 }
309                 for (int d1=0; d1 < idx; d1++) {
310                         if (idesign[d1] == -1) {
311                                 tmpLatentMean[d1] = 0;
312                         } else {
313                                 tmpLatentMean[d1] = ElatentMean[idesign[d1]];
314                         }
315                         for (int d2=0; d2 <= d1; d2++) {
316                                 int cell = idesign[d2] * maxAbilities + idesign[d1];
317                                 if (idesign[d1] == -1 || idesign[d2] == -1) {
318                                         tmpLatentCov[d2 * maxDims + d1] = d1==d2? 1 : 0;
319                                 } else {
320                                         tmpLatentCov[d2 * maxDims + d1] = ElatentCov[cell];
321                                 }
322                         }
323                 }
324                 if (1) {  // ease debugging, make optional TODO
325                         for (int d1=idx; d1 < maxDims; d1++) tmpLatentMean[d1] = nan("");
326                         for (int d1=0; d1 < maxDims; d1++) {
327                                 for (int d2=0; d2 < maxDims; d2++) {
328                                         if (d1 < idx && d2 < idx) continue;
329                                         tmpLatentCov[d2 * maxDims + d1] = nan("");
330                                 }
331                         }
332                 }
333                 double *iparam = omxMatrixColumn(itemParam, ix);
334                 int *mask = state->paramMap + state->derivPadSize * ix;
335                 rpf_model[id].rescale(spec, iparam, mask, tmpLatentMean, tmpLatentCov);
336         }
337
338         int numFreeParams = int(oo->freeVarGroup->vars.size());
339         for (int rx=0; rx < itemParam->rows; rx++) {
340                 for (int cx=0; cx < itemParam->cols; cx++) {
341                         int vx = state->paramMap[cx * state->derivPadSize + rx];
342                         if (vx >= 0 && vx < numFreeParams) {
343                                 fc->est[vx] = omxMatrixElement(itemParam, rx, cx);
344                         }
345                 }
346         }
347 }
348
349 static void
350 schilling_bock_2005_rescale(omxFitFunction *oo, FitContext *fc)
351 {
352         BA81FitState *state = (BA81FitState*) oo->argStruct;
353         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
354         double *ElatentMean = estate->ElatentMean;
355         double *ElatentCov = estate->ElatentCov;
356         int maxAbilities = estate->maxAbilities;
357
358         //mxLog("schilling bock\n");
359         //pda(ElatentMean, maxAbilities, 1);
360         //pda(ElatentCov, maxAbilities, maxAbilities);
361         //omxPrint(design, "design");
362
363         // use omxDPOTRF instead? TODO
364         const char triangle = 'L';
365         F77_CALL(dpotrf)(&triangle, &maxAbilities, ElatentCov, &maxAbilities, &state->choleskyError);
366         if (state->choleskyError != 0) {
367                 warning("Cholesky failed with %d; rescaling disabled", state->choleskyError); // make error TODO?
368                 return;
369         }
370
371         moveLatentDistribution(oo, fc, ElatentMean, ElatentCov);
372         fc->copyParamToModel(globalState);
373 }
374
375 OMXINLINE static void
376 updateLatentParam(omxFitFunction* oo, FitContext *fc)
377 {
378         BA81FitState *state = (BA81FitState*) oo->argStruct;
379         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
380         int maxAbilities = estate->maxAbilities;
381
382         // TODO need denom for multigroup
383         for (int a1=0; a1 < maxAbilities; ++a1) {
384                 if (state->latentMeanMap[a1] >= 0) {
385                         double val = estate->ElatentMean[a1];
386                         int vx = state->latentMeanMap[a1];
387                         omxFreeVar *fv = fc->varGroup->vars[vx];
388                         if (val < fv->lbound) val = fv->lbound;
389                         if (val > fv->ubound) val = fv->ubound;
390                         fc->est[vx] = val;
391                 }
392                 for (int a2=0; a2 < maxAbilities; ++a2) {
393                         int cell = a2 * maxAbilities + a1;
394                         if (state->latentCovMap[cell] < 0) continue;
395                         double val = estate->ElatentCov[cell];
396                         int vx = state->latentCovMap[cell];
397                         omxFreeVar *fv = fc->varGroup->vars[vx];
398                         if (val < fv->lbound) val = fv->lbound;
399                         if (val > fv->ubound) val = fv->ubound;
400                         fc->est[vx] = val;
401                 }
402         }
403
404         fc->copyParamToModel(globalState);
405 }
406
407 void ba81SetFreeVarGroup(omxFitFunction *oo, FreeVarGroup *fvg) // too ad hoc? TODO
408 {
409         if (!oo->argStruct) { // ugh!
410                 BA81FitState *state = new BA81FitState;
411                 oo->argStruct = state;
412         }
413
414         BA81FitState *state = (BA81FitState*) oo->argStruct;
415
416         state->varGroups.push_back(fvg);
417         if (state->varGroups.size() == 2) {
418                 int small = 0;
419                 if (state->varGroups[0] == state->varGroups[1])
420                         warning("Cannot recognize correct free parameter groups");
421                 if (state->varGroups[0]->vars.size() > state->varGroups[1]->vars.size())
422                         small = 1;
423                 oo->freeVarGroup = state->varGroups[small];
424                 state->latentFVG = state->varGroups[!small];
425         } else if (state->varGroups.size() > 2) {
426                 // ignore
427         }
428 }
429
430 static double
431 ba81ComputeFit(omxFitFunction* oo, int want, FitContext *fc)
432 {
433         BA81FitState *state = (BA81FitState*) oo->argStruct;
434
435         if (!state->paramMap) buildParamMap(oo);
436
437         if (want & FF_COMPUTE_PREOPTIMIZE) {
438                 if (state->rescale) schilling_bock_2005_rescale(oo, fc); // how does this work in multigroup? TODO
439                 return 0;
440         }
441
442         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
443
444         ++state->fitCount;
445
446         if (estate->type == EXPECTATION_AUGMENTED) {
447                 if (fc->varGroup != oo->freeVarGroup) error("FreeVarGroup mismatch");
448
449                 if (want & FF_COMPUTE_GRADIENT) ++state->gradientCount;
450
451                 omxMatrix *itemParam = state->itemParam;
452                 OMXZERO(state->thrDeriv, state->derivPadSize * itemParam->cols * Global->numThreads);
453
454                 for (size_t nx=0; nx < state->NAtriangle.size(); ++nx) {
455                         fc->hess[ state->NAtriangle[nx] ] = nan("symmetric");
456                 }
457
458                 double got = ba81ComputeMFit1(oo, want, fc->grad, fc->hess);
459                 return got;
460         } else if (estate->type == EXPECTATION_OBSERVED) {
461                 if (fc->varGroup != state->latentFVG) error("FreeVarGroup mismatch");
462
463                 updateLatentParam(oo, fc);
464
465                 double *patternLik = estate->patternLik;
466                 int *numIdentical = estate->numIdentical;
467                 int numUnique = estate->numUnique;
468                 double got = 0;
469                 for (int ux=0; ux < numUnique; ux++) {
470                         got += numIdentical[ux] * patternLik[ux];
471                 }
472                 return -2 * got;
473         }
474 }
475
476 static void ba81Compute(omxFitFunction *oo, int want, FitContext *fc)
477 {
478         if (!want) return;
479         oo->matrix->data[0] = ba81ComputeFit(oo, want, fc);
480 }
481
482 static void ba81Destroy(omxFitFunction *oo) {
483         BA81FitState *state = (BA81FitState *) oo->argStruct;
484
485         omxFreeAllMatrixData(state->customPrior);
486         Free(state->paramMap);
487         Free(state->thrDeriv);
488         Free(state->tmpLatentMean);
489         Free(state->tmpLatentCov);
490         omxFreeAllMatrixData(state->itemParam);
491         delete state;
492 }
493
494 void omxInitFitFunctionBA81(omxFitFunction* oo)
495 {
496         if (!oo->argStruct) { // ugh!
497                 BA81FitState *state = new BA81FitState;
498                 oo->argStruct = state;
499         }
500
501         BA81FitState *state = (BA81FitState*) oo->argStruct;
502         SEXP rObj = oo->rObj;
503
504         omxExpectation *expectation = oo->expectation;
505         BA81Expect *estate = (BA81Expect*) expectation->argStruct;
506
507         //newObj->data = oo->expectation->data;
508
509         oo->computeFun = ba81Compute;
510         oo->setVarGroup = ba81SetFreeVarGroup;
511         oo->destructFun = ba81Destroy;
512         oo->gradientAvailable = TRUE;
513         oo->hessianAvailable = TRUE;
514
515         SEXP tmp;
516         PROTECT(tmp = GET_SLOT(rObj, install("rescale")));
517         state->rescale = asLogical(tmp);
518
519         state->itemParam =
520                 omxNewMatrixFromSlot(rObj, globalState, "ItemParam");
521
522         if (estate->EitemParam->rows != state->itemParam->rows ||
523             estate->EitemParam->cols != state->itemParam->cols) {
524                 error("ItemParam and EItemParam matrices must be the same dimension");
525         }
526
527         state->customPrior =
528                 omxNewMatrixFromSlot(rObj, globalState, "CustomPrior");
529         
530         int maxParam = state->itemParam->rows;
531         state->derivPadSize = maxParam + maxParam*(1+maxParam)/2;
532
533         state->tmpLatentMean = Realloc(NULL, estate->maxDims, double);
534         state->tmpLatentCov = Realloc(NULL, estate->maxDims * estate->maxDims, double);
535
536         int numItems = state->itemParam->cols;
537         for (int ix=0; ix < numItems; ix++) {
538                 double *spec = omxMatrixColumn(estate->itemSpec, ix);
539                 int id = spec[RPF_ISpecID];
540                 if (id < 0 || id >= rpf_numModels) {
541                         error("ItemSpec column %d has unknown item model %d", ix, id);
542                 }
543         }
544 }