Enable with IMX_OPT_ENGINE=CSOLNP
[openmx:openmx.git] / src / ComputeGD.cpp
1 /*
2  *  Copyright 2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include "omxState.h"
18 #include "omxFitFunction.h"
19 #include "omxNPSOLSpecific.h"
20 #include "omxExportBackendState.h"
21 #include "omxCsolnp.h"
22 #include "Compute.h"
23
24 enum OptEngine {
25         OptEngine_NPSOL,
26         OptEngine_CSOLNP
27 };
28
29 class omxComputeGD : public omxCompute {
30         typedef omxCompute super;
31         enum OptEngine engine;
32         omxMatrix *fitMatrix;
33         bool adjustStart;
34         bool useGradient;
35         int verbose;
36
37         SEXP intervals, intervalCodes; // move to FitContext? TODO
38         int inform, iter;
39
40 public:
41         omxComputeGD();
42         virtual void initFromFrontend(SEXP rObj);
43         virtual void compute(FitContext *fc);
44         virtual void reportResults(FitContext *fc, MxRList *out);
45         virtual double getOptimizerStatus() { return inform; }  // backward compatibility
46 };
47
48 class omxCompute *newComputeGradientDescent()
49 {
50         return new omxComputeGD();
51 }
52
53 omxComputeGD::omxComputeGD()
54 {
55         intervals = 0;
56         intervalCodes = 0;
57         inform = 0;
58         iter = 0;
59 }
60
61 void omxComputeGD::initFromFrontend(SEXP rObj)
62 {
63         super::initFromFrontend(rObj);
64         fitMatrix = omxNewMatrixFromSlot(rObj, globalState, "fitfunction");
65         setFreeVarGroup(fitMatrix->fitFunction, varGroup);
66         omxCompleteFitFunction(fitMatrix);
67
68         SEXP slotValue;
69         PROTECT(slotValue = GET_SLOT(rObj, install("adjustStart")));
70         adjustStart = asLogical(slotValue);
71
72         PROTECT(slotValue = GET_SLOT(rObj, install("useGradient")));
73         if (LOGICAL(slotValue)[0] == NA_LOGICAL) {
74                 useGradient = Global->analyticGradients;
75         } else {
76                 useGradient = LOGICAL(slotValue)[0];
77         }
78
79         PROTECT(slotValue = GET_SLOT(rObj, install("verbose")));
80         verbose = asInteger(slotValue);
81
82         PROTECT(slotValue = GET_SLOT(rObj, install("engine")));
83         const char *engine_name = CHAR(asChar(slotValue));
84         if (strcmp(engine_name, "CSOLNP")==0) {
85                 engine = OptEngine_CSOLNP;
86         } else if (strcmp(engine_name, "NPSOL")==0) {
87                 engine = OptEngine_NPSOL;
88         } else {
89                 error("MxComputeGradientDescent engine %s unknown", engine_name);
90         }
91 }
92
93 void omxComputeGD::compute(FitContext *fc)
94 {
95         size_t numParam = varGroup->vars.size();
96         if (numParam <= 0) {
97                 error("Model has no free parameters");
98                 return;
99         }
100
101         if (adjustStart) {
102                 omxFitFunctionCompute(fitMatrix->fitFunction, FF_COMPUTE_PREOPTIMIZE, fc);
103                 fc->copyParamToModel(globalState);
104         }
105
106         if (fitMatrix->fitFunction && fitMatrix->fitFunction->usesChildModels)
107                 omxFitFunctionCreateChildren(globalState);
108
109         switch (engine) {
110         case OptEngine_NPSOL:
111                 omxInvokeNPSOL(fitMatrix, fc, &inform, &iter, useGradient, varGroup, verbose);
112                 break;
113         case OptEngine_CSOLNP:
114                 omxInvokeCSOLNP(fitMatrix, fc);
115                 break;
116         default: error("huh?");
117         }
118
119         omxFreeChildStates(globalState);
120
121         if (Global->numIntervals && engine == OptEngine_NPSOL) {
122                 if (!(inform == 0 || inform == 1 || inform == 6)) {
123                         // TODO: Throw a warning, allow force()
124                         warning("Not calculating confidence intervals because of NPSOL status %d", inform);
125                 } else {
126                         PROTECT(intervals = allocMatrix(REALSXP, Global->numIntervals, 2));
127                         PROTECT(intervalCodes = allocMatrix(INTSXP, Global->numIntervals, 2));
128
129                         omxNPSOLConfidenceIntervals(fitMatrix, fc);
130                         omxPopulateConfidenceIntervals(intervals, intervalCodes); // TODO move code here
131                 }
132         }  
133
134         omxFitFunctionCompute(fitMatrix->fitFunction, FF_COMPUTE_POSTOPTIMIZE, fc);
135         omxMarkDirty(fitMatrix); // not sure why it needs to be dirty
136 }
137
138 void omxComputeGD::reportResults(FitContext *fc, MxRList *out)
139 {
140         omxPopulateFitFunction(fitMatrix, out);
141
142         size_t numFree = varGroup->vars.size();
143
144         SEXP estimate, gradient, hessian;
145         PROTECT(estimate = allocVector(REALSXP, numFree));
146         PROTECT(gradient = allocVector(REALSXP, numFree));
147         PROTECT(hessian = allocMatrix(REALSXP, numFree, numFree));
148
149         memcpy(REAL(estimate), fc->est, sizeof(double) * numFree);
150         memcpy(REAL(gradient), fc->grad, sizeof(double) * numFree);
151         memcpy(REAL(hessian), fc->hess, sizeof(double) * numFree * numFree);
152
153         out->push_back(std::make_pair(mkChar("minimum"), ScalarReal(fc->fit)));
154         out->push_back(std::make_pair(mkChar("Minus2LogLikelihood"), ScalarReal(fc->fit)));
155         out->push_back(std::make_pair(mkChar("estimate"), estimate));
156         out->push_back(std::make_pair(mkChar("gradient"), gradient));
157         out->push_back(std::make_pair(mkChar("hessianCholesky"), hessian));
158
159         if (intervals && intervalCodes) {
160                 out->push_back(std::make_pair(mkChar("confidenceIntervals"), intervals));
161                 out->push_back(std::make_pair(mkChar("confidenceIntervalCodes"), intervalCodes));
162         }
163
164         SEXP code, iterations;
165
166         PROTECT(code = NEW_NUMERIC(1));
167         REAL(code)[0] = inform;
168         out->push_back(std::make_pair(mkChar("npsol.code"), code));
169
170         PROTECT(iterations = NEW_NUMERIC(1));
171         REAL(iterations)[0] = iter;
172         out->push_back(std::make_pair(mkChar("npsol.iterations"), iterations));
173         out->push_back(std::make_pair(mkChar("iterations"), iterations)); // backward compatibility
174 }