Fix regression in GrowthMixtureModelRandomStarts
[openmx:openmx.git] / src / ComputeGD.cpp
1 /*
2  *  Copyright 2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include "omxState.h"
18 #include "omxFitFunction.h"
19 #include "omxOptimizer.h"
20 #include "omxNPSOLSpecific.h"
21 #include "omxExportBackendState.h"
22 #include "Compute.h"
23
24 class omxCompute *newComputeGradientDescent()
25 {
26         return new omxComputeGD();
27 }
28
29 void omxComputeGD::init()
30 {
31         intervals = 0;
32         intervalCodes = 0;
33         inform = 0;
34         iter = 0;
35 }
36
37 void omxComputeGD::initFromFrontend(SEXP rObj)
38 {
39         fitMatrix = omxNewMatrixFromSlot(rObj, globalState, "fitfunction");
40
41         numFree = Global.numFreeParams;
42         if (numFree <= 0) {
43                 error("Model has no free parameters");
44                 return;
45         }
46
47         PROTECT(minimum = NEW_NUMERIC(1));
48         PROTECT(estimate = allocVector(REALSXP, numFree));
49         PROTECT(gradient = allocVector(REALSXP, numFree));
50         PROTECT(hessian = allocMatrix(REALSXP, numFree, numFree));
51 }
52
53 void omxComputeGD::compute(double *startVals)
54 {
55         memcpy(REAL(estimate), startVals, sizeof(double)*numFree);
56
57         if (fitMatrix->fitFunction && fitMatrix->fitFunction->usesChildModels)
58                 omxFitFunctionCreateChildren(globalState);
59
60         omxInvokeNPSOL(fitMatrix, REAL(minimum), REAL(estimate),
61                        REAL(gradient), REAL(hessian), &inform, &iter);
62
63         omxFreeChildStates(globalState);
64
65         if (Global.numIntervals) {
66                 if (!(inform == 0 || inform == 1 || inform == 6)) {
67                         // TODO: Throw a warning, allow force()
68                         warning("Not calculating confidence intervals because of NPSOL status %d", inform);
69                 } else {
70                         PROTECT(intervals = allocMatrix(REALSXP, Global.numIntervals, 2));
71                         PROTECT(intervalCodes = allocMatrix(INTSXP, Global.numIntervals, 2));
72
73                         omxNPSOLConfidenceIntervals(fitMatrix, getFit(),
74                                                     getEstimate(), Global.ciMaxIterations);
75                         omxPopulateConfidenceIntervals(intervals, intervalCodes);
76                 }
77         }  
78 }
79
80 void omxComputeGD::reportResults(MxRList *out)
81 {
82         omxPopulateFitFunction(fitMatrix, out);
83
84         out->push_back(std::make_pair(mkChar("minimum"), minimum));
85         out->push_back(std::make_pair(mkChar("estimate"), estimate));
86         out->push_back(std::make_pair(mkChar("gradient"), gradient));
87         out->push_back(std::make_pair(mkChar("hessianCholesky"), hessian));
88
89         if (intervals && intervalCodes) {
90                 out->push_back(std::make_pair(mkChar("confidenceIntervals"), intervals));
91                 out->push_back(std::make_pair(mkChar("confidenceIntervalCodes"), intervalCodes));
92         }
93
94         SEXP code, iterations;
95
96         PROTECT(code = NEW_NUMERIC(1));
97         REAL(code)[0] = inform;
98         out->push_back(std::make_pair(mkChar("npsol.code"), code));
99
100         PROTECT(iterations = NEW_NUMERIC(1));
101         REAL(iterations)[0] = iter;
102         out->push_back(std::make_pair(mkChar("npsol.iterations"), iterations));
103         out->push_back(std::make_pair(mkChar("iterations"), iterations)); // backward compatibility
104 }