Allow ComputeIterate to test maximum absolute change
[openmx:openmx.git] / src / omxWLSFitFunction.cpp
1  /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include <R.h>
18 #include <Rinternals.h>
19 #include <Rdefines.h>
20 #include <R_ext/Rdynload.h>
21 #include <R_ext/BLAS.h>
22 #include <R_ext/Lapack.h>
23 #include "omxAlgebraFunctions.h"
24 #include "omxWLSFitFunction.h"
25
26 void flattenDataToVector(omxMatrix* cov, omxMatrix* means, omxThresholdColumn* thresholds, int nThresholds, omxMatrix* vector) {
27     // TODO: vectorize data flattening
28     // if(OMX_DEBUG) { mxLog("Flattening out data vectors: cov 0x%x, mean 0x%x, thresh 0x%x[n=%d] ==> 0x%x", 
29     //         cov, means, thresholds, nThresholds, vector); }
30     
31     int nextLoc = 0;
32     for(int j = 0; j < cov->rows; j++) {
33         for(int k = 0; k <= j; k++) {
34             omxSetVectorElement(vector, nextLoc, omxMatrixElement(cov, k, j)); // Use upper triangle in case of SYMM-style mat.
35             nextLoc++;
36         }
37     }
38     if (means != NULL) {
39         for(int j = 0; j < cov->rows; j++) {
40             omxSetVectorElement(vector, nextLoc, omxVectorElement(means, j));
41             nextLoc++;
42         }
43     }
44     if (thresholds != NULL) {
45         for(int j = 0; j < nThresholds; j++) {
46             omxThresholdColumn* thresh = thresholds + j;
47             for(int k = 0; k < thresh->numThresholds; k++) {
48                 omxSetVectorElement(vector, nextLoc, omxMatrixElement(thresh->matrix, k, thresh->column));
49                 nextLoc++;
50             }
51         }
52     }
53 }
54
55 void omxDestroyWLSFitFunction(omxFitFunction *oo) {
56
57         if(OMX_DEBUG) {mxLog("Freeing WLS FitFunction.");}
58     if(oo->argStruct == NULL) return;
59
60         omxWLSFitFunction* owo = ((omxWLSFitFunction*)oo->argStruct);
61     omxFreeMatrixData(owo->observedFlattened);
62     omxFreeMatrixData(owo->expectedFlattened);
63     omxFreeMatrixData(owo->weights);
64     omxFreeMatrixData(owo->B);
65     omxFreeMatrixData(owo->P);
66 }
67
68 static void omxCallWLSFitFunction(omxFitFunction *oo, int want, FitContext *) {
69         if (want & (FF_COMPUTE_PREOPTIMIZE)) return;
70
71         if(OMX_DEBUG) { mxLog("Beginning WLS Evaluation.");}
72         // Requires: Data, means, covariances.
73
74         double sum = 0.0;
75
76         omxMatrix *oCov, *oMeans, *eCov, *eMeans, *P, *B, *weights, *oFlat, *eFlat;
77         
78     omxThresholdColumn *oThresh, *eThresh;
79
80         omxWLSFitFunction *owo = ((omxWLSFitFunction*)oo->argStruct);
81         
82     /* Locals for readability.  Compiler should cut through this. */
83         oCov            = owo->observedCov;
84         oMeans          = owo->observedMeans;
85         oThresh         = owo->observedThresholds;
86         eCov            = owo->expectedCov;
87         eMeans          = owo->expectedMeans;
88         eThresh         = owo->expectedThresholds;
89         oFlat           = owo->observedFlattened;
90         eFlat           = owo->expectedFlattened;
91         weights         = owo->weights;
92         B                       = owo->B;
93         P                       = owo->P;
94     int nThresh = owo->nThresholds;
95     int onei    = 1;
96         
97         omxExpectation* expectation = oo->expectation;
98
99     /* Recompute and recopy */
100         if(OMX_DEBUG) { mxLog("WLSFitFunction Computing expectation"); }
101         omxExpectationCompute(expectation, NULL);
102
103     // TODO: Flatten data only once.
104         flattenDataToVector(oCov, oMeans, oThresh, nThresh, oFlat);
105         flattenDataToVector(eCov, eMeans, eThresh, nThresh, eFlat);
106
107         omxCopyMatrix(B, oFlat);
108
109         omxDAXPY(-1.0, eFlat, B);
110         
111     if(weights != NULL) {
112         omxDGEMV(TRUE, 1.0, weights, B, 0.0, P);
113     } else {
114         // ULS Case: Memcpy faster than dgemv.
115         // TODO: Better to use an omxMatrix duplicator here.
116         memcpy(P, B, B->cols*sizeof(double));
117     }
118
119     sum = F77_CALL(ddot)(&(P->cols), P->data, &onei, B->data, &onei);
120
121     oo->matrix->data[0] = sum;
122
123         if(OMX_DEBUG) { mxLog("WLSFitFunction value comes to: %f.", oo->matrix->data[0]); }
124
125 }
126
127 void omxPopulateWLSAttributes(omxFitFunction *oo, SEXP algebra) {
128     if(OMX_DEBUG) { mxLog("Populating WLS Attributes."); }
129
130         omxWLSFitFunction *argStruct = ((omxWLSFitFunction*)oo->argStruct);
131         omxMatrix *expCovInt = argStruct->expectedCov;                  // Expected covariance
132         omxMatrix *expMeanInt = argStruct->expectedMeans;                       // Expected means
133         omxMatrix *weightInt = argStruct->weights;                      // Expected means
134
135         SEXP expCovExt, expMeanExt, weightExt, gradients;
136         PROTECT(expCovExt = allocMatrix(REALSXP, expCovInt->rows, expCovInt->cols));
137         for(int row = 0; row < expCovInt->rows; row++)
138                 for(int col = 0; col < expCovInt->cols; col++)
139                         REAL(expCovExt)[col * expCovInt->rows + row] =
140                                 omxMatrixElement(expCovInt, row, col);
141
142         if (expMeanInt != NULL) {
143                 PROTECT(expMeanExt = allocMatrix(REALSXP, expMeanInt->rows, expMeanInt->cols));
144                 for(int row = 0; row < expMeanInt->rows; row++)
145                         for(int col = 0; col < expMeanInt->cols; col++)
146                                 REAL(expMeanExt)[col * expMeanInt->rows + row] =
147                                         omxMatrixElement(expMeanInt, row, col);
148         } else {
149                 PROTECT(expMeanExt = allocMatrix(REALSXP, 0, 0));               
150         }
151         
152         PROTECT(weightExt = allocMatrix(REALSXP, weightInt->rows, weightInt->cols));
153         for(int row = 0; row < weightInt->rows; row++)
154                 for(int col = 0; col < weightInt->cols; col++)
155                         REAL(weightExt)[col * weightInt->rows + row] =
156                                 omxMatrixElement(weightInt, row, col);
157         
158         
159         if(0) {  /* TODO fix for new internal API
160                 int nLocs = Global->numFreeParams;
161                 double gradient[Global->numFreeParams];
162                 for(int loc = 0; loc < nLocs; loc++) {
163                         gradient[loc] = NA_REAL;
164                 }
165                 //oo->gradientFun(oo, gradient);
166                 PROTECT(gradients = allocMatrix(REALSXP, 1, nLocs));
167
168                 for(int loc = 0; loc < nLocs; loc++)
169                         REAL(gradients)[loc] = gradient[loc];
170                  */
171         } else {
172                 PROTECT(gradients = allocMatrix(REALSXP, 0, 0));
173         }
174     
175         setAttrib(algebra, install("expCov"), expCovExt);
176         setAttrib(algebra, install("expMean"), expMeanExt);
177         setAttrib(algebra, install("weights"), weightExt);
178         setAttrib(algebra, install("gradients"), gradients);
179         
180         setAttrib(algebra, install("SaturatedLikelihood"), ScalarReal(0));
181         setAttrib(algebra, install("IndependenceLikelihood"), ScalarReal(0));
182         setAttrib(algebra, install("ADFMisfit"), ScalarReal(omxMatrixElement(oo->matrix, 0, 0)));
183
184         UNPROTECT(4);
185 }
186
187 void omxSetWLSFitFunctionCalls(omxFitFunction* oo) {
188         
189         /* Set FitFunction Calls to WLS FitFunction Calls */
190         oo->computeFun = omxCallWLSFitFunction;
191         oo->destructFun = omxDestroyWLSFitFunction;
192         oo->populateAttrFun = omxPopulateWLSAttributes;
193 }
194
195 void omxInitWLSFitFunction(omxFitFunction* oo) {
196     
197         omxMatrix *cov, *means, *weights;
198         
199     if(OMX_DEBUG) { mxLog("Initializing WLS FitFunction function."); }
200         
201     int vectorSize = 0;
202         
203         omxSetWLSFitFunctionCalls(oo);
204         
205         if(OMX_DEBUG) { mxLog("Retrieving expectation.\n"); }
206         if (!oo->expectation) { error("%s requires an expectation", oo->fitType); }
207         
208         if(OMX_DEBUG) { mxLog("Retrieving data.\n"); }
209     omxData* dataMat = oo->expectation->data;
210         
211         if(strncmp(omxDataType(dataMat), "acov", 4) != 0 && strncmp(omxDataType(dataMat), "cov", 3) != 0) {
212                 char *errstr = (char*) calloc(250, sizeof(char));
213                 sprintf(errstr, "WLS FitFunction unable to handle data type %s.  Data must be of type 'acov'.\n", omxDataType(dataMat));
214                 omxRaiseError(oo->matrix->currentState, -1, errstr);
215                 free(errstr);
216                 if(OMX_DEBUG) { mxLog("WLS FitFunction unable to handle data type %s.  Aborting.", omxDataType(dataMat)); }
217                 return;
218         }
219
220         omxWLSFitFunction *newObj = (omxWLSFitFunction*) R_alloc(1, sizeof(omxWLSFitFunction));
221         
222     if(OMX_DEBUG) { mxLog("WLS being intialized is at %p (within %p).", oo, newObj); }
223
224     /* Get Expectation Elements */
225         newObj->expectedCov = omxGetExpectationComponent(oo->expectation, oo, "cov");
226         newObj->expectedMeans = omxGetExpectationComponent(oo->expectation, oo, "means");
227     newObj->nThresholds = oo->expectation->numOrdinal;
228     newObj->expectedThresholds = oo->expectation->thresholds;
229     // FIXME: threshold structure should be asked for by omxGetExpectationComponent
230
231         /* Read and set expected means, variances, and weights */
232     cov = omxDataMatrix(dataMat, NULL);
233     means = omxDataMeans(dataMat, NULL, NULL);
234     weights = omxDataAcov(dataMat, NULL);
235         newObj->observedThresholds  = omxDataThresholds(dataMat);
236
237     newObj->observedCov = cov;
238     newObj->observedMeans = means;
239     newObj->weights = weights;
240     newObj->n = omxDataNumObs(dataMat);
241     newObj->nThresholds = omxDataNumFactor(dataMat);
242         UNPROTECT(1);
243         
244         // Error Checking: Observed/Expected means must agree.  
245         // ^ is XOR: true when one is false and the other is not.
246         if((newObj->expectedMeans == NULL) ^ (newObj->observedMeans == NULL)) {
247             if(newObj->expectedMeans != NULL) {
248                     omxRaiseError(oo->matrix->currentState, OMX_ERROR,
249                             "Observed means not detected, but an expected means matrix was specified.\n  If you  wish to model the means, you must provide observed means.\n");
250                     return;
251             } else {
252                     omxRaiseError(oo->matrix->currentState, OMX_ERROR,
253                             "Observed means were provided, but an expected means matrix was not specified.\n  If you provide observed means, you must specify a model for the means.\n");
254                     return;             
255             }
256         }
257
258         if((newObj->expectedThresholds == NULL) ^ (newObj->observedThresholds == NULL)) {
259             if(newObj->expectedThresholds != NULL) {
260                     omxRaiseError(oo->matrix->currentState, OMX_ERROR,
261                             "Observed thresholds not detected, but an expected thresholds matrix was specified.\n   If you wish to model the thresholds, you must provide observed thresholds.\n ");
262                     return;
263             } else {
264                     omxRaiseError(oo->matrix->currentState, OMX_ERROR,
265                             "Observed thresholds were provided, but an expected thresholds matrix was not specified.\nIf you provide observed thresholds, you must specify a model for the thresholds.\n");
266                     return;             
267             }
268         }
269
270     /* Error check weight matrix size */
271     int ncol = newObj->observedCov->cols;
272     vectorSize = (ncol * (ncol + 1) ) / 2;
273     if(newObj->expectedMeans != NULL) {
274         vectorSize = vectorSize + ncol;
275     }
276     if(newObj->observedThresholds != NULL) {
277         for(int i = 0; i < newObj->nThresholds; i++) {
278             vectorSize = vectorSize + newObj->observedThresholds[i].numThresholds;
279         }
280     }
281
282     if(weights != NULL && weights->rows != weights->cols && weights->cols != vectorSize) {
283         omxRaiseError(oo->matrix->currentState, OMX_DEVELOPER_ERROR,
284          "Developer Error in WLS-based FitFunction object: WLS-based expectation specified an incorrectly-sized weight matrix.\nIf you are not developing a new expectation type, you should probably post this to the OpenMx forums.");
285      return;
286     }
287
288         
289         // FIXME: More error checking for incoming Fit Functions
290
291         /* Temporary storage for calculation */
292         newObj->observedFlattened = omxInitMatrix(NULL, vectorSize, 1, TRUE, oo->matrix->currentState);
293         newObj->expectedFlattened = omxInitMatrix(NULL, vectorSize, 1, TRUE, oo->matrix->currentState);
294         newObj->P = omxInitMatrix(NULL, 1, vectorSize, TRUE, oo->matrix->currentState);
295         newObj->B = omxInitMatrix(NULL, vectorSize, 1, TRUE, oo->matrix->currentState);
296
297     flattenDataToVector(newObj->observedCov, newObj->observedMeans, newObj->observedThresholds, newObj->nThresholds, newObj->observedFlattened);
298     flattenDataToVector(newObj->expectedCov, newObj->expectedMeans, newObj->expectedThresholds, newObj->nThresholds, newObj->expectedFlattened);
299
300     oo->argStruct = (void*)newObj;
301
302 }