Added several omx-version BLAS functions to better encapsulate omxMatrix class. ...
[openmx:openmx.git] / src / omxMatrix.h
1 /*
2  *  Copyright 2007-2009 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  *
16  */
17
18 /***********************************************************
19  *
20  *  omxMatrix.h
21  *
22  *  Created: Timothy R. Brick   Date: 2008-11-13 12:33:06
23  *
24  *      Contains header information for the omxMatrix class
25  *   omxDataMatrices hold necessary information to simplify
26  *      dealings between the OpenMX back end and BLAS.
27  *
28  **********************************************************/
29
30 #ifndef _OMXMATRIX_H_
31 #define _OMXMATRIX_H_
32
33 #include "R.h"
34 #include <Rinternals.h>
35 #include <Rdefines.h>
36 #include <R_ext/Rdynload.h>
37 #include <R_ext/BLAS.h>
38 #include <R_ext/Lapack.h>
39 #include "omxDefines.h"
40 #include "omxBLAS.h"
41
42 typedef struct omxMatrix omxMatrix;
43
44 #include "omxAlgebra.h"
45 #include "omxObjective.h"
46 #include "omxState.h"
47
48
49 struct omxMatrix {                                              // A matrix
50                                                                                 //TODO: Improve encapsulation
51 /* Actually Useful Members */
52         int rows, cols;                                         // Matrix size  (specifically, its leading edge)
53         double* data;                                           // Actual Data Pointer
54         unsigned short colMajor;                        // and column-majority.
55
56 /* For Memory Administrivia */
57         unsigned short localData;                       // If data has been malloc'd, and must be freed.
58
59 /* For aliased matrices */                              // Maybe this should be a subclass, as well.
60         omxMatrix* aliasedPtr;                          // For now, assumes outside data if aliased.
61         unsigned short originalColMajor;        // Saved for reset of aliased matrix.
62         unsigned short originalRows;            // Saved for reset of aliased matrix.
63         unsigned short originalCols;            // Saved for reset of aliased matrix.
64
65 /* For BLAS Multiplication Speedup */   // TODO: Replace some of these with inlines or macros.
66         const char* majority;                           // Filled by compute(), included for speed
67         const char* minority;                           // Filled by compute(), included for speed
68         int leading;                                            // Leading edge; depends on original majority
69         int lagging;                                            // Non-leading edge.
70
71 /* Curent State */
72         omxState* currentState;                         // Optimizer State
73         unsigned short isDirty;                         // Retained, for historical purposes.
74         unsigned short isTemporary;                     // Whether or not to destroy the omxMatrix Structure when omxFreeAllMatrixData is called.
75         int lastCompute;                                        // Compute Count Number at last computation
76         int lastRow;                                            // Compute Count Number at last row update (Used for row-by-row computation only)
77
78 /* For Algebra Functions */                             // At most, one of these may be non-NULL.
79         omxAlgebra* algebra;                            // If it's not an algebra, this is NULL.
80         omxObjective* objective;                        // If it's not an objective function, this is NULL.
81
82 /* For inclusion in(or of) other matrices */
83         int numPopulateLocations;
84         omxMatrix** populateFrom;
85         int *populateFromRow, *populateFromCol;
86         int *populateToRow, *populateToCol;
87
88 };
89
90 /* Initialize and Destroy */
91         omxMatrix* omxInitMatrix(omxMatrix* om, int nrows, int ncols, unsigned short colMajor, omxState* os);                   // Set up matrix 
92         omxMatrix* omxInitTemporaryMatrix(omxMatrix* om, int nrows, int ncols, unsigned short colMajor, omxState* os);  // Set up matrix that can be freed
93         void omxFreeMatrixData(omxMatrix* om);                                                  // Release any held data.
94         void omxFreeAllMatrixData(omxMatrix* om);                                               // Ditto, traversing argument trees
95
96 /* Matrix Creation Functions */
97         omxMatrix* omxNewMatrixFromMxMatrix(SEXP matrix, omxState *state);                      // Create an omxMatrix from an R MxMatrix
98         omxMatrix* omxNewMatrixFromRPrimitive(SEXP rObject, omxState *state);                   // Create an omxMatrix from an R object
99         omxMatrix* omxNewIdentityMatrix(int nrows, omxState* state);                            // Creates an Identity Matrix of a given size
100         extern omxMatrix* omxNewMatrixFromMxIndex(SEXP matrix, omxState* os);   // Create a matrix/algebra from a matrix pointer
101         extern omxMatrix* omxNewMatrixFromIndexSlot(SEXP rObj, omxState* state, char* const slotName);  // Gets a matrix from an R SEXP slot
102     omxMatrix* omxDuplicateMatrix(omxMatrix* tgt, omxMatrix* src, omxState* newState, short fullCopy);
103
104 /* Getters 'n Setters (static functions declared below) */
105         // static inline double omxMatrixElement(omxMatrix *om, int row, int col);
106         // static inline double omxVectorElement(omxMatrix *om, int index);
107         // static inline void omxSetMatrixElement(omxMatrix *om, int row, int col, double value);
108         // static inline void omxSetVectorElement(omxMatrix *om, int index, double value);
109
110         double omxAliasedMatrixElement(omxMatrix *om, int row, int col);                        // Element from unaliased form of the same matrix
111         double* omxLocationOfMatrixElement(omxMatrix *om, int row, int col);
112         void omxMarkDirty(omxMatrix *om);
113
114 /* Matrix Modification Functions */
115         void omxZeroByZeroMatrix(omxMatrix *source);
116         void omxResizeMatrix(omxMatrix *source, int nrows, int ncols,
117                                                         unsigned short keepMemory);                                                                     // Resize, with or without re-initialization
118         omxMatrix* omxFillMatrixFromMxMatrix(omxMatrix* om, SEXP matrix, omxState *state);      // Populate an omxMatrix from an R MxMatrix
119         omxMatrix* omxFillMatrixFromRPrimitive(omxMatrix* om, SEXP rObject, omxState *state);   // Populate an omxMatrix from an R object
120         void omxProcessMatrixPopulationList(omxMatrix *matrix, SEXP matStruct);
121         void omxCopyMatrix(omxMatrix *dest, omxMatrix *src);                                                            // Copy across another matrix.
122         void omxTransposeMatrix(omxMatrix *mat);                                                                                        // Transpose a matrix in place.
123         void omxToggleRowColumnMajor(omxMatrix *mat);                                                                           // Transform row-major into col-major and vice versa 
124
125 /* Function wrappers that switch based on inclusion of algebras */
126         void omxPrint(omxMatrix *source, char* d);                                                                                      // Pretty-print a (small) matrix
127         unsigned short int omxNeedsUpdate(omxMatrix *matrix);                                                           // Does this need to be recomputed?
128         void omxRecompute(omxMatrix *matrix);                                                                                           // Recompute the matrix if needed.
129         void omxCompute(omxMatrix *matrix);                                                                                                     // Recompute the matrix no matter what.
130
131 /* Aliased Matrix Functions */
132         void omxAliasMatrix(omxMatrix *alias, omxMatrix* const source);         // Allows aliasing for faster reset.
133         void omxResetAliasedMatrix(omxMatrix *matrix);                                          // Reset to the original matrix
134         void omxRemoveRowsAndColumns(omxMatrix* om, int numRowsRemoved, int numColsRemoved, int rowsRemoved[], int colsRemoved[]);
135
136 /* Matrix-Internal Helper functions */
137         void omxMatrixCompute(omxMatrix *matrix);
138         void omxPrintMatrix(omxMatrix *source, char* d);                    // Pretty-print a (small) matrix
139         unsigned short int omxMatrixNeedsUpdate(omxMatrix *matrix);
140
141 /* Inline functions and helper functions */
142
143 void setMatrixError(int row, int col, int numrow, int numcol);
144 void setVectorError(int index, int numrow, int numcol);
145 void matrixElementError(int row, int col, int numrow, int numcol);
146 void vectorElementError(int index, int numrow, int numcol);
147
148
149 /* BLAS Wrappers */
150
151 static inline void omxSetMatrixElement(omxMatrix *om, int row, int col, double value) {
152         if(row >= om->rows || col >= om->cols) {
153                 setMatrixError(row + 1, col + 1, om->rows, om->cols);
154         }
155         int index = 0;
156         if(om->colMajor) {
157                 index = col * om->rows + row;
158         } else {
159                 index = row * om->cols + col;
160         }
161         om->data[index] = value;
162 }
163
164 static inline double omxMatrixElement(omxMatrix *om, int row, int col) {
165         int index = 0;
166         if(row >= om->rows || col >= om->cols) {
167                 matrixElementError(row + 1, col + 1, om->rows, om->cols);
168         }
169         if(om->colMajor) {
170                 index = col * om->rows + row;
171         } else {
172                 index = row * om->cols + col;
173         }
174         return om->data[index];
175 }
176
177 static inline void omxSetVectorElement(omxMatrix *om, int index, double value) {
178         if(index < om->rows * om->cols) {
179                 om->data[index] = value;
180         } else {
181                 setVectorError(index, om->rows, om->cols);
182     }
183 }
184
185 static inline double omxVectorElement(omxMatrix *om, int index) {
186         if(index < om->rows * om->cols) {
187                 return om->data[index];
188         } else {
189                 vectorElementError(index, om->rows, om->cols);
190         return (NA_REAL);
191     }
192 }
193
194 static inline void omxDGEMM(unsigned short int transposeA, unsigned short int transposeB,               // result <- alpha * A %*% B + beta * C
195                                 double alpha, omxMatrix* a, omxMatrix *b, double beta, omxMatrix* result) {
196         int nrow = (transposeA?a->cols:a->rows);
197         int nmid = (transposeA?a->rows:a->cols);
198         int ncol = (transposeB?b->rows:b->cols);
199         F77_CALL(omxunsafedgemm)((transposeA?a->minority:a->majority), (transposeB?b->minority:b->majority), 
200                                                         &(nrow), &(ncol), &(nmid),
201                                                         &alpha, a->data, &(a->leading), 
202                                                         b->data, &(b->leading),
203                                                         &beta, result->data, &(result->leading));
204
205         if(!result->colMajor) omxToggleRowColumnMajor(result);
206 }
207
208 static inline void omxDGEMV(unsigned short int transposeMat, double alpha, omxMatrix* mat,      // result <- alpha * A %*% B + beta * C
209                                 omxMatrix* vec, double beta, omxMatrix*result) {                                                        // where B is treated as a vector
210         int onei = 1;
211         int nrows = (transposeMat?mat->cols:mat->rows);
212         int ncols = (transposeMat?mat->rows:mat->cols);
213         F77_CALL(omxunsafedgemv)((transposeMat?mat->minority:mat->majority), &(nrows), &(ncols), &alpha, mat->data, &(mat->leading), vec->data, &onei, &beta, result->data, &onei);
214         if(!result->colMajor) omxToggleRowColumnMajor(result);
215 }
216
217 static inline void omxDSYMV(unsigned short int transposeMat, double* alpha, omxMatrix* mat,     // result <- alpha * A %*% B + beta * C
218                                 omxMatrix* vec, double* beta, omxMatrix*result){                                                        // only A is symmetric, and B is a vector
219
220         if(!result->colMajor) omxToggleRowColumnMajor(result);
221 }
222
223 static inline void omxDSYMM(unsigned short int symmOnLeft, double alpha, omxMatrix* a,          // result <- alpha * A %*% B + beta * C
224                                 omxMatrix *b, double beta, omxMatrix* result, unsigned short int fillMat) {     // One of A or B is symmetric
225
226         char r='R', l = 'L';
227         char u='U';
228         F77_CALL(dsymm)((symmOnLeft?&l:&r), &u, &(result->rows), &(result->cols),
229                                         &alpha, a->data, &(a->leading),
230                                         b->data, &(b->leading),
231                                         &beta, result->data, &(result->leading));
232
233         if(!result->colMajor) omxToggleRowColumnMajor(result);
234         
235         if(fillMat) {
236                 for(int j = 0; j < result->rows; j++)
237                         for(int k = j+1; k < result->cols; k++)
238                                 omxSetMatrixElement(result, j, k, omxMatrixElement(result, k, j));
239         }
240 }
241
242 static inline int omxDGETRF(omxMatrix* mat, int* ipiv) {                                                                                // LUP decomposition of mat
243         int info = 0;
244         F77_CALL(dgetrf)(&(mat->rows), &(mat->cols), mat->data, &(mat->leading), ipiv, &info);
245         return info;
246 }
247
248 static inline int omxDGETRI(omxMatrix* mat, int* ipiv, double* work, int lwork) {                               // Invert mat from LUP decomposition
249         int info = 0;
250         F77_CALL(dgetri)(&(mat->rows), mat->data, &(mat->leading), ipiv, work, &lwork, &info);
251         return info;
252 }
253
254 static inline void omxDPOTRF(omxMatrix* mat, int* info) {                                                                               // Cholesky decomposition of mat
255         // ERROR: NYI.
256 }
257 static inline void omxDPOTRI(omxMatrix* mat, int* info) {                                                                               // Invert mat from Cholesky
258         // ERROR: NYI
259 }
260
261
262 #endif /* _OMXMATRIX_H_ */