Split up fit function initialization similar to expectations
[openmx:openmx.git] / src / omxFitFunction.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 /***********************************************************
18
19 *  omxFitFunction.cc
20 *
21 *  Created: Timothy R. Brick    Date: 2008-11-13 12:33:06
22 *
23 *       FitFunction objects are a subclass of data matrix that evaluates
24 *   itself anew at each iteration, so that any changes to
25 *   free parameters can be incorporated into the update.
26 *   // Question: Should FitFunction be a ``subtype'' of 
27 *   // omxAlgebra or a separate beast entirely?
28 *
29 **********************************************************/
30
31 #include "omxFitFunction.h"
32 #include "omxOptimizer.h"
33 #include "fitMultigroup.h"
34
35 typedef struct omxFitFunctionTableEntry omxFitFunctionTableEntry;
36
37 struct omxFitFunctionTableEntry {
38
39         char name[32];
40         void (*initFun)(omxFitFunction*);
41
42 };
43
44 static const omxFitFunctionTableEntry omxFitFunctionSymbolTable[] = {
45         {"MxFitFunctionAlgebra",                        &omxInitAlgebraFitFunction},
46         {"MxFitFunctionWLS",                            &omxInitWLSFitFunction},
47         {"MxFitFunctionRow",                            &omxInitRowFitFunction},
48         {"MxFitFunctionML",                             &omxInitMLFitFunction},
49         {"imxFitFunctionFIML", &omxInitFIMLFitFunction},
50         {"MxFitFunctionR",                                      &omxInitRFitFunction},
51         {"MxFitFunctionMultigroup", &initFitMultigroup}
52 };
53
54 void omxFreeFitFunctionArgs(omxFitFunction *off) {
55         if(off==NULL) return;
56     
57         /* Completely destroy the fit function structures */
58         if(OMX_DEBUG) {mxLog("Freeing fit function object at 0x%x.", off);}
59         if(off->matrix != NULL) {
60                 if(off->destructFun != NULL) {
61                         if(OMX_DEBUG) {mxLog("Calling fit function destructor for 0x%x.", off);}
62                         off->destructFun(off);
63                 }
64                 off->matrix = NULL;
65         }
66 }
67
68 void omxFitFunctionCreateChildren(omxState *globalState)
69 {
70         if (Global.numThreads <= 1) return;
71
72         int numThreads = Global.numChildren = Global.numThreads;
73
74         globalState->childList = (omxState**) Calloc(numThreads, omxState*);
75
76         for(int ii = 0; ii < numThreads; ii++) {
77                 globalState->childList[ii] = new omxState;
78                 omxInitState(globalState->childList[ii]);
79                 omxDuplicateState(globalState->childList[ii], globalState);
80         }
81 }
82
83 void omxDuplicateFitMatrix(omxMatrix *tgt, const omxMatrix *src, omxState* newState) {
84
85         if(tgt == NULL || src == NULL) return;
86
87         omxFitFunction *ff = src->fitFunction;
88         if(ff == NULL) return;
89     
90         omxFillMatrixFromMxFitFunction(tgt, ff->fitType, src->matrixNumber);
91         omxCompleteFitFunction(tgt, ff->rObj);
92 }
93
94 void omxFitFunctionCompute(omxFitFunction *off, int want, double* gradient) {
95         if (!off->initialized) error("FitFunction not initialized");
96
97         if(OMX_DEBUG_ALGEBRA) { 
98             mxLog("FitFunction compute: 0x%0x (needed: %s).", off, (off->matrix->isDirty?"Yes":"No"));
99         }
100
101         off->computeFun(off, want, gradient);
102
103         omxMarkClean(off->matrix);
104 }
105
106 void omxFillMatrixFromMxFitFunction(omxMatrix* om, const char *fitType, int matrixNumber)
107 {
108         omxFitFunction *obj = (omxFitFunction*) R_alloc(1, sizeof(omxFitFunction));
109         memset(obj, 0, sizeof(omxFitFunction));
110
111         /* Register FitFunction and Matrix with each other */
112         obj->matrix = om;
113         omxResizeMatrix(om, 1, 1, FALSE);                                       // FitFunction matrices MUST be 1x1.
114         om->fitFunction = obj;
115         om->hasMatrixNumber = TRUE;
116         om->matrixNumber = matrixNumber;
117         
118         for (size_t fx=0; fx < OMX_STATIC_ARRAY_SIZE(omxFitFunctionSymbolTable); fx++) {
119                 const omxFitFunctionTableEntry *entry = omxFitFunctionSymbolTable + fx;
120                 if(strcmp(fitType, entry->name) == 0) {
121                         obj->fitType = entry->name;
122                         obj->initFun = entry->initFun;
123                         break;
124                 }
125         }
126
127         if (obj->initFun == NULL) error("Fit function %s not implemented", fitType);
128 }
129
130 void omxCompleteFitFunction(omxMatrix *om, SEXP rObj)
131 {
132         omxFitFunction *obj = om->fitFunction;
133         if (obj->initialized) return;
134         obj->rObj = rObj;
135
136         SEXP slotValue;
137         PROTECT(slotValue = GET_SLOT(rObj, install("expectation")));
138         if (LENGTH(slotValue) == 1) {
139                 int expNumber = INTEGER(slotValue)[0];  
140                 if(expNumber != NA_INTEGER) {
141                         obj->expectation = omxExpectationFromIndex(expNumber, om->currentState);
142                 }
143         }
144         UNPROTECT(1);   /* slotValue */
145         
146         obj->initFun(obj);
147
148         if(obj->computeFun == NULL) error("Failed to initialize fit function %s", obj->fitType); 
149         
150         obj->matrix->isDirty = TRUE;
151         obj->initialized = TRUE;
152 }
153
154 void omxFitFunctionPrint(omxFitFunction* off, const char* d) {
155         mxLog("(FitFunction, type %s)", off->fitType);
156         omxPrintMatrix(off->matrix, d);
157 }
158
159
160 /* Helper functions */
161 omxMatrix* omxNewMatrixFromSlot(SEXP rObj, omxState* currentState, const char* slotName) {
162         SEXP slotValue;
163         PROTECT(slotValue = GET_SLOT(rObj, install(slotName)));
164         omxMatrix* newMatrix = omxMatrixLookupFromState1(slotValue, currentState);
165         UNPROTECT(1);
166         return newMatrix;
167 }
168