Rewrite Newton-Raphson with better math
[openmx:openmx.git] / src / Compute.cpp
1 /*
2  *  Copyright 2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include "omxDefines.h"
18 #include "Compute.h"
19 #include "omxState.h"
20 #include "omxExportBackendState.h"
21 #include "omxRFitFunction.h"
22
23 void FitContext::init()
24 {
25         size_t numParam = varGroup->vars.size();
26         fit = parent? parent->fit : 0;
27         est = new double[numParam];
28         grad = new double[numParam];
29         hess = new double[numParam * numParam];
30 }
31
32 FitContext::FitContext(std::vector<double> &startingValues)
33 {
34         parent = NULL;
35         varGroup = Global->freeGroup[0];
36         init();
37
38         size_t numParam = varGroup->vars.size();
39         if (startingValues.size() != numParam) error("mismatch");
40         memcpy(est, startingValues.data(), sizeof(double) * numParam);
41
42         for (size_t v1=0; v1 < numParam; v1++) {
43                 grad[v1] = nan("unset");
44                 for (size_t v2=0; v2 < numParam; v2++) {
45                         hess[v1 * numParam + v2] = nan("unset");
46                 }
47         }
48 }
49
50 // arg to control what to copy? usually don't want everything TODO
51 FitContext::FitContext(FitContext *parent, FreeVarGroup *varGroup)
52 {
53         this->parent = parent;
54         this->varGroup = varGroup;
55         init();
56
57         FreeVarGroup *src = parent->varGroup;
58         FreeVarGroup *dest = varGroup;
59         size_t svars = parent->varGroup->vars.size();
60         size_t dvars = varGroup->vars.size();
61
62         size_t d1 = 0;
63         for (size_t s1=0; s1 < src->vars.size(); ++s1) {
64                 if (src->vars[s1] != dest->vars[d1]) continue;
65                 est[d1] = parent->est[s1];
66                 grad[d1] = parent->grad[s1];
67
68                 size_t d2 = 0;
69                 for (size_t s2=0; s2 < src->vars.size(); ++s2) {
70                         if (src->vars[s2] != dest->vars[d2]) continue;
71                         hess[d1 * dvars + d2] = parent->hess[s1 * svars + s2];
72                         if (++d2 == dvars) break;
73                 }
74
75                 if (++d1 == dvars) break;
76         }
77         if (d1 != dvars) error("Parent free parameter group is not a superset");
78
79         // pda(parent->est, 1, svars);
80         // pda(est, 1, dvars);
81         // pda(parent->grad, 1, svars);
82         // pda(grad, 1, dvars);
83         // pda(parent->hess, svars, svars);
84         // pda(hess, dvars, dvars);
85 }
86
87 void FitContext::copyParamToModel(omxMatrix *mat)
88 { copyParamToModel(mat->currentState); }
89
90 void FitContext::copyParamToModel(omxMatrix *mat, double *at)
91 { copyParamToModel(mat->currentState, at); }
92
93 void FitContext::updateParentAndFree()
94 {
95         FreeVarGroup *src = varGroup;
96         FreeVarGroup *dest = parent->varGroup;
97         size_t svars = varGroup->vars.size();
98         size_t dvars = parent->varGroup->vars.size();
99
100         parent->fit = fit;
101
102         size_t s1 = 0;
103         for (size_t d1=0; d1 < dest->vars.size(); ++d1) {
104                 if (dest->vars[d1] != src->vars[s1]) continue;
105                 parent->est[d1] = est[s1];
106                 parent->grad[d1] = grad[s1];
107
108                 size_t s2 = 0;
109                 for (size_t d2=0; d2 < dest->vars.size(); ++d2) {
110                         if (dest->vars[d2] != src->vars[s2]) continue;
111                         parent->hess[d1 * dvars + d2] = hess[s1 * svars + s2];
112                         if (++s2 == svars) break;
113                 }
114
115                 if (++s1 == svars) break;
116         }
117         
118         // pda(est, 1, svars);
119         // pda(parent->est, 1, dvars);
120         // pda(grad, 1, svars);
121         // pda(parent->grad, 1, dvars);
122         // pda(hess, svars, svars);
123         // pda(parent->hess, dvars, dvars);
124
125         delete this;
126 }
127
128 void FitContext::log(const char *where, int what)
129 {
130         size_t count = varGroup->vars.size();
131         std::string buf(where);
132         buf += " ---\n";
133         if (what & FF_COMPUTE_FIT) buf += string_snprintf("fit: %.5f\n", fit);
134         if (what & FF_COMPUTE_ESTIMATE) {
135                 buf += "est: c(";
136                 for (size_t vx=0; vx < count; ++vx) {
137                         buf += string_snprintf("%.5f", est[vx]);
138                         if (vx < count - 1) buf += ", ";
139                 }
140                 buf += ")\n";
141         }
142         if (what & FF_COMPUTE_GRADIENT) {
143                 buf += "grad: c(";
144                 for (size_t vx=0; vx < count; ++vx) {
145                         buf += string_snprintf("%.5f", grad[vx]);
146                         if (vx < count - 1) buf += ", ";
147                 }
148                 buf += ")\n";
149         }
150         if (what & FF_COMPUTE_HESSIAN) {
151                 buf += "hess: c(";
152                 for (size_t v1=0; v1 < count; ++v1) {
153                         for (size_t v2=0; v2 < count; ++v2) {
154                                 buf += string_snprintf("%.5f", hess[v1 * count + v2]);
155                                 if (v1 < count-1 || v2 < count-1) buf += ", ";
156                         }
157                         buf += "\n";
158                 }
159                 buf += ")\n";
160         }
161         mxLogBig(buf);
162 }
163
164 void FitContext::fixHessianSymmetry()
165 {
166         size_t numParam = varGroup->vars.size();
167         for (size_t h1=1; h1 < numParam; h1++) {
168                 for (size_t h2=0; h2 < h1; h2++) {
169                         double lower = hess[h2 * numParam + h1];
170                         hess[h1 * numParam + h2] = lower;
171                 }
172         }
173 }
174
175 static void omxRepopulateRFitFunction(omxFitFunction* oo, double* x, int n)
176 {
177         omxRFitFunction* rFitFunction = (omxRFitFunction*)oo->argStruct;
178
179         SEXP theCall, estimate;
180
181         PROTECT(estimate = allocVector(REALSXP, n));
182         double *est = REAL(estimate);
183         for(int i = 0; i < n ; i++) {
184                 est[i] = x[i];
185         }
186
187         PROTECT(theCall = allocVector(LANGSXP, 4));
188
189         SETCAR(theCall, install("imxUpdateModelValues"));
190         SETCADR(theCall, rFitFunction->model);
191         SETCADDR(theCall, rFitFunction->flatModel);
192         SETCADDDR(theCall, estimate);
193
194         REPROTECT(rFitFunction->model = eval(theCall, R_GlobalEnv), rFitFunction->modelIndex);
195
196         UNPROTECT(2); // theCall, estimate
197 }
198
199 void FitContext::copyParamToModel(omxState* os)
200 {
201         copyParamToModel(os, est);
202 }
203
204 void FitContext::copyParamToModel(omxState* os, double *at)
205 {
206         size_t numParam = varGroup->vars.size();
207         if(OMX_DEBUG) {
208                 mxLog("Copying %lu free parameter estimates to model %p", numParam, os);
209         }
210
211         if(numParam == 0) return;
212
213         os->computeCount++;
214
215         if(OMX_VERBOSE) {
216                 std::string buf;
217                 buf += string_snprintf("Call: %d.%d (%ld) ", os->majorIteration, os->minorIteration, os->computeCount);
218                 buf += ("Estimates: [");
219                 for(size_t k = 0; k < numParam; k++) {
220                         buf += string_snprintf(" %f", at[k]);
221                 }
222                 buf += ("]\n");
223                 mxLogBig(buf);
224         }
225
226         for(size_t k = 0; k < numParam; k++) {
227                 omxFreeVar* freeVar = varGroup->vars[k];
228                 for(size_t l = 0; l < freeVar->locations.size(); l++) {
229                         omxFreeVarLocation *loc = &freeVar->locations[l];
230                         omxMatrix *matrix = os->matrixList[loc->matrix];
231                         int row = loc->row;
232                         int col = loc->col;
233                         omxSetMatrixElement(matrix, row, col, at[k]);
234                         if(OMX_DEBUG) {
235                                 mxLog("Setting location (%d, %d) of matrix %d to value %f for var %lu",
236                                         row, col, loc->matrix, at[k], k);
237                         }
238                 }
239         }
240
241         if (RFitFunction) omxRepopulateRFitFunction(RFitFunction, at, numParam);
242
243         varGroup->markDirty(os);
244
245         if (!os->childList) return;
246
247         for(int i = 0; i < Global->numChildren; i++) {
248                 copyParamToModel(os->childList[i]);
249         }
250 }
251
252 FitContext::~FitContext()
253 {
254         delete [] est;
255         delete [] grad;
256         delete [] hess;
257 }
258
259 omxFitFunction *FitContext::RFitFunction = NULL;
260
261 void FitContext::setRFitFunction(omxFitFunction *rff)
262 {
263         if (rff) {
264                 Global->numThreads = 1;
265                 if (RFitFunction) {
266                         error("You can only create 1 MxRFitFunction per independent model");
267                 }
268         }
269         RFitFunction = rff;
270 }
271
272 omxCompute::~omxCompute()
273 {}
274
275 void omxCompute::initFromFrontend(SEXP rObj)
276 {
277         SEXP slotValue;
278         PROTECT(slotValue = GET_SLOT(rObj, install("id")));
279         int id = INTEGER(slotValue)[0];
280         varGroup = Global->findVarGroup(id);
281         if (!varGroup) varGroup = Global->freeGroup[0];
282 }
283
284 class omxComputeSequence : public omxCompute {
285         typedef omxCompute super;
286         std::vector< omxCompute* > clist;
287
288  public:
289         virtual void initFromFrontend(SEXP rObj);
290         virtual void compute(FitContext *fc);
291         virtual void reportResults(FitContext *fc, MxRList *out);
292         virtual double getOptimizerStatus();
293         virtual ~omxComputeSequence();
294 };
295
296 class omxComputeIterate : public omxCompute {
297         typedef omxCompute super;
298         std::vector< omxCompute* > clist;
299         int maxIter;
300         double tolerance;
301         int verbose;
302
303  public:
304         virtual void initFromFrontend(SEXP rObj);
305         virtual void compute(FitContext *fc);
306         virtual void reportResults(FitContext *fc, MxRList *out);
307         virtual double getOptimizerStatus();
308         virtual ~omxComputeIterate();
309 };
310
311 class omxComputeOnce : public omxCompute {
312         typedef omxCompute super;
313         std::vector< omxMatrix* > algebras;
314         std::vector< omxExpectation* > expectations;
315         bool adjustStart;
316         const char *context;
317         bool gradient;
318         bool hessian;
319
320  public:
321         virtual void initFromFrontend(SEXP rObj);
322         virtual void compute(FitContext *fc);
323         virtual void reportResults(FitContext *fc, MxRList *out);
324 };
325
326 static class omxCompute *newComputeSequence()
327 { return new omxComputeSequence(); }
328
329 static class omxCompute *newComputeIterate()
330 { return new omxComputeIterate(); }
331
332 static class omxCompute *newComputeOnce()
333 { return new omxComputeOnce(); }
334
335 struct omxComputeTableEntry {
336         char name[32];
337         omxCompute *(*ctor)();
338 };
339
340 static const struct omxComputeTableEntry omxComputeTable[] = {
341         {"MxComputeEstimatedHessian", &newComputeEstimatedHessian},
342         {"MxComputeGradientDescent", &newComputeGradientDescent},
343         {"MxComputeSequence", &newComputeSequence },
344         {"MxComputeIterate", &newComputeIterate },
345         {"MxComputeOnce", &newComputeOnce },
346         {"MxComputeNewtonRaphson", &newComputeNewtonRaphson},
347 };
348
349 omxCompute *omxNewCompute(omxState* os, const char *type)
350 {
351         omxCompute *got = NULL;
352
353         for (size_t fx=0; fx < OMX_STATIC_ARRAY_SIZE(omxComputeTable); fx++) {
354                 const struct omxComputeTableEntry *entry = omxComputeTable + fx;
355                 if(strcmp(type, entry->name) == 0) {
356                         got = entry->ctor();
357                         break;
358                 }
359         }
360
361         if (!got) error("Compute %s is not implemented", type);
362
363         return got;
364 }
365
366 void omxComputeSequence::initFromFrontend(SEXP rObj)
367 {
368         super::initFromFrontend(rObj);
369
370         SEXP slotValue;
371         PROTECT(slotValue = GET_SLOT(rObj, install("steps")));
372
373         for (int cx = 0; cx < length(slotValue); cx++) {
374                 SEXP step = VECTOR_ELT(slotValue, cx);
375                 SEXP s4class;
376                 PROTECT(s4class = STRING_ELT(getAttrib(step, install("class")), 0));
377                 omxCompute *compute = omxNewCompute(globalState, CHAR(s4class));
378                 compute->initFromFrontend(step);
379                 if (isErrorRaised(globalState)) break;
380                 clist.push_back(compute);
381         }
382 }
383
384 void omxComputeSequence::compute(FitContext *fc)
385 {
386         for (size_t cx=0; cx < clist.size(); ++cx) {
387                 FitContext *context = fc;
388                 if (fc->varGroup != clist[cx]->varGroup) {
389                         context = new FitContext(fc, clist[cx]->varGroup);
390                 }
391                 clist[cx]->compute(context);
392                 if (context != fc) context->updateParentAndFree();
393                 if (isErrorRaised(globalState)) break;
394         }
395 }
396
397 void omxComputeSequence::reportResults(FitContext *fc, MxRList *out)
398 {
399         // put this stuff in a new list?
400         // merge with Iterate TODO
401         for (size_t cx=0; cx < clist.size(); ++cx) {
402                 FitContext *context = fc;
403                 if (fc->varGroup != clist[cx]->varGroup) {
404                         context = new FitContext(fc, clist[cx]->varGroup);
405                 }
406                 clist[cx]->reportResults(context, out);
407                 if (context != fc) context->updateParentAndFree();
408                 if (isErrorRaised(globalState)) break;
409         }
410 }
411
412 double omxComputeSequence::getOptimizerStatus()
413 {
414         // for backward compatibility, not indended to work generally
415         for (size_t cx=0; cx < clist.size(); ++cx) {
416                 double got = clist[cx]->getOptimizerStatus();
417                 if (got != NA_REAL) return got;
418         }
419         return NA_REAL;
420 }
421
422 omxComputeSequence::~omxComputeSequence()
423 {
424         for (size_t cx=0; cx < clist.size(); ++cx) {
425                 delete clist[cx];
426         }
427 }
428
429 void omxComputeIterate::initFromFrontend(SEXP rObj)
430 {
431         SEXP slotValue;
432
433         super::initFromFrontend(rObj);
434
435         PROTECT(slotValue = GET_SLOT(rObj, install("maxIter")));
436         maxIter = INTEGER(slotValue)[0];
437
438         PROTECT(slotValue = GET_SLOT(rObj, install("tolerance")));
439         tolerance = REAL(slotValue)[0];
440         if (tolerance <= 0) error("tolerance must be positive");
441
442         PROTECT(slotValue = GET_SLOT(rObj, install("steps")));
443
444         for (int cx = 0; cx < length(slotValue); cx++) {
445                 SEXP step = VECTOR_ELT(slotValue, cx);
446                 SEXP s4class;
447                 PROTECT(s4class = STRING_ELT(getAttrib(step, install("class")), 0));
448                 omxCompute *compute = omxNewCompute(globalState, CHAR(s4class));
449                 compute->initFromFrontend(step);
450                 if (isErrorRaised(globalState)) break;
451                 clist.push_back(compute);
452         }
453
454         PROTECT(slotValue = GET_SLOT(rObj, install("verbose")));
455         verbose = asInteger(slotValue);
456 }
457
458 void omxComputeIterate::compute(FitContext *fc)
459 {
460         int iter = 0;
461         double prevFit = 0;
462         double change = tolerance * 10;
463         while (1) {
464                 for (size_t cx=0; cx < clist.size(); ++cx) {
465                         FitContext *context = fc;
466                         if (fc->varGroup != clist[cx]->varGroup) {
467                                 context = new FitContext(fc, clist[cx]->varGroup);
468                         }
469                         clist[cx]->compute(context);
470                         if (context != fc) context->updateParentAndFree();
471                         if (isErrorRaised(globalState)) break;
472                 }
473                 if (fc->fit == 0) {
474                         warning("Fit estimated at 0; something is wrong");
475                         break;
476                 }
477                 if (prevFit != 0) {
478                         change = prevFit - fc->fit;
479                         if (verbose) mxLog("fit %.9g change %.9g", fc->fit, change);
480                 }
481                 prevFit = fc->fit;
482                 if (isErrorRaised(globalState) || ++iter > maxIter || fabs(change) < tolerance) break;
483         }
484 }
485
486 void omxComputeIterate::reportResults(FitContext *fc, MxRList *out)
487 {
488         for (size_t cx=0; cx < clist.size(); ++cx) {
489                 FitContext *context = fc;
490                 if (fc->varGroup != clist[cx]->varGroup) {
491                         context = new FitContext(fc, clist[cx]->varGroup);
492                 }
493                 clist[cx]->reportResults(context, out);
494                 if (context != fc) context->updateParentAndFree();
495                 if (isErrorRaised(globalState)) break;
496         }
497 }
498
499 double omxComputeIterate::getOptimizerStatus()
500 {
501         // for backward compatibility, not indended to work generally
502         for (size_t cx=0; cx < clist.size(); ++cx) {
503                 double got = clist[cx]->getOptimizerStatus();
504                 if (got != NA_REAL) return got;
505         }
506         return NA_REAL;
507 }
508
509 omxComputeIterate::~omxComputeIterate()
510 {
511         for (size_t cx=0; cx < clist.size(); ++cx) {
512                 delete clist[cx];
513         }
514 }
515
516 void omxComputeOnce::initFromFrontend(SEXP rObj)
517 {
518         super::initFromFrontend(rObj);
519
520         SEXP slotValue;
521         PROTECT(slotValue = GET_SLOT(rObj, install("what")));
522         for (int wx=0; wx < length(slotValue); ++wx) {
523                 int objNum = INTEGER(slotValue)[wx];
524                 if (objNum >= 0) {
525                         omxMatrix *algebra = globalState->algebraList[objNum];
526                         if (algebra->fitFunction) {
527                                 setFreeVarGroup(algebra->fitFunction, varGroup);
528                                 omxCompleteFitFunction(algebra);
529                         }
530                         algebras.push_back(algebra);
531                 } else {
532                         omxExpectation *expectation = globalState->expectationList[~objNum];
533                         setFreeVarGroup(expectation, varGroup);
534                         omxCompleteExpectation(expectation);
535                         expectations.push_back(expectation);
536                 }
537         }
538
539         context = "";
540
541         PROTECT(slotValue = GET_SLOT(rObj, install("context")));
542         if (length(slotValue) == 0) {
543                 // OK
544         } else if (length(slotValue) == 1) {
545                 SEXP elem;
546                 PROTECT(elem = STRING_ELT(slotValue, 0));
547                 context = CHAR(elem);
548         }
549
550         PROTECT(slotValue = GET_SLOT(rObj, install("gradient")));
551         gradient = asLogical(slotValue);
552
553         PROTECT(slotValue = GET_SLOT(rObj, install("hessian")));
554         hessian = asLogical(slotValue);
555
556         if (algebras.size() == 1 && algebras[0]->fitFunction) {
557                 omxFitFunction *ff = algebras[0]->fitFunction;
558                 if (gradient && !ff->gradientAvailable) {
559                         error("Gradient requested but not available");
560                 }
561                 if (hessian && !ff->hessianAvailable) {
562                         error("Hessian requested but not available");
563                 }
564         }
565
566         PROTECT(slotValue = GET_SLOT(rObj, install("adjustStart")));
567         adjustStart = asLogical(slotValue);
568 }
569
570 void omxComputeOnce::compute(FitContext *fc)
571 {
572         if (algebras.size()) {
573                 int want = FF_COMPUTE_FIT;
574                 size_t numParam = fc->varGroup->vars.size();
575                 if (gradient) {
576                         want |= FF_COMPUTE_GRADIENT;
577                         OMXZERO(fc->grad, numParam);
578                 }
579                 if (hessian) {
580                         want |= FF_COMPUTE_HESSIAN;
581                         OMXZERO(fc->hess, numParam * numParam);
582                 }
583
584                 for (size_t wx=0; wx < algebras.size(); ++wx) {
585                         omxMatrix *algebra = algebras[wx];
586                         if (algebra->fitFunction) {
587                                 if (adjustStart) {
588                                         omxFitFunctionCompute(algebra->fitFunction, FF_COMPUTE_PREOPTIMIZE, fc);
589                                         fc->copyParamToModel(globalState);
590                                 }
591
592                                 omxFitFunctionCompute(algebra->fitFunction, want, fc);
593                                 fc->fit = algebra->data[0];
594                                 if (hessian) fc->fixHessianSymmetry();
595                         } else {
596                                 omxForceCompute(algebra);
597                         }
598                 }
599         } else if (expectations.size()) {
600                 for (size_t wx=0; wx < expectations.size(); ++wx) {
601                         omxExpectation *expectation = expectations[wx];
602                         omxExpectationCompute(expectation, context);
603                 }
604         }
605 }
606
607 void omxComputeOnce::reportResults(FitContext *fc, MxRList *out)
608 {
609         if (algebras.size()==0 || algebras[0]->fitFunction == NULL) return;
610
611         omxMatrix *algebra = algebras[0];
612
613         omxPopulateFitFunction(algebra, out);
614
615         out->push_back(std::make_pair(mkChar("minimum"), ScalarReal(fc->fit)));
616         out->push_back(std::make_pair(mkChar("Minus2LogLikelihood"), ScalarReal(fc->fit)));
617
618         size_t numFree = fc->varGroup->vars.size();
619         if (numFree) {
620                 SEXP estimate;
621                 PROTECT(estimate = allocVector(REALSXP, numFree));
622                 memcpy(REAL(estimate), fc->est, sizeof(double)*numFree);
623                 out->push_back(std::make_pair(mkChar("estimate"), estimate));
624
625                 if (gradient) {
626                         SEXP Rgradient;
627                         PROTECT(Rgradient = allocVector(REALSXP, numFree));
628                         memcpy(REAL(Rgradient), fc->grad, sizeof(double) * numFree);
629                         out->push_back(std::make_pair(mkChar("gradient"), Rgradient));
630                 }
631
632                 if (hessian) {
633                         SEXP Rhessian;
634                         PROTECT(Rhessian = allocMatrix(REALSXP, numFree, numFree));
635                         memcpy(REAL(Rhessian), fc->hess, sizeof(double) * numFree * numFree);
636                         out->push_back(std::make_pair(mkChar("hessian"), Rhessian));
637                 }
638         }
639 }