Rework specification of free param groups per Tim's suggestion
[openmx:openmx.git] / src / Compute.cpp
1 /*
2  *  Copyright 2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include "omxDefines.h"
18 #include "Compute.h"
19 #include "omxState.h"
20 #include "omxExportBackendState.h"
21 #include "omxRFitFunction.h"
22
23 std::vector<int> FitContext::markMatrices;
24
25 void FitContext::init()
26 {
27         size_t numParam = varGroup->vars.size();
28         fit = parent? parent->fit : 0;
29         est = new double[numParam];
30         grad = new double[numParam];
31         hess = new double[numParam * numParam];
32 }
33
34 FitContext::FitContext()
35 {
36         parent = NULL;
37         varGroup = Global->freeGroup[0];
38         init();
39
40         size_t numParam = varGroup->vars.size();
41         for (size_t v1=0; v1 < numParam; v1++) {
42                 est[v1] = Global->freeGroup[0]->vars[v1]->start;
43                 grad[v1] = nan("unset");
44                 for (size_t v2=0; v2 < numParam; v2++) {
45                         hess[v1 * numParam + v2] = nan("unset");
46                 }
47         }
48 }
49
50 // arg to control what to copy? usually don't want everything TODO
51 FitContext::FitContext(FitContext *parent, FreeVarGroup *varGroup)
52 {
53         this->parent = parent;
54         this->varGroup = varGroup;
55         init();
56
57         FreeVarGroup *src = parent->varGroup;
58         FreeVarGroup *dest = varGroup;
59         size_t svars = parent->varGroup->vars.size();
60         size_t dvars = varGroup->vars.size();
61
62         size_t d1 = 0;
63         for (size_t s1=0; s1 < src->vars.size(); ++s1) {
64                 if (src->vars[s1] != dest->vars[d1]) continue;
65                 est[d1] = parent->est[s1];
66                 grad[d1] = parent->grad[s1];
67
68                 size_t d2 = 0;
69                 for (size_t s2=0; s2 < src->vars.size(); ++s2) {
70                         if (src->vars[s2] != dest->vars[d2]) continue;
71                         hess[d1 * dvars + d2] = parent->hess[s1 * svars + s2];
72                         if (++d2 == dvars) break;
73                 }
74
75                 if (++d1 == dvars) break;
76         }
77         if (d1 != dvars) error("Parent free parameter group is not a superset");
78
79         // pda(parent->est, 1, svars);
80         // pda(est, 1, dvars);
81         // pda(parent->grad, 1, svars);
82         // pda(grad, 1, dvars);
83         // pda(parent->hess, svars, svars);
84         // pda(hess, dvars, dvars);
85 }
86
87 void FitContext::copyParamToModel(omxMatrix *mat)
88 { copyParamToModel(mat->currentState); }
89
90 void FitContext::copyParamToModel(omxMatrix *mat, double *at)
91 { copyParamToModel(mat->currentState, at); }
92
93 void FitContext::updateParentAndFree()
94 {
95         FreeVarGroup *src = varGroup;
96         FreeVarGroup *dest = parent->varGroup;
97         size_t svars = varGroup->vars.size();
98         size_t dvars = parent->varGroup->vars.size();
99
100         size_t s1 = 0;
101         for (size_t d1=0; d1 < dest->vars.size(); ++d1) {
102                 if (dest->vars[d1] != src->vars[s1]) continue;
103                 parent->est[d1] = est[s1];
104                 parent->grad[d1] = grad[s1];
105
106                 size_t s2 = 0;
107                 for (size_t d2=0; d2 < dest->vars.size(); ++d2) {
108                         if (dest->vars[d2] != src->vars[s2]) continue;
109                         parent->hess[d1 * dvars + d2] = hess[s1 * svars + s2];
110                         if (++s2 == svars) break;
111                 }
112
113                 if (++s1 == svars) break;
114         }
115         
116         // pda(est, 1, svars);
117         // pda(parent->est, 1, dvars);
118         // pda(grad, 1, svars);
119         // pda(parent->grad, 1, dvars);
120         // pda(hess, svars, svars);
121         // pda(parent->hess, dvars, dvars);
122
123         delete this;
124 }
125
126 void FitContext::log(int what)
127 {
128         size_t count = varGroup->vars.size();
129         std::string buf;
130         if (what & FF_COMPUTE_FIT) buf += string_snprintf("fit: %.3f\n", fit);
131         if (what & FF_COMPUTE_ESTIMATE) {
132                 buf += "est: c(";
133                 for (size_t vx=0; vx < count; ++vx) {
134                         buf += string_snprintf("%.3f", est[vx]);
135                         if (vx < count - 1) buf += ", ";
136                 }
137                 buf += ")\n";
138         }
139         if (what & FF_COMPUTE_GRADIENT) {
140                 buf += "grad: c(";
141                 for (size_t vx=0; vx < count; ++vx) {
142                         buf += string_snprintf("%.3f", grad[vx]);
143                         if (vx < count - 1) buf += ", ";
144                 }
145                 buf += ")\n";
146         }
147         if (what & FF_COMPUTE_HESSIAN) {
148                 buf += "hess: c(";
149                 for (size_t v1=0; v1 < count; ++v1) {
150                         for (size_t v2=0; v2 < count; ++v2) {
151                                 buf += string_snprintf("%.3f", hess[v1 * count + v2]);
152                                 if (v1 < count-1 || v2 < count-1) buf += ", ";
153                         }
154                         buf += "\n";
155                 }
156                 buf += ")\n";
157         }
158         mxLogBig(buf);
159 }
160
161 void FitContext::cacheFreeVarDependencies()
162 {
163         omxState *os = globalState;
164         size_t numMats = os->matrixList.size();
165         size_t numAlgs = os->algebraList.size();
166
167         markMatrices.clear();
168         markMatrices.resize(numMats + numAlgs, 0);
169
170         // More efficient to use the appropriate group instead of the default group. TODO
171         FreeVarGroup *varGroup = Global->freeGroup[0];
172         for (size_t freeVarIndex = 0; freeVarIndex < varGroup->vars.size(); freeVarIndex++) {
173                 omxFreeVar* freeVar = varGroup->vars[freeVarIndex];
174                 int *deps   = freeVar->deps;
175                 int numDeps = freeVar->numDeps;
176                 for (int index = 0; index < numDeps; index++) {
177                         markMatrices[deps[index] + numMats] = 1;
178                 }
179         }
180 }
181
182 void FitContext::fixHessianSymmetry()
183 {
184         // make non-symmetric entries symmetric, if possible
185         size_t numParam = varGroup->vars.size();
186         for (size_t h1=1; h1 < numParam; h1++) {
187                 for (size_t h2=0; h2 < h1; h2++) {
188                         double upper = hess[h1 * numParam + h2];
189                         double lower = hess[h2 * numParam + h1];
190                         if (isfinite(upper)) continue;
191                         if (isfinite(lower)) {
192                                 hess[h1 * numParam + h2] = lower;
193                         } else {
194                                 log(FF_COMPUTE_ESTIMATE|FF_COMPUTE_GRADIENT|FF_COMPUTE_HESSIAN);
195                                 error("Hessian is not finite at [%d,%d]", h1,h2);
196                         }
197                 }
198         }
199 }
200
201 static void omxRepopulateRFitFunction(omxFitFunction* oo, double* x, int n)
202 {
203         omxRFitFunction* rFitFunction = (omxRFitFunction*)oo->argStruct;
204
205         SEXP theCall, estimate;
206
207         PROTECT(estimate = allocVector(REALSXP, n));
208         double *est = REAL(estimate);
209         for(int i = 0; i < n ; i++) {
210                 est[i] = x[i];
211         }
212
213         PROTECT(theCall = allocVector(LANGSXP, 4));
214
215         SETCAR(theCall, install("imxUpdateModelValues"));
216         SETCADR(theCall, rFitFunction->model);
217         SETCADDR(theCall, rFitFunction->flatModel);
218         SETCADDDR(theCall, estimate);
219
220         REPROTECT(rFitFunction->model = eval(theCall, R_GlobalEnv), rFitFunction->modelIndex);
221
222         UNPROTECT(2); // theCall, estimate
223 }
224
225 void FitContext::copyParamToModel(omxState* os)
226 {
227         copyParamToModel(os, est);
228 }
229
230 void FitContext::copyParamToModel(omxState* os, double *at)
231 {
232         size_t numParam = varGroup->vars.size();
233         if(OMX_DEBUG) {
234                 mxLog("Copying %d free parameter estimates to model %p", numParam, os);
235         }
236
237         if(numParam == 0) return;
238
239         size_t numMats = os->matrixList.size();
240         size_t numAlgs = os->algebraList.size();
241
242         os->computeCount++;
243
244         if(OMX_VERBOSE) {
245                 std::string buf;
246                 buf += string_snprintf("Call: %d.%d (%d) ", os->majorIteration, os->minorIteration, os->computeCount);
247                 buf += ("Estimates: [");
248                 for(size_t k = 0; k < numParam; k++) {
249                         buf += string_snprintf(" %f", at[k]);
250                 }
251                 buf += ("]\n");
252                 mxLogBig(buf);
253         }
254
255         for(size_t k = 0; k < numParam; k++) {
256                 omxFreeVar* freeVar = varGroup->vars[k];
257                 for(size_t l = 0; l < freeVar->locations.size(); l++) {
258                         omxFreeVarLocation *loc = &freeVar->locations[l];
259                         omxMatrix *matrix = os->matrixList[loc->matrix];
260                         int row = loc->row;
261                         int col = loc->col;
262                         omxSetMatrixElement(matrix, row, col, at[k]);
263                         if(OMX_DEBUG) {
264                                 mxLog("Setting location (%d, %d) of matrix %d to value %f for var %d",
265                                         row, col, loc->matrix, at[k], k);
266                         }
267                 }
268         }
269
270         if (RFitFunction) omxRepopulateRFitFunction(RFitFunction, at, numParam);
271
272         for(size_t i = 0; i < numMats; i++) {
273                 if (markMatrices[i]) {
274                         int offset = ~(i - numMats);
275                         omxMarkDirty(os->matrixList[offset]);
276                 }
277         }
278
279         for(size_t i = 0; i < numAlgs; i++) {
280                 if (markMatrices[i + numMats]) {
281                         omxMarkDirty(os->algebraList[i]);
282                 }
283         }
284
285         if (!os->childList) return;
286
287         for(int i = 0; i < Global->numChildren; i++) {
288                 copyParamToModel(os->childList[i]);
289         }
290 }
291
292 FitContext::~FitContext()
293 {
294         delete [] est;
295         delete [] grad;
296         delete [] hess;
297 }
298
299 omxFitFunction *FitContext::RFitFunction = NULL;
300
301 void FitContext::setRFitFunction(omxFitFunction *rff)
302 {
303         if (rff) {
304                 Global->numThreads = 1;
305                 if (RFitFunction) {
306                         error("You can only create 1 MxRFitFunction per independent model");
307                 }
308         }
309         RFitFunction = rff;
310 }
311
312 omxCompute::~omxCompute()
313 {}
314
315 void omxComputeOperation::initFromFrontend(SEXP rObj)
316 {
317         SEXP slotValue;
318         PROTECT(slotValue = GET_SLOT(rObj, install("id")));
319         int id = INTEGER(slotValue)[0];
320         varGroup = Global->findVarGroup(id);
321         if (!varGroup) varGroup = Global->freeGroup[0];
322 }
323
324 class omxComputeSequence : public omxCompute {
325         std::vector< omxCompute* > clist;
326
327  public:
328         virtual void initFromFrontend(SEXP rObj);
329         virtual void compute(FitContext *fc);
330         virtual void reportResults(FitContext *fc, MxRList *out);
331         virtual double getOptimizerStatus();
332         virtual ~omxComputeSequence();
333 };
334
335 class omxComputeIterate : public omxCompute {
336         std::vector< omxCompute* > clist;
337         int maxIter;
338         double tolerance;
339         bool verbose;
340
341  public:
342         virtual void initFromFrontend(SEXP rObj);
343         virtual void compute(FitContext *fc);
344         virtual void reportResults(FitContext *fc, MxRList *out);
345         virtual double getOptimizerStatus();
346         virtual ~omxComputeIterate();
347 };
348
349 class omxComputeOnce : public omxComputeOperation {
350         typedef omxComputeOperation super;
351         std::vector< omxMatrix* > algebras;
352         std::vector< omxExpectation* > expectations;
353         const char *context;
354         bool gradient;
355         bool hessian;
356
357  public:
358         virtual void initFromFrontend(SEXP rObj);
359         virtual void compute(FitContext *fc);
360         virtual void reportResults(FitContext *fc, MxRList *out);
361 };
362
363 static class omxCompute *newComputeSequence()
364 { return new omxComputeSequence(); }
365
366 static class omxCompute *newComputeIterate()
367 { return new omxComputeIterate(); }
368
369 static class omxCompute *newComputeOnce()
370 { return new omxComputeOnce(); }
371
372 struct omxComputeTableEntry {
373         char name[32];
374         omxCompute *(*ctor)();
375 };
376
377 static const struct omxComputeTableEntry omxComputeTable[] = {
378         {"MxComputeEstimatedHessian", &newComputeEstimatedHessian},
379         {"MxComputeGradientDescent", &newComputeGradientDescent},
380         {"MxComputeSequence", &newComputeSequence },
381         {"MxComputeIterate", &newComputeIterate },
382         {"MxComputeOnce", &newComputeOnce },
383         {"MxComputeNewtonRaphson", &newComputeNewtonRaphson},
384 };
385
386 omxCompute *omxNewCompute(omxState* os, const char *type)
387 {
388         omxCompute *got = NULL;
389
390         for (size_t fx=0; fx < OMX_STATIC_ARRAY_SIZE(omxComputeTable); fx++) {
391                 const struct omxComputeTableEntry *entry = omxComputeTable + fx;
392                 if(strcmp(type, entry->name) == 0) {
393                         got = entry->ctor();
394                         break;
395                 }
396         }
397
398         if (!got) error("Compute %s is not implemented", type);
399
400         return got;
401 }
402
403 void omxComputeSequence::initFromFrontend(SEXP rObj)
404 {
405         SEXP slotValue;
406         PROTECT(slotValue = GET_SLOT(rObj, install("steps")));
407
408         for (int cx = 0; cx < length(slotValue); cx++) {
409                 SEXP step = VECTOR_ELT(slotValue, cx);
410                 SEXP s4class;
411                 PROTECT(s4class = STRING_ELT(getAttrib(step, install("class")), 0));
412                 omxCompute *compute = omxNewCompute(globalState, CHAR(s4class));
413                 compute->initFromFrontend(step);
414                 if (isErrorRaised(globalState)) break;
415                 clist.push_back(compute);
416         }
417 }
418
419 void omxComputeSequence::compute(FitContext *fc)
420 {
421         for (size_t cx=0; cx < clist.size(); ++cx) {
422                 FitContext *context = fc;
423                 if (fc->varGroup != clist[cx]->varGroup) {
424                         context = new FitContext(fc, clist[cx]->varGroup);
425                 }
426                 clist[cx]->compute(context);
427                 if (context != fc) context->updateParentAndFree();
428                 if (isErrorRaised(globalState)) break;
429         }
430 }
431
432 void omxComputeSequence::reportResults(FitContext *fc, MxRList *out)
433 {
434         // put this stuff in a new list?
435         // merge with Iterate TODO
436         for (size_t cx=0; cx < clist.size(); ++cx) {
437                 FitContext *context = fc;
438                 if (fc->varGroup != clist[cx]->varGroup) {
439                         context = new FitContext(fc, clist[cx]->varGroup);
440                 }
441                 clist[cx]->reportResults(context, out);
442                 if (context != fc) context->updateParentAndFree();
443                 if (isErrorRaised(globalState)) break;
444         }
445 }
446
447 double omxComputeSequence::getOptimizerStatus()
448 {
449         // for backward compatibility, not indended to work generally
450         for (size_t cx=0; cx < clist.size(); ++cx) {
451                 double got = clist[cx]->getOptimizerStatus();
452                 if (got != NA_REAL) return got;
453         }
454         return NA_REAL;
455 }
456
457 omxComputeSequence::~omxComputeSequence()
458 {
459         for (size_t cx=0; cx < clist.size(); ++cx) {
460                 delete clist[cx];
461         }
462 }
463
464 void omxComputeIterate::initFromFrontend(SEXP rObj)
465 {
466         SEXP slotValue;
467
468         PROTECT(slotValue = GET_SLOT(rObj, install("maxIter")));
469         maxIter = INTEGER(slotValue)[0];
470
471         PROTECT(slotValue = GET_SLOT(rObj, install("tolerance")));
472         tolerance = REAL(slotValue)[0];
473         if (tolerance <= 0) error("tolerance must be positive");
474
475         PROTECT(slotValue = GET_SLOT(rObj, install("steps")));
476
477         for (int cx = 0; cx < length(slotValue); cx++) {
478                 SEXP step = VECTOR_ELT(slotValue, cx);
479                 SEXP s4class;
480                 PROTECT(s4class = STRING_ELT(getAttrib(step, install("class")), 0));
481                 omxCompute *compute = omxNewCompute(globalState, CHAR(s4class));
482                 compute->initFromFrontend(step);
483                 if (isErrorRaised(globalState)) break;
484                 clist.push_back(compute);
485         }
486
487         PROTECT(slotValue = GET_SLOT(rObj, install("verbose")));
488         verbose = asLogical(slotValue);
489 }
490
491 void omxComputeIterate::compute(FitContext *fc)
492 {
493         int iter = 0;
494         double prevFit = 0;
495         double change = tolerance * 10;
496         while (1) {
497                 for (size_t cx=0; cx < clist.size(); ++cx) {
498                         FitContext *context = fc;
499                         if (fc->varGroup != clist[cx]->varGroup) {
500                                 context = new FitContext(fc, clist[cx]->varGroup);
501                         }
502                         clist[cx]->compute(context);
503                         if (context != fc) context->updateParentAndFree();
504                         if (isErrorRaised(globalState)) break;
505                 }
506                 if (prevFit != 0) {
507                         change = prevFit - fc->fit;
508                         if (verbose) mxLog("fit %.9g change %.9g", fc->fit, change);
509                 }
510                 prevFit = fc->fit;
511                 if (isErrorRaised(globalState) || ++iter > maxIter || fabs(change) < tolerance) break;
512         }
513 }
514
515 void omxComputeIterate::reportResults(FitContext *fc, MxRList *out)
516 {
517         for (size_t cx=0; cx < clist.size(); ++cx) {
518                 FitContext *context = fc;
519                 if (fc->varGroup != clist[cx]->varGroup) {
520                         context = new FitContext(fc, clist[cx]->varGroup);
521                 }
522                 clist[cx]->reportResults(context, out);
523                 if (context != fc) context->updateParentAndFree();
524                 if (isErrorRaised(globalState)) break;
525         }
526 }
527
528 double omxComputeIterate::getOptimizerStatus()
529 {
530         // for backward compatibility, not indended to work generally
531         for (size_t cx=0; cx < clist.size(); ++cx) {
532                 double got = clist[cx]->getOptimizerStatus();
533                 if (got != NA_REAL) return got;
534         }
535         return NA_REAL;
536 }
537
538 omxComputeIterate::~omxComputeIterate()
539 {
540         for (size_t cx=0; cx < clist.size(); ++cx) {
541                 delete clist[cx];
542         }
543 }
544
545 void omxComputeOnce::initFromFrontend(SEXP rObj)
546 {
547         super::initFromFrontend(rObj);
548
549         SEXP slotValue;
550         PROTECT(slotValue = GET_SLOT(rObj, install("what")));
551         for (int wx=0; wx < length(slotValue); ++wx) {
552                 int objNum = INTEGER(slotValue)[wx];
553                 if (objNum >= 0) {
554                         omxMatrix *algebra = globalState->algebraList[objNum];
555                         if (algebra->fitFunction) {
556                                 setFreeVarGroup(algebra->fitFunction, varGroup);
557                                 omxCompleteFitFunction(algebra);
558                         }
559                         algebras.push_back(algebra);
560                 } else {
561                         omxExpectation *expectation = globalState->expectationList[~objNum];
562                         setFreeVarGroup(expectation, varGroup);
563                         omxCompleteExpectation(expectation);
564                         expectations.push_back(expectation);
565                 }
566         }
567
568         PROTECT(slotValue = GET_SLOT(rObj, install("context")));
569         if (length(slotValue) == 0) {
570                 // OK
571         } else if (length(slotValue) == 1) {
572                 SEXP elem;
573                 PROTECT(elem = STRING_ELT(slotValue, 0));
574                 context = CHAR(elem);
575         }
576
577         PROTECT(slotValue = GET_SLOT(rObj, install("gradient")));
578         gradient = asLogical(slotValue);
579
580         PROTECT(slotValue = GET_SLOT(rObj, install("hessian")));
581         hessian = asLogical(slotValue);
582 }
583
584 void omxComputeOnce::compute(FitContext *fc)
585 {
586         if (algebras.size()) {
587                 int want = FF_COMPUTE_FIT;
588                 if (gradient) want |= FF_COMPUTE_GRADIENT;
589                 if (hessian)  want |= FF_COMPUTE_HESSIAN;
590
591                 for (size_t wx=0; wx < algebras.size(); ++wx) {
592                         omxMatrix *algebra = algebras[wx];
593                         if (algebra->fitFunction) {
594                                 omxFitFunctionCompute(algebra->fitFunction, want, fc);
595                                 fc->fit = algebra->data[0];
596                                 if (hessian) fc->fixHessianSymmetry();
597                         } else {
598                                 omxForceCompute(algebra);
599                         }
600                 }
601         } else if (expectations.size()) {
602                 for (size_t wx=0; wx < expectations.size(); ++wx) {
603                         omxExpectation *expectation = expectations[wx];
604                         omxExpectationCompute(expectation, context);
605                 }
606         }
607 }
608
609 void omxComputeOnce::reportResults(FitContext *fc, MxRList *out)
610 {
611         if (algebras.size()==0 || algebras[0]->fitFunction == NULL) return;
612
613         omxMatrix *algebra = algebras[0];
614
615         omxPopulateFitFunction(algebra, out);
616
617         out->push_back(std::make_pair(mkChar("minimum"), ScalarReal(fc->fit)));
618         out->push_back(std::make_pair(mkChar("Minus2LogLikelihood"), ScalarReal(fc->fit)));
619
620         size_t numFree = fc->varGroup->vars.size();
621         if (numFree) {
622                 SEXP estimate;
623                 PROTECT(estimate = allocVector(REALSXP, numFree));
624                 memcpy(REAL(estimate), fc->est, sizeof(double)*numFree);
625                 out->push_back(std::make_pair(mkChar("estimate"), estimate));
626
627                 if (gradient) {
628                         SEXP Rgradient;
629                         PROTECT(Rgradient = allocVector(REALSXP, numFree));
630                         memcpy(REAL(Rgradient), fc->grad, sizeof(double) * numFree);
631                         out->push_back(std::make_pair(mkChar("gradient"), Rgradient));
632                 }
633
634                 if (hessian) {
635                         SEXP Rhessian;
636                         PROTECT(Rhessian = allocMatrix(REALSXP, numFree, numFree));
637                         memcpy(REAL(Rhessian), fc->hess, sizeof(double) * numFree * numFree);
638                         out->push_back(std::make_pair(mkChar("hessian"), Rhessian));
639                 }
640         }
641 }