Revert "Leave fitType alone"
[openmx:openmx.git] / src / omxFitFunction.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 /***********************************************************
18
19 *  omxFitFunction.cc
20 *
21 *  Created: Timothy R. Brick    Date: 2008-11-13 12:33:06
22 *
23 *       FitFunction objects are a subclass of data matrix that evaluates
24 *   itself anew at each iteration, so that any changes to
25 *   free parameters can be incorporated into the update.
26 *   // Question: Should FitFunction be a ``subtype'' of 
27 *   // omxAlgebra or a separate beast entirely?
28 *
29 **********************************************************/
30
31 #include "omxFitFunction.h"
32 #include "omxOptimizer.h"
33 #include "fitMultigroup.h"
34
35 typedef struct omxFitFunctionTableEntry omxFitFunctionTableEntry;
36
37 struct omxFitFunctionTableEntry {
38
39         char name[32];
40         void (*initFun)(omxFitFunction*);
41
42 };
43
44 extern void omxInitAlgebraFitFunction(omxFitFunction *off);
45 extern void omxInitWLSFitFunction(omxFitFunction *off);
46 extern void omxInitRowFitFunction(omxFitFunction *off);
47 extern void omxInitMLFitFunction(omxFitFunction *off);
48 extern void omxInitRFitFunction(omxFitFunction *off);
49
50 static const omxFitFunctionTableEntry omxFitFunctionSymbolTable[] = {
51         {"MxFitFunctionAlgebra",                        &omxInitAlgebraFitFunction},
52         {"MxFitFunctionWLS",                            &omxInitWLSFitFunction},
53         {"MxFitFunctionRow",                            &omxInitRowFitFunction},
54         {"MxFitFunctionML",                             &omxInitMLFitFunction},
55         {"MxFitFunctionR",                                      &omxInitRFitFunction},
56         {"MxFitFunctionMultigroup", &initFitMultigroup}
57 };
58
59 void omxCalculateStdErrorFromHessian(double scale, omxFitFunction *off) {
60         /* This function calculates the standard errors from the hessian matrix */
61         // sqrt(diag(solve(hessian)))
62
63         if(off->hessian == NULL) return;
64         
65         int numParams = off->matrix->currentState->numFreeParams;
66         
67         if(off->stdError == NULL) {
68                 off->stdError = (double*) R_alloc(numParams, sizeof(double));
69         }
70         
71         double* stdErr = off->stdError;
72         
73         double* hessian = off->hessian;
74         double* workspace = (double *) Calloc(numParams * numParams, double);
75         
76         for(int i = 0; i < numParams; i++)
77                 for(int j = 0; j <= i; j++)
78                         workspace[i*numParams+j] = hessian[i*numParams+j];              // Populate upper triangle
79         
80         char u = 'U';
81         int ipiv[numParams];
82         int lwork = -1;
83         double temp;
84         int info = 0;
85         
86         F77_CALL(dsytrf)(&u, &numParams, workspace, &numParams, ipiv, &temp, &lwork, &info);
87         
88         lwork = (temp > numParams?temp:numParams);
89         
90         double* work = (double*) Calloc(lwork, double);
91         
92         F77_CALL(dsytrf)(&u, &numParams, workspace, &numParams, ipiv, work, &lwork, &info);
93         
94         if(info != 0) {
95                 
96                 off->stdError = NULL;
97                 
98         } else {
99                 
100                 F77_CALL(dsytri)(&u, &numParams, workspace, &numParams, ipiv, work, &info);
101         
102                 if(info != 0) {
103                         off->stdError = NULL;
104                 } else {
105                         for(int i = 0; i < numParams; i++) {
106                                 stdErr[i] = sqrt(scale) * sqrt(workspace[i * numParams + i]);
107                         }
108                 }
109         }
110         
111         Free(workspace);
112         Free(work);
113         
114 }
115
116
117 void omxFreeFitFunctionArgs(omxFitFunction *off) {
118         if(off==NULL) return;
119     
120         /* Completely destroy the fit function structures */
121         if(OMX_DEBUG) {Rprintf("Freeing fit function object at 0x%x.\n", off);}
122         if(off->matrix != NULL) {
123                 if(off->destructFun != NULL) {
124                         if(OMX_DEBUG) {Rprintf("Calling fit function destructor for 0x%x.\n", off);}
125                         off->destructFun(off);
126                 }
127                 off->matrix = NULL;
128         }
129 }
130
131 void omxFitFunctionCreateChildren(omxState *globalState, int numThreads)
132 {
133         if (numThreads <= 1) return;
134
135         omxMatrix *fm = globalState->fitMatrix;
136         if (!fm) return;
137
138         omxFitFunction *ff = fm->fitFunction;
139         if (!ff->usesChildModels) return;
140
141         globalState->numChildren = numThreads;
142
143         globalState->childList = (omxState**) Calloc(numThreads, omxState*);
144
145         for(int ii = 0; ii < numThreads; ii++) {
146                 globalState->childList[ii] = new omxState;
147                 omxInitState(globalState->childList[ii], globalState);
148                 omxDuplicateState(globalState->childList[ii], globalState);
149         }
150 }
151
152 void omxDuplicateFitMatrix(omxMatrix *tgt, const omxMatrix *src, omxState* newState) {
153
154         if(tgt == NULL || src == NULL) return;
155         if(src->fitFunction == NULL) return;
156     
157         omxFillMatrixFromMxFitFunction(src->fitFunction->rObj, tgt, src->currentState);
158
159 }
160
161 omxFitFunction* omxCreateDuplicateFitFunction(omxFitFunction *tgt, const omxFitFunction *src, omxState* newState) {
162
163         if(OMX_DEBUG) {Rprintf("Duplicating fit function 0x%x into 0x%x.", src, tgt);}
164
165         if(src == NULL) {
166                 return NULL;
167         }
168         
169         if(tgt == NULL) {
170         tgt = (omxFitFunction*) R_alloc(1, sizeof(omxFitFunction));
171         OMXZERO(tgt, 1);
172     } else {
173                 omxRaiseError(newState, -1,
174                         "omxCreateDuplicateFitFunction requested to overwrite target");
175                 return NULL;
176         }
177
178         memcpy(tgt, src, sizeof(omxFitFunction));
179         return tgt;
180
181 }
182
183 void omxFitFunctionCompute(omxFitFunction *off, int want, double* gradient) {
184         if(OMX_DEBUG_ALGEBRA) { 
185             Rprintf("FitFunction compute: 0x%0x (needed: %s).\n", off, (off->matrix->isDirty?"Yes":"No"));
186         }
187
188         off->computeFun(off, want, gradient);
189
190         omxMarkClean(off->matrix);
191 }
192
193 omxFitFunction *omxNewInternalFitFunction(omxState* os, const char *fitType,
194                                           omxExpectation *expect, SEXP rObj, omxMatrix *matrix)
195 {
196         omxFitFunction *obj = (omxFitFunction*) R_alloc(1, sizeof(omxFitFunction));
197         OMXZERO(obj, 1);
198
199         for (size_t fx=0; fx < OMX_STATIC_ARRAY_SIZE(omxFitFunctionSymbolTable); fx++) {
200                 const omxFitFunctionTableEntry *entry = omxFitFunctionSymbolTable + fx;
201                 if(strcmp(fitType, entry->name) == 0) {
202                         obj->fitType = entry->name;
203                         obj->initFun = entry->initFun;
204                         break;
205                 }
206         }
207
208         if(obj->initFun == NULL) error("Fit function %s not implemented", fitType);
209
210         if (!matrix) {
211                 obj->matrix = omxInitMatrix(NULL, 1, 1, TRUE, os);
212                 obj->matrix->hasMatrixNumber = TRUE;
213                 obj->matrix->matrixNumber = ~os->algebraList.size();
214                 os->algebraList.push_back(obj->matrix);
215         } else {
216                 obj->matrix = matrix;
217         }
218
219         obj->matrix->fitFunction = obj;
220         
221         obj->rObj = rObj;
222         obj->expectation = expect;
223
224         return obj;
225 }
226
227 void omxFillMatrixFromMxFitFunction(SEXP rObj, omxMatrix *matrix, omxState *os)
228 {
229         SEXP slotValue, fitFunctionClass;
230
231         PROTECT(fitFunctionClass = STRING_ELT(getAttrib(rObj, install("class")), 0));
232         const char *fitType = CHAR(fitFunctionClass);
233
234         omxExpectation *expect = NULL;
235         PROTECT(slotValue = GET_SLOT(rObj, install("expectation")));
236         if (LENGTH(slotValue) == 1) {
237                 int expNumber = INTEGER(slotValue)[0];  
238                 if(expNumber != NA_INTEGER) {
239                         expect = omxExpectationFromIndex(expNumber, os);
240                 }
241         }
242
243         omxNewInternalFitFunction(os, fitType, expect, rObj, matrix);
244
245         UNPROTECT(2);
246 }
247
248 void omxInitializeFitFunction(omxMatrix *om)
249 {
250         omxFitFunction *obj = om->fitFunction;
251         if (!obj) error("Matrix 0x%p has no fit function", om);
252
253         if (obj->initialized) return;
254         obj->initialized = TRUE;
255
256         obj->initFun(obj);
257
258         if(obj->computeFun == NULL) error("Could not initialize fit function %s", obj->fitType);
259         
260         obj->matrix->isDirty = TRUE;
261 }
262
263 void omxFitFunctionPrint(omxFitFunction* off, const char* d) {
264         Rprintf("(FitFunction, type %s) ", off->fitType);
265         omxPrintMatrix(off->matrix, d);
266 }
267
268
269 /* Helper functions */
270 omxMatrix* omxNewMatrixFromSlot(SEXP rObj, omxState* currentState, const char* slotName) {
271         SEXP slotValue;
272         PROTECT(slotValue = GET_SLOT(rObj, install(slotName)));
273         omxMatrix* newMatrix = omxMatrixLookupFromState1(slotValue, currentState);
274         if (newMatrix) omxRecompute(newMatrix);
275         UNPROTECT(1);
276         return newMatrix;
277 }
278