Rewrite Newton-Raphson with better math
[openmx:openmx.git] / src / ComputeNR.cpp
1 /*
2  *  Copyright 2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include "omxState.h"
18 #include "omxFitFunction.h"
19 #include "omxExportBackendState.h"
20 #include "Compute.h"
21
22 class ComputeNR : public omxCompute {
23         typedef omxCompute super;
24         omxMatrix *fitMatrix;
25         bool adjustStart;
26
27         int maxIter;
28         double tolerance;
29         int inform, iter;
30         int verbose;
31
32 public:
33         ComputeNR();
34         virtual void initFromFrontend(SEXP rObj);
35         virtual void compute(FitContext *fc);
36         virtual void reportResults(FitContext *fc, MxRList *out);
37         virtual double getOptimizerStatus() { return inform; }  // backward compatibility
38 };
39
40 class omxCompute *newComputeNewtonRaphson()
41 {
42         return new ComputeNR();
43 }
44
45 ComputeNR::ComputeNR()
46 {
47         inform = 0;
48         iter = 0;
49 }
50
51 void ComputeNR::initFromFrontend(SEXP rObj)
52 {
53         super::initFromFrontend(rObj);
54
55         fitMatrix = omxNewMatrixFromSlot(rObj, globalState, "fitfunction");
56         setFreeVarGroup(fitMatrix->fitFunction, varGroup);
57         omxCompleteFitFunction(fitMatrix);
58
59         if (!fitMatrix->fitFunction->hessianAvailable ||
60             !fitMatrix->fitFunction->gradientAvailable) {
61                 error("Newton-Raphson requires derivatives");
62         }
63
64         SEXP slotValue;
65         PROTECT(slotValue = GET_SLOT(rObj, install("adjustStart")));
66         adjustStart = asLogical(slotValue);
67
68         PROTECT(slotValue = GET_SLOT(rObj, install("maxIter")));
69         maxIter = INTEGER(slotValue)[0];
70
71         PROTECT(slotValue = GET_SLOT(rObj, install("tolerance")));
72         tolerance = REAL(slotValue)[0];
73         if (tolerance <= 0) error("tolerance must be positive");
74
75         PROTECT(slotValue = GET_SLOT(rObj, install("verbose")));
76         verbose = asInteger(slotValue);
77 }
78
79 void ComputeNR::compute(FitContext *fc)
80 {
81         // complain if there are non-linear constraints TODO
82
83         size_t numParam = varGroup->vars.size();
84         if (numParam <= 0) {
85                 error("Model has no free parameters");
86                 return;
87         }
88
89         if (adjustStart) {
90                 omxFitFunctionCompute(fitMatrix->fitFunction, FF_COMPUTE_PREOPTIMIZE, fc);
91                 fc->copyParamToModel(globalState);
92         }
93
94         iter = 0;
95         //double prevLL = nan("unset");
96         //bool decreasing = TRUE;
97
98         while (1) {
99                 const int want = FF_COMPUTE_GRADIENT|FF_COMPUTE_HESSIAN;
100
101                 OMXZERO(fc->grad, numParam);
102                 OMXZERO(fc->hess, numParam * numParam);
103
104                 omxFitFunctionCompute(fitMatrix->fitFunction, want, fc);
105
106                 if (verbose >= 2) {
107                         fc->log("Newton-Raphson", FF_COMPUTE_ESTIMATE);
108                 }
109
110                 // Only need LL for diagnostics; Can avoid computing it? TODO
111                 //double LL = fitMatrix->data[0];
112                 //if (isfinite(prevLL) && prevLL < LL - tolerance) decreasing = FALSE;
113                 //prevLL = LL;
114
115                 //              fc->log(FF_COMPUTE_ESTIMATE|FF_COMPUTE_GRADIENT|FF_COMPUTE_HESSIAN);
116
117                 std::vector<double> ihess(numParam * numParam);
118                 memcpy(ihess.data(), fc->hess, sizeof(double) * numParam * numParam);
119
120                 int dim = int(numParam);
121                 const char uplo = 'L';
122                 int info;
123                 F77_CALL(dpotrf)(&uplo, &dim, ihess.data(), &dim, &info);
124                 if (info < 0) error("Arg %d is invalid", -info);
125                 if (info > 0) {
126                         omxRaiseErrorf(globalState, "Hessian is not positive definite");
127                         // Worth checking for zero rows? TODO
128                         for (size_t rx=0; rx < numParam; ++rx) {
129                                 double row = 0;
130                                 for (size_t cx=0; cx < numParam; ++cx) {
131                                         row += fc->hess[rx * numParam + cx];
132                                 }
133                                 if (row == 0) warning("Check %s", fc->varGroup->vars[rx]->name);
134                         }
135                         break;
136                 }
137
138                 F77_CALL(dpotri)(&uplo, &dim, ihess.data(), &dim, &info);
139                 if (info < 0) error("Arg %d is invalid", -info);
140                 if (info > 0) {
141                         omxRaiseErrorf(globalState, "Hessian is not of full rank");
142                         break;
143                 }
144
145                 std::vector<double> adj(numParam);
146                 double alpha = -1;
147                 int incx = 1;
148                 double beta = 0;
149                 F77_CALL(dsymv)(&uplo, &dim, &alpha, ihess.data(), &dim,
150                                 fc->grad, &incx, &beta, adj.data(), &incx);
151
152                 double maxAdj = 0;
153                 for (size_t px=0; px < numParam; ++px) {
154                         double param = fc->est[px];
155                         param += adj[px];
156                         omxFreeVar *fv = fc->varGroup->vars[px];
157                         if (param < fv->lbound) param = fv->lbound;
158                         if (param > fv->ubound) param = fv->ubound;
159                         double adj = fabs(param - fc->est[px]);
160                         if (maxAdj < adj)
161                                 maxAdj = adj;
162                         fc->est[px] = param;
163                 }
164                 fc->copyParamToModel(globalState);
165                 R_CheckUserInterrupt();
166                 if (maxAdj < tolerance || ++iter > maxIter) break;
167         }
168
169         if (verbose >= 1) {
170                 mxLog("Newton-Raphson converged in %d cycles", iter);
171         }
172
173         // The check is too dependent on numerical precision to enable by default.
174         // Anyway, it's just a tool for developers.
175         //if (0 && !decreasing) warning("Newton-Raphson iterations did not converge");
176
177         omxFitFunctionCompute(fitMatrix->fitFunction, FF_COMPUTE_POSTOPTIMIZE, fc);
178 }
179
180 void ComputeNR::reportResults(FitContext *fc, MxRList *out)
181 {
182         if (Global->numIntervals) {
183                 warning("Confidence intervals are not implemented for Newton-Raphson");
184         }  
185
186         omxPopulateFitFunction(fitMatrix, out);
187
188         size_t numFree = varGroup->vars.size();
189
190         SEXP estimate;
191         PROTECT(estimate = allocVector(REALSXP, numFree));
192         memcpy(REAL(estimate), fc->est, sizeof(double) * numFree);
193
194         out->push_back(std::make_pair(mkChar("minimum"), ScalarReal(fc->fit)));
195         out->push_back(std::make_pair(mkChar("Minus2LogLikelihood"), ScalarReal(fc->fit)));
196         out->push_back(std::make_pair(mkChar("estimate"), estimate));
197
198         SEXP iterations;
199
200         // SEXP code;
201         // PROTECT(code = NEW_NUMERIC(1));
202         // REAL(code)[0] = inform;
203         // out->push_back(std::make_pair(mkChar("nr.code"), code));
204
205         PROTECT(iterations = NEW_NUMERIC(1));
206         REAL(iterations)[0] = iter;
207         out->push_back(std::make_pair(mkChar("nr.iterations"), iterations));
208 }