Revert "Refrain from duplicating the model unless required by the fitfunction"
[openmx:openmx.git] / src / omxRowFitFunction.c
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include <R.h>
18 #include <Rinternals.h>
19 #include <Rdefines.h>
20 #include <R_ext/Rdynload.h>
21 #include <R_ext/BLAS.h>
22 #include <R_ext/Lapack.h>
23 #include "omxDefines.h"
24 #include "omxAlgebraFunctions.h"
25 #include "omxSymbolTable.h"
26 #include "omxData.h"
27 #include "omxRowFitFunction.h"
28 #include "omxFIMLFitFunction.h"
29
30 void omxDestroyRowFitFunction(omxFitFunction *oo) {
31
32         omxRowFitFunction* argStruct = (omxRowFitFunction*)(oo->argStruct);
33
34         omxFreeMatrixData(argStruct->dataRow);
35 }
36
37 omxRListElement* omxSetFinalReturnsRowFitFunction(omxFitFunction *oo, int *numReturns) {
38         *numReturns = 0;
39         omxRListElement* retVal = (omxRListElement*) R_alloc(1, sizeof(omxRListElement));
40
41         retVal[0].numValues = 0;
42
43         return retVal;
44 }
45
46
47 void omxCopyMatrixToRow(omxMatrix* source, int row, omxMatrix* target) {
48         
49         int i;
50         for(i = 0; i < source->cols; i++) {
51                 omxSetMatrixElement(target, row, i, omxMatrixElement(source, 0, i));
52         }
53
54 }
55
56 void markDataRowDependencies(omxState* os, omxRowFitFunction* orff) {
57
58         int numDeps = orff->numDataRowDeps;
59         int *deps = orff->dataRowDeps;
60
61         omxMatrix** matrixList = os->matrixList;
62         omxMatrix** algebraList = os->algebraList;
63
64         for (int i = 0; i < numDeps; i++) {
65                 int value = deps[i];
66
67                 if(value < 0) {
68                         omxMarkDirty(matrixList[~value]);
69                 } else {
70                         omxMarkDirty(algebraList[value]);
71                 }
72         }
73
74 }
75
76 void omxRowFitFunctionSingleIteration(omxFitFunction *localobj, omxFitFunction *sharedobj, int rowbegin, int rowcount) {
77
78     omxRowFitFunction* oro = ((omxRowFitFunction*) localobj->argStruct);
79     omxRowFitFunction* shared_oro = ((omxRowFitFunction*) sharedobj->argStruct);
80
81         int numDefs;
82
83     omxMatrix *rowAlgebra, *rowResults;
84     omxMatrix *filteredDataRow, *dataRow, *existenceVector;
85     omxMatrix *dataColumns;
86         omxDefinitionVar* defVars;
87         omxData *data;
88         int isContiguous, contiguousStart, contiguousLength;
89     double* oldDefs;
90     int numCols, numRemoves;
91
92         rowAlgebra          = oro->rowAlgebra;
93         rowResults          = shared_oro->rowResults;
94         data                = oro->data;
95         defVars             = oro->defVars;
96         numDefs             = oro->numDefs;
97     oldDefs         = oro->oldDefs;
98     dataColumns     = oro->dataColumns;
99     dataRow         = oro->dataRow;
100     filteredDataRow = oro->filteredDataRow;
101     existenceVector = oro->existenceVector;
102     
103     isContiguous    = oro->contiguous.isContiguous;
104         contiguousStart = oro->contiguous.start;
105         contiguousLength = oro->contiguous.length;
106
107         numCols = dataColumns->cols;
108         int *toRemove = malloc(sizeof(int) * dataColumns->cols);
109         int *zeros = calloc(dataColumns->cols, sizeof(int));
110
111     resetDefinitionVariables(oldDefs, numDefs);
112
113         for(int row = rowbegin; row < data->rows && (row - rowbegin) < rowcount; row++) {
114
115                 // Handle Definition Variables.
116         if(OMX_DEBUG_ROWS(row)) { Rprintf("numDefs is %d", numDefs);}
117                 if(numDefs != 0) {              // With defs, just copy repeatedly to the rowResults matrix.
118                         handleDefinitionVarList(data, localobj->matrix->currentState, row, defVars, oldDefs, numDefs);
119                 }
120
121                 omxStateNextRow(localobj->matrix->currentState);                                                // Advance row
122                 
123         // Populate data row
124                 numRemoves = 0;
125         
126                 if (isContiguous) {
127                         omxContiguousDataRow(data, row, contiguousStart, contiguousLength, dataRow);
128                 } else {
129                         omxDataRow(data, row, dataColumns, dataRow);    // Populate data row
130                 }
131
132                 markDataRowDependencies(localobj->matrix->currentState, oro);
133                 
134                 for(int j = 0; j < dataColumns->cols; j++) {
135                         double dataValue = omxVectorElement(dataRow, j);
136                         if(isnan(dataValue)) {
137                                 numRemoves++;
138                                 toRemove[j] = 1;
139                 omxSetVectorElement(existenceVector, j, 0);
140                         } else {
141                             toRemove[j] = 0;
142                 omxSetVectorElement(existenceVector, j, 1);
143                         }
144                 }               
145                 // TODO: Determine if this is the correct response.
146                 
147                 if(numRemoves == numCols) {
148                     char *errstr = calloc(250, sizeof(char));
149                         sprintf(errstr, "Row %d completely missing.  omxRowFitFunction cannot have completely missing rows.", omxDataIndex(data, row));
150                         omxRaiseError(localobj->matrix->currentState, -1, errstr);
151                         free(errstr);
152                         continue;
153                 }
154
155                 omxResetAliasedMatrix(filteredDataRow);                         // Reset the row
156                 omxRemoveRowsAndColumns(filteredDataRow, 0, numRemoves, zeros, toRemove);
157
158                 omxRecompute(rowAlgebra);                                                       // Compute this row
159
160                 omxCopyMatrixToRow(rowAlgebra, omxDataIndex(data, row), rowResults);
161         }
162         free(toRemove);
163         free(zeros);
164 }
165
166 static void omxCallRowFitFunction(omxFitFunction *oo, int want, double *gradient) {     // TODO: Figure out how to give access to other per-iteration structures.
167     if(OMX_DEBUG) { Rprintf("Beginning Row Evaluation.\n");}
168         // Requires: Data, means, covariances.
169
170         omxMatrix* objMatrix  = oo->matrix;
171         omxState* parentState = objMatrix->currentState;
172         int numChildren = parentState->numChildren;
173
174     omxMatrix *reduceAlgebra;
175         omxData *data;
176
177     omxRowFitFunction* oro = ((omxRowFitFunction*) oo->argStruct);
178
179         reduceAlgebra   = oro->reduceAlgebra;
180         data                = oro->data;
181
182         /* Michael Spiegel, 7/31/12
183         * The demo "RowFitFunctionSimpleExamples" will fail in the parallel 
184         * Hessian calculation if the resizing operation is performed.
185         *
186         omxMatrix *rowAlgebra, *rowResults
187         rowAlgebra          = oro->rowAlgebra;
188         rowResults          = oro->rowResults;
189
190         if(rowResults->cols != rowAlgebra->cols || rowResults->rows != data->rows) {
191                 if(OMX_DEBUG_ROWS(1)) { 
192                         Rprintf("Resizing rowResults from %dx%d to %dx%d.\n", 
193                                 rowResults->rows, rowResults->cols, 
194                                 data->rows, rowAlgebra->cols); 
195                 }
196                 omxResizeMatrix(rowResults, data->rows, rowAlgebra->cols, FALSE);
197         }
198         */
199                 
200     int parallelism = (numChildren == 0) ? 1 : numChildren;
201
202         if (parallelism > data->rows) {
203                 parallelism = data->rows;
204         }
205
206         if (parallelism > 1) {
207                 int stride = (data->rows / parallelism);
208
209                 #pragma omp parallel for num_threads(parallelism) 
210                 for(int i = 0; i < parallelism; i++) {
211                         omxMatrix *childMatrix = omxLookupDuplicateElement(parentState->childList[i], objMatrix);
212                         omxFitFunction *childFit = childMatrix->fitFunction;
213                         if (i == parallelism - 1) {
214                                 omxRowFitFunctionSingleIteration(childFit, oo, stride * i, data->rows - stride * i);
215                         } else {
216                                 omxRowFitFunctionSingleIteration(childFit, oo, stride * i, stride);
217                         }
218                 }
219
220                 for(int i = 0; i < parallelism; i++) {
221                         if (parentState->childList[i]->statusCode < 0) {
222                                 parentState->statusCode = parentState->childList[i]->statusCode;
223                                 strncpy(parentState->statusMsg, parentState->childList[i]->statusMsg, 249);
224                                 parentState->statusMsg[249] = '\0';
225                         }
226                 }
227
228         } else {
229                 omxRowFitFunctionSingleIteration(oo, oo, 0, data->rows);
230         }
231
232         omxRecompute(reduceAlgebra);
233
234         omxCopyMatrix(oo->matrix, reduceAlgebra);
235
236 }
237
238 void omxInitRowFitFunction(omxFitFunction* oo, SEXP rObj) {
239
240         if(OMX_DEBUG) { Rprintf("Initializing Row/Reduce fit function.\n"); }
241
242         SEXP nextMatrix, itemList, nextItem;
243         int nextDef, index, numDeps;
244
245         omxRowFitFunction *newObj = (omxRowFitFunction*) R_alloc(1, sizeof(omxRowFitFunction));
246
247         if(OMX_DEBUG) {Rprintf("Accessing data source.\n"); }
248         PROTECT(nextMatrix = GET_SLOT(rObj, install("data")));
249         newObj->data = omxNewDataFromMxDataPtr(nextMatrix, oo->matrix->currentState);
250         if(newObj->data == NULL) {
251                 char *errstr = calloc(250, sizeof(char));
252                 sprintf(errstr, "No data provided to omxRowFitFunction.");
253                 omxRaiseError(oo->matrix->currentState, -1, errstr);
254                 free(errstr);
255         }
256         UNPROTECT(1); // nextMatrix
257
258         PROTECT(nextMatrix = GET_SLOT(rObj, install("rowAlgebra")));
259         newObj->rowAlgebra = omxNewMatrixFromMxIndex(nextMatrix, oo->matrix->currentState);
260         if(newObj->rowAlgebra == NULL) {
261                 char *errstr = calloc(250, sizeof(char));
262                 sprintf(errstr, "No row-wise algebra in omxRowFitFunction.");
263                 omxRaiseError(oo->matrix->currentState, -1, errstr);
264                 free(errstr);
265         }
266         UNPROTECT(1);// nextMatrix
267
268         PROTECT(nextMatrix = GET_SLOT(rObj, install("filteredDataRow")));
269         newObj->filteredDataRow = omxNewMatrixFromMxIndex(nextMatrix, oo->matrix->currentState);
270         if(newObj->filteredDataRow == NULL) {
271                 char *errstr = calloc(250, sizeof(char));
272                 sprintf(errstr, "No row results matrix in omxRowFitFunction.");
273                 omxRaiseError(oo->matrix->currentState, -1, errstr);
274                 free(errstr);
275         }
276         // Create the original data row from which to filter.
277     newObj->dataRow = omxInitMatrix(NULL, newObj->filteredDataRow->rows, newObj->filteredDataRow->cols, TRUE, oo->matrix->currentState);
278     omxAliasMatrix(newObj->filteredDataRow, newObj->dataRow);
279         UNPROTECT(1);// nextMatrix
280
281         PROTECT(nextMatrix = GET_SLOT(rObj, install("existenceVector")));
282         newObj->existenceVector = omxNewMatrixFromMxIndex(nextMatrix, oo->matrix->currentState);
283     // Do we allow NULL existence?  (Whoa, man. That's, like, deep, or something.)
284         if(newObj->existenceVector == NULL) {
285                 char *errstr = calloc(250, sizeof(char));
286                 sprintf(errstr, "No existance matrix in omxRowFitFunction.");
287                 omxRaiseError(oo->matrix->currentState, -1, errstr);
288                 free(errstr);
289         }
290         UNPROTECT(1);// nextMatrix
291
292
293         PROTECT(nextMatrix = GET_SLOT(rObj, install("rowResults")));
294         newObj->rowResults = omxNewMatrixFromMxIndex(nextMatrix, oo->matrix->currentState);
295         if(newObj->rowResults == NULL) {
296                 char *errstr = calloc(250, sizeof(char));
297                 sprintf(errstr, "No row results matrix in omxRowFitFunction.");
298                 omxRaiseError(oo->matrix->currentState, -1, errstr);
299                 free(errstr);
300         }
301         UNPROTECT(1);// nextMatrix
302
303         PROTECT(nextMatrix = GET_SLOT(rObj, install("reduceAlgebra")));
304         newObj->reduceAlgebra = omxNewMatrixFromMxIndex(nextMatrix, oo->matrix->currentState);
305         if(newObj->reduceAlgebra == NULL) {
306                 char *errstr = calloc(250, sizeof(char));
307                 sprintf(errstr, "No row reduction algebra in omxRowFitFunction.");
308                 omxRaiseError(oo->matrix->currentState, -1, errstr);
309                 free(errstr);
310         }
311         UNPROTECT(1);// nextMatrix
312         
313         if(OMX_DEBUG) {Rprintf("Accessing variable mapping structure.\n"); }
314         PROTECT(nextMatrix = GET_SLOT(rObj, install("dataColumns")));
315         newObj->dataColumns = omxNewMatrixFromRPrimitive(nextMatrix, oo->matrix->currentState, 0, 0);
316         if(OMX_DEBUG) { omxPrint(newObj->dataColumns, "Variable mapping"); }
317         UNPROTECT(1);
318
319         if(OMX_DEBUG) {Rprintf("Accessing data row dependencies.\n"); }
320         PROTECT(nextItem = GET_SLOT(rObj, install("dataRowDeps")));
321         numDeps = LENGTH(nextItem);
322         newObj->numDataRowDeps = numDeps;
323         newObj->dataRowDeps = (int*) R_alloc(numDeps, sizeof(int));
324         for(int i = 0; i < numDeps; i++) {
325                 newObj->dataRowDeps[i] = INTEGER(nextItem)[i];
326         }
327         UNPROTECT(1);
328
329         if(OMX_DEBUG) {Rprintf("Accessing definition variables structure.\n"); }
330         PROTECT(nextMatrix = GET_SLOT(rObj, install("definitionVars")));
331         newObj->numDefs = length(nextMatrix);
332         newObj->oldDefs = (double *) R_alloc(newObj->numDefs, sizeof(double));          // Storage for Def Vars
333         if(OMX_DEBUG) {Rprintf("Number of definition variables is %d.\n", newObj->numDefs); }
334         newObj->defVars = (omxDefinitionVar *) R_alloc(newObj->numDefs, sizeof(omxDefinitionVar));
335         for(nextDef = 0; nextDef < newObj->numDefs; nextDef++) {
336                 SEXP dataSource, columnSource, depsSource; 
337
338                 PROTECT(itemList = VECTOR_ELT(nextMatrix, nextDef));
339                 PROTECT(dataSource = VECTOR_ELT(itemList, 0));
340                 if(OMX_DEBUG) {Rprintf("Data source number is %d.\n", INTEGER(dataSource)[0]); }
341                 newObj->defVars[nextDef].data = INTEGER(dataSource)[0];
342                 newObj->defVars[nextDef].source = oo->matrix->currentState->dataList[INTEGER(dataSource)[0]];
343                 PROTECT(columnSource = VECTOR_ELT(itemList, 1));
344                 if(OMX_DEBUG) {Rprintf("Data column number is %d.\n", INTEGER(columnSource)[0]); }
345                 newObj->defVars[nextDef].column = INTEGER(columnSource)[0];
346                 PROTECT(depsSource = VECTOR_ELT(itemList, 2));
347                 numDeps = LENGTH(depsSource);
348                 newObj->defVars[nextDef].numDeps = numDeps;
349                 newObj->defVars[nextDef].deps = (int*) R_alloc(numDeps, sizeof(int));
350                 for(int i = 0; i < numDeps; i++) {
351                         newObj->defVars[nextDef].deps[i] = INTEGER(depsSource)[i];
352                 }
353                 UNPROTECT(3); // unprotect dataSource, columnSource, and depsSource
354
355                 newObj->defVars[nextDef].numLocations = length(itemList) - 3;
356                 newObj->defVars[nextDef].matrices = (int *) R_alloc(length(itemList) - 3, sizeof(int));
357                 newObj->defVars[nextDef].rows = (int *) R_alloc(length(itemList) - 3, sizeof(int));
358                 newObj->defVars[nextDef].cols = (int *) R_alloc(length(itemList) - 3, sizeof(int));
359
360                 for(index = 3; index < length(itemList); index++) {
361                         PROTECT(nextItem = VECTOR_ELT(itemList, index));
362                         newObj->defVars[nextDef].matrices[index-3] = INTEGER(nextItem)[0];
363                         newObj->defVars[nextDef].rows[index-3]     = INTEGER(nextItem)[1];
364                         newObj->defVars[nextDef].cols[index-3]     = INTEGER(nextItem)[2];
365                         UNPROTECT(1); // unprotect nextItem
366                 }
367                 UNPROTECT(1); // unprotect itemList
368         }
369         UNPROTECT(1); // unprotect nextMatrix
370         
371         /* Set up data columns */
372         omxSetContiguousDataColumns(&(newObj->contiguous), newObj->data, newObj->dataColumns);
373
374         oo->computeFun = omxCallRowFitFunction;
375         oo->setFinalReturns = omxSetFinalReturnsRowFitFunction;
376         oo->destructFun = omxDestroyRowFitFunction;
377         oo->repopulateFun = NULL;
378
379         oo->argStruct = (void*) newObj;
380 }
381
382