Allow ComputeIterate to test maximum absolute change
[openmx:openmx.git] / src / omxFitFunction.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 /***********************************************************
18
19 *  omxFitFunction.cc
20 *
21 *  Created: Timothy R. Brick    Date: 2008-11-13 12:33:06
22 *
23 *       FitFunction objects are a subclass of data matrix that evaluates
24 *   itself anew at each iteration, so that any changes to
25 *   free parameters can be incorporated into the update.
26 *   // Question: Should FitFunction be a ``subtype'' of 
27 *   // omxAlgebra or a separate beast entirely?
28 *
29 **********************************************************/
30
31 #include "omxFitFunction.h"
32 #include "fitMultigroup.h"
33
34 typedef struct omxFitFunctionTableEntry omxFitFunctionTableEntry;
35
36 struct omxFitFunctionTableEntry {
37
38         char name[32];
39         void (*initFun)(omxFitFunction*);
40         void (*setVarGroup)(omxFitFunction*, FreeVarGroup *);  // TODO ugh, just convert to C++
41
42 };
43
44 static void defaultSetFreeVarGroup(omxFitFunction *ff, FreeVarGroup *fvg)
45 {
46         if (ff->freeVarGroup && ff->freeVarGroup != fvg) {
47                 warning("setFreeVarGroup called with different group (%d vs %d) on %s",
48                         ff->matrix->name, ff->freeVarGroup->id, fvg->id);
49         }
50         ff->freeVarGroup = fvg;
51 }
52
53 static const omxFitFunctionTableEntry omxFitFunctionSymbolTable[] = {
54         {"MxFitFunctionAlgebra",                        &omxInitAlgebraFitFunction, defaultSetFreeVarGroup},
55         {"MxFitFunctionWLS",                            &omxInitWLSFitFunction, defaultSetFreeVarGroup},
56         {"MxFitFunctionRow",                            &omxInitRowFitFunction, defaultSetFreeVarGroup},
57         {"MxFitFunctionML",                             &omxInitMLFitFunction, defaultSetFreeVarGroup},
58         {"imxFitFunctionFIML", &omxInitFIMLFitFunction, defaultSetFreeVarGroup},
59         {"MxFitFunctionR",                                      &omxInitRFitFunction, defaultSetFreeVarGroup},
60         {"MxFitFunctionMultigroup", &initFitMultigroup, mgSetFreeVarGroup},
61 };
62
63 void omxFreeFitFunctionArgs(omxFitFunction *off) {
64         if(off==NULL) return;
65     
66         /* Completely destroy the fit function structures */
67         if(OMX_DEBUG) {mxLog("Freeing fit function object at %p.", off);}
68         if(off->matrix != NULL) {
69                 if(off->destructFun != NULL) {
70                         if(OMX_DEBUG) {mxLog("Calling fit function destructor for %p.", off);}
71                         off->destructFun(off);
72                 }
73                 off->matrix = NULL;
74         }
75 }
76
77 void omxFitFunctionCreateChildren(omxState *globalState)
78 {
79         if (Global->numThreads <= 1) return;
80
81         int numThreads = Global->numChildren = Global->numThreads;
82
83         globalState->childList = (omxState**) Calloc(numThreads, omxState*);
84
85         for(int ii = 0; ii < numThreads; ii++) {
86                 globalState->childList[ii] = new omxState;
87                 omxInitState(globalState->childList[ii]);
88                 omxDuplicateState(globalState->childList[ii], globalState);
89         }
90 }
91
92 void omxDuplicateFitMatrix(omxMatrix *tgt, const omxMatrix *src, omxState* newState) {
93
94         if(tgt == NULL || src == NULL) return;
95
96         omxFitFunction *ff = src->fitFunction;
97         if(ff == NULL) return;
98     
99         omxFillMatrixFromMxFitFunction(tgt, ff->fitType, src->matrixNumber);
100         setFreeVarGroup(tgt->fitFunction, src->fitFunction->freeVarGroup);
101         tgt->fitFunction->rObj = ff->rObj;
102         omxCompleteFitFunction(tgt);
103 }
104
105 void omxFitFunctionCompute(omxFitFunction *off, int want, FitContext *fc)
106 {
107         if (!off->initialized) error("FitFunction not initialized");
108
109         if(OMX_DEBUG_ALGEBRA) { 
110                 mxLog("FitFunction compute: %p (needed: %s).", off, (omxMatrixIsDirty(off->matrix)?"Yes":"No"));
111         }
112
113         off->computeFun(off, want, fc);
114         if (fc) fc->wanted = want;
115
116         omxMarkClean(off->matrix);
117 }
118
119 void defaultAddOutput(omxFitFunction* oo, MxRList *out)
120 {}
121
122 void omxFillMatrixFromMxFitFunction(omxMatrix* om, const char *fitType, int matrixNumber)
123 {
124         omxFitFunction *obj = (omxFitFunction*) R_alloc(1, sizeof(omxFitFunction));
125         memset(obj, 0, sizeof(omxFitFunction));
126
127         /* Register FitFunction and Matrix with each other */
128         obj->matrix = om;
129         omxResizeMatrix(om, 1, 1, FALSE);                                       // FitFunction matrices MUST be 1x1.
130         om->fitFunction = obj;
131         om->hasMatrixNumber = TRUE;
132         om->matrixNumber = matrixNumber;
133         
134         for (size_t fx=0; fx < OMX_STATIC_ARRAY_SIZE(omxFitFunctionSymbolTable); fx++) {
135                 const omxFitFunctionTableEntry *entry = omxFitFunctionSymbolTable + fx;
136                 if(strcmp(fitType, entry->name) == 0) {
137                         obj->fitType = entry->name;
138                         obj->initFun = entry->initFun;
139
140                         // We need to set up the FreeVarGroup before calling initFun
141                         // because older fit functions expect to know the number of
142                         // free variables during initFun.
143                         obj->setVarGroup = entry->setVarGroup; // ugh!
144                         obj->addOutput = defaultAddOutput;
145                         break;
146                 }
147         }
148
149         if (obj->initFun == NULL) error("Fit function %s not implemented", fitType);
150 }
151
152 void omxCompleteFitFunction(omxMatrix *om)
153 {
154         omxFitFunction *obj = om->fitFunction;
155         if (obj->initialized) return;
156         SEXP rObj = obj->rObj;
157
158         SEXP slotValue;
159         PROTECT(slotValue = GET_SLOT(rObj, install("expectation")));
160         if (LENGTH(slotValue) == 1) {
161                 int expNumber = INTEGER(slotValue)[0];  
162                 if(expNumber != NA_INTEGER) {
163                         obj->expectation = omxExpectationFromIndex(expNumber, om->currentState);
164                         setFreeVarGroup(obj->expectation, obj->freeVarGroup);
165                         omxCompleteExpectation(obj->expectation);
166                 }
167         }
168         UNPROTECT(1);   /* slotValue */
169         
170         obj->initFun(obj);
171
172         if(obj->computeFun == NULL) error("Failed to initialize fit function %s", obj->fitType); 
173         
174         omxMarkDirty(obj->matrix);
175         obj->initialized = TRUE;
176 }
177
178 void setFreeVarGroup(omxFitFunction *ff, FreeVarGroup *fvg)
179 {
180         (*ff->setVarGroup)(ff, fvg);
181 }
182
183 void omxFitFunctionPrint(omxFitFunction* off, const char* d) {
184         mxLog("(FitFunction, type %s)", off->fitType);
185         omxPrintMatrix(off->matrix, d);
186 }
187
188
189 /* Helper functions */
190 omxMatrix* omxNewMatrixFromSlot(SEXP rObj, omxState* currentState, const char* slotName) {
191         SEXP slotValue;
192         PROTECT(slotValue = GET_SLOT(rObj, install(slotName)));
193         omxMatrix* newMatrix = omxMatrixLookupFromState1(slotValue, currentState);
194         UNPROTECT(1);
195         return newMatrix;
196 }
197