Remove most instances of setFinalReturns
[openmx:openmx.git] / src / ComputeGD.cpp
1 /*
2  *  Copyright 2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include "omxState.h"
18 #include "omxFitFunction.h"
19 #include "omxNPSOLSpecific.h"
20 #include "omxExportBackendState.h"
21 #include "Compute.h"
22
23 class omxComputeGD : public omxComputeOperation {
24         typedef omxComputeOperation super;
25         omxMatrix *fitMatrix;
26
27         SEXP intervals, intervalCodes; // move to FitContext? TODO
28         int inform, iter;
29
30 public:
31         omxComputeGD();
32         virtual void initFromFrontend(SEXP rObj);
33         virtual void compute(FitContext *fc);
34         virtual void reportResults(FitContext *fc, MxRList *out);
35         virtual double getOptimizerStatus() { return inform; }  // backward compatibility
36 };
37
38 class omxCompute *newComputeGradientDescent()
39 {
40         return new omxComputeGD();
41 }
42
43 omxComputeGD::omxComputeGD()
44 {
45         intervals = 0;
46         intervalCodes = 0;
47         inform = 0;
48         iter = 0;
49 }
50
51 void omxComputeGD::initFromFrontend(SEXP rObj)
52 {
53         super::initFromFrontend(rObj);
54         fitMatrix = omxNewMatrixFromSlot(rObj, globalState, "fitfunction");
55         setFreeVarGroup(fitMatrix->fitFunction, varGroup);
56         omxCompleteFitFunction(fitMatrix);
57 }
58
59 void omxComputeGD::compute(FitContext *fc)
60 {
61         size_t numParam = varGroup->vars.size();
62         if (numParam <= 0) {
63                 error("Model has no free parameters");
64                 return;
65         }
66
67         omxFitFunctionCompute(fitMatrix->fitFunction, FF_COMPUTE_PREOPTIMIZE, fc);
68
69         if (fitMatrix->fitFunction && fitMatrix->fitFunction->usesChildModels)
70                 omxFitFunctionCreateChildren(globalState);
71
72         omxInvokeNPSOL(fitMatrix, fc, &inform, &iter);
73
74         omxFreeChildStates(globalState);
75
76         if (Global->numIntervals) {
77                 if (!(inform == 0 || inform == 1 || inform == 6)) {
78                         // TODO: Throw a warning, allow force()
79                         warning("Not calculating confidence intervals because of NPSOL status %d", inform);
80                 } else {
81                         PROTECT(intervals = allocMatrix(REALSXP, Global->numIntervals, 2));
82                         PROTECT(intervalCodes = allocMatrix(INTSXP, Global->numIntervals, 2));
83
84                         omxNPSOLConfidenceIntervals(fitMatrix, fc);
85                         omxPopulateConfidenceIntervals(intervals, intervalCodes); // TODO move code here
86                 }
87         }  
88 }
89
90 void omxComputeGD::reportResults(FitContext *fc, MxRList *out)
91 {
92         omxPopulateFitFunction(fitMatrix, out);
93
94         size_t numFree = varGroup->vars.size();
95
96         SEXP estimate, gradient, hessian;
97         PROTECT(estimate = allocVector(REALSXP, numFree));
98         PROTECT(gradient = allocVector(REALSXP, numFree));
99         PROTECT(hessian = allocMatrix(REALSXP, numFree, numFree));
100
101         memcpy(REAL(estimate), fc->est, sizeof(double) * numFree);
102         memcpy(REAL(gradient), fc->grad, sizeof(double) * numFree);
103         memcpy(REAL(hessian), fc->hess, sizeof(double) * numFree * numFree);
104
105         out->push_back(std::make_pair(mkChar("minimum"), ScalarReal(fc->fit)));
106         out->push_back(std::make_pair(mkChar("Minus2LogLikelihood"), ScalarReal(fc->fit)));
107         out->push_back(std::make_pair(mkChar("estimate"), estimate));
108         out->push_back(std::make_pair(mkChar("gradient"), gradient));
109         out->push_back(std::make_pair(mkChar("hessianCholesky"), hessian));
110
111         if (intervals && intervalCodes) {
112                 out->push_back(std::make_pair(mkChar("confidenceIntervals"), intervals));
113                 out->push_back(std::make_pair(mkChar("confidenceIntervalCodes"), intervalCodes));
114         }
115
116         SEXP code, iterations;
117
118         PROTECT(code = NEW_NUMERIC(1));
119         REAL(code)[0] = inform;
120         out->push_back(std::make_pair(mkChar("npsol.code"), code));
121
122         PROTECT(iterations = NEW_NUMERIC(1));
123         REAL(iterations)[0] = iter;
124         out->push_back(std::make_pair(mkChar("npsol.iterations"), iterations));
125         out->push_back(std::make_pair(mkChar("iterations"), iterations)); // backward compatibility
126 }