Switch over to new structured Compute system
[openmx:openmx.git] / src / Compute.cpp
1 /*
2  *  Copyright 2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include "omxDefines.h"
18 #include "Compute.h"
19 #include "omxState.h"
20 #include "omxExportBackendState.h"
21
22 class omxComputeSequence : public omxCompute {
23         std::vector< omxCompute* > clist;
24         double *est;
25
26  public:
27         virtual void initFromFrontend(SEXP rObj);
28         virtual void compute(double *startVals);
29         virtual void reportResults(MxRList *out);
30         virtual double getFit() { return 0; }
31         virtual double *getEstimate() { return est; }
32         virtual double getOptimizerStatus();
33         virtual ~omxComputeSequence();
34 };
35
36 class omxComputeOnce : public omxCompute {
37         omxMatrix *fitMatrix;
38         double fit;
39         double *est;
40
41  public:
42         virtual void initFromFrontend(SEXP rObj);
43         virtual void compute(double *startVals);
44         virtual void reportResults(MxRList *out);
45         virtual double getFit() { return fit; }
46         virtual double *getEstimate() { return est; }
47 };
48
49 static class omxCompute *newComputeSequence()
50 { return new omxComputeSequence(); }
51
52 static class omxCompute *newComputeOnce()
53 { return new omxComputeOnce(); }
54
55 struct omxComputeTableEntry {
56         char name[32];
57         omxCompute *(*ctor)();
58 };
59
60 static const struct omxComputeTableEntry omxComputeTable[] = {
61         {"MxComputeEstimatedHessian", &newComputeEstimatedHessian},
62         {"MxComputeGradientDescent", &newComputeGradientDescent},
63         {"MxComputeSequence", &newComputeSequence },
64         {"MxComputeOnce", &newComputeOnce },
65 };
66
67 omxCompute *omxNewCompute(omxState* os, const char *type)
68 {
69         omxCompute *got = NULL;
70
71         for (size_t fx=0; fx < OMX_STATIC_ARRAY_SIZE(omxComputeTable); fx++) {
72                 const struct omxComputeTableEntry *entry = omxComputeTable + fx;
73                 if(strcmp(type, entry->name) == 0) {
74                         got = entry->ctor();
75                         break;
76                 }
77         }
78
79         if (!got) error("Compute %s is not implemented", type);
80
81         return got;
82 }
83
84 void omxComputeSequence::initFromFrontend(SEXP rObj)
85 {
86         SEXP slotValue;
87         PROTECT(slotValue = GET_SLOT(rObj, install("steps")));
88
89         for (int cx = 0; cx < length(slotValue); cx++) {
90                 SEXP step = VECTOR_ELT(slotValue, cx);
91                 SEXP s4class;
92                 PROTECT(s4class = STRING_ELT(getAttrib(step, install("class")), 0));
93                 omxCompute *compute = omxNewCompute(globalState, CHAR(s4class));
94                 compute->initFromFrontend(step);
95                 if (isErrorRaised(globalState)) break;
96                 clist.push_back(compute);
97         }
98 }
99
100 void omxComputeSequence::compute(double *startVals)
101 {
102         est = startVals;
103         for (size_t cx=0; cx < clist.size(); ++cx) {
104                 clist[cx]->compute(est);
105                 est = clist[cx]->getEstimate();
106                 if (isErrorRaised(globalState)) break;
107         }
108 }
109
110 void omxComputeSequence::reportResults(MxRList *out)
111 {
112         for (size_t cx=0; cx < clist.size(); ++cx) {
113                 clist[cx]->reportResults(out);
114         }
115 }
116
117 double omxComputeSequence::getOptimizerStatus()
118 {
119         // for backward compatibility, not indended to work generally
120         for (size_t cx=0; cx < clist.size(); ++cx) {
121                 double got = clist[cx]->getOptimizerStatus();
122                 if (got != NA_REAL) return got;
123         }
124         return NA_REAL;
125 }
126
127 omxComputeSequence::~omxComputeSequence()
128 {
129         for (size_t cx=0; cx < clist.size(); ++cx) {
130                 delete clist[cx];
131         }
132 }
133
134 void omxComputeOnce::initFromFrontend(SEXP rObj)
135 {
136         fitMatrix = omxNewMatrixFromSlot(rObj, globalState, "fitfunction");
137 }
138
139 void omxComputeOnce::compute(double *startVals)
140 {
141         est = startVals;
142         omxFitFunctionCompute(fitMatrix->fitFunction, FF_COMPUTE_FIT, NULL);
143         fit = fitMatrix->data[0];
144 }
145
146 void omxComputeOnce::reportResults(MxRList *out)
147 {
148         omxPopulateFitFunction(fitMatrix, out);
149
150         out->push_back(std::make_pair(mkChar("minimum"), ScalarReal(fit)));
151
152         if (est) {
153                 int numFree = globalState->numFreeParams;
154                 SEXP estimate;
155                 PROTECT(estimate = allocVector(REALSXP, numFree));
156                 memcpy(REAL(estimate), est, sizeof(double)*numFree);
157                 out->push_back(std::make_pair(mkChar("estimate"), estimate));
158         }
159 }