Better protect BA81 from parallel execution
[openmx:openmx.git] / src / omxExpectation.cpp
1 /*
2  *  Copyright 2007-2014 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 /***********************************************************
18
19 *  omxExpectation.cc
20 *
21 *  Created: Timothy R. Brick    Date: 2008-11-13 12:33:06
22 *
23 *       Expectation objects carry distributional expectations
24 *               for the model.  Because they have no requirement
25 *               to produce a single matrix of output, they are
26 *               not a subclass of mxMatrix, but rather their own
27 *               strange beast.
28 *       // TODO:  Create a multi-matrix Algebra type, and make
29 *       //      MxExpectation a subtype of that.
30 *
31 **********************************************************/
32
33 #include "omxExpectation.h"
34
35 typedef struct omxExpectationTableEntry omxExpectationTableEntry;
36
37 struct omxExpectationTableEntry {
38         char name[32];
39         void (*initFun)(omxExpectation*);
40 };
41
42 static const omxExpectationTableEntry omxExpectationSymbolTable[] = {
43         {"MxExpectationLISREL",                 &omxInitLISRELExpectation},
44         {"MxExpectationStateSpace",                     &omxInitStateSpaceExpectation},
45         {"MxExpectationNormal",                 &omxInitNormalExpectation},
46         {"MxExpectationRAM",                    &omxInitRAMExpectation},
47         {"MxExpectationBA81", &omxInitExpectationBA81}
48 };
49
50 void omxFreeExpectationArgs(omxExpectation *ox) {
51         if(ox==NULL) return;
52     
53         if (ox->destructFun) ox->destructFun(ox);
54         omxFreeMatrix(ox->dataColumns);
55         Free(ox->submodels);
56         Free(ox);
57 }
58
59 void omxExpectationRecompute(omxExpectation *ox) {
60         if(ox->thresholds != NULL) {
61                 for(int i = 0; i < ox->numOrdinal; i++) {
62                         if (!ox->thresholds[i].matrix) continue;
63                         omxRecompute(ox->thresholds[i].matrix);
64                 }
65         }
66
67         omxExpectationCompute(ox, NULL);
68 }
69
70 void omxExpectationCompute(omxExpectation *ox, const char *what, const char *how)
71 {
72         if (!ox) return;
73
74         ox->computeFun(ox, what, how);
75 }
76
77 omxMatrix* omxGetExpectationComponent(omxExpectation* ox, omxFitFunction* off, const char* component) {
78
79         if(component == NULL) return NULL;
80
81         omxCompleteExpectation(ox);
82
83         /* Hard-wired expectation components */
84         if(strEQ("dataColumns", component)) {
85                 return ox->dataColumns;
86         }
87
88         if(ox->componentFun == NULL) return NULL;
89
90         return(ox->componentFun(ox, off, component));
91         
92 }
93
94 void omxSetExpectationComponent(omxExpectation* ox, omxFitFunction* off, const char* component, omxMatrix* om) {
95         if(!strcmp(ox->expType, "MxExpectationStateSpace")) {
96                 ox->mutateFun(ox, off, component, om);
97         }
98 }
99
100 omxExpectation* omxDuplicateExpectation(const omxExpectation *src, omxState* newState) {
101
102         return omxNewIncompleteExpectation(src->rObj, src->expNum, newState);
103 }
104
105 omxExpectation* omxNewIncompleteExpectation(SEXP rObj, int expNum, omxState* os) {
106
107         SEXP ExpectationClass;
108         Rf_protect(ExpectationClass = STRING_ELT(Rf_getAttrib(rObj, Rf_install("class")), 0));
109         const char* expType = CHAR(ExpectationClass);
110
111         omxExpectation* expect = omxNewInternalExpectation(expType, os);
112
113         expect->rObj = rObj;
114         expect->expNum = expNum;
115         
116         return expect;
117 }
118
119 omxExpectation* omxExpectationFromIndex(int expIndex, omxState* os)
120 {
121         omxExpectation* ox = os->expectationList.at(expIndex);
122         return ox;
123 }
124
125 void omxExpectationProcessDataStructures(omxExpectation* ox, SEXP rObj){
126
127         int index, numDefs, nextDef, numCols, numOrdinal=0;
128         SEXP nextMatrix, itemList, nextItem, threshMatrix; 
129         
130         if(rObj == NULL) return;
131
132         if(OMX_DEBUG) { mxLog("Retrieving data."); }
133         Rf_protect(nextMatrix = R_do_slot(rObj, Rf_install("data")));
134         ox->data = omxDataLookupFromState(nextMatrix, ox->currentState);
135
136         if(OMX_DEBUG) {
137                 mxLog("Accessing variable mapping structure.");
138         }
139
140         if (R_has_slot(rObj, Rf_install("dataColumns"))) {
141                 Rf_protect(nextMatrix = R_do_slot(rObj, Rf_install("dataColumns")));
142                 ox->dataColumns = omxNewMatrixFromRPrimitive(nextMatrix, ox->currentState, 0, 0);
143                 if(OMX_DEBUG) {
144                         omxPrint(ox->dataColumns, "Variable mapping");
145                 }
146         
147                 numCols = ox->dataColumns->cols;
148
149                 if (R_has_slot(rObj, Rf_install("thresholds"))) {
150                         if(OMX_DEBUG) {
151                                 mxLog("Accessing Threshold matrix.");
152                         }
153                         Rf_protect(threshMatrix = R_do_slot(rObj, Rf_install("thresholds")));
154
155                         if(INTEGER(threshMatrix)[0] != NA_INTEGER) {
156                                 if(OMX_DEBUG) {
157                                         mxLog("Accessing Threshold Mappings.");
158                                 }
159         
160                                 /* Process the data and threshold mapping structures */
161                                 /* if (threshMatrix == NA_INTEGER), then we could ignore the slot "thresholdColumns"
162                                  * and fill all the thresholds with {NULL, 0, 0}.
163                                  * However the current path does not have a lot of overhead. */
164                                 Rf_protect(nextMatrix = R_do_slot(rObj, Rf_install("thresholdColumns")));
165                                 Rf_protect(itemList = R_do_slot(rObj, Rf_install("thresholdLevels")));
166                                 int* thresholdColumn, *thresholdNumber;
167                                 thresholdColumn = INTEGER(nextMatrix);
168                                 thresholdNumber = INTEGER(itemList);
169                                 ox->thresholds = (omxThresholdColumn *) R_alloc(numCols, sizeof(omxThresholdColumn));
170                                 for(index = 0; index < numCols; index++) {
171                                         if(thresholdColumn[index] == NA_INTEGER) {      // Continuous variable
172                                                 if(OMX_DEBUG) {
173                                                         mxLog("Column %d is continuous.", index);
174                                                 }
175                                                 ox->thresholds[index].matrix = NULL;
176                                                 ox->thresholds[index].column = 0;
177                                                 ox->thresholds[index].numThresholds = 0;
178                                         } else {
179                                                 ox->thresholds[index].matrix = omxMatrixLookupFromState1(threshMatrix, 
180                                                                                                        ox->currentState);
181                                                 ox->thresholds[index].column = thresholdColumn[index];
182                                                 ox->thresholds[index].numThresholds = thresholdNumber[index];
183                                                 if(OMX_DEBUG) {
184                                                         mxLog("Column %d is ordinal with %d thresholds in threshold column %d.", 
185                                                                 index, thresholdNumber[index], thresholdColumn[index]);
186                                                 }
187                                                 numOrdinal++;
188                                         }
189                                 }
190                                 if(OMX_DEBUG) {
191                                         mxLog("%d threshold columns processed.", numOrdinal);
192                                 }
193                                 ox->numOrdinal = numOrdinal;
194                         } else {
195                                 if (OMX_DEBUG) {
196                                         mxLog("No thresholds matrix; not processing thresholds.");
197                                 }
198                                 ox->thresholds = NULL;
199                                 ox->numOrdinal = 0;
200                         }
201                 }
202         }
203
204         if(!R_has_slot(rObj, Rf_install("definitionVars"))) {
205                 ox->numDefs = 0;
206                 ox->defVars = NULL;
207         } else {        
208                 if(OMX_DEBUG) {
209                         mxLog("Accessing definition variables structure.");
210                 }
211                 Rf_protect(nextMatrix = R_do_slot(rObj, Rf_install("definitionVars")));
212                 numDefs = Rf_length(nextMatrix);
213                 ox->numDefs = numDefs;
214                 if(OMX_DEBUG) {
215                         mxLog("Number of definition variables is %d.", numDefs);
216                 }
217                 ox->defVars = (omxDefinitionVar *) R_alloc(numDefs, sizeof(omxDefinitionVar));
218                 for(nextDef = 0; nextDef < numDefs; nextDef++) {
219                         SEXP dataSource, columnSource, depsSource; 
220                         int nextDataSource, numDeps;
221
222                         Rf_protect(itemList = VECTOR_ELT(nextMatrix, nextDef));
223                         Rf_protect(dataSource = VECTOR_ELT(itemList, 0));
224                         nextDataSource = INTEGER(dataSource)[0];
225                         if(OMX_DEBUG) {
226                                 mxLog("Data source number is %d.", nextDataSource);
227                         }
228                         ox->defVars[nextDef].data = nextDataSource;
229                         ox->defVars[nextDef].source = ox->currentState->dataList[nextDataSource];
230                         Rf_protect(columnSource = VECTOR_ELT(itemList, 1));
231                         if(OMX_DEBUG) {
232                                 mxLog("Data column number is %d.", INTEGER(columnSource)[0]);
233                         }
234                         ox->defVars[nextDef].column = INTEGER(columnSource)[0];
235                         Rf_protect(depsSource = VECTOR_ELT(itemList, 2));
236                         numDeps = LENGTH(depsSource);
237                         ox->defVars[nextDef].numDeps = numDeps;
238                         ox->defVars[nextDef].deps = (int*) R_alloc(numDeps, sizeof(int));
239                         for(int i = 0; i < numDeps; i++) {
240                                 ox->defVars[nextDef].deps[i] = INTEGER(depsSource)[i];
241                         }
242
243                         ox->defVars[nextDef].numLocations = Rf_length(itemList) - 3;
244                         ox->defVars[nextDef].matrices = (int *) R_alloc(Rf_length(itemList) - 3, sizeof(int));
245                         ox->defVars[nextDef].rows = (int *) R_alloc(Rf_length(itemList) - 3, sizeof(int));
246                         ox->defVars[nextDef].cols = (int *) R_alloc(Rf_length(itemList) - 3, sizeof(int));
247                         for(index = 3; index < Rf_length(itemList); index++) {
248                                 Rf_protect(nextItem = VECTOR_ELT(itemList, index));
249                                 ox->defVars[nextDef].matrices[index-3] = INTEGER(nextItem)[0];
250                                 ox->defVars[nextDef].rows[index-3] = INTEGER(nextItem)[1];
251                                 ox->defVars[nextDef].cols[index-3] = INTEGER(nextItem)[2];
252                         }
253                 }
254         }
255         
256 }
257
258 void omxCompleteExpectation(omxExpectation *ox) {
259         
260         if(ox->isComplete) return;
261
262         omxState* os = ox->currentState;
263
264         if (ox->rObj) {
265                 SEXP slot;
266                 Rf_protect(slot = R_do_slot(ox->rObj, Rf_install("container")));
267                 if (Rf_length(slot) == 1) {
268                         int ex = INTEGER(slot)[0];
269                         ox->container = os->expectationList.at(ex);
270                 }
271
272                 Rf_protect(slot = R_do_slot(ox->rObj, Rf_install("submodels")));
273                 if (Rf_length(slot)) {
274                         ox->numSubmodels = Rf_length(slot);
275                         ox->submodels = Realloc(NULL, Rf_length(slot), omxExpectation*);
276                         int *submodel = INTEGER(slot);
277                         for (int ex=0; ex < ox->numSubmodels; ex++) {
278                                 int sx = submodel[ex];
279                                 ox->submodels[ex] = omxExpectationFromIndex(sx, os);
280                                 omxCompleteExpectation(ox->submodels[ex]);
281                         }
282                 }
283
284                 omxExpectationProcessDataStructures(ox, ox->rObj);
285         }
286
287         ox->initFun(ox);
288
289         if(ox->computeFun == NULL) {
290                 // Should never happen
291                 Rf_error("Could not initialize Expectation function %s", ox->expType);
292         }
293
294         ox->isComplete = TRUE;
295
296 }
297
298 static void defaultSetVarGroup(omxExpectation *ox, FreeVarGroup *fvg)
299 {
300         if (ox->freeVarGroup && ox->freeVarGroup != fvg) {
301                 Rf_warning("setFreeVarGroup called with different group (%d vs %d) on %s",
302                         ox->name, ox->freeVarGroup->id[0], fvg->id[0]);
303         }
304         ox->freeVarGroup = fvg;
305 }
306
307 void setFreeVarGroup(omxExpectation *ox, FreeVarGroup *fvg)
308 {
309         (*ox->setVarGroup)(ox, fvg);
310 }
311
312 omxExpectation *
313 omxNewInternalExpectation(const char *expType, omxState* os)
314 {
315         omxExpectation* expect = Calloc(1, omxExpectation);
316         expect->setVarGroup = defaultSetVarGroup;
317
318         /* Switch based on Expectation type. */ 
319         for (size_t ex=0; ex < OMX_STATIC_ARRAY_SIZE(omxExpectationSymbolTable); ex++) {
320                 const omxExpectationTableEntry *entry = omxExpectationSymbolTable + ex;
321                 if(strEQ(expType, entry->name)) {
322                         expect->expType = entry->name;
323                         expect->initFun = entry->initFun;
324                         break;
325                 }
326         }
327
328         if(!expect->initFun) {
329                 Free(expect);
330                 Rf_error("Expectation %s not implemented", expType);
331         }
332
333         expect->currentState = os;
334         expect->canDuplicate = true;
335
336         return expect;
337 }
338
339 void omxExpectationPrint(omxExpectation* ox, char* d) {
340         if(ox->printFun != NULL) {
341                 ox->printFun(ox);
342         } else {
343                 mxLog("(Expectation, type %s) ", (ox->expType==NULL?"Untyped":ox->expType));
344         }
345 }