First swipe at State duplication for parallelism. Also some changes to subobjective...
[openmx:openmx.git] / src / omxObjective.c
1 /*
2  *  Copyright 2007-2009 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 /***********************************************************
18
19 *  omxObjective.cc
20 *
21 *  Created: Timothy R. Brick    Date: 2008-11-13 12:33:06
22 *
23 *       Objective objects are a subclass of data matrix that evaluates
24 *   itself anew at each iteration, so that any changes to
25 *   free parameters can be incorporated into the update.
26 *   // Question: Should Objective be a ``subtype'' of 
27 *   // omxAlgebra or a separate beast entirely?
28 *
29 **********************************************************/
30
31 #include "omxObjective.h"
32
33 void omxCalculateStdErrorFromHessian(double scale, omxObjective *oo) {
34         /* This function calculates the standard errors from the hessian matrix */
35         // sqrt(diag(solve(hessian)))
36
37         if(oo->hessian == NULL) return;
38         
39         int numParams = oo->matrix->currentState->numFreeParams;
40         
41         if(oo->stdError == NULL) {
42                 oo->stdError = (double*) R_alloc(numParams, sizeof(double));
43         }
44         
45         double* stdErr = oo->stdError;
46         
47         double* hessian = oo->hessian;
48         double* workspace = (double *) Calloc(numParams * numParams, double);
49         
50         for(int i = 0; i < numParams; i++)
51                 for(int j = 0; j <= i; j++)
52                         workspace[i*numParams+j] = hessian[i*numParams+j];              // Populate upper triangle
53         
54         char u = 'U';
55         int ipiv[numParams];
56         int lwork = -1;
57         double temp;
58         int info = 0;
59         
60         F77_CALL(dsytrf)(&u, &numParams, workspace, &numParams, ipiv, &temp, &lwork, &info);
61         
62         lwork = (temp > numParams?temp:numParams);
63         
64         double* work = (double*) Calloc(lwork, double);
65         
66         F77_CALL(dsytrf)(&u, &numParams, workspace, &numParams, ipiv, work, &lwork, &info);
67         
68         if(info != 0) {
69                 
70                 oo->stdError = NULL;
71                 
72         } else {
73                 
74                 F77_CALL(dsytri)(&u, &numParams, workspace, &numParams, ipiv, work, &info);
75         
76                 if(info != 0) {
77                         oo->stdError = NULL;
78                 } else {
79                         for(int i = 0; i < numParams; i++) {
80                                 stdErr[i] = sqrt(scale) * sqrt(workspace[i * numParams + i]);
81                         }
82                 }
83         }
84         
85         Free(workspace);
86         Free(work);
87         
88 }
89
90 void omxInitEmptyObjective(omxObjective *oo) {
91         /* Sets everything to NULL to avoid bad pointer calls */
92         
93         oo->initFun = NULL;
94         oo->destructFun = NULL;
95         oo->repopulateFun = NULL;
96         oo->objectiveFun = NULL;
97         oo->needsUpdateFun = NULL;
98         oo->getStandardErrorFun = NULL;
99     oo->populateAttrFun = NULL;
100         oo->setFinalReturns = NULL;
101         oo->gradientFun = NULL;
102     oo->sharedArgs = NULL;
103         oo->argStruct = NULL;
104         oo->subObjective = NULL;
105         oo->objType = (char*) calloc(MAX_STRING_LEN, sizeof(char*));
106         oo->objType[0] = '\0';
107         oo->matrix = NULL;
108         oo->stdError = NULL;
109         oo->hessian = NULL;
110         oo->gradient = NULL;
111 }
112
113 omxObjective* omxCreateSubObjective(omxObjective *oo) {
114
115         if(OMX_DEBUG) {Rprintf("Creating SubObjective Object....\n");}
116     if(oo == NULL) {
117                 if(OMX_DEBUG) {Rprintf("Got Null objective.  Returning.");}
118                 return NULL;
119         }
120     omxObjective* subObjective = (omxObjective*) Calloc(1, omxObjective);
121     omxInitEmptyObjective(subObjective);
122         omxDuplicateObjective(subObjective, oo, oo->matrix->currentState, FALSE);
123         oo->subObjective = subObjective;
124         
125     return subObjective;
126
127 }
128
129 void omxFreeObjectiveArgs(omxObjective *oo) {
130     if(oo==NULL) return;
131     
132         /* Completely destroy the objective function tree */
133     if(OMX_DEBUG) {Rprintf("Freeing objective object at 0x%x with subobjective 0x%x.\n", oo, oo->subObjective);}
134         if(oo->matrix != NULL) {
135             if(oo->objType != NULL) {
136                 free(oo->objType);
137             oo->objType = NULL;
138         }
139             if(oo->subObjective != NULL) {
140                     omxFreeObjectiveArgs(oo->subObjective);
141             }
142             if(oo->destructFun != NULL) {
143             if(OMX_DEBUG) {Rprintf("Calling objective destructor for 0x%x.\n", oo);}
144                     oo->destructFun(oo);
145             }
146             oo->matrix = NULL;
147     }
148 }
149
150 void omxObjectiveCompute(omxObjective *oo) {
151         if(OMX_DEBUG_ALGEBRA) { 
152             Rprintf("Objective compute: 0x%0x (needed: %s).\n", oo, (oo->matrix->isDirty?"Yes":"No"));
153         }
154
155         oo->objectiveFun(oo);
156
157         if(oo->matrix != NULL)
158                 omxMatrixCompute(oo->matrix);
159 }
160
161 void omxDuplicateObjectiveMatrix(omxMatrix *tgt, const omxMatrix *src, omxState* newState, short duplicateUnshared) {
162
163     if(tgt == NULL || src == NULL) return;
164     if(src->objective == NULL) {
165         return;
166     }
167     
168     omxObjective* target = tgt->objective;
169     omxObjective* source = src->objective;
170
171         tgt->objective = omxDuplicateObjective(target, source, newState, duplicateUnshared);
172
173 }
174
175 omxObjective* omxDuplicateObjective(omxObjective *tgt, const omxObjective *src, omxState* newState, short duplicateUnshared) {
176
177         if(OMX_DEBUG) {Rprintf("Duplicating objective 0x%x into 0x%x.", src, tgt);}
178
179         if(src == NULL) {
180                 return NULL;
181         }
182         
183         if(tgt == NULL) {
184         tgt = (omxObjective*) R_alloc(1, sizeof(omxObjective));
185         omxInitEmptyObjective(tgt);
186     }
187
188         tgt->initFun                            = src->initFun;
189         tgt->destructFun                        = src->destructFun;
190         tgt->repopulateFun                      = src->repopulateFun;
191         tgt->objectiveFun                       = src->objectiveFun;
192         tgt->needsUpdateFun                     = src->needsUpdateFun;
193         tgt->getStandardErrorFun        = src->getStandardErrorFun;
194         tgt->populateAttrFun            = src->populateAttrFun;
195         tgt->setFinalReturns            = src->setFinalReturns;
196         tgt->gradientFun                        = src->gradientFun;
197         tgt->sharedArgs                         = src->sharedArgs;
198         tgt->matrix                             = src->matrix;
199         tgt->subObjective                       = src->subObjective;
200         tgt->stdError                           = src->stdError;
201         tgt->hessian                            = src->hessian;
202         tgt->gradient                           = src->gradient;
203
204         if(tgt->objType == NULL) tgt->objType = (char*) calloc(MAX_STRING_LEN, sizeof(char*)); // Double-check
205     strncpy(tgt->objType, src->objType, MAX_STRING_LEN);
206
207         if(duplicateUnshared == TRUE && src->duplicateUnsharedArgs != NULL) {
208             // Duplicate function should replace any shared args from above as well
209             src->duplicateUnsharedArgs(tgt, src);
210     }
211         
212         return tgt;
213
214 }
215
216 unsigned short omxObjectiveNeedsUpdate(omxObjective *oo)
217 {
218         if(OMX_DEBUG_MATRIX) { Rprintf("omxObjectiveNeedsUpdate:"); }
219         unsigned short needsIt = TRUE;     // Defaults to TRUE if unspecified
220         if(!(oo->needsUpdateFun == NULL)) {
221                 if(OMX_DEBUG_MATRIX) {Rprintf("Calling update function 0x%x:", oo->needsUpdateFun);}
222                 needsIt = oo->needsUpdateFun(oo);
223                 if(!needsIt && !(oo->subObjective == NULL)) {
224                         needsIt = omxObjectiveNeedsUpdate(oo->subObjective);
225                 }
226         } else if(!(oo->subObjective == NULL)) {
227                 needsIt = omxObjectiveNeedsUpdate(oo->subObjective);
228         }
229         
230         if(OMX_DEBUG_MATRIX) {Rprintf("%s\n", (needsIt?"Yes":"No"));}
231         
232         return needsIt;
233 }
234
235 void omxSetObjectiveType(omxObjective* oo, const char* newName) {
236         if(oo->objType == NULL) {
237                 oo->objType = (char*) calloc(MAX_STRING_LEN, sizeof(char*));
238                 oo->objType[0] = '\0';
239         }
240         
241         strncpy(oo->objType, newName, MAX_STRING_LEN);
242 }
243
244 void omxFillMatrixFromMxObjective(omxMatrix* om, SEXP rObj) {
245
246         int i;
247         const char *objType;
248         SEXP objectiveClass;
249         char errorCode[MAX_STRING_LEN];
250         omxObjective *obj = (omxObjective*) R_alloc(1, sizeof(omxObjective));
251         omxInitEmptyObjective(obj);
252
253         /* Register Objective and Matrix with each other */
254         obj->matrix = om;
255         omxResizeMatrix(om, 1, 1, FALSE);                                       // Objective matrices MUST be 1x1.
256         om->objective = obj;
257         
258         /* Get Objective Type */
259         PROTECT(objectiveClass = STRING_ELT(getAttrib(rObj, install("class")), 0));
260         objType = CHAR(objectiveClass);
261         obj->objType[MAX_STRING_LEN] = '\0';
262         strncpy(obj->objType, objType, MAX_STRING_LEN);
263         
264         /* Switch based on objective type. */ 
265         for(i = 0; i < omxObjectiveTableLength; i++) {
266                 if(strncmp(objType, omxObjectiveSymbolTable[i].name, MAX_STRING_LEN) == 0) {
267                         obj->initFun = omxObjectiveSymbolTable[i].initFun;
268                         break;
269                 }
270         }
271
272         if(i == omxObjectiveTableLength) {
273                 char newError[MAX_STRING_LEN];
274                 sprintf(newError, "Objective function %s not implemented.\n", obj->objType);
275                 omxRaiseError(om->currentState, -1, newError);
276         }
277
278         obj->initFun(obj, rObj);
279
280         if(obj->objectiveFun == NULL) {// If initialization fails, error code goes in argStruct
281                 if(om->currentState->statusCode != 0) {
282                         strncpy(errorCode, om->currentState->statusMsg, 150); // Report a status error
283                 } else {
284                         // If no error code is reported, we report that.
285                         strncpy(errorCode, "No error code reported.", 25);
286                 }
287                 if(obj->argStruct != NULL) {
288                         strncpy(errorCode, (char*)(obj->argStruct), 51);
289                 }
290                 char newError[MAX_STRING_LEN];
291                 sprintf(newError, "Could not initialize objective function %s.  Error: %s\n", 
292                     obj->objType, errorCode);
293                 omxRaiseError(om->currentState, -1, newError);
294         }
295         
296         obj->matrix->isDirty = TRUE;
297
298         UNPROTECT(1);   /* objectiveClass */
299
300 }
301
302 void omxObjectiveGradient(omxObjective* oo, double* gradient) {
303         if(!(oo->gradientFun == NULL)) { oo->gradientFun(oo, gradient); }
304         return;
305 }
306
307 void omxObjectivePrint(omxObjective* oo, char* d) {
308         Rprintf("(Objective, type %s) ", oo->objType);
309         omxPrintMatrix(oo->matrix, d);
310 }
311
312
313 /* Helper functions */
314 omxMatrix* omxNewMatrixFromIndexSlot(SEXP rObj, omxState* currentState, char* const slotName) {
315         SEXP slotValue;
316         omxMatrix* newMatrix = NULL;
317         if(strncmp(slotName, "", 1) == 0) return NULL;
318         PROTECT(slotValue = GET_SLOT(rObj, install(slotName)));
319         newMatrix = omxNewMatrixFromMxIndex(slotValue, currentState);
320         if(newMatrix != NULL) omxRecompute(newMatrix);
321         else if(OMX_DEBUG) Rprintf("No slot %s found.\n", slotName);
322         UNPROTECT(1);
323         return newMatrix;
324 }
325
326 omxData* omxNewDataFromDataSlot(SEXP rObj, omxState* currentState, char* const dataSlotName) {
327         
328         SEXP slotValue;
329         
330         PROTECT(slotValue = GET_SLOT(rObj, install(dataSlotName)));
331         if(OMX_DEBUG) { Rprintf("Data Element %d.\n", AS_INTEGER(slotValue)); }
332         omxData* dataElt = omxNewDataFromMxDataPtr(slotValue, currentState);
333         UNPROTECT(1); // newMatrix
334         
335         return dataElt;
336         
337 }