Allow ComputeIterate to test maximum absolute change
[openmx:openmx.git] / src / ComputeGD.cpp
1 /*
2  *  Copyright 2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include "omxState.h"
18 #include "omxFitFunction.h"
19 #include "omxNPSOLSpecific.h"
20 #include "omxExportBackendState.h"
21 #include "omxCsolnp.h"
22 #include "Compute.h"
23
24 enum OptEngine {
25         OptEngine_NPSOL,
26         OptEngine_CSOLNP
27 };
28
29 class omxComputeGD : public omxCompute {
30         typedef omxCompute super;
31         enum OptEngine engine;
32         omxMatrix *fitMatrix;
33         bool useGradient;
34         int verbose;
35
36         SEXP intervals, intervalCodes; // move to FitContext? TODO
37         int inform, iter;
38
39 public:
40         omxComputeGD();
41         virtual void initFromFrontend(SEXP rObj);
42         virtual void compute(FitContext *fc);
43         virtual void reportResults(FitContext *fc, MxRList *out);
44         virtual double getOptimizerStatus() { return inform; }  // backward compatibility
45 };
46
47 class omxCompute *newComputeGradientDescent()
48 {
49         return new omxComputeGD();
50 }
51
52 omxComputeGD::omxComputeGD()
53 {
54         intervals = 0;
55         intervalCodes = 0;
56         inform = 0;
57         iter = 0;
58 }
59
60 void omxComputeGD::initFromFrontend(SEXP rObj)
61 {
62         super::initFromFrontend(rObj);
63         fitMatrix = omxNewMatrixFromSlot(rObj, globalState, "fitfunction");
64         setFreeVarGroup(fitMatrix->fitFunction, varGroup);
65         omxCompleteFitFunction(fitMatrix);
66
67         SEXP slotValue;
68         PROTECT(slotValue = GET_SLOT(rObj, install("useGradient")));
69         if (length(slotValue)) {
70                 useGradient = asLogical(slotValue);
71         } else {
72                 useGradient = Global->analyticGradients;
73         }
74
75         PROTECT(slotValue = GET_SLOT(rObj, install("verbose")));
76         verbose = asInteger(slotValue);
77
78         PROTECT(slotValue = GET_SLOT(rObj, install("engine")));
79         const char *engine_name = CHAR(asChar(slotValue));
80         if (strcmp(engine_name, "CSOLNP")==0) {
81                 engine = OptEngine_CSOLNP;
82         } else if (strcmp(engine_name, "NPSOL")==0) {
83                 engine = OptEngine_NPSOL;
84         } else {
85                 error("MxComputeGradientDescent engine %s unknown", engine_name);
86         }
87 }
88
89 void omxComputeGD::compute(FitContext *fc)
90 {
91         size_t numParam = varGroup->vars.size();
92         if (numParam <= 0) {
93                 error("Model has no free parameters");
94                 return;
95         }
96
97         omxFitFunctionCompute(fitMatrix->fitFunction, FF_COMPUTE_PREOPTIMIZE, fc);
98         fc->maybeCopyParamToModel(globalState);
99
100         if (fitMatrix->fitFunction && fitMatrix->fitFunction->usesChildModels)
101                 omxFitFunctionCreateChildren(globalState);
102
103         switch (engine) {
104         case OptEngine_NPSOL:
105                 omxInvokeNPSOL(fitMatrix, fc, &inform, &iter, useGradient, varGroup, verbose);
106                 break;
107         case OptEngine_CSOLNP:
108                 omxInvokeCSOLNP(fitMatrix, fc, verbose);
109                 break;
110         default: error("huh?");
111         }
112
113         omxFreeChildStates(globalState);
114
115         if (Global->numIntervals && engine == OptEngine_NPSOL) {
116                 if (!(inform == 0 || inform == 1 || inform == 6)) {
117                         // TODO: allow forcing
118                         warning("Not calculating confidence intervals because of NPSOL status %d", inform);
119                 } else {
120                         PROTECT(intervals = allocMatrix(REALSXP, Global->numIntervals, 2));
121                         PROTECT(intervalCodes = allocMatrix(INTSXP, Global->numIntervals, 2));
122
123                         omxNPSOLConfidenceIntervals(fitMatrix, fc);
124                         omxPopulateConfidenceIntervals(intervals, intervalCodes); // TODO move code here
125                 }
126         }  
127
128         omxMarkDirty(fitMatrix); // not sure why it needs to be dirty
129 }
130
131 void omxComputeGD::reportResults(FitContext *fc, MxRList *out)
132 {
133         omxPopulateFitFunction(fitMatrix, out);
134
135         size_t numFree = varGroup->vars.size();
136
137         SEXP estimate, gradient, hessian;
138         PROTECT(estimate = allocVector(REALSXP, numFree));
139         PROTECT(gradient = allocVector(REALSXP, numFree));
140         PROTECT(hessian = allocMatrix(REALSXP, numFree, numFree));
141
142         memcpy(REAL(estimate), fc->est, sizeof(double) * numFree);
143         memcpy(REAL(gradient), fc->grad, sizeof(double) * numFree);
144         memcpy(REAL(hessian), fc->hess, sizeof(double) * numFree * numFree);
145
146         out->push_back(std::make_pair(mkChar("minimum"), ScalarReal(fc->fit)));
147         out->push_back(std::make_pair(mkChar("Minus2LogLikelihood"), ScalarReal(fc->fit)));
148         out->push_back(std::make_pair(mkChar("estimate"), estimate));
149         out->push_back(std::make_pair(mkChar("gradient"), gradient));
150         out->push_back(std::make_pair(mkChar("hessianCholesky"), hessian));
151
152         if (intervals && intervalCodes) {
153                 out->push_back(std::make_pair(mkChar("confidenceIntervals"), intervals));
154                 out->push_back(std::make_pair(mkChar("confidenceIntervalCodes"), intervalCodes));
155         }
156
157         SEXP code, iterations;
158
159         PROTECT(code = NEW_NUMERIC(1));
160         REAL(code)[0] = inform;
161         out->push_back(std::make_pair(mkChar("npsol.code"), code));
162
163         PROTECT(iterations = NEW_NUMERIC(1));
164         REAL(iterations)[0] = iter;
165         out->push_back(std::make_pair(mkChar("npsol.iterations"), iterations));
166         out->push_back(std::make_pair(mkChar("iterations"), iterations)); // backward compatibility
167 }