Store expectationList in std::vector
[openmx:openmx.git] / src / omxExpectation.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 /***********************************************************
18
19 *  omxExpectation.cc
20 *
21 *  Created: Timothy R. Brick    Date: 2008-11-13 12:33:06
22 *
23 *       Expectation objects carry distributional expectations
24 *               for the model.  Because they have no requirement
25 *               to produce a single matrix of output, they are
26 *               not a subclass of mxMatrix, but rather their own
27 *               strange beast.
28 *       // TODO:  Create a multi-matrix Algebra type, and make
29 *       //      MxExpectation a subtype of that.
30 *
31 **********************************************************/
32
33 #include "omxExpectation.h"
34
35 typedef struct omxExpectationTableEntry omxExpectationTableEntry;
36
37 struct omxExpectationTableEntry {
38         char name[32];
39         void (*initFun)(omxExpectation*);
40 };
41
42 void omxInitNormalExpectation(omxExpectation *ox);
43 void omxInitLISRELExpectation(omxExpectation *ox);
44 void omxInitStateSpaceExpectation(omxExpectation *ox);
45 void omxInitRAMExpectation(omxExpectation *ox);
46
47 static const omxExpectationTableEntry omxExpectationSymbolTable[] = {
48         {"MxExpectationLISREL",                 &omxInitLISRELExpectation},
49         {"MxExpectationStateSpace",                     &omxInitStateSpaceExpectation},
50         {"MxExpectationNormal",                 &omxInitNormalExpectation},
51         {"MxExpectationRAM",                    &omxInitRAMExpectation}
52 };
53
54 void omxFreeExpectationArgs(omxExpectation *ox) {
55         if(ox==NULL) return;
56     
57         /* Completely destroy the Expectation function tree */
58         if(OMX_DEBUG) {Rprintf("Freeing %s Expectation object at 0x%x.\n", (ox->expType == NULL?"untyped":ox->expType), ox);}
59         if(ox->destructFun != NULL) {
60                 if(OMX_DEBUG) {Rprintf("Calling Expectation destructor for 0x%x.\n", ox);}
61                 ox->destructFun(ox);
62         }
63         Free(ox->submodels);
64         Free(ox);
65 }
66
67 void omxExpectationRecompute(omxExpectation *ox) {
68         if(OMX_DEBUG_ALGEBRA) { 
69             Rprintf("Expectation recompute: 0x%0x\n", ox);
70         }
71
72         omxExpectationCompute(ox);
73 }
74
75 void omxExpectationCompute(omxExpectation *ox) {
76         if (!ox) return;
77
78         if(OMX_DEBUG_ALGEBRA) { 
79             Rprintf("Expectation compute: 0x%0x\n", ox);
80         }
81
82         ox->computeFun(ox);
83 }
84
85 omxMatrix* omxGetExpectationComponent(omxExpectation* ox, omxFitFunction* off, const char* component) {
86
87         if(component == NULL) return NULL;
88
89         /* Hard-wired expectation components */
90         if(!strncmp("dataColumns", component, 11)) {
91                 return ox->dataColumns;
92         }
93
94         if(ox->componentFun == NULL) return NULL;
95
96         return(ox->componentFun(ox, off, component));
97         
98 }
99
100 void omxSetExpectationComponent(omxExpectation* ox, omxFitFunction* off, const char* component, omxMatrix* om) {
101         if(!strcmp(ox->expType, "MxExpectationStateSpace")) {
102                 ox->mutateFun(ox, off, component, om);
103         }
104 }
105
106 omxExpectation* omxDuplicateExpectation(const omxExpectation *src, omxState* newState) {
107
108         if(OMX_DEBUG) {Rprintf("Duplicating Expectation 0x%x\n", src);}
109
110         return omxNewIncompleteExpectation(src->rObj, src->expNum, newState);
111 }
112
113 omxExpectation* omxNewIncompleteExpectation(SEXP rObj, int expNum, omxState* os) {
114
115         SEXP ExpectationClass;
116         PROTECT(ExpectationClass = STRING_ELT(getAttrib(rObj, install("class")), 0));
117         const char* expType = CHAR(ExpectationClass);
118
119         omxExpectation* expect = omxNewInternalExpectation(expType, os);
120
121         expect->rObj = rObj;
122         expect->expNum = expNum;
123         
124         return expect;
125 }
126
127 omxExpectation* omxExpectationFromIndex(int expIndex, omxState* os)
128 {
129         omxExpectation* ox = os->expectationList.at(expIndex);
130         return ox;
131 }
132
133 void omxExpectationProcessDataStructures(omxExpectation* ox, SEXP rObj){
134
135         int index, numDefs, nextDef, numCols, numOrdinal=0;
136         SEXP nextMatrix, itemList, nextItem, threshMatrix; 
137         
138         if(rObj == NULL) return;
139
140         if(OMX_DEBUG) { Rprintf("Retrieving data.\n"); }
141         PROTECT(nextMatrix = GET_SLOT(rObj, install("data")));
142         ox->data = omxDataLookupFromState(nextMatrix, ox->currentState);
143
144         if(OMX_DEBUG && ox->currentState->parentState == NULL) {
145                 Rprintf("Accessing variable mapping structure.\n");
146         }
147
148         if (R_has_slot(rObj, install("dataColumns"))) {
149                 PROTECT(nextMatrix = GET_SLOT(rObj, install("dataColumns")));
150                 ox->dataColumns = omxNewMatrixFromRPrimitive(nextMatrix, ox->currentState, 0, 0);
151                 if(OMX_DEBUG && ox->currentState->parentState == NULL) {
152                         omxPrint(ox->dataColumns, "Variable mapping");
153                 }
154         
155                 numCols = ox->dataColumns->cols;
156
157                 if (R_has_slot(rObj, install("thresholds"))) {
158                         if(OMX_DEBUG && ox->currentState->parentState == NULL) {
159                                 Rprintf("Accessing Threshold matrix.\n");
160                         }
161                         PROTECT(threshMatrix = GET_SLOT(rObj, install("thresholds")));
162
163                         if(INTEGER(threshMatrix)[0] != NA_INTEGER) {
164                                 if(OMX_DEBUG && ox->currentState->parentState == NULL) {
165                                         Rprintf("Accessing Threshold Mappings.\n");
166                                 }
167         
168                                 /* Process the data and threshold mapping structures */
169                                 /* if (threshMatrix == NA_INTEGER), then we could ignore the slot "thresholdColumns"
170                                  * and fill all the thresholds with {NULL, 0, 0}.
171                                  * However the current path does not have a lot of overhead. */
172                                 PROTECT(nextMatrix = GET_SLOT(rObj, install("thresholdColumns")));
173                                 PROTECT(itemList = GET_SLOT(rObj, install("thresholdLevels")));
174                                 int* thresholdColumn, *thresholdNumber;
175                                 thresholdColumn = INTEGER(nextMatrix);
176                                 thresholdNumber = INTEGER(itemList);
177                                 ox->thresholds = (omxThresholdColumn *) R_alloc(numCols, sizeof(omxThresholdColumn));
178                                 for(index = 0; index < numCols; index++) {
179                                         if(thresholdColumn[index] == NA_INTEGER) {      // Continuous variable
180                                                 if(OMX_DEBUG && ox->currentState->parentState == NULL) {
181                                                         Rprintf("Column %d is continuous.\n", index);
182                                                 }
183                                                 ox->thresholds[index].matrix = NULL;
184                                                 ox->thresholds[index].column = 0;
185                                                 ox->thresholds[index].numThresholds = 0;
186                                         } else {
187                                                 ox->thresholds[index].matrix = omxMatrixLookupFromState1(threshMatrix, 
188                                                                                                        ox->currentState);
189                                                 ox->thresholds[index].column = thresholdColumn[index];
190                                                 ox->thresholds[index].numThresholds = thresholdNumber[index];
191                                                 if(OMX_DEBUG && ox->currentState->parentState == NULL) {
192                                                         Rprintf("Column %d is ordinal with %d thresholds in threshold column %d.\n", 
193                                                                 index, thresholdColumn[index], thresholdNumber[index]);
194                                                 }
195                                                 numOrdinal++;
196                                         }
197                                 }
198                                 if(OMX_DEBUG && ox->currentState->parentState == NULL) {
199                                         Rprintf("%d threshold columns processed.\n", numOrdinal);
200                                 }
201                                 ox->numOrdinal = numOrdinal;
202                         } else {
203                                 if (OMX_DEBUG && ox->currentState->parentState == NULL) {
204                                         Rprintf("No thresholds matrix; not processing thresholds.");
205                                 }
206                                 ox->thresholds = NULL;
207                                 ox->numOrdinal = 0;
208                         }
209                 }
210         }
211
212         if(!R_has_slot(rObj, install("definitionVars"))) {
213                 ox->numDefs = 0;
214                 ox->defVars = NULL;
215         } else {        
216                 if(OMX_DEBUG && ox->currentState->parentState == NULL) {
217                         Rprintf("Accessing definition variables structure.\n");
218                 }
219                 PROTECT(nextMatrix = GET_SLOT(rObj, install("definitionVars")));
220                 numDefs = length(nextMatrix);
221                 ox->numDefs = numDefs;
222                 if(OMX_DEBUG && ox->currentState->parentState == NULL) {
223                         Rprintf("Number of definition variables is %d.\n", numDefs);
224                 }
225                 ox->defVars = (omxDefinitionVar *) R_alloc(numDefs, sizeof(omxDefinitionVar));
226                 for(nextDef = 0; nextDef < numDefs; nextDef++) {
227                         SEXP dataSource, columnSource, depsSource; 
228                         int nextDataSource, numDeps;
229
230                         PROTECT(itemList = VECTOR_ELT(nextMatrix, nextDef));
231                         PROTECT(dataSource = VECTOR_ELT(itemList, 0));
232                         nextDataSource = INTEGER(dataSource)[0];
233                         if(OMX_DEBUG && ox->currentState->parentState == NULL) {
234                                 Rprintf("Data source number is %d.\n", nextDataSource);
235                         }
236                         ox->defVars[nextDef].data = nextDataSource;
237                         ox->defVars[nextDef].source = ox->currentState->dataList[nextDataSource];
238                         PROTECT(columnSource = VECTOR_ELT(itemList, 1));
239                         if(OMX_DEBUG && ox->currentState->parentState == NULL) {
240                                 Rprintf("Data column number is %d.\n", INTEGER(columnSource)[0]);
241                         }
242                         ox->defVars[nextDef].column = INTEGER(columnSource)[0];
243                         PROTECT(depsSource = VECTOR_ELT(itemList, 2));
244                         numDeps = LENGTH(depsSource);
245                         ox->defVars[nextDef].numDeps = numDeps;
246                         ox->defVars[nextDef].deps = (int*) R_alloc(numDeps, sizeof(int));
247                         for(int i = 0; i < numDeps; i++) {
248                                 ox->defVars[nextDef].deps[i] = INTEGER(depsSource)[i];
249                         }
250
251                         ox->defVars[nextDef].numLocations = length(itemList) - 3;
252                         ox->defVars[nextDef].matrices = (int *) R_alloc(length(itemList) - 3, sizeof(int));
253                         ox->defVars[nextDef].rows = (int *) R_alloc(length(itemList) - 3, sizeof(int));
254                         ox->defVars[nextDef].cols = (int *) R_alloc(length(itemList) - 3, sizeof(int));
255                         for(index = 3; index < length(itemList); index++) {
256                                 PROTECT(nextItem = VECTOR_ELT(itemList, index));
257                                 ox->defVars[nextDef].matrices[index-3] = INTEGER(nextItem)[0];
258                                 ox->defVars[nextDef].rows[index-3] = INTEGER(nextItem)[1];
259                                 ox->defVars[nextDef].cols[index-3] = INTEGER(nextItem)[2];
260                         }
261                 }
262         }
263         
264 }
265
266 void omxCompleteExpectation(omxExpectation *ox) {
267         
268         if(ox->isComplete) return;
269
270         if(OMX_DEBUG) {Rprintf("Completing Expectation 0x%x, type %s.\n", 
271                 ox, ((ox==NULL || ox->expType==NULL)?"Untyped":ox->expType));}
272                 
273         omxState* os = ox->currentState;
274
275         if (ox->rObj) {
276                 SEXP slot;
277                 PROTECT(slot = GET_SLOT(ox->rObj, install("container")));
278                 if (length(slot) == 1) {
279                         int ex = INTEGER(slot)[0];
280                         ox->container = os->expectationList.at(ex);
281                 }
282
283                 PROTECT(slot = GET_SLOT(ox->rObj, install("submodels")));
284                 if (length(slot)) {
285                         ox->numSubmodels = length(slot);
286                         ox->submodels = Realloc(NULL, length(slot), omxExpectation*);
287                         int *submodel = INTEGER(slot);
288                         for (int ex=0; ex < ox->numSubmodels; ex++) {
289                                 int sx = submodel[ex];
290                                 ox->submodels[ex] = omxExpectationFromIndex(sx, os);
291                                 omxCompleteExpectation(ox->submodels[ex]);
292                         }
293                 }
294
295                 omxExpectationProcessDataStructures(ox, ox->rObj);
296         }
297
298         ox->initFun(ox);
299
300         if(ox->computeFun == NULL) {
301                 // Should never happen
302                 error("Could not initialize Expectation function %s", ox->expType);
303         }
304
305         ox->isComplete = TRUE;
306
307 }
308
309 omxExpectation *
310 omxNewInternalExpectation(const char *expType, omxState* os)
311 {
312         omxExpectation* expect = Calloc(1, omxExpectation);
313
314         /* Switch based on Expectation type. */ 
315         for (size_t ex=0; ex < OMX_STATIC_ARRAY_SIZE(omxExpectationSymbolTable); ex++) {
316                 const omxExpectationTableEntry *entry = omxExpectationSymbolTable + ex;
317                 if(strncmp(expType, entry->name, MAX_STRING_LEN) == 0) {
318                         expect->expType = entry->name;
319                         expect->initFun = entry->initFun;
320                         break;
321                 }
322         }
323
324         if(!expect->initFun) {
325                 Free(expect);
326                 error("Expectation %s not implemented", expType);
327         }
328
329         expect->currentState = os;
330
331         return expect;
332 }
333
334 void omxExpectationPrint(omxExpectation* ox, char* d) {
335         if(ox->printFun != NULL) {
336                 ox->printFun(ox);
337         } else {
338                 Rprintf("(Expectation, type %s) ", (ox->expType==NULL?"Untyped":ox->expType));
339         }
340 }