Newton-Raphson
[openmx:openmx.git] / src / Compute.cpp
1 /*
2  *  Copyright 2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include "omxDefines.h"
18 #include "Compute.h"
19 #include "omxState.h"
20 #include "omxExportBackendState.h"
21 #include "omxRFitFunction.h"
22
23 std::vector<int> FitContext::markMatrices;
24
25 void FitContext::init()
26 {
27         size_t numParam = varGroup->vars.size();
28         fit = 0;
29         est = new double[numParam];
30         grad = new double[numParam];
31         hess = new double[numParam * numParam];
32 }
33
34 FitContext::FitContext()
35 {
36         parent = NULL;
37         varGroup = Global->freeGroup[0];
38         init();
39
40         size_t numParam = varGroup->vars.size();
41         for (size_t v1=0; v1 < numParam; v1++) {
42                 est[v1] = Global->freeGroup[0]->vars[v1]->start;
43                 grad[v1] = nan("unset");
44                 for (size_t v2=0; v2 < numParam; v2++) {
45                         hess[v1 * numParam + v2] = nan("unset");
46                 }
47         }
48 }
49
50 FitContext::FitContext(FitContext *parent)
51 {
52         varGroup = parent->varGroup;
53         init();
54         fit = parent->fit;
55         size_t numParam = varGroup->vars.size();
56         memcpy(est, parent->est, sizeof(double) * numParam);
57         memcpy(grad, parent->grad, sizeof(double) * numParam);
58         memcpy(hess, parent->hess, sizeof(double) * numParam * numParam);
59 }
60
61 // arg to control what to copy? usually don't want everything TODO
62 FitContext::FitContext(FitContext *parent, FreeVarGroup *varGroup)
63 {
64         this->parent = parent;
65         this->varGroup = varGroup;
66         init();
67
68         FreeVarGroup *src = parent->varGroup;
69         FreeVarGroup *dest = varGroup;
70         size_t svars = parent->varGroup->vars.size();
71         size_t dvars = varGroup->vars.size();
72
73         size_t d1 = 0;
74         for (size_t s1=0; s1 < src->vars.size(); ++s1) {
75                 if (src->vars[s1] != dest->vars[d1]) continue;
76                 est[d1] = parent->est[s1];
77                 grad[d1] = parent->grad[s1];
78
79                 size_t d2 = 0;
80                 for (size_t s2=0; s2 < src->vars.size(); ++s2) {
81                         if (src->vars[s2] != dest->vars[d2]) continue;
82                         hess[d1 * dvars + d2] = parent->hess[s1 * svars + s2];
83                         if (++d2 == dvars) break;
84                 }
85
86                 if (++d1 == dvars) break;
87         }
88         if (d1 != dvars) error("Parent free parameter group is not a superset");
89
90         // pda(parent->est, 1, svars);
91         // pda(est, 1, dvars);
92         // pda(parent->grad, 1, svars);
93         // pda(grad, 1, dvars);
94         // pda(parent->hess, svars, svars);
95         // pda(hess, dvars, dvars);
96 }
97
98 void FitContext::copyParamToModel(omxMatrix *mat)
99 { copyParamToModel(mat->currentState); }
100
101 void FitContext::copyParamToModel(omxMatrix *mat, double *at)
102 { copyParamToModel(mat->currentState, at); }
103
104 void FitContext::updateParentAndFree()
105 {
106         FreeVarGroup *src = varGroup;
107         FreeVarGroup *dest = parent->varGroup;
108         size_t svars = varGroup->vars.size();
109         size_t dvars = parent->varGroup->vars.size();
110
111         size_t s1 = 0;
112         for (size_t d1=0; d1 < dest->vars.size(); ++d1) {
113                 if (dest->vars[d1] != src->vars[s1]) continue;
114                 parent->est[d1] = est[s1];
115                 parent->grad[d1] = grad[s1];
116
117                 size_t s2 = 0;
118                 for (size_t d2=0; d2 < dest->vars.size(); ++d2) {
119                         if (dest->vars[d2] != src->vars[s2]) continue;
120                         parent->hess[d1 * dvars + d2] = hess[s1 * svars + s2];
121                         if (++s2 == svars) break;
122                 }
123
124                 if (++s1 == svars) break;
125         }
126         
127         // pda(est, 1, svars);
128         // pda(parent->est, 1, dvars);
129         // pda(grad, 1, svars);
130         // pda(parent->grad, 1, dvars);
131         // pda(hess, svars, svars);
132         // pda(parent->hess, dvars, dvars);
133
134         delete this;
135 }
136
137 void FitContext::log(int what)
138 {
139         size_t count = varGroup->vars.size();
140         std::string buf;
141         if (what & FF_COMPUTE_FIT) buf += string_snprintf("fit: %.3f\n", fit);
142         if (what & FF_COMPUTE_ESTIMATE) {
143                 buf += "est: c(";
144                 for (size_t vx=0; vx < count; ++vx) {
145                         buf += string_snprintf("%.3f", est[vx]);
146                         if (vx < count - 1) buf += ", ";
147                 }
148                 buf += ")\n";
149         }
150         if (what & FF_COMPUTE_GRADIENT) {
151                 buf += "grad: c(";
152                 for (size_t vx=0; vx < count; ++vx) {
153                         buf += string_snprintf("%.3f", grad[vx]);
154                         if (vx < count - 1) buf += ", ";
155                 }
156                 buf += ")\n";
157         }
158         if (what & FF_COMPUTE_HESSIAN) {
159                 buf += "hess: c(";
160                 for (size_t v1=0; v1 < count; ++v1) {
161                         for (size_t v2=0; v2 < count; ++v2) {
162                                 buf += string_snprintf("%.3f", hess[v1 * count + v2]);
163                                 if (v1 < count-1 || v2 < count-1) buf += ", ";
164                         }
165                         buf += "\n";
166                 }
167                 buf += ")\n";
168         }
169         mxLogBig(buf);
170 }
171
172 void FitContext::cacheFreeVarDependencies()
173 {
174         omxState *os = globalState;
175         size_t numMats = os->matrixList.size();
176         size_t numAlgs = os->algebraList.size();
177
178         markMatrices.clear();
179         markMatrices.resize(numMats + numAlgs, 0);
180
181         // More efficient to use the appropriate group instead of the default group. TODO
182         FreeVarGroup *varGroup = Global->freeGroup[0];
183         for (size_t freeVarIndex = 0; freeVarIndex < varGroup->vars.size(); freeVarIndex++) {
184                 omxFreeVar* freeVar = varGroup->vars[freeVarIndex];
185                 int *deps   = freeVar->deps;
186                 int numDeps = freeVar->numDeps;
187                 for (int index = 0; index < numDeps; index++) {
188                         markMatrices[deps[index] + numMats] = 1;
189                 }
190         }
191 }
192
193 static void omxRepopulateRFitFunction(omxFitFunction* oo, double* x, int n)
194 {
195         omxRFitFunction* rFitFunction = (omxRFitFunction*)oo->argStruct;
196
197         SEXP theCall, estimate;
198
199         PROTECT(estimate = allocVector(REALSXP, n));
200         double *est = REAL(estimate);
201         for(int i = 0; i < n ; i++) {
202                 est[i] = x[i];
203         }
204
205         PROTECT(theCall = allocVector(LANGSXP, 4));
206
207         SETCAR(theCall, install("imxUpdateModelValues"));
208         SETCADR(theCall, rFitFunction->model);
209         SETCADDR(theCall, rFitFunction->flatModel);
210         SETCADDDR(theCall, estimate);
211
212         REPROTECT(rFitFunction->model = eval(theCall, R_GlobalEnv), rFitFunction->modelIndex);
213
214         UNPROTECT(2); // theCall, estimate
215 }
216
217 void FitContext::copyParamToModel(omxState* os)
218 {
219         copyParamToModel(os, est);
220 }
221
222 void FitContext::copyParamToModel(omxState* os, double *at)
223 {
224         size_t numParam = varGroup->vars.size();
225         if(OMX_DEBUG) {
226                 mxLog("Copying %d free parameter estimates to model %p", numParam, os);
227         }
228
229         if(numParam == 0) return;
230
231         size_t numMats = os->matrixList.size();
232         size_t numAlgs = os->algebraList.size();
233
234         os->computeCount++;
235
236         if(OMX_VERBOSE) {
237                 std::string buf;
238                 buf += string_snprintf("Call: %d.%d (%d) ", os->majorIteration, os->minorIteration, os->computeCount);
239                 buf += ("Estimates: [");
240                 for(size_t k = 0; k < numParam; k++) {
241                         buf += string_snprintf(" %f", at[k]);
242                 }
243                 buf += ("]\n");
244                 mxLogBig(buf);
245         }
246
247         for(size_t k = 0; k < numParam; k++) {
248                 omxFreeVar* freeVar = varGroup->vars[k];
249                 for(size_t l = 0; l < freeVar->locations.size(); l++) {
250                         omxFreeVarLocation *loc = &freeVar->locations[l];
251                         omxMatrix *matrix = os->matrixList[loc->matrix];
252                         int row = loc->row;
253                         int col = loc->col;
254                         omxSetMatrixElement(matrix, row, col, at[k]);
255                         if(OMX_DEBUG) {
256                                 mxLog("Setting location (%d, %d) of matrix %d to value %f for var %d",
257                                         row, col, loc->matrix, at[k], k);
258                         }
259                 }
260         }
261
262         if (RFitFunction) omxRepopulateRFitFunction(RFitFunction, at, numParam);
263
264         for(size_t i = 0; i < numMats; i++) {
265                 if (markMatrices[i]) {
266                         int offset = ~(i - numMats);
267                         omxMarkDirty(os->matrixList[offset]);
268                 }
269         }
270
271         for(size_t i = 0; i < numAlgs; i++) {
272                 if (markMatrices[i + numMats]) {
273                         omxMarkDirty(os->algebraList[i]);
274                 }
275         }
276
277         if (!os->childList) return;
278
279         for(int i = 0; i < Global->numChildren; i++) {
280                 copyParamToModel(os->childList[i]);
281         }
282 }
283
284 FitContext::~FitContext()
285 {
286         delete [] est;
287         delete [] grad;
288         delete [] hess;
289 }
290
291 omxFitFunction *FitContext::RFitFunction = NULL;
292
293 void FitContext::setRFitFunction(omxFitFunction *rff)
294 {
295         if (rff) {
296                 Global->numThreads = 1;
297                 if (RFitFunction) {
298                         error("You can only create 1 MxRFitFunction per independent model");
299                 }
300         }
301         RFitFunction = rff;
302 }
303
304 omxCompute::~omxCompute()
305 {}
306
307 void omxComputeOperation::initFromFrontend(SEXP rObj)
308 {
309         SEXP slotValue;
310         PROTECT(slotValue = GET_SLOT(rObj, install("free.group")));
311         int paramGroup = INTEGER(slotValue)[0];
312         varGroup = Global->freeGroup[paramGroup];
313 }
314
315 class omxComputeSequence : public omxCompute {
316         std::vector< omxCompute* > clist;
317
318  public:
319         virtual void initFromFrontend(SEXP rObj);
320         virtual void compute(FitContext *fc);
321         virtual void reportResults(FitContext *fc, MxRList *out);
322         virtual double getOptimizerStatus();
323         virtual ~omxComputeSequence();
324 };
325
326 class omxComputeIterate : public omxCompute {
327         std::vector< omxCompute* > clist;
328         int maxIter;
329         double tolerance;
330         bool verbose;
331
332  public:
333         virtual void initFromFrontend(SEXP rObj);
334         virtual void compute(FitContext *fc);
335         virtual void reportResults(FitContext *fc, MxRList *out);
336         virtual double getOptimizerStatus();
337         virtual ~omxComputeIterate();
338 };
339
340 class omxComputeOnce : public omxComputeOperation {
341         typedef omxComputeOperation super;
342         omxMatrix *algebra;
343         omxExpectation *expectation;
344         const char *context;
345         bool gradient;
346         bool hessian;
347
348  public:
349         virtual void initFromFrontend(SEXP rObj);
350         virtual void compute(FitContext *fc);
351         virtual void reportResults(FitContext *fc, MxRList *out);
352 };
353
354 static class omxCompute *newComputeSequence()
355 { return new omxComputeSequence(); }
356
357 static class omxCompute *newComputeIterate()
358 { return new omxComputeIterate(); }
359
360 static class omxCompute *newComputeOnce()
361 { return new omxComputeOnce(); }
362
363 struct omxComputeTableEntry {
364         char name[32];
365         omxCompute *(*ctor)();
366 };
367
368 static const struct omxComputeTableEntry omxComputeTable[] = {
369         {"MxComputeEstimatedHessian", &newComputeEstimatedHessian},
370         {"MxComputeGradientDescent", &newComputeGradientDescent},
371         {"MxComputeSequence", &newComputeSequence },
372         {"MxComputeIterate", &newComputeIterate },
373         {"MxComputeOnce", &newComputeOnce },
374         {"MxComputeNewtonRaphson", &newComputeNewtonRaphson},
375 };
376
377 omxCompute *omxNewCompute(omxState* os, const char *type)
378 {
379         omxCompute *got = NULL;
380
381         for (size_t fx=0; fx < OMX_STATIC_ARRAY_SIZE(omxComputeTable); fx++) {
382                 const struct omxComputeTableEntry *entry = omxComputeTable + fx;
383                 if(strcmp(type, entry->name) == 0) {
384                         got = entry->ctor();
385                         break;
386                 }
387         }
388
389         if (!got) error("Compute %s is not implemented", type);
390
391         return got;
392 }
393
394 void omxComputeSequence::initFromFrontend(SEXP rObj)
395 {
396         SEXP slotValue;
397         PROTECT(slotValue = GET_SLOT(rObj, install("steps")));
398
399         for (int cx = 0; cx < length(slotValue); cx++) {
400                 SEXP step = VECTOR_ELT(slotValue, cx);
401                 SEXP s4class;
402                 PROTECT(s4class = STRING_ELT(getAttrib(step, install("class")), 0));
403                 omxCompute *compute = omxNewCompute(globalState, CHAR(s4class));
404                 compute->initFromFrontend(step);
405                 if (isErrorRaised(globalState)) break;
406                 clist.push_back(compute);
407         }
408 }
409
410 void omxComputeSequence::compute(FitContext *fc)
411 {
412         for (size_t cx=0; cx < clist.size(); ++cx) {
413                 FitContext *context = fc;
414                 if (fc->varGroup != clist[cx]->varGroup) {
415                         context = new FitContext(fc, clist[cx]->varGroup);
416                 }
417                 clist[cx]->compute(context);
418                 if (context != fc) context->updateParentAndFree();
419                 if (isErrorRaised(globalState)) break;
420         }
421 }
422
423 void omxComputeSequence::reportResults(FitContext *fc, MxRList *out)
424 {
425         // put this stuff in a new list?
426         // merge with Iterate TODO
427         for (size_t cx=0; cx < clist.size(); ++cx) {
428                 FitContext *context = fc;
429                 if (fc->varGroup != clist[cx]->varGroup) {
430                         context = new FitContext(fc, clist[cx]->varGroup);
431                 }
432                 clist[cx]->reportResults(context, out);
433                 if (context != fc) context->updateParentAndFree();
434                 if (isErrorRaised(globalState)) break;
435         }
436 }
437
438 double omxComputeSequence::getOptimizerStatus()
439 {
440         // for backward compatibility, not indended to work generally
441         for (size_t cx=0; cx < clist.size(); ++cx) {
442                 double got = clist[cx]->getOptimizerStatus();
443                 if (got != NA_REAL) return got;
444         }
445         return NA_REAL;
446 }
447
448 omxComputeSequence::~omxComputeSequence()
449 {
450         for (size_t cx=0; cx < clist.size(); ++cx) {
451                 delete clist[cx];
452         }
453 }
454
455 void omxComputeIterate::initFromFrontend(SEXP rObj)
456 {
457         SEXP slotValue;
458
459         PROTECT(slotValue = GET_SLOT(rObj, install("maxIter")));
460         maxIter = INTEGER(slotValue)[0];
461
462         PROTECT(slotValue = GET_SLOT(rObj, install("tolerance")));
463         tolerance = REAL(slotValue)[0];
464         if (tolerance <= 0) error("tolerance must be positive");
465
466         PROTECT(slotValue = GET_SLOT(rObj, install("steps")));
467
468         for (int cx = 0; cx < length(slotValue); cx++) {
469                 SEXP step = VECTOR_ELT(slotValue, cx);
470                 SEXP s4class;
471                 PROTECT(s4class = STRING_ELT(getAttrib(step, install("class")), 0));
472                 omxCompute *compute = omxNewCompute(globalState, CHAR(s4class));
473                 compute->initFromFrontend(step);
474                 if (isErrorRaised(globalState)) break;
475                 clist.push_back(compute);
476         }
477
478         PROTECT(slotValue = GET_SLOT(rObj, install("verbose")));
479         verbose = asLogical(slotValue);
480 }
481
482 void omxComputeIterate::compute(FitContext *fc)
483 {
484         int iter = 0;
485         double prevFit = 0;
486         double change = tolerance * 10;
487         while (1) {
488                 for (size_t cx=0; cx < clist.size(); ++cx) {
489                         FitContext *context = fc;
490                         if (fc->varGroup != clist[cx]->varGroup) {
491                                 context = new FitContext(fc, clist[cx]->varGroup);
492                         }
493                         clist[cx]->compute(context);
494                         if (context != fc) context->updateParentAndFree();
495                         if (isErrorRaised(globalState)) break;
496                 }
497                 if (prevFit != 0) {
498                         change = prevFit - fc->fit;
499                         if (verbose) mxLog("fit %.9g change %.9g", fc->fit, change);
500                 }
501                 prevFit = fc->fit;
502                 if (isErrorRaised(globalState) || ++iter > maxIter || fabs(change) < tolerance) break;
503         }
504 }
505
506 void omxComputeIterate::reportResults(FitContext *fc, MxRList *out)
507 {
508         for (size_t cx=0; cx < clist.size(); ++cx) {
509                 FitContext *context = fc;
510                 if (fc->varGroup != clist[cx]->varGroup) {
511                         context = new FitContext(fc, clist[cx]->varGroup);
512                 }
513                 clist[cx]->reportResults(context, out);
514                 if (context != fc) context->updateParentAndFree();
515                 if (isErrorRaised(globalState)) break;
516         }
517 }
518
519 double omxComputeIterate::getOptimizerStatus()
520 {
521         // for backward compatibility, not indended to work generally
522         for (size_t cx=0; cx < clist.size(); ++cx) {
523                 double got = clist[cx]->getOptimizerStatus();
524                 if (got != NA_REAL) return got;
525         }
526         return NA_REAL;
527 }
528
529 omxComputeIterate::~omxComputeIterate()
530 {
531         for (size_t cx=0; cx < clist.size(); ++cx) {
532                 delete clist[cx];
533         }
534 }
535
536 void omxComputeOnce::initFromFrontend(SEXP rObj)
537 {
538         super::initFromFrontend(rObj);
539
540         SEXP slotValue;
541         PROTECT(slotValue = GET_SLOT(rObj, install("what")));
542         int objNum = INTEGER(slotValue)[0];
543         if (objNum >= 0) {
544                 algebra = globalState->algebraList[objNum];
545                 if (algebra->fitFunction) {
546                         setFreeVarGroup(algebra->fitFunction, varGroup);
547                         omxCompleteFitFunction(algebra);
548                 }
549         } else {
550                 expectation = globalState->expectationList[~objNum];
551                 setFreeVarGroup(expectation, varGroup);
552                 omxCompleteExpectation(expectation);
553         }
554
555         PROTECT(slotValue = GET_SLOT(rObj, install("context")));
556         if (length(slotValue) == 0) {
557                 // OK
558         } else if (length(slotValue) == 1) {
559                 SEXP elem;
560                 PROTECT(elem = STRING_ELT(slotValue, 0));
561                 context = CHAR(elem);
562         }
563
564         PROTECT(slotValue = GET_SLOT(rObj, install("gradient")));
565         gradient = asLogical(slotValue);
566
567         PROTECT(slotValue = GET_SLOT(rObj, install("hessian")));
568         hessian = asLogical(slotValue);
569 }
570
571 void omxComputeOnce::compute(FitContext *fc)
572 {
573         if (algebra) {
574                 if (algebra->fitFunction) {
575                         int want = FF_COMPUTE_FIT;
576                         if (gradient) want |= FF_COMPUTE_GRADIENT;
577                         if (hessian)  want |= FF_COMPUTE_HESSIAN;
578                         omxFitFunctionCompute(algebra->fitFunction, want, fc);
579                         fc->fit = algebra->data[0];
580                 } else {
581                         omxForceCompute(algebra);
582                 }
583         } else if (expectation) {
584                 omxExpectationCompute(expectation, context);
585         }
586 }
587
588 void omxComputeOnce::reportResults(FitContext *fc, MxRList *out)
589 {
590         if (!algebra || !algebra->fitFunction) return;
591
592         omxPopulateFitFunction(algebra, out);
593
594         out->push_back(std::make_pair(mkChar("minimum"), ScalarReal(fc->fit)));
595
596         size_t numFree = fc->varGroup->vars.size();
597         if (numFree) {
598                 SEXP estimate;
599                 PROTECT(estimate = allocVector(REALSXP, numFree));
600                 memcpy(REAL(estimate), fc->est, sizeof(double)*numFree);
601                 out->push_back(std::make_pair(mkChar("estimate"), estimate));
602
603                 if (gradient) {
604                         SEXP Rgradient;
605                         PROTECT(Rgradient = allocVector(REALSXP, numFree));
606                         memcpy(REAL(Rgradient), fc->grad, sizeof(double) * numFree);
607                         out->push_back(std::make_pair(mkChar("gradient"), Rgradient));
608                 }
609
610                 if (hessian) {
611                         SEXP Rhessian;
612                         PROTECT(Rhessian = allocMatrix(REALSXP, numFree, numFree));
613                         memcpy(REAL(Rhessian), fc->hess, sizeof(double) * numFree * numFree);
614                         out->push_back(std::make_pair(mkChar("hessian"), Rhessian));
615                 }
616         }
617 }