Several changes: Fixed substitution of matrix/algebra results into other matrices...
[openmx:openmx.git] / src / omxMatrix.c
1 /*
2  *  Copyright 2007-2009 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 /***********************************************************
18
19 *  omxMatrix.cc
20 *
21 *  Created: Timothy R. Brick    Date: 2008-11-13 12:33:06
22 *
23 *       Contains code for the omxMatrix class
24 *   omxDataMatrices hold necessary information to simplify
25 *       dealings between the OpenMX back end and BLAS.
26 *
27 **********************************************************/
28 #include "omxMatrix.h"
29
30 const char omxMatrixMajorityList[3] = "Tn";             // BLAS Column Majority.
31
32 void omxPrintMatrixHelper(omxMatrix *source, char* header) {
33         int j, k;
34         
35         Rprintf("%s: (%d x %d) [%s-major]\n", header, source->rows, source->cols, (source->colMajor?"col":"row"));
36         if(OMX_DEBUG) {Rprintf("Matrix Printing is at %0x\n", source);}
37         
38         if(source->colMajor) {
39                 for(j = 0; j < source->rows; j++) {
40                         for(k = 0; k < source->cols; k++) {
41                                 Rprintf("\t%3.6f", source->data[k*source->rows+j]);
42                         }
43                         Rprintf("\n");
44                 }
45         } else {
46                 for(j = 0; j < source->cols; j++) {
47                         for(k = 0; k < source->rows; k++) {
48                                 Rprintf("\t%3.6f", source->data[k*source->cols+j]);
49                         }
50                         Rprintf("\n");
51                 }
52         }
53 }
54
55 omxMatrix* omxInitMatrix(omxMatrix* om, int nrows, int ncols, unsigned short isColMajor, omxState* os) {
56         
57         if(om == NULL) om = (omxMatrix*) R_alloc(1, sizeof(omxMatrix));
58         if(OMX_DEBUG) { Rprintf("Initializing 0x%0x to (%d, %d).\n", om, nrows, ncols); }
59
60         om->rows = nrows;
61         om->cols = ncols;
62         om->colMajor = (isColMajor?1:0);
63
64         om->originalRows = om->rows;
65         om->originalCols = om->cols;
66         om->originalColMajor=om->colMajor;
67         
68         if(om->rows == 0 || om->cols == 0) {
69                 om->data = NULL;
70                 om->localData = FALSE;
71         } else {
72                 om->data = (double*) Calloc(nrows * ncols, double);
73                 om->localData = TRUE;
74         }
75
76         om->populateFrom = NULL;
77         om->populateToCol = NULL;
78         om->populateToRow = NULL;
79         om->numPopulateLocations = 0;
80
81         om->aliasedPtr = NULL;
82         om->algebra = NULL;
83         om->objective = NULL;
84         
85         om->currentState = os;
86         om->lastCompute = -1;
87         om->lastRow = -1;
88         
89         omxComputeMatrixHelper(om);
90         
91         return om;
92         
93 }
94
95 void omxCopyMatrix(omxMatrix *dest, omxMatrix *orig) {
96         /* Duplicate a matrix.  NOTE: Matrix maintains its algebra bindings. */
97         
98         if(OMX_DEBUG) { Rprintf("omxCopyMatrix"); }
99         
100         omxFreeMatrixData(dest);
101
102         dest->rows = orig->rows;
103         dest->cols = orig->cols;
104         dest->colMajor = orig->colMajor;
105         dest->originalRows = dest->rows;
106         dest->originalCols = dest->cols;
107         dest->originalColMajor = dest->colMajor;
108         dest->currentState = orig->currentState;
109         dest->lastCompute = orig->lastCompute;
110         dest->lastRow = orig->lastRow;
111
112         if(dest->rows == 0 || dest->cols == 0) {
113                 dest->data = NULL;
114                 dest->localData=FALSE;
115         } else {
116                 dest->data = (double*) Calloc(dest->rows * dest->cols, double);
117                 memcpy(dest->data, orig->data, dest->rows * dest->cols * sizeof(double));
118                 dest->localData = TRUE;
119         }
120
121         dest->aliasedPtr = NULL;
122
123         omxComputeMatrixHelper(dest);
124         
125 }
126
127 void omxAliasMatrix(omxMatrix *dest, omxMatrix *src) {
128         omxCopyMatrix(dest, src);
129         dest->aliasedPtr = src->data;                   // Interesting Aside: back matrix can change without alias
130         dest->algebra = NULL;                                   // Have to look at how this effect interacts with populating
131         dest->objective = NULL;                                 //  matrix values to other locations.
132 }
133
134 void omxFreeMatrixData(omxMatrix * om) {
135  
136         if(om->localData && om->data != NULL) {
137                 if(OMX_DEBUG) { Rprintf("Freeing 0x%0x. Localdata = %d.\n", om->data, om->localData); }
138                 Free(om->data);
139                 om->data = NULL;
140                 om->localData = FALSE;
141         }
142
143 }
144
145 void omxFreeAllMatrixData(omxMatrix *om) {
146         
147         if(OMX_DEBUG) { Rprintf("Freeing 0x%0x with data = %0x and algebra %0x.\n", om, om->data, om->algebra); }
148         
149         if(om->localData && om->data != NULL) {
150                 Free(om->data);
151                 om->data = NULL;
152                 om->localData = FALSE;
153         }
154         
155         if(om->algebra != NULL) {
156                 omxFreeAlgebraArgs(om->algebra);
157                 om->algebra = NULL;
158         }
159         
160         if(om->objective != NULL) {
161                 omxFreeObjectiveArgs(om->objective);
162                 om->objective = NULL;
163         }
164
165 }
166
167 void omxResizeMatrix(omxMatrix *om, int nrows, int ncols, unsigned short keepMemory) {
168         // Always Recompute() before you Resize().
169         if(OMX_DEBUG) { Rprintf("Resizing matrix from (%d, %d) to (%d, %d) (keepMemory: %d)", om->rows, om->cols, nrows, ncols, keepMemory); }
170         if(keepMemory == FALSE) { 
171                 if(OMX_DEBUG) { Rprintf(" and regenerating memory to do it"); }
172                 omxFreeMatrixData(om);
173                 om->data = (double*) Calloc(nrows * ncols, double);
174                 om->localData = TRUE;
175         } else if(om->originalRows * om->originalCols < nrows * ncols) {
176                 warning("Upsizing an existing matrix may cause undefined behavior.\n"); // TODO: Define this behavior?
177         }
178
179         if(OMX_DEBUG) { Rprintf(".\n"); }
180         om->rows = nrows;
181         om->cols = ncols;
182         if(keepMemory == FALSE) {
183                 om->originalRows = om->rows;
184                 om->originalCols = om->cols;
185         }
186         
187         omxComputeMatrixHelper(om);
188 }
189
190 void omxResetAliasedMatrix(omxMatrix *om) {
191         om->rows = om->originalRows;
192         om->cols = om->originalCols;
193         om->colMajor = om->originalColMajor;
194         if(om->aliasedPtr != NULL) {
195 //              if(OMX_DEBUG) { omxPrintMatrix(om, "I was");}
196                 memcpy(om->data, om->aliasedPtr, om->rows*om->cols*sizeof(double));
197 //              if(OMX_DEBUG) { omxPrintMatrix(om, "I am");}
198         }
199         omxComputeMatrixHelper(om);
200 }
201
202 void omxComputeMatrixHelper(omxMatrix *om) {
203         
204         if(OMX_DEBUG) { Rprintf("Matrix compute: 0x%0x, 0x%0x, %d.\n", om, om->currentState, om->colMajor); }
205         om->majority = &(omxMatrixMajorityList[(om->colMajor?1:0)]);
206         om->minority = &(omxMatrixMajorityList[(om->colMajor?0:1)]);
207         om->leading = (om->colMajor?om->rows:om->cols);
208         om->lagging = (om->colMajor?om->cols:om->rows);
209         
210         for(int i = 0; i < om->numPopulateLocations; i++) {
211                 omxRecomputeMatrix(om->populateFrom[i]);                                // Make sure it's up to date
212                 omxSetMatrixElement(om, om->populateToRow[i], om->populateToCol[i], om->populateFrom[i]->data[0]);      
213                 // And then fill in the details.  Use the Helper here in case of transposition/downsampling.
214         }
215         
216         om->isDirty = FALSE;
217         om->lastCompute = om->currentState->computeCount;
218         om->lastRow = om->currentState->currentRow;
219 }
220
221 double* omxLocationOfMatrixElement(omxMatrix *om, int row, int col) {
222         int index = 0;
223         if(om->colMajor) {
224                 index = col * om->rows + row;
225         } else {
226                 index = row * om->cols + col;
227         }
228         return om->data + index;
229 }
230
231 double omxMatrixElement(omxMatrix *om, int row, int col) {
232         int index = 0;
233         if(om->colMajor) {
234                 index = col * om->rows + row;
235         } else {
236                 index = row * om->cols + col;
237         }
238         return om->data[index];
239 }
240
241 void omxSetMatrixElement(omxMatrix *om, int row, int col, double value) {
242         int index = 0;
243         if(om->colMajor) {
244                 index = col * om->rows + row;
245         } else {
246                 index = row * om->cols + col;
247         }
248         om->data[index] = value;
249 }
250
251 void omxMarkDirty(omxMatrix *om) { om->isDirty = TRUE; }
252
253 unsigned short omxMatrixNeedsUpdate(omxMatrix *om) { 
254
255         for(int i = 0; i < om->numPopulateLocations; i++) {
256                 if(omxNeedsUpdate(om->populateFrom[i])) return TRUE;    // Make sure it's up to date
257         }
258         
259 };
260
261 omxMatrix* omxNewMatrixFromMxMatrix(SEXP matrix, omxState* state) {
262 /* Populates the fields of a omxMatrix with details from an R Matrix. */ 
263         
264         SEXP className;
265         SEXP matrixDims;
266         int* dimList;
267         
268         omxMatrix *om = NULL;
269         om = omxInitMatrix(NULL, 0, 0, FALSE, state);
270         
271         if(OMX_DEBUG) { Rprintf("Filling omxMatrix from R matrix.\n"); }
272         
273         /* Sanity Check */
274         if(!isMatrix(matrix) && !isVector(matrix)) {
275                 SEXP values;
276                 int *rowVec, *colVec;
277                 double  *dataVec;
278                 const char *stringName;
279                 if(OMX_DEBUG) {Rprintf("R Matrix is an object of some sort.\n");}
280                 PROTECT(className = getAttrib(matrix, install("class")));
281                 if(strncmp(CHAR(STRING_ELT(className, 0)), "Symm", 2) != 0) { // Should be "Mx"
282                         error("omxMatrix::fillFromMatrix--Passed Non-vector, non-matrix SEXP.\n");
283                 }
284                 stringName = CHAR(STRING_ELT(className, 0));
285                 if(strncmp(stringName, "SymmMatrix", 12) == 0) {
286                         if(OMX_DEBUG) { Rprintf("R matrix is SymmMatrix.  Processing.\n"); }
287                         PROTECT(values = GET_SLOT(matrix, install("values")));
288                         om->rows = INTEGER(GET_SLOT(values, install("nrow")))[0];
289                         om->cols = INTEGER(GET_SLOT(values, install("ncol")))[0];
290                         
291                         om->data = (double*) S_alloc(om->rows * om->cols, sizeof(double));              // We can afford to keep through the whole call
292                         rowVec = INTEGER(GET_SLOT(values, install("rowVector")));
293                         colVec = INTEGER(GET_SLOT(values, install("colVector")));
294                         dataVec = REAL(GET_SLOT(values, install("dataVector")));
295                         for(int j = 0; j < length(GET_SLOT(values, install("dataVector"))); j++) {
296                                 om->data[(rowVec[j]-1) + (colVec[j]-1) * om->rows] = dataVec[j];
297                                 om->data[(rowVec[j]-1) * om->cols + (colVec[j]-1)] = dataVec[j];  // Symmetric, after all.
298                         }
299                         UNPROTECT(1); // values
300                 }
301                 UNPROTECT(1); // className
302         } else {
303                 if(OMX_DEBUG) { Rprintf("R matrix is Mx Matrix.  Processing.\n"); }
304                 
305                 om->data = REAL(matrix);        // TODO: Class-check first?
306                 
307                 if(isMatrix(matrix)) {
308                         PROTECT(matrixDims = getAttrib(matrix, R_DimSymbol));
309                         dimList = INTEGER(matrixDims);
310                         om->rows = dimList[0];
311                         om->cols = dimList[1];
312                         UNPROTECT(1);   // MatrixDims
313                 } else if (isVector(matrix)) {          // If it's a vector, assume it's a row vector. BLAS doesn't care.
314                         if(OMX_DEBUG) { Rprintf("Vector discovered.  Assuming rowity.\n"); }
315                         om->rows = 1;
316                         om->cols = length(matrix);
317                 }
318                 if(OMX_DEBUG) { Rprintf("Data connected to (%d, %d) matrix.\n", om->rows, om->cols); }
319         }       
320         
321         om->localData = FALSE;
322         om->colMajor = TRUE;
323         om->originalRows = om->rows;
324         om->originalCols = om->cols;
325         om->originalColMajor = TRUE;
326         om->aliasedPtr = om->data;
327         om->algebra = NULL;
328         om->objective = NULL;
329         om->currentState = state;
330         om->lastCompute = -1;
331         om->lastRow = -1;
332         
333         if(OMX_DEBUG) { Rprintf("Pre-compute call.\n");}
334         omxComputeMatrixHelper(om);
335         if(OMX_DEBUG) { Rprintf("Post-compute call.\n");}
336
337         if(OMX_DEBUG) {
338                 omxPrintMatrixHelper(om, "Finished importing matrix");
339         }
340
341         return om;
342 }
343
344 void omxProcessMatrixPopulationList(omxMatrix* matrix, SEXP matStruct) {
345         
346         if(OMX_DEBUG) { Rprintf("Processing Population List: %d elements.\n", length(matStruct) - 1); }
347         SEXP subList;
348         SEXP matLoc, xLoc, yLoc;
349         
350         if(length(matStruct) > 1) {
351                 int numPopLocs = length(matStruct) - 1;
352                 matrix->numPopulateLocations = numPopLocs;
353                 matrix->populateFrom = (omxMatrix**)R_alloc(numPopLocs, sizeof(omxMatrix*));
354                 matrix->populateToRow = (int*)R_alloc(numPopLocs, sizeof(int));
355                 matrix->populateToCol = (int*)R_alloc(numPopLocs, sizeof(int));
356         }
357         
358         for(int i = 0; i < length(matStruct)-1; i++) {
359                 PROTECT(subList = AS_INTEGER(VECTOR_ELT(matStruct, i+1)));
360                 
361                 int* locations = INTEGER(subList);
362                 int loc = locations[2];
363                 Rprintf("."); //:::
364                 if(loc < 0) {                   // NOTE: This duplicates some of the functionality of NewMatrixFromMxIndex
365                         matrix->populateFrom[i] = matrix->currentState->matrixList[(~loc)];
366                 } else {
367                         matrix->populateFrom[i] = matrix->currentState->algebraList[(loc)];
368                 }
369                 
370                 matrix->populateToRow[i] = locations[0];
371                 matrix->populateToCol[i] = locations[1];
372                 
373                 UNPROTECT(1); // subList
374         }
375 }
376
377 void omxRemoveRowsAndColumns(omxMatrix *om, int numRowsRemoved, int numColsRemoved, int rowsRemoved[], int colsRemoved[])
378 {
379         if(OMX_DEBUG) { Rprintf("Removing %d rows and %d columns from 0x%0x.\n", numRowsRemoved, numColsRemoved, om);}
380         
381         if(om->aliasedPtr == NULL) {  // This is meant only for aliased matrices.  Maybe Need a subclass?
382                 error("removeRowsAndColumns intended only for aliased matrices.\n");
383         }
384         
385         if(numRowsRemoved < 1 || numColsRemoved < 1) { return; }
386                 
387         int numCols = 0;
388         int nextCol = 0;
389         int nextRow = 0;
390         int oldRows = om->rows;
391         int oldCols = om->cols;
392         int j,k;
393         
394         om->rows = om->rows - numRowsRemoved;
395         om->cols = om->cols - numColsRemoved;
396         
397         // Note:  This really aught to be done using a matrix multiply.  Why isn't it?
398         if(om->colMajor) {
399                 for(int j = 0; j < oldCols; j++) {
400                         if(OMX_DEBUG) { Rprintf("Handling %d rows.\n", j);}
401                         if(colsRemoved[j]) {
402                                 continue;
403                         } else {
404                                 nextRow = 0;
405                                 for(int k = 0; k < oldRows; k++) {
406                                         if(rowsRemoved[k]) {
407                                                 continue;
408                                         } else {
409                                                 omxSetMatrixElement(om, nextRow, nextCol, om->aliasedPtr[k + j * oldRows]);
410                                                 nextRow++;
411                                         }
412                                 }
413                                 nextCol++;
414                         }
415                 }
416         } else {
417                 for(int j = 0; j < oldRows; j++) {
418                         if(rowsRemoved[j]) {
419                                 continue;
420                         } else {
421                                 nextCol = 0;
422                                 for(int k = 0; k < oldCols; k++) {
423                                         if(colsRemoved[k]) {
424                                                 continue;
425                                         } else {
426                                                 omxSetMatrixElement(om, nextRow, nextCol, om->aliasedPtr[k + j * oldCols]);
427                                                 nextCol++;
428                                         }
429                                 }
430                                 nextRow++;
431                         }
432                 }
433         }
434
435         omxComputeMatrixHelper(om);
436 }
437
438 /* Function wrappers that switch based on inclusion of algebras */
439 void omxPrintMatrix(omxMatrix *source, char* d) {                                       // Pretty-print a (small) matrix
440         if(source->algebra != NULL) omxAlgebraPrint(source->algebra, d);
441         else if(source->objective != NULL) omxObjectivePrint(source->objective, d);
442         else omxPrintMatrixHelper(source, d);
443 }
444
445 unsigned short omxNeedsUpdate(omxMatrix *matrix) {
446         /* Simplest update check: If we're dirty or haven't computed this cycle (iteration or row), we need to. */
447         if(OMX_DEBUG) {Rprintf("MatrixNeedsUpdate?");}
448         if(matrix->isDirty) return TRUE;
449         if(matrix->lastCompute < matrix->currentState->computeCount) return TRUE;       // No need to check args if oa's dirty.
450         if(matrix->lastRow < matrix->currentState->currentRow) return TRUE;                     // Ditto.
451         
452         if(matrix->algebra != NULL) return omxAlgebraNeedsUpdate(matrix->algebra); 
453         else if(matrix->objective != NULL) return omxObjectiveNeedsUpdate(matrix->objective);
454         else return omxMatrixNeedsUpdate(matrix);
455
456 }
457
458 void inline omxRecomputeMatrix(omxMatrix *matrix) {
459         if(!omxNeedsUpdate(matrix)) return;
460         if(matrix->algebra != NULL) omxAlgebraCompute(matrix->algebra);
461         else if(matrix->objective != NULL) omxObjectiveCompute(matrix->objective);
462         else omxComputeMatrixHelper(matrix);
463 }
464
465 void inline omxComputeMatrix(omxMatrix *matrix) {
466         if(matrix->algebra != NULL) omxAlgebraCompute(matrix->algebra);
467         else if(matrix->objective != NULL) omxObjectiveCompute(matrix->objective);
468         else omxComputeMatrixHelper(matrix);
469 }