Split omxState into truly global stuff and per-thread stuff
[openmx:openmx.git] / src / omxFitFunction.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 /***********************************************************
18
19 *  omxFitFunction.cc
20 *
21 *  Created: Timothy R. Brick    Date: 2008-11-13 12:33:06
22 *
23 *       FitFunction objects are a subclass of data matrix that evaluates
24 *   itself anew at each iteration, so that any changes to
25 *   free parameters can be incorporated into the update.
26 *   // Question: Should FitFunction be a ``subtype'' of 
27 *   // omxAlgebra or a separate beast entirely?
28 *
29 **********************************************************/
30
31 #include "omxFitFunction.h"
32 #include "omxOptimizer.h"
33 #include "fitMultigroup.h"
34
35 typedef struct omxFitFunctionTableEntry omxFitFunctionTableEntry;
36
37 struct omxFitFunctionTableEntry {
38
39         char name[32];
40         void (*initFun)(omxFitFunction*);
41
42 };
43
44 extern void omxInitAlgebraFitFunction(omxFitFunction *off);
45 extern void omxInitWLSFitFunction(omxFitFunction *off);
46 extern void omxInitRowFitFunction(omxFitFunction *off);
47 extern void omxInitMLFitFunction(omxFitFunction *off);
48 extern void omxInitRFitFunction(omxFitFunction *off);
49
50 static const omxFitFunctionTableEntry omxFitFunctionSymbolTable[] = {
51         {"MxFitFunctionAlgebra",                        &omxInitAlgebraFitFunction},
52         {"MxFitFunctionWLS",                            &omxInitWLSFitFunction},
53         {"MxFitFunctionRow",                            &omxInitRowFitFunction},
54         {"MxFitFunctionML",                             &omxInitMLFitFunction},
55         {"MxFitFunctionR",                                      &omxInitRFitFunction},
56         {"MxFitFunctionMultigroup", &initFitMultigroup}
57 };
58
59 void omxInitEmptyFitFunction(omxFitFunction *off) {
60         /* Sets everything to NULL to avoid bad pointer calls */
61         
62         memset(off, 0, sizeof(omxFitFunction));
63 }
64
65 void omxFreeFitFunctionArgs(omxFitFunction *off) {
66         if(off==NULL) return;
67     
68         /* Completely destroy the fit function structures */
69         if(OMX_DEBUG) {mxLog("Freeing fit function object at 0x%x.", off);}
70         if(off->matrix != NULL) {
71                 if(off->destructFun != NULL) {
72                         if(OMX_DEBUG) {mxLog("Calling fit function destructor for 0x%x.", off);}
73                         off->destructFun(off);
74                 }
75                 off->matrix = NULL;
76         }
77 }
78
79 void omxFitFunctionCreateChildren(omxState *globalState)
80 {
81         if (Global.numThreads <= 1) return;
82
83         int numThreads = Global.numChildren = Global.numThreads;
84
85         globalState->childList = (omxState**) Calloc(numThreads, omxState*);
86
87         for(int ii = 0; ii < numThreads; ii++) {
88                 globalState->childList[ii] = new omxState;
89                 omxInitState(globalState->childList[ii], globalState);
90                 omxDuplicateState(globalState->childList[ii], globalState);
91         }
92 }
93
94 void omxDuplicateFitMatrix(omxMatrix *tgt, const omxMatrix *src, omxState* newState) {
95
96         if(tgt == NULL || src == NULL) return;
97         if(src->fitFunction == NULL) return;
98     
99         omxFillMatrixFromMxFitFunction(tgt, src->fitFunction->rObj, src->hasMatrixNumber, src->matrixNumber);
100
101 }
102
103 void omxFitFunctionCompute(omxFitFunction *off, int want, double* gradient) {
104         if(OMX_DEBUG_ALGEBRA) { 
105             mxLog("FitFunction compute: 0x%0x (needed: %s).", off, (off->matrix->isDirty?"Yes":"No"));
106         }
107
108         off->computeFun(off, want, gradient);
109
110         omxMarkClean(off->matrix);
111 }
112
113 void omxFillMatrixFromMxFitFunction(omxMatrix* om, SEXP rObj,
114         unsigned short hasMatrixNumber, int matrixNumber) {
115
116         SEXP slotValue, fitFunctionClass;
117         omxFitFunction *obj = (omxFitFunction*) R_alloc(1, sizeof(omxFitFunction));
118         omxInitEmptyFitFunction(obj);
119
120         /* Register FitFunction and Matrix with each other */
121         obj->matrix = om;
122         omxResizeMatrix(om, 1, 1, FALSE);                                       // FitFunction matrices MUST be 1x1.
123         om->fitFunction = obj;
124         om->hasMatrixNumber = hasMatrixNumber;
125         om->matrixNumber = matrixNumber;
126         
127         /* Get FitFunction Type */
128         PROTECT(fitFunctionClass = STRING_ELT(getAttrib(rObj, install("class")), 0));
129         {
130           const char *fitType = CHAR(fitFunctionClass);
131         
132           for (size_t fx=0; fx < OMX_STATIC_ARRAY_SIZE(omxFitFunctionSymbolTable); fx++) {
133                   const omxFitFunctionTableEntry *entry = omxFitFunctionSymbolTable + fx;
134                   if(strcmp(fitType, entry->name) == 0) {
135                           obj->fitType = entry->name;
136                           obj->initFun = entry->initFun;
137                           break;
138                   }
139           }
140
141           if(obj->initFun == NULL) {
142             const int MaxErrorLen = 256;
143             char newError[MaxErrorLen];
144             snprintf(newError, MaxErrorLen, "Fit function %s not implemented.\n", fitType);
145             omxRaiseError(om->currentState, -1, newError);
146             return;
147           }
148         }
149         UNPROTECT(1);   /* fitType */
150
151         PROTECT(slotValue = GET_SLOT(rObj, install("expectation")));
152         if (LENGTH(slotValue) == 1) {
153                 int expNumber = INTEGER(slotValue)[0];  
154                 if(expNumber != NA_INTEGER) {
155                         obj->expectation = omxExpectationFromIndex(expNumber, om->currentState);
156                 }
157         }
158         UNPROTECT(1);   /* slotValue */
159         
160         if (om->currentState->statusMsg[0]) return;
161
162         obj->rObj = rObj;
163         obj->initFun(obj);
164
165         if(obj->computeFun == NULL) {// If initialization fails, error code goes in argStruct
166                 const char *errorCode;
167                 if(isErrorRaised(om->currentState)) {
168                         errorCode = om->currentState->statusMsg;
169                 } else {
170                         // If no error code is reported, we report that.
171                         errorCode = "No error code reported.";
172                 }
173                 if(obj->argStruct != NULL) {
174                         errorCode = (char*)(obj->argStruct);
175                 }
176         const int MaxErrorLen = 256;
177         char newError[MaxErrorLen];
178         snprintf(newError, MaxErrorLen, "Could not initialize fit function %s.  Error: %s\n",
179                         obj->fitType, errorCode); 
180                 omxRaiseError(om->currentState, -1, newError);
181         }
182         
183         obj->matrix->isDirty = TRUE;
184
185 }
186
187 void omxFitFunctionPrint(omxFitFunction* off, const char* d) {
188         mxLog("(FitFunction, type %s)", off->fitType);
189         omxPrintMatrix(off->matrix, d);
190 }
191
192
193 /* Helper functions */
194 omxMatrix* omxNewMatrixFromSlot(SEXP rObj, omxState* currentState, const char* slotName) {
195         SEXP slotValue;
196         PROTECT(slotValue = GET_SLOT(rObj, install(slotName)));
197         omxMatrix* newMatrix = omxMatrixLookupFromState1(slotValue, currentState);
198         if (newMatrix) omxRecompute(newMatrix);
199         UNPROTECT(1);
200         return newMatrix;
201 }
202