Fix 'fitType' scope and reindent
[openmx:openmx.git] / src / omxFitFunction.c
1 /*
2  *  Copyright 2007-2012 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 /***********************************************************
18
19 *  omxFitFunction.cc
20 *
21 *  Created: Timothy R. Brick    Date: 2008-11-13 12:33:06
22 *
23 *       FitFunction objects are a subclass of data matrix that evaluates
24 *   itself anew at each iteration, so that any changes to
25 *   free parameters can be incorporated into the update.
26 *   // Question: Should FitFunction be a ``subtype'' of 
27 *   // omxAlgebra or a separate beast entirely?
28 *
29 **********************************************************/
30
31 #include "omxFitFunction.h"
32
33 void omxCalculateStdErrorFromHessian(double scale, omxFitFunction *off) {
34         /* This function calculates the standard errors from the hessian matrix */
35         // sqrt(diag(solve(hessian)))
36
37         if(off->hessian == NULL) return;
38         
39         int numParams = off->matrix->currentState->numFreeParams;
40         
41         if(off->stdError == NULL) {
42                 off->stdError = (double*) R_alloc(numParams, sizeof(double));
43         }
44         
45         double* stdErr = off->stdError;
46         
47         double* hessian = off->hessian;
48         double* workspace = (double *) Calloc(numParams * numParams, double);
49         
50         for(int i = 0; i < numParams; i++)
51                 for(int j = 0; j <= i; j++)
52                         workspace[i*numParams+j] = hessian[i*numParams+j];              // Populate upper triangle
53         
54         char u = 'U';
55         int ipiv[numParams];
56         int lwork = -1;
57         double temp;
58         int info = 0;
59         
60         F77_CALL(dsytrf)(&u, &numParams, workspace, &numParams, ipiv, &temp, &lwork, &info);
61         
62         lwork = (temp > numParams?temp:numParams);
63         
64         double* work = (double*) Calloc(lwork, double);
65         
66         F77_CALL(dsytrf)(&u, &numParams, workspace, &numParams, ipiv, work, &lwork, &info);
67         
68         if(info != 0) {
69                 
70                 off->stdError = NULL;
71                 
72         } else {
73                 
74                 F77_CALL(dsytri)(&u, &numParams, workspace, &numParams, ipiv, work, &info);
75         
76                 if(info != 0) {
77                         off->stdError = NULL;
78                 } else {
79                         for(int i = 0; i < numParams; i++) {
80                                 stdErr[i] = sqrt(scale) * sqrt(workspace[i * numParams + i]);
81                         }
82                 }
83         }
84         
85         Free(workspace);
86         Free(work);
87         
88 }
89
90 void omxInitEmptyFitFunction(omxFitFunction *off) {
91         /* Sets everything to NULL to avoid bad pointer calls */
92         
93         memset(off, 0, sizeof(omxFitFunction));
94 }
95
96 void omxFreeFitFunctionArgs(omxFitFunction *off) {
97         if(off==NULL) return;
98     
99         /* Completely destroy the fit function structures */
100         if(OMX_DEBUG) {Rprintf("Freeing fit function object at 0x%x.\n", off);}
101         if(off->matrix != NULL) {
102                 if(off->destructFun != NULL) {
103                         if(OMX_DEBUG) {Rprintf("Calling fit function destructor for 0x%x.\n", off);}
104                         off->destructFun(off);
105                 }
106                 off->matrix = NULL;
107         }
108 }
109
110 void omxFitFunctionCompute(omxFitFunction *off) {
111         if(OMX_DEBUG_ALGEBRA) { 
112             Rprintf("FitFunction compute: 0x%0x (needed: %s).\n", off, (off->matrix->isDirty?"Yes":"No"));
113         }
114
115         off->computeFun(off);
116
117         omxMarkClean(off->matrix);
118
119 }
120
121 void omxDuplicateFitMatrix(omxMatrix *tgt, const omxMatrix *src, omxState* newState) {
122
123         if(tgt == NULL || src == NULL) return;
124         if(src->fitFunction == NULL) return;
125     
126         omxFillMatrixFromMxFitFunction(tgt, src->fitFunction->rObj, src->hasMatrixNumber, src->matrixNumber);
127
128 }
129
130 omxFitFunction* omxCreateDuplicateFitFunction(omxFitFunction *tgt, const omxFitFunction *src, omxState* newState) {
131
132         if(OMX_DEBUG) {Rprintf("Duplicating fit function 0x%x into 0x%x.", src, tgt);}
133
134         if(src == NULL) {
135                 return NULL;
136         }
137         
138         if(tgt == NULL) {
139         tgt = (omxFitFunction*) R_alloc(1, sizeof(omxFitFunction));
140         omxInitEmptyFitFunction(tgt);
141     } else {
142                 omxRaiseError(newState, -1,
143                         "omxCreateDuplicateFitFunction requested to overwrite target");
144                 return NULL;
145         }
146
147         memcpy(tgt, src, sizeof(omxFitFunction));
148         return tgt;
149
150 }
151
152 void omxFillMatrixFromMxFitFunction(omxMatrix* om, SEXP rObj,
153         unsigned short hasMatrixNumber, int matrixNumber) {
154
155         SEXP slotValue, fitFunctionClass;
156         omxFitFunction *obj = (omxFitFunction*) R_alloc(1, sizeof(omxFitFunction));
157         omxInitEmptyFitFunction(obj);
158
159         /* Register FitFunction and Matrix with each other */
160         obj->matrix = om;
161         omxResizeMatrix(om, 1, 1, FALSE);                                       // FitFunction matrices MUST be 1x1.
162         om->fitFunction = obj;
163         om->hasMatrixNumber = hasMatrixNumber;
164         om->matrixNumber = matrixNumber;
165         
166         /* Get FitFunction Type */
167         PROTECT(fitFunctionClass = STRING_ELT(getAttrib(rObj, install("class")), 0));
168         {
169           const char *fitType = CHAR(fitFunctionClass);
170         
171           /* Switch based on fit function type. */ 
172           const omxFitFunctionTableEntry *entry = omxFitFunctionSymbolTable;
173           while (entry->initFun) {
174             if(strncmp(fitType, entry->name, MAX_STRING_LEN) == 0) {
175               obj->fitType = entry->name;
176               obj->initFun = entry->initFun;
177               break;
178             }
179             entry += 1;
180           }
181
182           if(obj->initFun == NULL) {
183             const int MaxErrorLen = 256;
184             char newError[MaxErrorLen];
185             snprintf(newError, MaxErrorLen, "Fit function %s not implemented.\n", fitType);
186             omxRaiseError(om->currentState, -1, newError);
187             return;
188           }
189         }
190         UNPROTECT(1);   /* fitType */
191
192         PROTECT(slotValue = GET_SLOT(rObj, install("expectation")));
193         int expNumber = INTEGER(slotValue)[0];  
194         if(expNumber == NA_INTEGER) {                                           // Has no expectation associated with it
195                 obj->expectation = NULL;
196         } else {
197                 obj->expectation = omxNewExpectationFromExpectationIndex(expNumber, om->currentState);
198         }
199         UNPROTECT(1);   /* slotValue */
200         
201         obj->rObj = rObj;
202         obj->initFun(obj, rObj);
203
204         if(obj->computeFun == NULL) {// If initialization fails, error code goes in argStruct
205                 char *errorCode;
206                 if(om->currentState->statusCode != 0) {
207                         errorCode = om->currentState->statusMsg;
208                 } else {
209                         // If no error code is reported, we report that.
210                         errorCode = "No error code reported.";
211                 }
212                 if(obj->argStruct != NULL) {
213                         errorCode = (char*)(obj->argStruct);
214                 }
215         const int MaxErrorLen = 256;
216         char newError[MaxErrorLen];
217         snprintf(newError, MaxErrorLen, "Could not initialize fit function %s.  Error: %s\n",
218                         obj->fitType, errorCode); 
219                 omxRaiseError(om->currentState, -1, newError);
220         }
221         
222         obj->matrix->isDirty = TRUE;
223
224 }
225
226 void omxFitFunctionGradient(omxFitFunction* off, double* gradient) {
227         if(!(off->gradientFun == NULL)) { off->gradientFun(off, gradient); }
228         return;
229 }
230
231 void omxFitFunctionPrint(omxFitFunction* off, char* d) {
232         Rprintf("(FitFunction, type %s) ", off->fitType);
233         omxPrintMatrix(off->matrix, d);
234 }
235
236
237 /* Helper functions */
238 omxMatrix* omxNewMatrixFromIndexSlot(SEXP rObj, omxState* currentState, char* const slotName) {
239         SEXP slotValue;
240         omxMatrix* newMatrix = NULL;
241         if(strncmp(slotName, "", 1) == 0) return NULL;
242         PROTECT(slotValue = GET_SLOT(rObj, install(slotName)));
243         newMatrix = omxNewMatrixFromMxIndex(slotValue, currentState);
244         if(newMatrix != NULL) omxRecompute(newMatrix);
245         else if(OMX_DEBUG) Rprintf("No slot %s found.\n", slotName);
246         UNPROTECT(1);
247         return newMatrix;
248 }
249
250 omxData* omxNewDataFromDataSlot(SEXP rObj, omxState* currentState, char* const dataSlotName) {
251         
252         SEXP slotValue;
253         
254         PROTECT(slotValue = GET_SLOT(rObj, install(dataSlotName)));
255         if(OMX_DEBUG) { Rprintf("Data Element %d.\n", AS_INTEGER(slotValue)); }
256         omxData* dataElt = omxNewDataFromMxDataPtr(slotValue, currentState);
257         UNPROTECT(1); // newMatrix
258         
259         return dataElt;
260         
261 }