Remove unused argument from omxResizeMatrix
[openmx:openmx.git] / src / omxMatrix.h
1 /*
2  *  Copyright 2007-2014 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  *
16  */
17
18 /***********************************************************
19  *
20  *  omxMatrix.h
21  *
22  *  Created: Timothy R. Brick   Date: 2008-11-13 12:33:06
23  *
24  *      Contains header information for the omxMatrix class
25  *   omxDataMatrices hold necessary information to simplify
26  *      dealings between the OpenMX back end and BLAS.
27  *
28  **********************************************************/
29
30 #ifndef _OMXMATRIX_H_
31 #define _OMXMATRIX_H_
32
33 #include "omxDefines.h"
34 #include "omxBLAS.h"
35
36 #include "omxAlgebra.h"
37 #include "omxFitFunction.h"
38 #include "omxExpectation.h"
39 #include "omxState.h"
40
41 struct populateLocation {
42         int from;
43         int srcRow, srcCol;
44         int destRow, destCol;
45
46         void transpose() { std::swap(destRow, destCol); }
47 };
48
49 class omxMatrix {
50 /* For inclusion in(or of) other matrices */
51         std::vector< populateLocation > populate;
52  public:
53         void transposePopulate();
54         void omxProcessMatrixPopulationList(SEXP matStruct);
55         bool omxPopulateSubstitutions();
56                                                                                 //TODO: Improve encapsulation
57 /* Actually Useful Members */
58         int rows, cols;                                         // Matrix size  (specifically, its leading edge)
59         double* data;                                           // Actual Data Pointer
60         unsigned short colMajor;                        // used for quick transpose
61         unsigned short hasMatrixNumber;         // is this object in the matrix or algebra arrays?
62         int matrixNumber;                                       // the offset into the matrices or algebras arrays
63
64         SEXP owner;     // The R object owning data or NULL if we own it.
65
66         // size of allocated memory of data pointer
67         int originalRows;
68         int originalCols;
69
70 /* For BLAS Multiplication Speedup */   // TODO: Replace some of these with inlines or macros.
71         const char* majority;                           // Filled by compute(), included for speed
72         const char* minority;                           // Filled by compute(), included for speed
73         int leading;                                            // Leading edge; depends on original majority
74         int lagging;                                            // Non-leading edge.
75
76 /* Curent State */
77         omxState* currentState;                         // Optimizer State
78         int cleanVersion;
79         int version;
80
81 /* For Algebra Functions */                             // At most, one of these may be non-NULL.
82         omxAlgebra* algebra;                            // If it's not an algebra, this is NULL.
83         omxFitFunction* fitFunction;            // If it's not a fit function, this is NULL.
84
85         const char* name;
86
87         // Currently, this is only used by BA81 expectations to deal with
88         // equality constraints among latent distribution parameters.
89         // This should really be a vector because more than one expectation
90         // can "own" the same matrix. However, we can't use nice C++
91         // std::vector here until the allocation model of omxMatrix
92         // is cleaned up. Currently, we allocate omxMatrix from both
93         // R and the regular C allocator.
94         struct omxExpectation *expectation;       // weak reference
95
96         friend void omxCopyMatrix(omxMatrix *dest, omxMatrix *src);  // turn into method later TODO
97 };
98
99 // If you call these functions directly then you need to free the memory with omxFreeMatrix.
100 // If you obtain a matrix from omxNewMatrixFromSlot then you must NOT free it.
101 omxMatrix* omxInitMatrix(int nrows, int ncols, unsigned short colMajor, omxState* os);
102
103         void omxFreeMatrix(omxMatrix* om);                                              // Ditto, traversing argument trees
104
105 /* Matrix Creation Functions */
106         omxMatrix* omxNewMatrixFromRPrimitive(SEXP rObject, omxState *state,
107         unsigned short hasMatrixNumber, int matrixNumber);                                                      // Create an omxMatrix from an R object
108         omxMatrix* omxNewIdentityMatrix(int nrows, omxState* state);                            // Creates an Identity Matrix of a given size
109         extern omxMatrix* omxMatrixLookupFromState1(SEXP matrix, omxState* os); // Create a matrix/algebra from a matrix pointer
110
111         omxMatrix* omxDuplicateMatrix(omxMatrix* src, omxState* newState);
112         SEXP omxExportMatrix(omxMatrix *om);
113
114 /* Getters 'n Setters (static functions declared below) */
115         // static OMXINLINE double omxMatrixElement(omxMatrix *om, int row, int col);
116         // static OMXINLINE double omxVectorElement(omxMatrix *om, int index);
117         // static OMXINLINE void omxSetMatrixElement(omxMatrix *om, int row, int col, double value);
118         // static OMXINLINE void omxSetVectorElement(omxMatrix *om, int index, double value);
119
120         double* omxLocationOfMatrixElement(omxMatrix *om, int row, int col);
121         void omxMarkDirty(omxMatrix *om);
122         void omxMarkClean(omxMatrix *om);
123
124 /* Matrix Modification Functions */
125         void omxZeroByZeroMatrix(omxMatrix *source);
126 void omxResizeMatrix(omxMatrix *source, int nrows, int ncols);
127         omxMatrix* omxFillMatrixFromRPrimitive(omxMatrix* om, SEXP rObject, omxState *state,
128                 unsigned short hasMatrixNumber, int matrixNumber);                                                              // Populate an omxMatrix from an R object
129         void omxTransposeMatrix(omxMatrix *mat);                                                                                                // Transpose a matrix in place.
130         void omxToggleRowColumnMajor(omxMatrix *mat);                                                                           // Transform row-major into col-major and vice versa 
131
132 /* Function wrappers that switch based on inclusion of algebras */
133         void omxPrint(omxMatrix *source, const char* d);
134         unsigned short int omxNeedsUpdate(omxMatrix *matrix);                                                           // Does this need to be recomputed?
135 void omxInitialCompute(omxMatrix *matrix);
136         void omxRecompute(omxMatrix *matrix);                                                                                           // Recompute the matrix if needed.
137         void omxForceCompute(omxMatrix *matrix);
138
139 void omxRemoveElements(omxMatrix *om, int numRemoved, int removed[]);
140 void omxRemoveRowsAndColumns(omxMatrix* om, int numRowsRemoved, int numColsRemoved, int rowsRemoved[], int colsRemoved[]);
141
142 /* Matrix-Internal Helper functions */
143         void omxMatrixLeadingLagging(omxMatrix *matrix);
144 void omxPrintMatrix(omxMatrix *source, const char* header);
145
146 /* OMXINLINE functions and helper functions */
147
148 void setMatrixError(omxMatrix *om, int row, int col, int numrow, int numcol);
149 void setVectorError(int index, int numrow, int numcol);
150 void matrixElementError(int row, int col, int numrow, int numcol);
151 void vectorElementError(int index, int numrow, int numcol);
152
153 OMXINLINE static bool omxMatrixIsDirty(omxMatrix *om) { return om->cleanVersion != om->version; }
154 OMXINLINE static bool omxMatrixIsClean(omxMatrix *om) { return om->cleanVersion == om->version; }
155 OMXINLINE static int omxGetMatrixVersion(omxMatrix *om) { return om->version; }
156
157 static OMXINLINE int omxIsMatrix(omxMatrix *mat) {
158     return (mat->algebra == NULL && mat->fitFunction == NULL);
159 }
160
161 /* BLAS Wrappers */
162
163 static OMXINLINE void omxSetMatrixElement(omxMatrix *om, int row, int col, double value) {
164         if((row < 0) || (col < 0) || (row >= om->rows) || (col >= om->cols)) {
165                 setMatrixError(om, row + 1, col + 1, om->rows, om->cols);
166                 return;
167         }
168         int index = 0;
169         if(om->colMajor) {
170                 index = col * om->rows + row;
171         } else {
172                 index = row * om->cols + col;
173         }
174         om->data[index] = value;
175 }
176
177 static OMXINLINE void omxAccumulateMatrixElement(omxMatrix *om, int row, int col, double value) {
178         if((row < 0) || (col < 0) || (row >= om->rows) || (col >= om->cols)) {
179                 setMatrixError(om, row + 1, col + 1, om->rows, om->cols);
180                 return;
181         }
182         int index = 0;
183         if(om->colMajor) {
184                 index = col * om->rows + row;
185         } else {
186                 index = row * om->cols + col;
187         }
188         om->data[index] += value;
189 }
190
191 static OMXINLINE double omxMatrixElement(omxMatrix *om, int row, int col) {
192         int index = 0;
193         if((row < 0) || (col < 0) || (row >= om->rows) || (col >= om->cols)) {
194                 matrixElementError(row + 1, col + 1, om->rows, om->cols);
195         return (NA_REAL);
196         }
197         if(om->colMajor) {
198                 index = col * om->rows + row;
199         } else {
200                 index = row * om->cols + col;
201         }
202         return om->data[index];
203 }
204
205 static OMXINLINE double *omxMatrixColumn(omxMatrix *om, int col) {
206   if (!om->colMajor) Rf_error("omxMatrixColumn requires colMajor order");
207   if (col < 0 || col >= om->cols) Rf_error(0, col, om->rows, om->cols);
208   return om->data + col * om->rows;
209 }
210
211 static OMXINLINE void omxAccumulateVectorElement(omxMatrix *om, int index, double value) {
212         if (index < 0 || index >= (om->rows * om->cols)) {
213                 setVectorError(index + 1, om->rows, om->cols);
214                 return;
215         } else {
216                 om->data[index] += value;
217     }
218 }
219
220 static OMXINLINE void omxSetVectorElement(omxMatrix *om, int index, double value) {
221         if (index < 0 || index >= (om->rows * om->cols)) {
222                 setVectorError(index + 1, om->rows, om->cols);
223                 return;
224         } else {
225                 om->data[index] = value;
226     }
227 }
228
229 static OMXINLINE double omxVectorElement(omxMatrix *om, int index) {
230         if (index < 0 || index >= (om->rows * om->cols)) {
231                 vectorElementError(index + 1, om->rows, om->cols);
232         return (NA_REAL);
233         } else {
234                 return om->data[index];
235     }
236 }
237
238 static OMXINLINE void omxUnsafeSetVectorElement(omxMatrix *om, int index, double value) {
239         om->data[index] = value;
240 }
241
242 static OMXINLINE double omxUnsafeVectorElement(omxMatrix *om, int index) {
243         return om->data[index];
244 }
245
246
247 static OMXINLINE void omxDGEMM(unsigned short int transposeA, unsigned short int transposeB,            // result <- alpha * A %*% B + beta * C
248                                 double alpha, omxMatrix* a, omxMatrix *b, double beta, omxMatrix* result) {
249         int nrow = (transposeA?a->cols:a->rows);
250         int nmid = (transposeA?a->rows:a->cols);
251         int ncol = (transposeB?b->rows:b->cols);
252
253         F77_CALL(omxunsafedgemm)((transposeA?a->minority:a->majority), (transposeB?b->minority:b->majority), 
254                                                         &(nrow), &(ncol), &(nmid),
255                                                         &alpha, a->data, &(a->leading), 
256                                                         b->data, &(b->leading),
257                                                         &beta, result->data, &(result->leading));
258
259         if(!result->colMajor) omxToggleRowColumnMajor(result);
260 }
261
262 static OMXINLINE void omxDGEMV(unsigned short int transposeMat, double alpha, omxMatrix* mat,   // result <- alpha * A %*% B + beta * C
263                                 omxMatrix* vec, double beta, omxMatrix*result) {                                                        // where B is treated as a vector
264         int onei = 1;
265         int nrows = mat->rows;
266         int ncols = mat->cols;
267         if(OMX_DEBUG_DEVELOPER) {
268                 int nVecEl = vec->rows * vec->cols;
269                 // mxLog("DGEMV: %c, %d, %d, %f, 0x%x, %d, 0x%x, %d, 0x%x, %d\n", *(transposeMat?mat->minority:mat->majority), (nrows), (ncols), 
270                 // alpha, mat->data, (mat->leading), vec->data, onei, beta, result->data, onei); //:::DEBUG:::
271                 if((transposeMat && nrows != nVecEl) || (!transposeMat && ncols != nVecEl)) {
272                         Rf_error("Mismatch in vector/matrix multiply: %s (%d x %d) * (%d x 1).\n", (transposeMat?"transposed":""), mat->rows, mat->cols, nVecEl); // :::DEBUG:::
273                 }
274         }
275         F77_CALL(omxunsafedgemv)((transposeMat?mat->minority:mat->majority), &(nrows), &(ncols), 
276                 &alpha, mat->data, &(mat->leading), vec->data, &onei, &beta, result->data, &onei);
277         if(!result->colMajor) omxToggleRowColumnMajor(result);
278 }
279
280 static OMXINLINE void omxDSYMV(double alpha, omxMatrix* mat,            // result <- alpha * A %*% B + beta * C
281                                 omxMatrix* vec, double beta, omxMatrix* result) {       // only A is symmetric, and B is a vector
282         char u='U';
283     int onei = 1;
284
285         if(OMX_DEBUG_DEVELOPER) {
286                 int nVecEl = vec->rows * vec->cols;
287                 // mxLog("DSYMV: %c, %d, %f, 0x%x, %d, 0x%x, %d, %f, 0x%x, %d\n", u, (mat->cols),alpha, mat->data, (mat->leading), 
288                             // vec->data, onei, beta, result->data, onei); //:::DEBUG:::
289                 if(mat->cols != nVecEl) {
290                         Rf_error("Mismatch in symmetric vector/matrix multiply: %s (%d x %d) * (%d x 1).\n", "symmetric", mat->rows, mat->cols, nVecEl); // :::DEBUG:::
291                 }
292         }
293
294     F77_CALL(dsymv)(&u, &(mat->cols), &alpha, mat->data, &(mat->leading), 
295                     vec->data, &onei, &beta, result->data, &onei);
296
297     // if(!result->colMajor) omxToggleRowColumnMajor(result);
298 }
299
300 static OMXINLINE void omxDSYMM(unsigned short int symmOnLeft, double alpha, omxMatrix* symmetric,               // result <- alpha * A %*% B + beta * C
301                                 omxMatrix *other, double beta, omxMatrix* result) {                                 // One of A or B is symmetric
302
303         char r='R', l = 'L';
304         char u='U';
305         F77_CALL(dsymm)((symmOnLeft?&l:&r), &u, &(result->rows), &(result->cols),
306                                         &alpha, symmetric->data, &(symmetric->leading),
307                                         other->data, &(other->leading),
308                                         &beta, result->data, &(result->leading));
309
310         if(!result->colMajor) omxToggleRowColumnMajor(result);
311 }
312
313 static OMXINLINE void omxDAXPY(double alpha, omxMatrix* lhs, omxMatrix* rhs) {              // RHS += alpha*lhs  
314     // N.B.  Not fully tested.                                                              // Assumes common majority or vectordom.
315     if(lhs->colMajor != rhs->colMajor) { omxToggleRowColumnMajor(rhs);}
316     int len = lhs->rows * lhs->cols;
317     int onei = 1;
318     F77_CALL(daxpy)(&len, &alpha, lhs->data, &onei, rhs->data, &onei);
319
320 }
321
322 static OMXINLINE double omxDDOT(omxMatrix* lhs, omxMatrix* rhs) {              // returns dot product, as if they were vectors
323     // N.B.  Not fully tested.                                                  // Assumes common majority or vectordom.
324     if(lhs->colMajor != rhs->colMajor) { omxToggleRowColumnMajor(rhs);}
325     int len = lhs->rows * lhs->cols;
326     int onei = 1;
327     return(F77_CALL(ddot)(&len, lhs->data, &onei, rhs->data, &onei));
328 }
329
330 static OMXINLINE void omxDPOTRF(omxMatrix* mat, int* info) {                                                                            // Cholesky decomposition of mat
331         // TODO: Add Rf_error checking, and/or adjustments for row vs. column majority.
332         // N.B. Not fully tested.
333         char u = 'U'; //l = 'L'; //U for storing upper triangle
334         F77_CALL(dpotrf)(&u, &(mat->rows), mat->data, &(mat->cols), info);
335 }
336 static OMXINLINE void omxDPOTRI(omxMatrix* mat, int* info) {                                                                            // Invert mat from Cholesky
337         // TODO: Add Rf_error checking, and/or adjustments for row vs. column majority.
338         // N.B. Not fully tested.
339         char u = 'U'; //l = 'L'; // U for storing upper triangle
340         F77_CALL(dpotri)(&u, &(mat->rows), mat->data, &(mat->cols), info);
341 }
342
343 void omxShallowInverse(FitContext *fc, int numIters, omxMatrix* A, omxMatrix* Z, omxMatrix* Ax, omxMatrix* I );
344
345 double omxMaxAbsDiff(omxMatrix *m1, omxMatrix *m2);
346
347 OMXINLINE static int
348 triangleLoc1(int diag)
349 {
350         //if (diag < 1) error("Out of domain");
351         return (diag) * (diag+1) / 2;   // 0 1 3 6 10 15 ..
352 }
353
354 OMXINLINE static int
355 triangleLoc0(int diag)
356 {
357         //if (diag < 0) error("Out of domain");
358         return triangleLoc1(diag+1) - 1;  // 0 2 5 9 14 ..
359 }
360
361 #endif /* _OMXMATRIX_H_ */