No need to report EM.LL in a special way
[openmx:openmx.git] / src / omxFitFunctionBA81.cpp
1 /*
2   Copyright 2012-2013 Joshua Nathaniel Pritikin and contributors
3
4   This is free software: you can redistribute it and/or modify
5   it under the terms of the GNU General Public License as published by
6   the Free Software Foundation, either version 3 of the License, or
7   (at your option) any later version.
8
9   This program is distributed in the hope that it will be useful,
10   but WITHOUT ANY WARRANTY; without even the implied warranty of
11   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12   GNU General Public License for more details.
13
14   You should have received a copy of the GNU General Public License
15   along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #include "omxFitFunction.h"
19 #include "omxExpectationBA81.h"
20 #include "omxOpenmpWrap.h"
21 #include "libifa-rpf.h"
22
23 static const char *NAME = "FitFunctionBA81";
24
25 struct BA81FitState {
26
27         omxMatrix *itemParam;     // M step version
28         int derivPadSize;         // maxParam + maxParam*(1+maxParam)/2
29         double *thrDeriv;         // itemParam->cols * derivPadSize * thread
30         int *paramMap;            // itemParam->cols * derivPadSize -> index of free parameter
31         bool rescale;
32         omxMatrix *customPrior;
33         int choleskyError;
34         double *tmpLatentMean;    // maxDims
35         double *tmpLatentCov;     // maxDims * maxDims ; only lower triangle is used
36         int fitCount;
37         int gradientCount;
38
39         std::vector< FreeVarGroup* > varGroups;
40         FreeVarGroup *latentFVG;
41
42         BA81FitState();
43 };
44
45 BA81FitState::BA81FitState()
46 {
47         paramMap = NULL;
48         latentFVG = NULL;
49         customPrior = NULL;
50         fitCount = 0;
51         gradientCount = 0;
52 }
53
54 static void buildParamMap(omxFitFunction* oo)
55 {
56         BA81FitState *state = (BA81FitState *) oo->argStruct;
57         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
58         omxMatrix *itemParam = state->itemParam;
59         int size = itemParam->cols * state->derivPadSize;
60
61         state->paramMap = Realloc(NULL, size, int);  // matrix location to free param index
62         for (int px=0; px < size; px++) {
63                 state->paramMap[px] = -1;
64         }
65
66         size_t numFreeParams = oo->freeVarGroup->vars.size();
67         int *pRow = Realloc(NULL, numFreeParams, int);
68         int *pCol = Realloc(NULL, numFreeParams, int);
69
70         for (size_t px=0; px < numFreeParams; px++) {
71                 pRow[px] = -1;
72                 pCol[px] = -1;
73                 omxFreeVar *fv = oo->freeVarGroup->vars[px];
74                 for (size_t lx=0; lx < fv->locations.size(); lx++) {
75                         omxFreeVarLocation *loc = &fv->locations[lx];
76                         if (~loc->matrix == itemParam->matrixNumber) {
77                                 pRow[px] = loc->row;
78                                 pCol[px] = loc->col;
79                                 int at = pCol[px] * state->derivPadSize + pRow[px];
80                                 state->paramMap[at] = px;
81                         }
82                 }
83         }
84
85         for (size_t p1=0; p1 < numFreeParams; p1++) {
86                 for (size_t p2=p1; p2 < numFreeParams; p2++) {
87                         if (pCol[p1] == -1 || pCol[p1] != pCol[p2]) continue;
88                         const double *spec = omxMatrixColumn(estate->itemSpec, pCol[p1]);
89                         int id = spec[RPF_ISpecID];
90                         int numParam = (*rpf_model[id].numParam)(spec);
91                         int r1 = pRow[p1];
92                         int r2 = pRow[p2];
93                         if (r1 > r2) { int tmp=r1; r1=r2; r2=tmp; }
94                         int rowOffset = 0;
95                         for (int rx=1; rx <= r2; rx++) rowOffset += rx;
96                         int at = pCol[p1] * state->derivPadSize + numParam + rowOffset + r1;
97                         state->paramMap[at] = numFreeParams + p1 * numFreeParams + p2;
98                 }
99         }
100
101         Free(pRow);
102         Free(pCol);
103
104         state->thrDeriv = Realloc(NULL, itemParam->cols * state->derivPadSize * Global->numThreads, double);
105 }
106
107 OMXINLINE static double
108 ba81Fit1Ordinate(omxFitFunction* oo, const int *quad, const double *weight, int want)
109 {
110         BA81FitState *state = (BA81FitState*) oo->argStruct;
111         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
112         omxMatrix *itemSpec = estate->itemSpec;
113         omxMatrix *itemParam = state->itemParam;
114         int numItems = itemParam->cols;
115         int maxOutcomes = estate->maxOutcomes;
116         int maxDims = estate->maxDims;
117         double *myDeriv = state->thrDeriv + itemParam->cols * state->derivPadSize * omx_absolute_thread_num();
118         int do_deriv = want & (FF_COMPUTE_GRADIENT | FF_COMPUTE_HESSIAN);
119
120         double where[maxDims];
121         pointToWhere(estate->Qpoint, quad, where, maxDims);
122
123         double *outcomeProb = computeRPF(estate->itemSpec, estate->design, itemParam, estate->maxDims,
124                                          estate->maxOutcomes, quad, estate->Qpoint); // avoid malloc/free? TODO
125         if (!outcomeProb) return 0;
126
127         double thr_ll = 0;
128         for (int ix=0; ix < numItems; ix++) {
129                 const double *spec = omxMatrixColumn(itemSpec, ix);
130                 int id = spec[RPF_ISpecID];
131                 int iOutcomes = spec[RPF_ISpecOutcomes];
132
133                 double area = exp(logAreaProduct(estate, quad, estate->Sgroup[ix]));   // avoid exp() here? TODO
134                 for (int ox=0; ox < iOutcomes; ox++) {
135 #if 0
136 #pragma omp critical(ba81Fit1OrdinateDebug1)
137                         if (!isfinite(outcomeProb[ix * maxOutcomes + ox])) {
138                                 pda(itemParam->data, itemParam->rows, itemParam->cols);
139                                 pda(outcomeProb, outcomes, numItems);
140                                 error("RPF produced NAs");
141                         }
142 #endif
143                         double got = weight[ox] * outcomeProb[ix * maxOutcomes + ox];
144                         thr_ll += got * area;
145                 }
146
147                 if (do_deriv) {
148                         double *iparam = omxMatrixColumn(itemParam, ix);
149                         double *pad = myDeriv + ix * state->derivPadSize;
150                         (*rpf_model[id].dLL1)(spec, iparam, where, area, weight, pad);
151                 }
152                 weight += iOutcomes;
153         }
154
155         Free(outcomeProb);
156
157         return thr_ll;
158 }
159
160 static double
161 ba81ComputeMFit1(omxFitFunction* oo, int want, double *gradient, double *hessian)
162 {
163         BA81FitState *state = (BA81FitState*) oo->argStruct;
164         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
165         omxMatrix *customPrior = state->customPrior;
166         omxMatrix *itemParam = state->itemParam;
167         omxMatrix *itemSpec = estate->itemSpec;
168         int maxDims = estate->maxDims;
169         const int totalOutcomes = estate->totalOutcomes;
170
171         double ll = 0;
172         if (customPrior) {
173                 omxRecompute(customPrior);
174                 ll = customPrior->data[0];
175                 // need deriv adjustment TODO
176         }
177
178         if (!isfinite(ll)) {
179                 omxPrint(itemParam, "item param");
180                 error("Bayesian prior returned %g; do you need to add a lbound/ubound?", ll);
181         }
182
183 #pragma omp parallel for num_threads(Global->numThreads)
184         for (long qx=0; qx < estate->totalQuadPoints; qx++) {
185                 //double area = exp(state->priLogQarea[qx]);  // avoid exp() here? TODO
186                 int quad[maxDims];
187                 decodeLocation(qx, maxDims, estate->quadGridSize, quad);
188                 double *weight = estate->expected + qx * totalOutcomes;
189                 double thr_ll = ba81Fit1Ordinate(oo, quad, weight, want);
190                 
191 #pragma omp atomic
192                 ll += thr_ll;
193         }
194
195         if (gradient) {
196                 double *deriv0 = state->thrDeriv;
197
198                 int perThread = itemParam->cols * state->derivPadSize;
199                 for (int th=1; th < Global->numThreads; th++) {
200                         double *thrD = state->thrDeriv + th * perThread;
201                         for (int ox=0; ox < perThread; ox++) deriv0[ox] += thrD[ox];
202                 }
203
204                 int numItems = itemParam->cols;
205                 for (int ix=0; ix < numItems; ix++) {
206                         const double *spec = omxMatrixColumn(itemSpec, ix);
207                         int id = spec[RPF_ISpecID];
208                         double *iparam = omxMatrixColumn(itemParam, ix);
209                         double *pad = deriv0 + ix * state->derivPadSize;
210                         (*rpf_model[id].dLL2)(spec, iparam, pad);
211                 }
212
213                 int numFreeParams = int(oo->freeVarGroup->vars.size());
214                 int numParams = itemParam->cols * state->derivPadSize;
215                 for (int ox=0; ox < numParams; ox++) {
216                         int to = state->paramMap[ox];
217                         if (to == -1) continue;
218
219                         // Need to check because this can happen if
220                         // lbounds/ubounds are not set appropriately.
221                         if (0 && !isfinite(deriv0[ox])) {
222                                 int item = ox / itemParam->rows;
223                                 mxLog("item parameters:\n");
224                                 const double *spec = omxMatrixColumn(itemSpec, item);
225                                 int id = spec[RPF_ISpecID];
226                                 int numParam = (*rpf_model[id].numParam)(spec);
227                                 double *iparam = omxMatrixColumn(itemParam, item);
228                                 pda(iparam, numParam, 1);
229                                 // Perhaps bounds can be pulled in from librpf? TODO
230                                 error("Deriv %d for item %d is %f; are you missing a lbound/ubound?",
231                                       ox, item, deriv0[ox]);
232                         }
233
234                         if (to < numFreeParams) {
235                                 gradient[to] -= deriv0[ox];
236                         } else {
237                                 hessian[to - numFreeParams] -= deriv0[ox];
238                         }
239                 }
240         }
241
242         return -ll;
243 }
244
245 static void
246 schilling_bock_2005_rescale(omxFitFunction *oo, FitContext *fc)
247 {
248         BA81FitState *state = (BA81FitState*) oo->argStruct;
249         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
250         omxMatrix *itemSpec = estate->itemSpec;
251         omxMatrix *itemParam = state->itemParam;
252         omxMatrix *design = estate->design;
253         double *ElatentMean = estate->ElatentMean;
254         double *ElatentCov = estate->ElatentCov;
255         double *tmpLatentMean = state->tmpLatentMean;
256         double *tmpLatentCov = state->tmpLatentCov;
257         int maxAbilities = estate->maxAbilities;
258         int maxDims = estate->maxDims;
259
260         //mxLog("schilling bock\n");
261         //pda(ElatentMean, maxAbilities, 1);
262         //pda(ElatentCov, maxAbilities, maxAbilities);
263         //omxPrint(design, "design");
264
265         const char triangle = 'L';
266         F77_CALL(dpotrf)(&triangle, &maxAbilities, ElatentCov, &maxAbilities, &state->choleskyError);
267         if (state->choleskyError != 0) {
268                 warning("Cholesky failed with %d; rescaling disabled", state->choleskyError); // make error TODO?
269                 return;
270         }
271
272         //fc->log(FF_COMPUTE_ESTIMATE);
273
274         int numItems = itemParam->cols;
275         for (int ix=0; ix < numItems; ix++) {
276                 const double *spec = omxMatrixColumn(itemSpec, ix);
277                 int id = spec[RPF_ISpecID];
278                 const double *rawDesign = omxMatrixColumn(design, ix);
279                 int idesign[design->rows];
280                 int idx = 0;
281                 for (int dx=0; dx < design->rows; dx++) {
282                         if (isfinite(rawDesign[dx])) {
283                                 idesign[idx++] = rawDesign[dx]-1;
284                         } else {
285                                 idesign[idx++] = -1;
286                         }
287                 }
288                 for (int d1=0; d1 < idx; d1++) {
289                         if (idesign[d1] == -1) {
290                                 tmpLatentMean[d1] = 0;
291                         } else {
292                                 tmpLatentMean[d1] = ElatentMean[idesign[d1]];
293                         }
294                         for (int d2=0; d2 <= d1; d2++) {
295                                 int cell = idesign[d2] * maxAbilities + idesign[d1];
296                                 if (idesign[d1] == -1 || idesign[d2] == -1) {
297                                         tmpLatentCov[d2 * maxDims + d1] = d1==d2? 1 : 0;
298                                 } else {
299                                         tmpLatentCov[d2 * maxDims + d1] = ElatentCov[cell];
300                                 }
301                         }
302                 }
303                 if (1) {  // ease debugging, make optional TODO
304                         for (int d1=idx; d1 < maxDims; d1++) tmpLatentMean[d1] = nan("");
305                         for (int d1=0; d1 < maxDims; d1++) {
306                                 for (int d2=0; d2 < maxDims; d2++) {
307                                         if (d1 < idx && d2 < idx) continue;
308                                         tmpLatentCov[d2 * maxDims + d1] = nan("");
309                                 }
310                         }
311                 }
312                 double *iparam = omxMatrixColumn(itemParam, ix);
313                 int *mask = state->paramMap + state->derivPadSize * ix;
314                 rpf_model[id].rescale(spec, iparam, mask, tmpLatentMean, tmpLatentCov);
315         }
316
317         int numFreeParams = int(oo->freeVarGroup->vars.size());
318         for (int rx=0; rx < itemParam->rows; rx++) {
319                 for (int cx=0; cx < itemParam->cols; cx++) {
320                         int vx = state->paramMap[cx * state->derivPadSize + rx];
321                         if (vx >= 0 && vx < numFreeParams) {
322                                 fc->est[vx] = omxMatrixElement(itemParam, rx, cx);
323                         }
324                 }
325         }
326         fc->copyParamToModel(globalState);
327         //fc->log(FF_COMPUTE_ESTIMATE);
328 }
329
330 OMXINLINE static void
331 updateLatentParam(omxFitFunction* oo, FitContext *fc)
332 {
333         BA81FitState *state = (BA81FitState*) oo->argStruct;
334         BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
335         int maxAbilities = estate->maxAbilities;
336         int meanNum = estate->latentMeanOut->matrixNumber;
337         int covNum = estate->latentCovOut->matrixNumber;
338         FreeVarGroup *latentFVG = state->latentFVG;
339
340         // TODO need denom for multigroup
341         size_t numFreeParams = latentFVG->vars.size();
342         for (size_t px=0; px < numFreeParams; px++) {
343                 omxFreeVar *fv = latentFVG->vars[px];
344                 for (size_t lx=0; lx < fv->locations.size(); ++lx) {
345                         omxFreeVarLocation *loc = &fv->locations[lx];
346                         int matNum = ~loc->matrix;
347                         if (matNum == meanNum) {
348                                 int dx = loc->row * loc->col;
349                                 fc->est[px] = estate->ElatentMean[dx];
350                         } else if (matNum == covNum) {
351                                 int cell = loc->col * maxAbilities + loc->row;
352                                 fc->est[px] = estate->ElatentCov[cell];
353                         }
354                 }
355         }
356         fc->copyParamToModel(globalState);
357 }
358
359 void ba81SetFreeVarGroup(omxFitFunction *oo, FreeVarGroup *fvg) // too ad hoc? TODO
360 {
361         if (!oo->argStruct) { // ugh!
362                 BA81FitState *state = new BA81FitState;
363                 oo->argStruct = state;
364         }
365
366         BA81FitState *state = (BA81FitState*) oo->argStruct;
367
368         state->varGroups.push_back(fvg);
369         if (state->varGroups.size() == 2) {
370                 int small = 0;
371                 if (state->varGroups[0]->vars.size() > state->varGroups[1]->vars.size())
372                         small = 1;
373                 oo->freeVarGroup = state->varGroups[small];
374                 state->latentFVG = state->varGroups[!small];
375         } else if (state->varGroups.size() > 2) {
376                 // ignore
377         }
378 }
379
380 static double
381 ba81ComputeFit(omxFitFunction* oo, int want, FitContext *fc)
382 {
383         if (!want) return 0;
384
385         BA81FitState *state = (BA81FitState*) oo->argStruct;
386
387         if (!state->paramMap) buildParamMap(oo);
388
389         if (want & FF_COMPUTE_PREOPTIMIZE) {
390                 if (state->rescale) schilling_bock_2005_rescale(oo, fc);
391                 return 0;
392         }
393
394         if (want & FF_COMPUTE_FIT) {
395                 ++state->fitCount;
396         }
397
398         if (want & (FF_COMPUTE_GRADIENT|FF_COMPUTE_HESSIAN)) {
399                 // M-step
400
401                 ++state->gradientCount;
402
403                 size_t numFreeParams = oo->freeVarGroup->vars.size();
404                 double *gradient = fc->grad;
405                 double *hessian = fc->hess;
406                 OMXZERO(gradient, numFreeParams);
407                 OMXZERO(hessian, numFreeParams * numFreeParams);
408
409                 omxMatrix *itemParam = state->itemParam;
410                 OMXZERO(state->thrDeriv, state->derivPadSize * itemParam->cols * Global->numThreads);
411
412                 double got = ba81ComputeMFit1(oo, want, gradient, hessian);
413                 return got;
414         } else {
415                 // Major EM iteration, note completely different LL calculation
416
417                 updateLatentParam(oo, fc);
418
419                 BA81Expect *estate = (BA81Expect*) oo->expectation->argStruct;
420                 double *patternLik = estate->patternLik;
421                 int *numIdentical = estate->numIdentical;
422                 int numUnique = estate->numUnique;
423                 double got = 0;
424                 for (int ux=0; ux < numUnique; ux++) {
425                         got += numIdentical[ux] * patternLik[ux];
426                 }
427                 return -2 * got;
428         }
429 }
430
431 static void ba81Compute(omxFitFunction *oo, int want, FitContext *fc)
432 {
433         oo->matrix->data[0] = ba81ComputeFit(oo, want, fc);
434 }
435
436 static void ba81Destroy(omxFitFunction *oo) {
437         BA81FitState *state = (BA81FitState *) oo->argStruct;
438
439         omxFreeAllMatrixData(state->customPrior);
440         Free(state->paramMap);
441         Free(state->thrDeriv);
442         Free(state->tmpLatentMean);
443         Free(state->tmpLatentCov);
444         omxFreeAllMatrixData(state->itemParam);
445         delete state;
446 }
447
448 static omxRListElement *ba81SetFinalReturns(omxFitFunction *oo, int *numReturns)
449 {
450         omxRListElement *ret = ba81EAP(oo->expectation, numReturns);
451
452         return ret;
453 }
454
455 void omxInitFitFunctionBA81(omxFitFunction* oo)
456 {
457         BA81FitState *state = (BA81FitState*) oo->argStruct;
458         SEXP rObj = oo->rObj;
459
460         omxExpectation *expectation = oo->expectation;
461         BA81Expect *estate = (BA81Expect*) expectation->argStruct;
462
463         //newObj->data = oo->expectation->data;
464
465         oo->computeFun = ba81Compute;
466         oo->setVarGroup = ba81SetFreeVarGroup;
467         oo->setFinalReturns = ba81SetFinalReturns;
468         oo->destructFun = ba81Destroy;
469
470         SEXP tmp;
471         PROTECT(tmp = GET_SLOT(rObj, install("rescale")));
472         state->rescale = asLogical(tmp);
473
474         state->itemParam =
475                 omxNewMatrixFromSlot(rObj, globalState, "ItemParam");
476
477         if (estate->EitemParam->rows != state->itemParam->rows ||
478             estate->EitemParam->cols != state->itemParam->cols) {
479                 error("ItemParam and EItemParam matrices must be the same dimension");
480         }
481
482         state->customPrior =
483                 omxNewMatrixFromSlot(rObj, globalState, "CustomPrior");
484         
485         int maxParam = state->itemParam->rows;
486         state->derivPadSize = maxParam + maxParam*(1+maxParam)/2;
487
488         state->tmpLatentMean = Realloc(NULL, estate->maxDims, double);
489         state->tmpLatentCov = Realloc(NULL, estate->maxDims * estate->maxDims, double);
490
491         int numItems = state->itemParam->cols;
492         for (int ix=0; ix < numItems; ix++) {
493                 double *spec = omxMatrixColumn(estate->itemSpec, ix);
494                 int id = spec[RPF_ISpecID];
495                 if (id < 0 || id >= rpf_numModels) {
496                         error("ItemSpec column %d has unknown item model %d", ix, id);
497                 }
498         }
499 }