added square bracket operator to MxAlgebra expressions.
[openmx:openmx.git] / src / omxMatrix.c
1 /*
2  *  Copyright 2007-2009 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 /***********************************************************
18 *
19 *  omxMatrix.cc
20 *
21 *  Created: Timothy R. Brick    Date: 2008-11-13 12:33:06
22 *
23 *       Contains code for the omxMatrix class
24 *   omxDataMatrices hold necessary information to simplify
25 *       dealings between the OpenMX back end and BLAS.
26 *
27 **********************************************************/
28 #include "omxMatrix.h"
29
30 const char omxMatrixMajorityList[3] = "Tn";             // BLAS Column Majority.
31
32 void omxPrintMatrix(omxMatrix *source, char* header) {
33         int j, k;
34
35         Rprintf("%s: (%d x %d) [%s-major]\n", header, source->rows, source->cols, (source->colMajor?"col":"row"));
36         if(OMX_DEBUG_MATRIX) {Rprintf("Matrix Printing is at %0x\n", source);}
37
38         if(source->colMajor) {
39                 for(j = 0; j < source->rows; j++) {
40                         for(k = 0; k < source->cols; k++) {
41                                 Rprintf("\t%3.6f", source->data[k*source->rows+j]);
42                         }
43                         Rprintf("\n");
44                 }
45         } else {
46                 for(j = 0; j < source->cols; j++) {
47                         for(k = 0; k < source->rows; k++) {
48                                 Rprintf("\t%3.6f", source->data[k*source->cols+j]);
49                         }
50                         Rprintf("\n");
51                 }
52         }
53 }
54
55 omxMatrix* omxInitMatrix(omxMatrix* om, int nrows, int ncols, unsigned short isColMajor, omxState* os) {
56
57         if(om == NULL) om = (omxMatrix*) R_alloc(1, sizeof(omxMatrix));
58         if(OMX_DEBUG) { Rprintf("Initializing matrix 0x%0x to (%d, %d) with state at 0x%x.\n", om, nrows, ncols, os); }
59
60         om->rows = nrows;
61         om->cols = ncols;
62         om->colMajor = (isColMajor?1:0);
63
64         om->originalRows = om->rows;
65         om->originalCols = om->cols;
66         om->originalColMajor=om->colMajor;
67
68         if(om->rows == 0 || om->cols == 0) {
69                 om->data = NULL;
70                 om->localData = FALSE;
71         } else {
72                 om->data = (double*) Calloc(nrows * ncols, double);
73                 om->localData = TRUE;
74         }
75
76         om->populateFrom = NULL;
77         om->populateFromCol = NULL;
78         om->populateFromRow = NULL;
79         om->populateToCol = NULL;
80         om->populateToRow = NULL;
81
82         om->numPopulateLocations = 0;
83
84         om->aliasedPtr = NULL;
85         om->algebra = NULL;
86         om->objective = NULL;
87
88         om->currentState = os;
89         om->lastCompute = -2;
90         om->lastRow = -2;
91
92         omxMatrixCompute(om);
93
94         return om;
95
96 }
97
98 void omxCopyMatrix(omxMatrix *dest, omxMatrix *orig) {
99         /* Duplicate a matrix.  NOTE: Matrix maintains its algebra bindings. */
100
101         if(OMX_DEBUG_MATRIX || OMX_DEBUG_ALGEBRA) { Rprintf("omxCopyMatrix"); }
102         
103         int regenerateMemory = TRUE;
104         
105         if(dest->localData && (dest->originalRows == orig->rows && dest->originalCols == orig->cols)) {
106                 regenerateMemory = FALSE;                               // If it's local data and the right size, we can keep memory.
107         }
108
109         dest->rows = orig->rows;
110         dest->cols = orig->cols;
111         dest->colMajor = orig->colMajor;
112         dest->originalRows = dest->rows;
113         dest->originalCols = dest->cols;
114         dest->originalColMajor = dest->colMajor;
115         dest->currentState = orig->currentState;
116         dest->lastCompute = orig->lastCompute;
117         dest->lastRow = orig->lastRow;
118
119         if(dest->rows == 0 || dest->cols == 0) {
120                 omxFreeMatrixData(dest);
121                 dest->data = NULL;
122                 dest->localData=FALSE;
123         } else {
124                 if(regenerateMemory) {
125                         omxFreeMatrixData(dest);                                                                                        // Free and regenerate memory
126                         dest->data = (double*) Calloc(dest->rows * dest->cols, double);
127                 }
128                 memcpy(dest->data, orig->data, dest->rows * dest->cols * sizeof(double));
129                 dest->localData = TRUE;
130         }
131
132         dest->aliasedPtr = NULL;
133
134         omxMatrixCompute(dest);
135
136 }
137
138 void omxAliasMatrix(omxMatrix *dest, omxMatrix *src) {
139         omxCopyMatrix(dest, src);
140         dest->aliasedPtr = src;                                 // Alias now follows back matrix precisely.
141         dest->algebra = NULL;                                   // Have to look at how this effect interacts with populating
142         dest->objective = NULL;                                 //              matrix values to other locations.
143 }
144
145 void omxFreeMatrixData(omxMatrix * om) {
146
147         if(om->localData && om->data != NULL) {
148                 if(OMX_DEBUG_MATRIX) { Rprintf("Freeing matrix at 0x%0x. Localdata = %d.\n", om->data, om->localData); }
149                 Free(om->data);
150                 om->data = NULL;
151                 om->localData = FALSE;
152         }
153
154 }
155
156 void omxFreeAllMatrixData(omxMatrix *om) {
157
158         if(OMX_DEBUG) { Rprintf("Freeing matrix at 0x%0x with data = %0x and algebra %0x.\n", om, om->data, om->algebra); }
159
160         if(om->localData && om->data != NULL) {
161                 Free(om->data);
162                 om->data = NULL;
163                 om->localData = FALSE;
164         }
165
166         if(om->algebra != NULL) {
167                 omxFreeAlgebraArgs(om->algebra);
168                 om->algebra = NULL;
169         }
170
171         if(om->objective != NULL) {
172                 omxFreeObjectiveArgs(om->objective);
173                 om->objective = NULL;
174         }
175
176 }
177
178 void omxZeroByZeroMatrix(omxMatrix *om) {
179         if (om->rows > 0 || om->cols > 0) {
180                 omxResizeMatrix(om, 0, 0, FALSE);
181         }
182 }
183
184 void omxResizeMatrix(omxMatrix *om, int nrows, int ncols, unsigned short keepMemory) {
185         // Always Recompute() before you Resize().
186         if(OMX_DEBUG_MATRIX) { Rprintf("Resizing matrix from (%d, %d) to (%d, %d) (keepMemory: %d)", om->rows, om->cols, nrows, ncols, keepMemory); }
187         if(keepMemory == FALSE) {
188                 if(OMX_DEBUG_MATRIX) { Rprintf(" and regenerating memory to do it"); }
189                 omxFreeMatrixData(om);
190                 om->data = (double*) Calloc(nrows * ncols, double);
191                 om->localData = TRUE;
192         } else if(om->originalRows * om->originalCols < nrows * ncols) {
193                 warning("Upsizing an existing matrix may cause undefined behavior.\n"); // TODO: Define this behavior?
194         }
195
196         if(OMX_DEBUG_MATRIX) { Rprintf(".\n"); }
197         om->rows = nrows;
198         om->cols = ncols;
199         if(keepMemory == FALSE) {
200                 om->originalRows = om->rows;
201                 om->originalCols = om->cols;
202         }
203
204         omxMatrixCompute(om);
205 }
206
207 void omxResetAliasedMatrix(omxMatrix *om) {
208         om->rows = om->originalRows;
209         om->cols = om->originalCols;
210         if(om->aliasedPtr != NULL) {
211                 omxRecompute(om->aliasedPtr);
212                 memcpy(om->data, om->aliasedPtr->data, om->rows*om->cols*sizeof(double));
213                 om->colMajor = om->aliasedPtr->colMajor;
214         }
215         omxMatrixCompute(om);
216 }
217
218 void omxMatrixCompute(omxMatrix *om) {
219
220         if(OMX_DEBUG_MATRIX) { Rprintf("Matrix compute: 0x%0x, 0x%0x, algebra: 0x%x.\n", om, om->currentState, om->algebra); }
221         om->majority = &(omxMatrixMajorityList[(om->colMajor?1:0)]);
222         om->minority = &(omxMatrixMajorityList[(om->colMajor?0:1)]);
223         om->leading = (om->colMajor?om->rows:om->cols);
224         om->lagging = (om->colMajor?om->cols:om->rows);
225
226         for(int i = 0; i < om->numPopulateLocations; i++) {
227                 omxRecompute(om->populateFrom[i]);                              // Make sure it's up to date
228                 double value = omxMatrixElement(om->populateFrom[i], om->populateFromRow[i], om->populateFromCol[i]);
229                 omxSetMatrixElement(om, om->populateToRow[i], om->populateToCol[i], value);
230                 // And then fill in the details.  Use the Helper here in case of transposition/downsampling.
231         }
232
233         om->isDirty = FALSE;
234         om->lastCompute = om->currentState->computeCount;
235         om->lastRow = om->currentState->currentRow;
236
237 }
238
239 double* omxLocationOfMatrixElement(omxMatrix *om, int row, int col) {
240         int index = 0;
241         if(om->colMajor) {
242                 index = col * om->rows + row;
243         } else {
244                 index = row * om->cols + col;
245         }
246         return om->data + index;
247 }
248
249 double omxVectorElement(omxMatrix *om, int index) {
250         if(index < om->rows * om->cols) {
251                 return om->data[index];
252         } else {
253                 char errstr[250];
254                 sprintf(errstr, "Requested improper index (%d) from (%d, %d) vector.", index, om->rows, om->cols);
255                 error(errstr);
256         return (NA_REAL);
257     }
258 }
259
260 double omxAliasedMatrixElement(omxMatrix *om, int row, int col) {
261         int index = 0;
262         if(row >= om->originalRows || col >= om->originalCols) {
263                 char errstr[250];
264                 sprintf(errstr, "Requested improper value (%d, %d) from (%d, %d) matrix.", row, col, om->originalRows, om->originalCols);
265                 error(errstr);
266         return (NA_REAL);
267         }
268         if(om->colMajor) {
269                 index = col * om->originalRows + row;
270         } else {
271                 index = row * om->originalCols + col;
272         }
273         return om->data[index];
274
275 }
276
277 double omxMatrixElement(omxMatrix *om, int row, int col) {
278         int index = 0;
279         if(row >= om->rows || col >= om->cols) {
280                 char errstr[250];
281                 sprintf(errstr, "Requested improper value (%d, %d) from (%d, %d) matrix.", row, col, om->rows, om->cols);
282                 error(errstr);
283         }
284         if(om->colMajor) {
285                 index = col * om->rows + row;
286         } else {
287                 index = row * om->cols + col;
288         }
289         return om->data[index];
290 }
291
292 void omxSetMatrixElement(omxMatrix *om, int row, int col, double value) {
293         int index = 0;
294         if(om->colMajor) {
295                 index = col * om->rows + row;
296         } else {
297                 index = row * om->cols + col;
298         }
299         om->data[index] = value;
300 }
301
302 void omxMarkDirty(omxMatrix *om) { om->isDirty = TRUE; }
303
304 unsigned short omxMatrixNeedsUpdate(omxMatrix *om) {
305         for(int i = 0; i < om->numPopulateLocations; i++) {
306                 if(omxNeedsUpdate(om->populateFrom[i])) return TRUE;    // Make sure it's up to date
307         }
308     return FALSE;
309 };
310
311 omxMatrix* omxNewMatrixFromMxMatrix(SEXP mxMatrix, omxState* state) {
312 /* Creates and populates an omxMatrix with details from an R Matrix. */
313
314         omxMatrix *om = NULL;
315         om = omxInitMatrix(NULL, 0, 0, FALSE, state);
316         return omxFillMatrixFromMxMatrix(om, mxMatrix, state);
317
318 }
319
320 omxMatrix* omxFillMatrixFromMxMatrix(omxMatrix* om, SEXP mxMatrix, omxState* state) {
321 /* Populates the fields of a omxMatrix with details from an R Matrix. */
322
323         SEXP matrixDims;
324         SEXP matrix = mxMatrix;
325         int* dimList;
326         unsigned short int isMxMatrix = FALSE;
327
328         if(OMX_DEBUG) { Rprintf("Filling omxMatrix from R matrix.\n"); }
329
330         if(om == NULL) {
331                 om = omxInitMatrix(NULL, 0, 0, FALSE, state);
332         }
333
334         if(!isMatrix(mxMatrix) && !isVector(mxMatrix)) { // Sanity Check
335                 if(OMX_DEBUG) { Rprintf("R matrix is an object of some sort.\n"); }
336                 if(inherits(mxMatrix, "MxMatrix")) {
337                         if(OMX_DEBUG) { Rprintf("R matrix is Mx Matrix.  Processing.\n"); }
338                         PROTECT(matrix = GET_SLOT(mxMatrix,  install("values")));
339                         isMxMatrix = TRUE; // So we remember to unprotect.
340                 } else {
341                         error("Recieved unknown matrix type.");
342                 }
343         }
344
345         om->data = REAL(AS_NUMERIC(matrix));    // TODO: Class-check first?
346
347         if(isMatrix(matrix)) {
348                 PROTECT(matrixDims = getAttrib(matrix, R_DimSymbol));
349                 dimList = INTEGER(matrixDims);
350                 om->rows = dimList[0];
351                 om->cols = dimList[1];
352                 UNPROTECT(1);   // MatrixDims
353         } else if (isVector(matrix)) {          // If it's a vector, assume it's a row vector. BLAS doesn't care.
354                 if(OMX_DEBUG) { Rprintf("Vector discovered.  Assuming rowity.\n"); }
355                 om->rows = 1;
356                 om->cols = length(matrix);
357         }
358         if(OMX_DEBUG) { Rprintf("Matrix connected to (%d, %d) mxMatrix.\n", om->rows, om->cols); }
359
360         om->localData = FALSE;
361         om->colMajor = TRUE;
362         om->originalRows = om->rows;
363         om->originalCols = om->cols;
364         om->originalColMajor = TRUE;
365         om->aliasedPtr = NULL;
366         om->algebra = NULL;
367         om->objective = NULL;
368         om->currentState = state;
369         om->lastCompute = -1;
370         om->lastRow = -1;
371
372         if(OMX_DEBUG) { Rprintf("Pre-compute call.\n");}
373         omxMatrixCompute(om);
374         if(OMX_DEBUG) { Rprintf("Post-compute call.\n");}
375
376         if(OMX_DEBUG) {
377                 omxPrintMatrix(om, "Finished importing matrix");
378         }
379
380         if(isMxMatrix) {
381                 UNPROTECT(1); // matrix
382         }
383
384         return om;
385 }
386
387 void omxProcessMatrixPopulationList(omxMatrix* matrix, SEXP matStruct) {
388
389         if(OMX_DEBUG) { Rprintf("Processing Population List: %d elements.\n", length(matStruct) - 1); }
390         SEXP subList;
391
392         if(length(matStruct) > 1) {
393                 int numPopLocs = length(matStruct) - 1;
394                 matrix->numPopulateLocations = numPopLocs;
395                 matrix->populateFrom = (omxMatrix**)R_alloc(numPopLocs, sizeof(omxMatrix*));
396                 matrix->populateFromRow = (int*)R_alloc(numPopLocs, sizeof(int));
397                 matrix->populateFromCol = (int*)R_alloc(numPopLocs, sizeof(int));
398                 matrix->populateToRow = (int*)R_alloc(numPopLocs, sizeof(int));
399                 matrix->populateToCol = (int*)R_alloc(numPopLocs, sizeof(int));
400         }
401
402         for(int i = 0; i < length(matStruct)-1; i++) {
403                 PROTECT(subList = AS_INTEGER(VECTOR_ELT(matStruct, i+1)));
404
405                 int* locations = INTEGER(subList);
406                 int loc = locations[0];
407                 if(OMX_DEBUG) { Rprintf("."); } //:::
408                 if(loc < 0) {                   // NOTE: This duplicates some of the functionality of NewMatrixFromMxIndex
409                         matrix->populateFrom[i] = matrix->currentState->matrixList[(~loc)];
410                 } else {
411                         matrix->populateFrom[i] = matrix->currentState->algebraList[(loc)];
412                 }
413                 matrix->populateFromRow[i] = locations[1];
414                 matrix->populateFromCol[i] = locations[2]; 
415                 matrix->populateToRow[i] = locations[3];
416                 matrix->populateToCol[i] = locations[4];
417
418                 UNPROTECT(1); // subList
419         }
420 }
421
422 void omxRemoveRowsAndColumns(omxMatrix *om, int numRowsRemoved, int numColsRemoved, int rowsRemoved[], int colsRemoved[])
423 {
424 //      if(OMX_DEBUG_MATRIX) { Rprintf("Removing %d rows and %d columns from 0x%0x.\n", numRowsRemoved, numColsRemoved, om);}
425
426         if(numRowsRemoved < 1 && numColsRemoved < 1) { return; }
427
428         int oldRows, oldCols;
429
430         if(om->aliasedPtr == NULL) {
431                 if(om->originalRows == 0 || om->originalCols == 0) {
432                         om->originalRows = om->rows;
433                         om->originalCols = om->cols;
434                 }
435                 oldRows = om->originalRows;
436                 oldCols = om->originalCols;
437         } else {
438                 omxRecompute(om->aliasedPtr);
439                 oldRows = om->aliasedPtr->rows;
440                 oldCols = om->aliasedPtr->cols;
441         }
442
443         int nextCol = 0;
444         int nextRow = 0;
445
446         if(om->rows > om->originalRows || om->cols > om->originalCols) {        // sanity check.
447                 error("Aliased Matrix is too small for alias.");
448         }
449
450         om->rows = oldRows - numRowsRemoved;
451         om->cols = oldCols - numColsRemoved;
452
453         // Note:  This really aught to be done using a matrix multiply.  Why isn't it?
454         for(int j = 0; j < oldCols; j++) {
455                 if(OMX_DEBUG_MATRIX || OMX_DEBUG_ALGEBRA) { Rprintf("Handling column %d/%d...", j, oldCols);}
456                 if(colsRemoved[j]) {
457                         if(OMX_DEBUG_MATRIX || OMX_DEBUG_ALGEBRA) { Rprintf("Removed.\n");}
458                         continue;
459                 } else {
460                         nextRow = 0;
461                         if(OMX_DEBUG_MATRIX || OMX_DEBUG_ALGEBRA) { Rprintf("Rows (max %d): ", oldRows); }
462                         for(int k = 0; k < oldRows; k++) {
463                                 if(rowsRemoved[k]) {
464                                         if(OMX_DEBUG_MATRIX || OMX_DEBUG_ALGEBRA) { Rprintf("%d removed....", k);}
465                                         continue;
466                                 } else {
467                                         if(OMX_DEBUG_MATRIX || OMX_DEBUG_ALGEBRA) { Rprintf("%d kept....", k);}
468                                         if(om->aliasedPtr == NULL) {
469                                                 if(OMX_DEBUG_MATRIX || OMX_DEBUG_ALGEBRA) { Rprintf("Self-aliased matrix access.\n");}
470                                                 omxSetMatrixElement(om, nextRow, nextCol, omxAliasedMatrixElement(om, k, j));
471                                         } else {
472                                                 if(OMX_DEBUG_MATRIX || OMX_DEBUG_ALGEBRA) { Rprintf("Matrix 0x%x re-aliasing to 0x%x.\n", om, om->aliasedPtr);}
473                                                 omxSetMatrixElement(om, nextRow, nextCol, omxMatrixElement(om->aliasedPtr, k,  j));
474                                         }
475                                         if(OMX_DEBUG_MATRIX || OMX_DEBUG_ALGEBRA) {
476                                                 omxPrint(om, "Now Reads: (:::DEBUG:::)");
477                                         }
478                                         nextRow++;
479                                 }
480                         }
481                         if(OMX_DEBUG_MATRIX || OMX_DEBUG_ALGEBRA) { Rprintf("\n");}
482                         nextCol++;
483                 }
484         }
485
486         omxMatrixCompute(om);
487 }
488
489 /* Function wrappers that switch based on inclusion of algebras */
490 void omxPrint(omxMatrix *source, char* d) {                                     // Pretty-print a (small) matrix
491         if(source->algebra != NULL) omxAlgebraPrint(source->algebra, d);
492         else if(source->objective != NULL) omxObjectivePrint(source->objective, d);
493         else omxPrintMatrix(source, d);
494 }
495
496 unsigned short omxNeedsUpdate(omxMatrix *matrix) {
497         unsigned short retval;
498         /* Simplest update check: If we're dirty or haven't computed this cycle (iteration or row), we need to. */
499         if(OMX_DEBUG_MATRIX) {Rprintf("Matrix 0x%x NeedsUpdate?", matrix);}
500         
501         if(matrix == NULL) {
502                 if(OMX_DEBUG_MATRIX) {Rprintf("matrix argument is NULL. ");}
503                 retval = FALSE;         // Not existing means never having to say you need to recompute.
504         } else if(matrix->isDirty) {
505                 if(OMX_DEBUG_MATRIX) {Rprintf("matrix is dirty. ");}
506                 retval = TRUE;
507         } else if(matrix->lastCompute < matrix->currentState->computeCount) {
508                 if(OMX_DEBUG_MATRIX) {Rprintf("matrix last compute is less than current compute count. ");}
509                 retval = TRUE;          // No need to check args if oa's dirty.
510         } else if(matrix->lastRow < matrix->currentState->currentRow) {
511                 if(OMX_DEBUG_MATRIX) {Rprintf("matrix last row is less than current row. ");}
512                 retval = TRUE;                  // Ditto.
513         } else if(matrix->algebra != NULL) {
514                 if(OMX_DEBUG_MATRIX) {Rprintf("checking algebra needs update. ");}
515                 retval = omxAlgebraNeedsUpdate(matrix->algebra);
516         } else if(matrix->objective != NULL) {
517                 if(OMX_DEBUG_MATRIX) {Rprintf("checking objective function needs update. ");}
518                 retval = omxObjectiveNeedsUpdate(matrix->objective);
519         } else {
520                 if(OMX_DEBUG_MATRIX) {Rprintf("checking matrix needs update. ");}
521                 retval = omxMatrixNeedsUpdate(matrix);
522         }
523         if(OMX_DEBUG_MATRIX && retval) {Rprintf("Yes.\n");}
524         if(OMX_DEBUG_MATRIX && !retval) {Rprintf("No.\n");}
525         return(retval);
526 }
527
528 void inline omxRecompute(omxMatrix *matrix) {
529         if(!omxNeedsUpdate(matrix)) return;
530         if(matrix->algebra != NULL) omxAlgebraCompute(matrix->algebra);
531         else if(matrix->objective != NULL) omxObjectiveCompute(matrix->objective);
532         else omxMatrixCompute(matrix);
533 }
534
535 void inline omxCompute(omxMatrix *matrix) {
536         if(matrix->algebra != NULL) omxAlgebraCompute(matrix->algebra);
537         else if(matrix->objective != NULL) omxObjectiveCompute(matrix->objective);
538         else omxMatrixCompute(matrix);
539 }