Rewrite dependency tracking
[openmx:openmx.git] / src / Compute.cpp
1 /*
2  *  Copyright 2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include "omxDefines.h"
18 #include "Compute.h"
19 #include "omxState.h"
20 #include "omxExportBackendState.h"
21 #include "omxRFitFunction.h"
22
23 void FitContext::init()
24 {
25         size_t numParam = varGroup->vars.size();
26         fit = parent? parent->fit : 0;
27         est = new double[numParam];
28         grad = new double[numParam];
29         hess = new double[numParam * numParam];
30 }
31
32 FitContext::FitContext()
33 {
34         parent = NULL;
35         varGroup = Global->freeGroup[0];
36         init();
37
38         size_t numParam = varGroup->vars.size();
39         for (size_t v1=0; v1 < numParam; v1++) {
40                 est[v1] = Global->freeGroup[0]->vars[v1]->start;
41                 grad[v1] = nan("unset");
42                 for (size_t v2=0; v2 < numParam; v2++) {
43                         hess[v1 * numParam + v2] = nan("unset");
44                 }
45         }
46 }
47
48 // arg to control what to copy? usually don't want everything TODO
49 FitContext::FitContext(FitContext *parent, FreeVarGroup *varGroup)
50 {
51         this->parent = parent;
52         this->varGroup = varGroup;
53         init();
54
55         FreeVarGroup *src = parent->varGroup;
56         FreeVarGroup *dest = varGroup;
57         size_t svars = parent->varGroup->vars.size();
58         size_t dvars = varGroup->vars.size();
59
60         size_t d1 = 0;
61         for (size_t s1=0; s1 < src->vars.size(); ++s1) {
62                 if (src->vars[s1] != dest->vars[d1]) continue;
63                 est[d1] = parent->est[s1];
64                 grad[d1] = parent->grad[s1];
65
66                 size_t d2 = 0;
67                 for (size_t s2=0; s2 < src->vars.size(); ++s2) {
68                         if (src->vars[s2] != dest->vars[d2]) continue;
69                         hess[d1 * dvars + d2] = parent->hess[s1 * svars + s2];
70                         if (++d2 == dvars) break;
71                 }
72
73                 if (++d1 == dvars) break;
74         }
75         if (d1 != dvars) error("Parent free parameter group is not a superset");
76
77         // pda(parent->est, 1, svars);
78         // pda(est, 1, dvars);
79         // pda(parent->grad, 1, svars);
80         // pda(grad, 1, dvars);
81         // pda(parent->hess, svars, svars);
82         // pda(hess, dvars, dvars);
83 }
84
85 void FitContext::copyParamToModel(omxMatrix *mat)
86 { copyParamToModel(mat->currentState); }
87
88 void FitContext::copyParamToModel(omxMatrix *mat, double *at)
89 { copyParamToModel(mat->currentState, at); }
90
91 void FitContext::updateParentAndFree()
92 {
93         FreeVarGroup *src = varGroup;
94         FreeVarGroup *dest = parent->varGroup;
95         size_t svars = varGroup->vars.size();
96         size_t dvars = parent->varGroup->vars.size();
97
98         parent->fit = fit;
99
100         size_t s1 = 0;
101         for (size_t d1=0; d1 < dest->vars.size(); ++d1) {
102                 if (dest->vars[d1] != src->vars[s1]) continue;
103                 parent->est[d1] = est[s1];
104                 parent->grad[d1] = grad[s1];
105
106                 size_t s2 = 0;
107                 for (size_t d2=0; d2 < dest->vars.size(); ++d2) {
108                         if (dest->vars[d2] != src->vars[s2]) continue;
109                         parent->hess[d1 * dvars + d2] = hess[s1 * svars + s2];
110                         if (++s2 == svars) break;
111                 }
112
113                 if (++s1 == svars) break;
114         }
115         
116         // pda(est, 1, svars);
117         // pda(parent->est, 1, dvars);
118         // pda(grad, 1, svars);
119         // pda(parent->grad, 1, dvars);
120         // pda(hess, svars, svars);
121         // pda(parent->hess, dvars, dvars);
122
123         delete this;
124 }
125
126 void FitContext::log(const char *where, int what)
127 {
128         size_t count = varGroup->vars.size();
129         std::string buf(where);
130         buf += " ---\n";
131         if (what & FF_COMPUTE_FIT) buf += string_snprintf("fit: %.5f\n", fit);
132         if (what & FF_COMPUTE_ESTIMATE) {
133                 buf += "est: c(";
134                 for (size_t vx=0; vx < count; ++vx) {
135                         buf += string_snprintf("%.5f", est[vx]);
136                         if (vx < count - 1) buf += ", ";
137                 }
138                 buf += ")\n";
139         }
140         if (what & FF_COMPUTE_GRADIENT) {
141                 buf += "grad: c(";
142                 for (size_t vx=0; vx < count; ++vx) {
143                         buf += string_snprintf("%.5f", grad[vx]);
144                         if (vx < count - 1) buf += ", ";
145                 }
146                 buf += ")\n";
147         }
148         if (what & FF_COMPUTE_HESSIAN) {
149                 buf += "hess: c(";
150                 for (size_t v1=0; v1 < count; ++v1) {
151                         for (size_t v2=0; v2 < count; ++v2) {
152                                 buf += string_snprintf("%.5f", hess[v1 * count + v2]);
153                                 if (v1 < count-1 || v2 < count-1) buf += ", ";
154                         }
155                         buf += "\n";
156                 }
157                 buf += ")\n";
158         }
159         mxLogBig(buf);
160 }
161
162 void FitContext::fixHessianSymmetry()
163 {
164         // make non-symmetric entries symmetric, if possible
165         size_t numParam = varGroup->vars.size();
166         for (size_t h1=1; h1 < numParam; h1++) {
167                 for (size_t h2=0; h2 < h1; h2++) {
168                         double upper = hess[h1 * numParam + h2];
169                         double lower = hess[h2 * numParam + h1];
170                         if (isfinite(upper)) continue;
171                         if (isfinite(lower)) {
172                                 hess[h1 * numParam + h2] = lower;
173                         } else {
174                                 log("FitContext", FF_COMPUTE_ESTIMATE|FF_COMPUTE_GRADIENT|FF_COMPUTE_HESSIAN);
175                                 error("Hessian is not finite at [%d,%d]", h1,h2);
176                         }
177                 }
178         }
179 }
180
181 static void omxRepopulateRFitFunction(omxFitFunction* oo, double* x, int n)
182 {
183         omxRFitFunction* rFitFunction = (omxRFitFunction*)oo->argStruct;
184
185         SEXP theCall, estimate;
186
187         PROTECT(estimate = allocVector(REALSXP, n));
188         double *est = REAL(estimate);
189         for(int i = 0; i < n ; i++) {
190                 est[i] = x[i];
191         }
192
193         PROTECT(theCall = allocVector(LANGSXP, 4));
194
195         SETCAR(theCall, install("imxUpdateModelValues"));
196         SETCADR(theCall, rFitFunction->model);
197         SETCADDR(theCall, rFitFunction->flatModel);
198         SETCADDDR(theCall, estimate);
199
200         REPROTECT(rFitFunction->model = eval(theCall, R_GlobalEnv), rFitFunction->modelIndex);
201
202         UNPROTECT(2); // theCall, estimate
203 }
204
205 void FitContext::copyParamToModel(omxState* os)
206 {
207         copyParamToModel(os, est);
208 }
209
210 void FitContext::copyParamToModel(omxState* os, double *at)
211 {
212         size_t numParam = varGroup->vars.size();
213         if(OMX_DEBUG) {
214                 mxLog("Copying %d free parameter estimates to model %p", numParam, os);
215         }
216
217         if(numParam == 0) return;
218
219         os->computeCount++;
220
221         if(OMX_VERBOSE) {
222                 std::string buf;
223                 buf += string_snprintf("Call: %d.%d (%d) ", os->majorIteration, os->minorIteration, os->computeCount);
224                 buf += ("Estimates: [");
225                 for(size_t k = 0; k < numParam; k++) {
226                         buf += string_snprintf(" %f", at[k]);
227                 }
228                 buf += ("]\n");
229                 mxLogBig(buf);
230         }
231
232         for(size_t k = 0; k < numParam; k++) {
233                 omxFreeVar* freeVar = varGroup->vars[k];
234                 for(size_t l = 0; l < freeVar->locations.size(); l++) {
235                         omxFreeVarLocation *loc = &freeVar->locations[l];
236                         omxMatrix *matrix = os->matrixList[loc->matrix];
237                         int row = loc->row;
238                         int col = loc->col;
239                         omxSetMatrixElement(matrix, row, col, at[k]);
240                         if(OMX_DEBUG) {
241                                 mxLog("Setting location (%d, %d) of matrix %d to value %f for var %d",
242                                         row, col, loc->matrix, at[k], k);
243                         }
244                 }
245         }
246
247         if (RFitFunction) omxRepopulateRFitFunction(RFitFunction, at, numParam);
248
249         varGroup->markDirty(os);
250
251         if (!os->childList) return;
252
253         for(int i = 0; i < Global->numChildren; i++) {
254                 copyParamToModel(os->childList[i]);
255         }
256 }
257
258 FitContext::~FitContext()
259 {
260         delete [] est;
261         delete [] grad;
262         delete [] hess;
263 }
264
265 omxFitFunction *FitContext::RFitFunction = NULL;
266
267 void FitContext::setRFitFunction(omxFitFunction *rff)
268 {
269         if (rff) {
270                 Global->numThreads = 1;
271                 if (RFitFunction) {
272                         error("You can only create 1 MxRFitFunction per independent model");
273                 }
274         }
275         RFitFunction = rff;
276 }
277
278 omxCompute::~omxCompute()
279 {}
280
281 void omxComputeOperation::initFromFrontend(SEXP rObj)
282 {
283         SEXP slotValue;
284         PROTECT(slotValue = GET_SLOT(rObj, install("id")));
285         int id = INTEGER(slotValue)[0];
286         varGroup = Global->findVarGroup(id);
287         if (!varGroup) varGroup = Global->freeGroup[0];
288 }
289
290 class omxComputeSequence : public omxCompute {
291         std::vector< omxCompute* > clist;
292
293  public:
294         virtual void initFromFrontend(SEXP rObj);
295         virtual void compute(FitContext *fc);
296         virtual void reportResults(FitContext *fc, MxRList *out);
297         virtual double getOptimizerStatus();
298         virtual ~omxComputeSequence();
299 };
300
301 class omxComputeIterate : public omxCompute {
302         std::vector< omxCompute* > clist;
303         int maxIter;
304         double tolerance;
305         bool verbose;
306
307  public:
308         virtual void initFromFrontend(SEXP rObj);
309         virtual void compute(FitContext *fc);
310         virtual void reportResults(FitContext *fc, MxRList *out);
311         virtual double getOptimizerStatus();
312         virtual ~omxComputeIterate();
313 };
314
315 class omxComputeOnce : public omxComputeOperation {
316         typedef omxComputeOperation super;
317         std::vector< omxMatrix* > algebras;
318         std::vector< omxExpectation* > expectations;
319         bool start;
320         const char *context;
321         bool gradient;
322         bool hessian;
323
324  public:
325         virtual void initFromFrontend(SEXP rObj);
326         virtual void compute(FitContext *fc);
327         virtual void reportResults(FitContext *fc, MxRList *out);
328 };
329
330 static class omxCompute *newComputeSequence()
331 { return new omxComputeSequence(); }
332
333 static class omxCompute *newComputeIterate()
334 { return new omxComputeIterate(); }
335
336 static class omxCompute *newComputeOnce()
337 { return new omxComputeOnce(); }
338
339 struct omxComputeTableEntry {
340         char name[32];
341         omxCompute *(*ctor)();
342 };
343
344 static const struct omxComputeTableEntry omxComputeTable[] = {
345         {"MxComputeEstimatedHessian", &newComputeEstimatedHessian},
346         {"MxComputeGradientDescent", &newComputeGradientDescent},
347         {"MxComputeSequence", &newComputeSequence },
348         {"MxComputeIterate", &newComputeIterate },
349         {"MxComputeOnce", &newComputeOnce },
350         {"MxComputeNewtonRaphson", &newComputeNewtonRaphson},
351 };
352
353 omxCompute *omxNewCompute(omxState* os, const char *type)
354 {
355         omxCompute *got = NULL;
356
357         for (size_t fx=0; fx < OMX_STATIC_ARRAY_SIZE(omxComputeTable); fx++) {
358                 const struct omxComputeTableEntry *entry = omxComputeTable + fx;
359                 if(strcmp(type, entry->name) == 0) {
360                         got = entry->ctor();
361                         break;
362                 }
363         }
364
365         if (!got) error("Compute %s is not implemented", type);
366
367         return got;
368 }
369
370 void omxComputeSequence::initFromFrontend(SEXP rObj)
371 {
372         SEXP slotValue;
373         PROTECT(slotValue = GET_SLOT(rObj, install("steps")));
374
375         for (int cx = 0; cx < length(slotValue); cx++) {
376                 SEXP step = VECTOR_ELT(slotValue, cx);
377                 SEXP s4class;
378                 PROTECT(s4class = STRING_ELT(getAttrib(step, install("class")), 0));
379                 omxCompute *compute = omxNewCompute(globalState, CHAR(s4class));
380                 compute->initFromFrontend(step);
381                 if (isErrorRaised(globalState)) break;
382                 clist.push_back(compute);
383         }
384 }
385
386 void omxComputeSequence::compute(FitContext *fc)
387 {
388         for (size_t cx=0; cx < clist.size(); ++cx) {
389                 FitContext *context = fc;
390                 if (fc->varGroup != clist[cx]->varGroup) {
391                         context = new FitContext(fc, clist[cx]->varGroup);
392                 }
393                 clist[cx]->compute(context);
394                 if (context != fc) context->updateParentAndFree();
395                 if (isErrorRaised(globalState)) break;
396         }
397 }
398
399 void omxComputeSequence::reportResults(FitContext *fc, MxRList *out)
400 {
401         // put this stuff in a new list?
402         // merge with Iterate TODO
403         for (size_t cx=0; cx < clist.size(); ++cx) {
404                 FitContext *context = fc;
405                 if (fc->varGroup != clist[cx]->varGroup) {
406                         context = new FitContext(fc, clist[cx]->varGroup);
407                 }
408                 clist[cx]->reportResults(context, out);
409                 if (context != fc) context->updateParentAndFree();
410                 if (isErrorRaised(globalState)) break;
411         }
412 }
413
414 double omxComputeSequence::getOptimizerStatus()
415 {
416         // for backward compatibility, not indended to work generally
417         for (size_t cx=0; cx < clist.size(); ++cx) {
418                 double got = clist[cx]->getOptimizerStatus();
419                 if (got != NA_REAL) return got;
420         }
421         return NA_REAL;
422 }
423
424 omxComputeSequence::~omxComputeSequence()
425 {
426         for (size_t cx=0; cx < clist.size(); ++cx) {
427                 delete clist[cx];
428         }
429 }
430
431 void omxComputeIterate::initFromFrontend(SEXP rObj)
432 {
433         SEXP slotValue;
434
435         PROTECT(slotValue = GET_SLOT(rObj, install("maxIter")));
436         maxIter = INTEGER(slotValue)[0];
437
438         PROTECT(slotValue = GET_SLOT(rObj, install("tolerance")));
439         tolerance = REAL(slotValue)[0];
440         if (tolerance <= 0) error("tolerance must be positive");
441
442         PROTECT(slotValue = GET_SLOT(rObj, install("steps")));
443
444         for (int cx = 0; cx < length(slotValue); cx++) {
445                 SEXP step = VECTOR_ELT(slotValue, cx);
446                 SEXP s4class;
447                 PROTECT(s4class = STRING_ELT(getAttrib(step, install("class")), 0));
448                 omxCompute *compute = omxNewCompute(globalState, CHAR(s4class));
449                 compute->initFromFrontend(step);
450                 if (isErrorRaised(globalState)) break;
451                 clist.push_back(compute);
452         }
453
454         PROTECT(slotValue = GET_SLOT(rObj, install("verbose")));
455         verbose = asLogical(slotValue);
456 }
457
458 void omxComputeIterate::compute(FitContext *fc)
459 {
460         int iter = 0;
461         double prevFit = 0;
462         double change = tolerance * 10;
463         while (1) {
464                 for (size_t cx=0; cx < clist.size(); ++cx) {
465                         FitContext *context = fc;
466                         if (fc->varGroup != clist[cx]->varGroup) {
467                                 context = new FitContext(fc, clist[cx]->varGroup);
468                         }
469                         clist[cx]->compute(context);
470                         if (context != fc) context->updateParentAndFree();
471                         if (isErrorRaised(globalState)) break;
472                 }
473                 if (fc->fit == 0) {
474                         warning("Fit estimated at 0; something is wrong");
475                         break;
476                 }
477                 if (prevFit != 0) {
478                         change = prevFit - fc->fit;
479                         if (verbose) mxLog("fit %.9g change %.9g", fc->fit, change);
480                 }
481                 prevFit = fc->fit;
482                 if (isErrorRaised(globalState) || ++iter > maxIter || fabs(change) < tolerance) break;
483         }
484 }
485
486 void omxComputeIterate::reportResults(FitContext *fc, MxRList *out)
487 {
488         for (size_t cx=0; cx < clist.size(); ++cx) {
489                 FitContext *context = fc;
490                 if (fc->varGroup != clist[cx]->varGroup) {
491                         context = new FitContext(fc, clist[cx]->varGroup);
492                 }
493                 clist[cx]->reportResults(context, out);
494                 if (context != fc) context->updateParentAndFree();
495                 if (isErrorRaised(globalState)) break;
496         }
497 }
498
499 double omxComputeIterate::getOptimizerStatus()
500 {
501         // for backward compatibility, not indended to work generally
502         for (size_t cx=0; cx < clist.size(); ++cx) {
503                 double got = clist[cx]->getOptimizerStatus();
504                 if (got != NA_REAL) return got;
505         }
506         return NA_REAL;
507 }
508
509 omxComputeIterate::~omxComputeIterate()
510 {
511         for (size_t cx=0; cx < clist.size(); ++cx) {
512                 delete clist[cx];
513         }
514 }
515
516 void omxComputeOnce::initFromFrontend(SEXP rObj)
517 {
518         super::initFromFrontend(rObj);
519
520         SEXP slotValue;
521         PROTECT(slotValue = GET_SLOT(rObj, install("what")));
522         for (int wx=0; wx < length(slotValue); ++wx) {
523                 int objNum = INTEGER(slotValue)[wx];
524                 if (objNum >= 0) {
525                         omxMatrix *algebra = globalState->algebraList[objNum];
526                         if (algebra->fitFunction) {
527                                 setFreeVarGroup(algebra->fitFunction, varGroup);
528                                 omxCompleteFitFunction(algebra);
529                         }
530                         algebras.push_back(algebra);
531                 } else {
532                         omxExpectation *expectation = globalState->expectationList[~objNum];
533                         setFreeVarGroup(expectation, varGroup);
534                         omxCompleteExpectation(expectation);
535                         expectations.push_back(expectation);
536                 }
537         }
538
539         context = "";
540
541         PROTECT(slotValue = GET_SLOT(rObj, install("context")));
542         if (length(slotValue) == 0) {
543                 // OK
544         } else if (length(slotValue) == 1) {
545                 SEXP elem;
546                 PROTECT(elem = STRING_ELT(slotValue, 0));
547                 context = CHAR(elem);
548         }
549
550         PROTECT(slotValue = GET_SLOT(rObj, install("gradient")));
551         gradient = asLogical(slotValue);
552
553         PROTECT(slotValue = GET_SLOT(rObj, install("hessian")));
554         hessian = asLogical(slotValue);
555
556         if (algebras.size() == 1 && algebras[0]->fitFunction) {
557                 omxFitFunction *ff = algebras[0]->fitFunction;
558                 if (gradient && !ff->gradientAvailable) {
559                         error("Gradient requested but not available");
560                 }
561                 if (hessian && !ff->hessianAvailable) {
562                         error("Hessian requested but not available");
563                 }
564         }
565
566         PROTECT(slotValue = GET_SLOT(rObj, install("start")));
567         start = asLogical(slotValue);
568 }
569
570 void omxComputeOnce::compute(FitContext *fc)
571 {
572         if (algebras.size()) {
573                 int want = FF_COMPUTE_FIT;
574                 size_t numParam = fc->varGroup->vars.size();
575                 if (gradient) {
576                         want |= FF_COMPUTE_GRADIENT;
577                         OMXZERO(fc->grad, numParam);
578                 }
579                 if (hessian) {
580                         want |= FF_COMPUTE_HESSIAN;
581                         OMXZERO(fc->hess, numParam * numParam);
582                 }
583
584                 for (size_t wx=0; wx < algebras.size(); ++wx) {
585                         omxMatrix *algebra = algebras[wx];
586                         if (algebra->fitFunction) {
587                                 if (start) {
588                                         omxFitFunctionCompute(algebra->fitFunction, FF_COMPUTE_PREOPTIMIZE, fc);
589                                         fc->copyParamToModel(globalState);
590                                 }
591
592                                 omxFitFunctionCompute(algebra->fitFunction, want, fc);
593                                 fc->fit = algebra->data[0];
594                                 if (hessian) fc->fixHessianSymmetry();
595                         } else {
596                                 omxForceCompute(algebra);
597                         }
598                 }
599         } else if (expectations.size()) {
600                 for (size_t wx=0; wx < expectations.size(); ++wx) {
601                         omxExpectation *expectation = expectations[wx];
602                         omxExpectationCompute(expectation, context);
603                 }
604         }
605 }
606
607 void omxComputeOnce::reportResults(FitContext *fc, MxRList *out)
608 {
609         if (algebras.size()==0 || algebras[0]->fitFunction == NULL) return;
610
611         omxMatrix *algebra = algebras[0];
612
613         omxPopulateFitFunction(algebra, out);
614
615         out->push_back(std::make_pair(mkChar("minimum"), ScalarReal(fc->fit)));
616         out->push_back(std::make_pair(mkChar("Minus2LogLikelihood"), ScalarReal(fc->fit)));
617
618         size_t numFree = fc->varGroup->vars.size();
619         if (numFree) {
620                 SEXP estimate;
621                 PROTECT(estimate = allocVector(REALSXP, numFree));
622                 memcpy(REAL(estimate), fc->est, sizeof(double)*numFree);
623                 out->push_back(std::make_pair(mkChar("estimate"), estimate));
624
625                 if (gradient) {
626                         SEXP Rgradient;
627                         PROTECT(Rgradient = allocVector(REALSXP, numFree));
628                         memcpy(REAL(Rgradient), fc->grad, sizeof(double) * numFree);
629                         out->push_back(std::make_pair(mkChar("gradient"), Rgradient));
630                 }
631
632                 if (hessian) {
633                         SEXP Rhessian;
634                         PROTECT(Rhessian = allocMatrix(REALSXP, numFree, numFree));
635                         memcpy(REAL(Rhessian), fc->hess, sizeof(double) * numFree * numFree);
636                         out->push_back(std::make_pair(mkChar("hessian"), Rhessian));
637                 }
638         }
639 }