C-side support for submodels
[openmx:openmx.git] / src / omxExpectation.c
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 /***********************************************************
18
19 *  omxExpectation.cc
20 *
21 *  Created: Timothy R. Brick    Date: 2008-11-13 12:33:06
22 *
23 *       Expectation objects carry distributional expectations
24 *               for the model.  Because they have no requirement
25 *               to produce a single matrix of output, they are
26 *               not a subclass of mxMatrix, but rather their own
27 *               strange beast.
28 *       // TODO:  Create a multi-matrix Algebra type, and make
29 *       //      MxExpectation a subtype of that.
30 *
31 **********************************************************/
32
33 #include "omxExpectation.h"
34
35 typedef struct omxExpectationTableEntry omxExpectationTableEntry;
36
37 struct omxExpectationTableEntry {
38         char name[32];
39         void (*initFun)(omxExpectation*, SEXP);
40 };
41
42 extern void omxInitNormalExpectation(omxExpectation *ox, SEXP rObj);
43 extern void omxInitLISRELExpectation(omxExpectation *ox, SEXP rObj);
44 extern void omxInitStateSpaceExpectation(omxExpectation *ox, SEXP rObj);
45 extern void omxInitRAMExpectation(omxExpectation *ox, SEXP rObj);
46
47 static const omxExpectationTableEntry omxExpectationSymbolTable[] = {
48         {"MxExpectationLISREL",                 &omxInitLISRELExpectation},
49         {"MxExpectationStateSpace",                     &omxInitStateSpaceExpectation},
50         {"MxExpectationNormal",                 &omxInitNormalExpectation},
51         {"MxExpectationRAM",                    &omxInitRAMExpectation},
52         { "", 0 }
53 };
54
55 void omxInitEmptyExpectation(omxExpectation *ox) {
56         /* Sets everything to NULL to avoid bad pointer calls */
57         
58   memset(ox, 0, sizeof(*ox));
59 }
60
61 void omxFreeExpectationArgs(omxExpectation *ox) {
62         if(ox==NULL) return;
63     
64         /* Completely destroy the Expectation function tree */
65         if(OMX_DEBUG) {Rprintf("Freeing %s Expectation object at 0x%x.\n", (ox->expType == NULL?"untyped":ox->expType), ox);}
66         if(ox->destructFun != NULL) {
67                 if(OMX_DEBUG) {Rprintf("Calling Expectation destructor for 0x%x.\n", ox);}
68                 ox->destructFun(ox);
69         }
70         Free(ox->submodels);
71 }
72
73 void omxExpectationRecompute(omxExpectation *ox) {
74         if(OMX_DEBUG_ALGEBRA) { 
75             Rprintf("Expectation recompute: 0x%0x\n", ox);
76         }
77
78         omxExpectationCompute(ox);
79 }
80
81 void omxExpectationCompute(omxExpectation *ox) {
82         if (!ox) return;
83
84         if(OMX_DEBUG_ALGEBRA) { 
85             Rprintf("Expectation compute: 0x%0x\n", ox);
86         }
87
88         ox->computeFun(ox);
89 }
90
91 omxMatrix* omxGetExpectationComponent(omxExpectation* ox, omxFitFunction* off, char* component) {
92
93         if(component == NULL) return NULL;
94
95         /* Hard-wired expectation components */
96         if(!strncmp("dataColumns", component, 11)) {
97                 return ox->dataColumns;
98         }
99
100         if(ox->componentFun == NULL) return NULL;
101
102         return(ox->componentFun(ox, off, component));
103         
104 }
105
106 void omxSetExpectationComponent(omxExpectation* ox, omxFitFunction* off, char* component, omxMatrix* om) {
107         if(!strcmp(ox->expType, "omxStateSpaceExpectation")) {
108                 ox->mutateFun(ox, off, component, om);
109         }
110 }
111
112 omxExpectation* omxDuplicateExpectation(const omxExpectation *src, omxState* newState) {
113
114         if(OMX_DEBUG) {Rprintf("Duplicating Expectation 0x%x\n", src);}
115
116         // if(src == NULL) {
117         //      return NULL;
118         // }
119         // 
120         // omxExpectation* tgt = (omxExpectation*) R_alloc(1, sizeof(omxExpectation));
121         // omxInitEmptyExpectation(tgt);
122         // 
123         // tgt->initFun                                         = src->initFun;
124         // tgt->destructFun                             = src->destructFun;
125         // tgt->repopulateFun                           = src->repopulateFun;
126         // tgt->computeFun                              = src->computeFun;
127         // tgt->componentFun                            = src->componentFun;
128         // tgt->populateAttrFun                         = src->populateAttrFun;
129         // tgt->setFinalReturns                         = src->setFinalReturns;
130         // tgt->sharedArgs                                      = src->sharedArgs;
131         // tgt->currentState                            = newState;
132         // tgt->rObj                                            = src->rObj;
133         // tgt->data                                            = src->data;
134         // tgt->dataColumns                             = omxLookupDuplicateElement(newState, src->dataColumns);
135         // tgt->defVars                                 = src->defVars;
136         // tgt->numDefs                                 = src->numDefs;
137         // int numDefs = tgt->numDefs;
138         // // for(int i = 0; i < numDefs; i++) {
139         // //   int thisCount = tgt->defVars[i].numLocations;
140         // //   for(int index = 0; index < thisCount; index++) {
141         // //           tgt->defVars[i].matrices[index] = omxLookupDuplicateElement(newState, src->defVars[i].matrices[index]);
142         // //   }
143         // // }
144         // 
145         // tgt->numOrdinal                                      = src->numOrdinal;
146         // tgt->thresholds                                      = src->thresholds;
147         // int nCols = tgt->dataColumns->rows;
148         // for(int i = 0; i < nCols; i++) {
149         //      if(tgt->thresholds[i].matrix != NULL) {
150         //              tgt->thresholds[i].matrix = omxLookupDuplicateElement(newState, src->thresholds[i].matrix);
151         //      }
152         // }
153         // 
154         // tgt->expNum                                          = src->expNum;
155         // 
156         //     strncpy(tgt->expType, src->expType, MAX_STRING_LEN);
157         // 
158         // return tgt;
159
160         return omxNewIncompleteExpectation(src->rObj, src->expNum, newState);
161
162 }
163
164 omxExpectation* omxNewIncompleteExpectation(SEXP rObj, int expNum, omxState* os) {
165
166         SEXP ExpectationClass;
167         const char* expType;
168         omxExpectation* expect = (omxExpectation*) R_alloc(1, sizeof(omxExpectation));
169         omxInitEmptyExpectation(expect);
170         
171         /* Get Expectation Type */
172         PROTECT(ExpectationClass = STRING_ELT(getAttrib(rObj, install("class")), 0));
173         expType = CHAR(ExpectationClass);
174
175         /* Switch based on Expectation type. */ 
176         const omxExpectationTableEntry *entry = omxExpectationSymbolTable;
177         while (entry->initFun) {
178                 if(strncmp(expType, entry->name, MAX_STRING_LEN) == 0) {
179                         expect->expType = entry->name;
180                         expect->initFun = entry->initFun;
181                         break;
182                 }
183                 entry += 1;
184         }
185
186         if(!expect->initFun) {
187                 char newError[MAX_STRING_LEN];
188                 sprintf(newError, "Expectation function %s not implemented.\n", (expect->expType==NULL?"Untyped":expect->expType));
189                 omxRaiseError(os, -1, newError);
190                 return NULL;
191         }
192
193         expect->rObj = rObj;
194         expect->expNum = expNum;
195         expect->currentState = os;
196         
197         return expect;
198 }
199
200 omxExpectation* omxExpectationFromIndex(int expIndex, omxState* os)
201 {
202         if (expIndex < 0 || expIndex >= os->numExpects) {
203                 error("Expectation %d out of range [0, %d]", expIndex, os->numExpects);
204         }
205
206         omxExpectation* ox = os->expectationList[expIndex];
207         
208         if(!ox->isComplete) omxCompleteExpectation(ox);
209         
210         return ox;
211 }
212
213 void omxExpectationProcessDataStructures(omxExpectation* ox, SEXP rObj){
214
215         int index, numDefs, nextDef, numCols, numOrdinal=0;
216         SEXP nextMatrix, itemList, nextItem, threshMatrix; 
217         
218         if(rObj == NULL) return;
219
220         if(OMX_DEBUG) { Rprintf("Retrieving data.\n"); }
221         PROTECT(nextMatrix = GET_SLOT(rObj, install("data")));
222         ox->data = omxDataLookupFromState(nextMatrix, ox->currentState);
223
224         if(OMX_DEBUG && ox->currentState->parentState == NULL) {
225                 Rprintf("Accessing variable mapping structure.\n");
226         }
227
228         if (R_has_slot(rObj, install("dataColumns"))) {
229                 PROTECT(nextMatrix = GET_SLOT(rObj, install("dataColumns")));
230                 ox->dataColumns = omxNewMatrixFromRPrimitive(nextMatrix, ox->currentState, 0, 0);
231                 if(OMX_DEBUG && ox->currentState->parentState == NULL) {
232                         omxPrint(ox->dataColumns, "Variable mapping");
233                 }
234         
235                 numCols = ox->dataColumns->cols;
236
237                 if (R_has_slot(rObj, install("thresholds"))) {
238                         if(OMX_DEBUG && ox->currentState->parentState == NULL) {
239                                 Rprintf("Accessing Threshold matrix.\n");
240                         }
241                         PROTECT(threshMatrix = GET_SLOT(rObj, install("thresholds")));
242
243                         if(INTEGER(threshMatrix)[0] != NA_INTEGER) {
244                                 if(OMX_DEBUG && ox->currentState->parentState == NULL) {
245                                         Rprintf("Accessing Threshold Mappings.\n");
246                                 }
247         
248                                 /* Process the data and threshold mapping structures */
249                                 /* if (threshMatrix == NA_INTEGER), then we could ignore the slot "thresholdColumns"
250                                  * and fill all the thresholds with {NULL, 0, 0}.
251                                  * However the current path does not have a lot of overhead. */
252                                 PROTECT(nextMatrix = GET_SLOT(rObj, install("thresholdColumns")));
253                                 PROTECT(itemList = GET_SLOT(rObj, install("thresholdLevels")));
254                                 int* thresholdColumn, *thresholdNumber;
255                                 thresholdColumn = INTEGER(nextMatrix);
256                                 thresholdNumber = INTEGER(itemList);
257                                 ox->thresholds = (omxThresholdColumn *) R_alloc(numCols, sizeof(omxThresholdColumn));
258                                 for(index = 0; index < numCols; index++) {
259                                         if(thresholdColumn[index] == NA_INTEGER) {      // Continuous variable
260                                                 if(OMX_DEBUG && ox->currentState->parentState == NULL) {
261                                                         Rprintf("Column %d is continuous.\n", index);
262                                                 }
263                                                 ox->thresholds[index].matrix = NULL;
264                                                 ox->thresholds[index].column = 0;
265                                                 ox->thresholds[index].numThresholds = 0;
266                                         } else {
267                                                 ox->thresholds[index].matrix = omxMatrixLookupFromState1(threshMatrix, 
268                                                                                                        ox->currentState);
269                                                 ox->thresholds[index].column = thresholdColumn[index];
270                                                 ox->thresholds[index].numThresholds = thresholdNumber[index];
271                                                 if(OMX_DEBUG && ox->currentState->parentState == NULL) {
272                                                         Rprintf("Column %d is ordinal with %d thresholds in threshold column %d.\n", 
273                                                                 index, thresholdColumn[index], thresholdNumber[index]);
274                                                 }
275                                                 numOrdinal++;
276                                         }
277                                 }
278                                 if(OMX_DEBUG && ox->currentState->parentState == NULL) {
279                                         Rprintf("%d threshold columns processed.\n", numOrdinal);
280                                 }
281                                 ox->numOrdinal = numOrdinal;
282                         } else {
283                                 if (OMX_DEBUG && ox->currentState->parentState == NULL) {
284                                         Rprintf("No thresholds matrix; not processing thresholds.");
285                                 }
286                                 ox->thresholds = NULL;
287                                 ox->numOrdinal = 0;
288                         }
289                 }
290         }
291
292         if(!R_has_slot(rObj, install("definitionVars"))) {
293                 ox->numDefs = 0;
294                 ox->defVars = NULL;
295         } else {        
296                 if(OMX_DEBUG && ox->currentState->parentState == NULL) {
297                         Rprintf("Accessing definition variables structure.\n");
298                 }
299                 PROTECT(nextMatrix = GET_SLOT(rObj, install("definitionVars")));
300                 numDefs = length(nextMatrix);
301                 ox->numDefs = numDefs;
302                 if(OMX_DEBUG && ox->currentState->parentState == NULL) {
303                         Rprintf("Number of definition variables is %d.\n", numDefs);
304                 }
305                 ox->defVars = (omxDefinitionVar *) R_alloc(numDefs, sizeof(omxDefinitionVar));
306                 for(nextDef = 0; nextDef < numDefs; nextDef++) {
307                         SEXP dataSource, columnSource, depsSource; 
308                         int nextDataSource, numDeps;
309
310                         PROTECT(itemList = VECTOR_ELT(nextMatrix, nextDef));
311                         PROTECT(dataSource = VECTOR_ELT(itemList, 0));
312                         nextDataSource = INTEGER(dataSource)[0];
313                         if(OMX_DEBUG && ox->currentState->parentState == NULL) {
314                                 Rprintf("Data source number is %d.\n", nextDataSource);
315                         }
316                         ox->defVars[nextDef].data = nextDataSource;
317                         ox->defVars[nextDef].source = ox->currentState->dataList[nextDataSource];
318                         PROTECT(columnSource = VECTOR_ELT(itemList, 1));
319                         if(OMX_DEBUG && ox->currentState->parentState == NULL) {
320                                 Rprintf("Data column number is %d.\n", INTEGER(columnSource)[0]);
321                         }
322                         ox->defVars[nextDef].column = INTEGER(columnSource)[0];
323                         PROTECT(depsSource = VECTOR_ELT(itemList, 2));
324                         numDeps = LENGTH(depsSource);
325                         ox->defVars[nextDef].numDeps = numDeps;
326                         ox->defVars[nextDef].deps = (int*) R_alloc(numDeps, sizeof(int));
327                         for(int i = 0; i < numDeps; i++) {
328                                 ox->defVars[nextDef].deps[i] = INTEGER(depsSource)[i];
329                         }
330
331                         ox->defVars[nextDef].numLocations = length(itemList) - 3;
332                         ox->defVars[nextDef].matrices = (int *) R_alloc(length(itemList) - 3, sizeof(int));
333                         ox->defVars[nextDef].rows = (int *) R_alloc(length(itemList) - 3, sizeof(int));
334                         ox->defVars[nextDef].cols = (int *) R_alloc(length(itemList) - 3, sizeof(int));
335                         for(index = 3; index < length(itemList); index++) {
336                                 PROTECT(nextItem = VECTOR_ELT(itemList, index));
337                                 ox->defVars[nextDef].matrices[index-3] = INTEGER(nextItem)[0];
338                                 ox->defVars[nextDef].rows[index-3] = INTEGER(nextItem)[1];
339                                 ox->defVars[nextDef].cols[index-3] = INTEGER(nextItem)[2];
340                         }
341                 }
342         }
343         
344 }
345
346 void omxCompleteExpectation(omxExpectation *ox) {
347         
348         if(ox->isComplete) return;
349
350         char errorCode[MAX_STRING_LEN];
351         
352         if(OMX_DEBUG) {Rprintf("Completing Expectation 0x%x, type %s.\n", 
353                 ox, ((ox==NULL || ox->expType==NULL)?"Untyped":ox->expType));}
354                 
355         omxState* os = ox->currentState;
356
357         if(ox->rObj == NULL || ox->initFun == NULL ) {
358                 char newError[MAX_STRING_LEN];
359                 sprintf(newError, "Could not complete expectation %s.\n", (ox->expType==NULL?"Untyped":ox->expType));
360                 omxRaiseError(os, -1, newError);
361                 return;
362         }
363
364         SEXP slot;
365         PROTECT(slot = GET_SLOT(ox->rObj, install("container")));
366         if (length(slot) == 1) {
367                 int ex = INTEGER(slot)[0];
368                 if (ex < 0 || ex >= os->numExpects) error("Expectation container out of range %d", ex);
369                 ox->container = os->expectationList[ex];
370         }
371
372         PROTECT(slot = GET_SLOT(ox->rObj, install("submodels")));
373         if (length(slot)) {
374                 ox->numSubmodels = length(slot);
375                 ox->submodels = Realloc(NULL, length(slot), omxExpectation*);
376                 int *submodel = INTEGER(slot);
377                 for (int ex=0; ex < ox->numSubmodels; ex++) {
378                         int sx = submodel[ex];
379                         ox->submodels[ex] = omxExpectationFromIndex(sx, os);
380                 }
381         }
382
383         omxExpectationProcessDataStructures(ox, ox->rObj);
384
385         ox->initFun(ox, ox->rObj);
386
387         if(ox->computeFun == NULL) {// If initialization fails, error code goes in argStruct
388                 if(os->statusCode != 0) {
389                         strncpy(errorCode, os->statusMsg, 150); // Report a status error
390                 } else {
391                         // If no error code is reported, we report that.
392                         strncpy(errorCode, "No error code reported.", 25);
393                 }
394                 if(ox->argStruct != NULL) {
395                         strncpy(errorCode, (char*)(ox->argStruct), 51);
396                 }
397                 char newError[MAX_STRING_LEN];
398                 sprintf(newError, "Could not initialize Expectation function %s.  Error: %s\n", 
399                                 ox->expType, errorCode);
400                 omxRaiseError(os, -1, newError);
401                 return;
402         }
403
404         ox->isComplete = TRUE;
405
406 }
407
408 void omxExpectationPrint(omxExpectation* ox, char* d) {
409         if(ox->printFun != NULL) {
410                 ox->printFun(ox);
411         } else {
412                 Rprintf("(Expectation, type %s) ", (ox->expType==NULL?"Untyped":ox->expType));
413         }
414 }