Store algebraList in std::vector
[openmx:openmx.git] / src / omxRowFitFunction.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include <R.h>
18 #include <Rinternals.h>
19 #include <Rdefines.h>
20 #include <R_ext/Rdynload.h>
21 #include <R_ext/BLAS.h>
22 #include <R_ext/Lapack.h>
23 #include "omxDefines.h"
24 #include "omxAlgebraFunctions.h"
25 #include "omxSymbolTable.h"
26 #include "omxData.h"
27 #include "omxRowFitFunction.h"
28 #include "omxFIMLFitFunction.h"
29
30 void omxDestroyRowFitFunction(omxFitFunction *oo) {
31
32         omxRowFitFunction* argStruct = (omxRowFitFunction*)(oo->argStruct);
33
34         omxFreeMatrixData(argStruct->dataRow);
35 }
36
37 omxRListElement* omxSetFinalReturnsRowFitFunction(omxFitFunction *oo, int *numReturns) {
38         *numReturns = 0;
39         omxRListElement* retVal = (omxRListElement*) R_alloc(1, sizeof(omxRListElement));
40
41         retVal[0].numValues = 0;
42
43         return retVal;
44 }
45
46
47 void omxCopyMatrixToRow(omxMatrix* source, int row, omxMatrix* target) {
48         
49         int i;
50         for(i = 0; i < source->cols; i++) {
51                 omxSetMatrixElement(target, row, i, omxMatrixElement(source, 0, i));
52         }
53
54 }
55
56 void markDataRowDependencies(omxState* os, omxRowFitFunction* orff) {
57
58         int numDeps = orff->numDataRowDeps;
59         int *deps = orff->dataRowDeps;
60
61         for (int i = 0; i < numDeps; i++) {
62                 int value = deps[i];
63
64                 if(value < 0) {
65                         omxMarkDirty(os->matrixList[~value]);
66                 } else {
67                         omxMarkDirty(os->algebraList[value]);
68                 }
69         }
70
71 }
72
73 void omxRowFitFunctionSingleIteration(omxFitFunction *localobj, omxFitFunction *sharedobj, int rowbegin, int rowcount) {
74
75     omxRowFitFunction* oro = ((omxRowFitFunction*) localobj->argStruct);
76     omxRowFitFunction* shared_oro = ((omxRowFitFunction*) sharedobj->argStruct);
77
78         int numDefs;
79
80     omxMatrix *rowAlgebra, *rowResults;
81     omxMatrix *filteredDataRow, *dataRow, *existenceVector;
82     omxMatrix *dataColumns;
83         omxDefinitionVar* defVars;
84         omxData *data;
85         int isContiguous, contiguousStart, contiguousLength;
86     double* oldDefs;
87     int numCols, numRemoves;
88
89         rowAlgebra          = oro->rowAlgebra;
90         rowResults          = shared_oro->rowResults;
91         data                = oro->data;
92         defVars             = oro->defVars;
93         numDefs             = oro->numDefs;
94     oldDefs         = oro->oldDefs;
95     dataColumns     = oro->dataColumns;
96     dataRow         = oro->dataRow;
97     filteredDataRow = oro->filteredDataRow;
98     existenceVector = oro->existenceVector;
99     
100     isContiguous    = oro->contiguous.isContiguous;
101         contiguousStart = oro->contiguous.start;
102         contiguousLength = oro->contiguous.length;
103
104         numCols = dataColumns->cols;
105         int *toRemove = (int*) malloc(sizeof(int) * dataColumns->cols);
106         int *zeros = (int*) calloc(dataColumns->cols, sizeof(int));
107
108     resetDefinitionVariables(oldDefs, numDefs);
109
110         for(int row = rowbegin; row < data->rows && (row - rowbegin) < rowcount; row++) {
111
112                 // Handle Definition Variables.
113         if(OMX_DEBUG_ROWS(row)) { Rprintf("numDefs is %d", numDefs);}
114                 if(numDefs != 0) {              // With defs, just copy repeatedly to the rowResults matrix.
115                         handleDefinitionVarList(data, localobj->matrix->currentState, row, defVars, oldDefs, numDefs);
116                 }
117
118                 omxStateNextRow(localobj->matrix->currentState);                                                // Advance row
119                 
120         // Populate data row
121                 numRemoves = 0;
122         
123                 if (isContiguous) {
124                         omxContiguousDataRow(data, row, contiguousStart, contiguousLength, dataRow);
125                 } else {
126                         omxDataRow(data, row, dataColumns, dataRow);    // Populate data row
127                 }
128
129                 markDataRowDependencies(localobj->matrix->currentState, oro);
130                 
131                 for(int j = 0; j < dataColumns->cols; j++) {
132                         double dataValue = omxVectorElement(dataRow, j);
133                         if(isnan(dataValue)) {
134                                 numRemoves++;
135                                 toRemove[j] = 1;
136                 omxSetVectorElement(existenceVector, j, 0);
137                         } else {
138                             toRemove[j] = 0;
139                 omxSetVectorElement(existenceVector, j, 1);
140                         }
141                 }               
142                 // TODO: Determine if this is the correct response.
143                 
144                 if(numRemoves == numCols) {
145                         char *errstr = (char*) calloc(250, sizeof(char));
146                         sprintf(errstr, "Row %d completely missing.  omxRowFitFunction cannot have completely missing rows.", omxDataIndex(data, row));
147                         omxRaiseError(localobj->matrix->currentState, -1, errstr);
148                         free(errstr);
149                         continue;
150                 }
151
152                 omxResetAliasedMatrix(filteredDataRow);                         // Reset the row
153                 omxRemoveRowsAndColumns(filteredDataRow, 0, numRemoves, zeros, toRemove);
154
155                 omxRecompute(rowAlgebra);                                                       // Compute this row
156
157                 omxCopyMatrixToRow(rowAlgebra, omxDataIndex(data, row), rowResults);
158         }
159         free(toRemove);
160         free(zeros);
161 }
162
163 static void omxCallRowFitFunction(omxFitFunction *oo, int want, double *gradient) {     // TODO: Figure out how to give access to other per-iteration structures.
164     if(OMX_DEBUG) { Rprintf("Beginning Row Evaluation.\n");}
165         // Requires: Data, means, covariances.
166
167         omxMatrix* objMatrix  = oo->matrix;
168         omxState* parentState = objMatrix->currentState;
169         int numChildren = parentState->numChildren;
170
171     omxMatrix *reduceAlgebra;
172         omxData *data;
173
174     omxRowFitFunction* oro = ((omxRowFitFunction*) oo->argStruct);
175
176         reduceAlgebra   = oro->reduceAlgebra;
177         data                = oro->data;
178
179         /* Michael Spiegel, 7/31/12
180         * The demo "RowFitFunctionSimpleExamples" will fail in the parallel 
181         * Hessian calculation if the resizing operation is performed.
182         *
183         omxMatrix *rowAlgebra, *rowResults
184         rowAlgebra          = oro->rowAlgebra;
185         rowResults          = oro->rowResults;
186
187         if(rowResults->cols != rowAlgebra->cols || rowResults->rows != data->rows) {
188                 if(OMX_DEBUG_ROWS(1)) { 
189                         Rprintf("Resizing rowResults from %dx%d to %dx%d.\n", 
190                                 rowResults->rows, rowResults->cols, 
191                                 data->rows, rowAlgebra->cols); 
192                 }
193                 omxResizeMatrix(rowResults, data->rows, rowAlgebra->cols, FALSE);
194         }
195         */
196                 
197     int parallelism = (numChildren == 0) ? 1 : numChildren;
198
199         if (parallelism > data->rows) {
200                 parallelism = data->rows;
201         }
202
203         if (parallelism > 1) {
204                 int stride = (data->rows / parallelism);
205
206                 #pragma omp parallel for num_threads(parallelism) 
207                 for(int i = 0; i < parallelism; i++) {
208                         omxMatrix *childMatrix = omxLookupDuplicateElement(parentState->childList[i], objMatrix);
209                         omxFitFunction *childFit = childMatrix->fitFunction;
210                         if (i == parallelism - 1) {
211                                 omxRowFitFunctionSingleIteration(childFit, oo, stride * i, data->rows - stride * i);
212                         } else {
213                                 omxRowFitFunctionSingleIteration(childFit, oo, stride * i, stride);
214                         }
215                 }
216
217                 for(int i = 0; i < parallelism; i++) {
218                         if (parentState->childList[i]->statusCode < 0) {
219                                 parentState->statusCode = parentState->childList[i]->statusCode;
220                                 strncpy(parentState->statusMsg, parentState->childList[i]->statusMsg, 249);
221                                 parentState->statusMsg[249] = '\0';
222                         }
223                 }
224
225         } else {
226                 omxRowFitFunctionSingleIteration(oo, oo, 0, data->rows);
227         }
228
229         omxRecompute(reduceAlgebra);
230
231         omxCopyMatrix(oo->matrix, reduceAlgebra);
232
233 }
234
235 void omxInitRowFitFunction(omxFitFunction* oo, SEXP rObj) {
236
237         if(OMX_DEBUG) { Rprintf("Initializing Row/Reduce fit function.\n"); }
238
239         SEXP nextMatrix, itemList, nextItem;
240         int nextDef, index, numDeps;
241
242         omxRowFitFunction *newObj = (omxRowFitFunction*) R_alloc(1, sizeof(omxRowFitFunction));
243
244         if(OMX_DEBUG) {Rprintf("Accessing data source.\n"); }
245         PROTECT(nextMatrix = GET_SLOT(rObj, install("data")));
246         newObj->data = omxDataLookupFromState(nextMatrix, oo->matrix->currentState);
247         if(newObj->data == NULL) {
248                 char *errstr = (char*) calloc(250, sizeof(char));
249                 sprintf(errstr, "No data provided to omxRowFitFunction.");
250                 omxRaiseError(oo->matrix->currentState, -1, errstr);
251                 free(errstr);
252         }
253         UNPROTECT(1); // nextMatrix
254
255         PROTECT(nextMatrix = GET_SLOT(rObj, install("rowAlgebra")));
256         newObj->rowAlgebra = omxMatrixLookupFromState1(nextMatrix, oo->matrix->currentState);
257         if(newObj->rowAlgebra == NULL) {
258                 char *errstr = (char*) calloc(250, sizeof(char));
259                 sprintf(errstr, "No row-wise algebra in omxRowFitFunction.");
260                 omxRaiseError(oo->matrix->currentState, -1, errstr);
261                 free(errstr);
262         }
263         UNPROTECT(1);// nextMatrix
264
265         PROTECT(nextMatrix = GET_SLOT(rObj, install("filteredDataRow")));
266         newObj->filteredDataRow = omxMatrixLookupFromState1(nextMatrix, oo->matrix->currentState);
267         if(newObj->filteredDataRow == NULL) {
268                 char *errstr = (char*) calloc(250, sizeof(char));
269                 sprintf(errstr, "No row results matrix in omxRowFitFunction.");
270                 omxRaiseError(oo->matrix->currentState, -1, errstr);
271                 free(errstr);
272         }
273         // Create the original data row from which to filter.
274     newObj->dataRow = omxInitMatrix(NULL, newObj->filteredDataRow->rows, newObj->filteredDataRow->cols, TRUE, oo->matrix->currentState);
275     omxAliasMatrix(newObj->filteredDataRow, newObj->dataRow);
276         UNPROTECT(1);// nextMatrix
277
278         PROTECT(nextMatrix = GET_SLOT(rObj, install("existenceVector")));
279         newObj->existenceVector = omxMatrixLookupFromState1(nextMatrix, oo->matrix->currentState);
280     // Do we allow NULL existence?  (Whoa, man. That's, like, deep, or something.)
281         if(newObj->existenceVector == NULL) {
282                 char *errstr = (char*) calloc(250, sizeof(char));
283                 sprintf(errstr, "No existance matrix in omxRowFitFunction.");
284                 omxRaiseError(oo->matrix->currentState, -1, errstr);
285                 free(errstr);
286         }
287         UNPROTECT(1);// nextMatrix
288
289
290         PROTECT(nextMatrix = GET_SLOT(rObj, install("rowResults")));
291         newObj->rowResults = omxMatrixLookupFromState1(nextMatrix, oo->matrix->currentState);
292         if(newObj->rowResults == NULL) {
293                 char *errstr = (char*) calloc(250, sizeof(char));
294                 sprintf(errstr, "No row results matrix in omxRowFitFunction.");
295                 omxRaiseError(oo->matrix->currentState, -1, errstr);
296                 free(errstr);
297         }
298         UNPROTECT(1);// nextMatrix
299
300         PROTECT(nextMatrix = GET_SLOT(rObj, install("reduceAlgebra")));
301         newObj->reduceAlgebra = omxMatrixLookupFromState1(nextMatrix, oo->matrix->currentState);
302         if(newObj->reduceAlgebra == NULL) {
303                 char *errstr = (char*) calloc(250, sizeof(char));
304                 sprintf(errstr, "No row reduction algebra in omxRowFitFunction.");
305                 omxRaiseError(oo->matrix->currentState, -1, errstr);
306                 free(errstr);
307         }
308         UNPROTECT(1);// nextMatrix
309         
310         if(OMX_DEBUG) {Rprintf("Accessing variable mapping structure.\n"); }
311         PROTECT(nextMatrix = GET_SLOT(rObj, install("dataColumns")));
312         newObj->dataColumns = omxNewMatrixFromRPrimitive(nextMatrix, oo->matrix->currentState, 0, 0);
313         if(OMX_DEBUG) { omxPrint(newObj->dataColumns, "Variable mapping"); }
314         UNPROTECT(1);
315
316         if(OMX_DEBUG) {Rprintf("Accessing data row dependencies.\n"); }
317         PROTECT(nextItem = GET_SLOT(rObj, install("dataRowDeps")));
318         numDeps = LENGTH(nextItem);
319         newObj->numDataRowDeps = numDeps;
320         newObj->dataRowDeps = (int*) R_alloc(numDeps, sizeof(int));
321         for(int i = 0; i < numDeps; i++) {
322                 newObj->dataRowDeps[i] = INTEGER(nextItem)[i];
323         }
324         UNPROTECT(1);
325
326         if(OMX_DEBUG) {Rprintf("Accessing definition variables structure.\n"); }
327         PROTECT(nextMatrix = GET_SLOT(rObj, install("definitionVars")));
328         newObj->numDefs = length(nextMatrix);
329         newObj->oldDefs = (double *) R_alloc(newObj->numDefs, sizeof(double));          // Storage for Def Vars
330         if(OMX_DEBUG) {Rprintf("Number of definition variables is %d.\n", newObj->numDefs); }
331         newObj->defVars = (omxDefinitionVar *) R_alloc(newObj->numDefs, sizeof(omxDefinitionVar));
332         for(nextDef = 0; nextDef < newObj->numDefs; nextDef++) {
333                 SEXP dataSource, columnSource, depsSource; 
334
335                 PROTECT(itemList = VECTOR_ELT(nextMatrix, nextDef));
336                 PROTECT(dataSource = VECTOR_ELT(itemList, 0));
337                 if(OMX_DEBUG) {Rprintf("Data source number is %d.\n", INTEGER(dataSource)[0]); }
338                 newObj->defVars[nextDef].data = INTEGER(dataSource)[0];
339                 newObj->defVars[nextDef].source = oo->matrix->currentState->dataList[INTEGER(dataSource)[0]];
340                 PROTECT(columnSource = VECTOR_ELT(itemList, 1));
341                 if(OMX_DEBUG) {Rprintf("Data column number is %d.\n", INTEGER(columnSource)[0]); }
342                 newObj->defVars[nextDef].column = INTEGER(columnSource)[0];
343                 PROTECT(depsSource = VECTOR_ELT(itemList, 2));
344                 numDeps = LENGTH(depsSource);
345                 newObj->defVars[nextDef].numDeps = numDeps;
346                 newObj->defVars[nextDef].deps = (int*) R_alloc(numDeps, sizeof(int));
347                 for(int i = 0; i < numDeps; i++) {
348                         newObj->defVars[nextDef].deps[i] = INTEGER(depsSource)[i];
349                 }
350                 UNPROTECT(3); // unprotect dataSource, columnSource, and depsSource
351
352                 newObj->defVars[nextDef].numLocations = length(itemList) - 3;
353                 newObj->defVars[nextDef].matrices = (int *) R_alloc(length(itemList) - 3, sizeof(int));
354                 newObj->defVars[nextDef].rows = (int *) R_alloc(length(itemList) - 3, sizeof(int));
355                 newObj->defVars[nextDef].cols = (int *) R_alloc(length(itemList) - 3, sizeof(int));
356
357                 for(index = 3; index < length(itemList); index++) {
358                         PROTECT(nextItem = VECTOR_ELT(itemList, index));
359                         newObj->defVars[nextDef].matrices[index-3] = INTEGER(nextItem)[0];
360                         newObj->defVars[nextDef].rows[index-3]     = INTEGER(nextItem)[1];
361                         newObj->defVars[nextDef].cols[index-3]     = INTEGER(nextItem)[2];
362                         UNPROTECT(1); // unprotect nextItem
363                 }
364                 UNPROTECT(1); // unprotect itemList
365         }
366         UNPROTECT(1); // unprotect nextMatrix
367         
368         /* Set up data columns */
369         omxSetContiguousDataColumns(&(newObj->contiguous), newObj->data, newObj->dataColumns);
370
371         oo->computeFun = omxCallRowFitFunction;
372         oo->setFinalReturns = omxSetFinalReturnsRowFitFunction;
373         oo->destructFun = omxDestroyRowFitFunction;
374         oo->repopulateFun = NULL;
375         oo->usesChildModels = TRUE;
376
377         oo->argStruct = (void*) newObj;
378 }
379
380