Fix recent bugs
[openmx:openmx.git] / src / omxState.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 /***********************************************************
18 *
19 *  omxState.cc
20 *
21 *  Created: Timothy R. Brick    Date: 2009-06-05
22 *
23 *       omxStates carry the current optimization state
24 *
25 **********************************************************/
26
27 #include <stdarg.h>
28 #include "omxState.h"
29
30 /* Initialize and Destroy */
31         void omxInitState(omxState* state, omxState *parentState) {
32                 state->numConstraints = 0;
33                 state->numFreeParams = 0;
34                 state->numChildren = 0;
35                 state->childList = NULL;
36                 state->parentState = parentState;
37                 state->fitMatrix = NULL;
38                 state->hessian = NULL;
39                 state->conList = NULL;
40                 state->freeVarList = NULL;
41                 state->optimizerState = NULL;
42                 state->optimalValues = NULL;
43                 state->optimum = 9999999999;
44
45                 state->majorIteration = 0;
46                 state->minorIteration = 0;
47                 state->startTime = 0;
48                 state->endTime = 0;
49                 state->numCheckpoints = 0;
50                 state->checkpointList = NULL;
51                 state->chkptText1 = NULL;
52                 state->chkptText2 = NULL;
53
54         state->currentInterval = -1;
55
56                 state->computeCount = -1;
57                 state->currentRow = -1;
58
59                 state->statusCode = 0;
60                 strncpy(state->statusMsg, "", 1);
61         }
62
63         void omxFillState(omxState* state, /*omxOptimizer *oo,*/ omxMatrix** matrixList,
64                                                 omxMatrix** algebraList, omxData** dataList, omxMatrix* fitFunction) {
65                 error("NYI: Can't fill a state from outside yet. Besides, do you really need a single function to do this?");
66         }
67         
68         omxState* omxGetState(omxState* os, int stateNumber) {
69                 // TODO: Need to implement a smarter way to enumerate children
70                 if(stateNumber == 0) return os;
71                 if((stateNumber-1) < os->numChildren) {
72                         return(os->childList[stateNumber-1]);
73                 } else {
74                         error("Not implemented");
75                         // TODO: Account for unequal numbers of grandchild states
76                         int subState = (stateNumber - os->numChildren - 1);
77                         return omxGetState(os->childList[subState % os->numChildren], subState / os->numChildren);
78                 }
79         }
80
81         void omxSetMajorIteration(omxState *state, int value) {
82                 state->majorIteration = value;
83                 for(int i = 0; i < state->numChildren; i++) {
84                         omxSetMajorIteration(state->childList[i], value);
85                 }
86         }
87
88         void omxSetMinorIteration(omxState *state, int value) {
89                 state->minorIteration = value;
90                 for(int i = 0; i < state->numChildren; i++) {
91                         omxSetMinorIteration(state->childList[i], value);
92                 }
93         }
94         
95         void omxDuplicateState(omxState* tgt, omxState* src) {
96                 tgt->dataList                   = src->dataList;
97                 tgt->numChildren                = 0;
98                 
99                 // Duplicate matrices and algebras and build parentLists.
100                 tgt->parentState                = src;
101                 tgt->markMatrices               = src->markMatrices; // TODO, unused in children?
102                                 
103                 for(size_t mx = 0; mx < src->matrixList.size(); mx++) {
104                         // TODO: Smarter inference for which matrices to duplicate
105                         tgt->matrixList.push_back(omxDuplicateMatrix(src->matrixList[mx], tgt));
106                 }
107
108                 tgt->numConstraints     = src->numConstraints;
109                 tgt->conList                    = (omxConstraint*) R_alloc(tgt->numConstraints, sizeof(omxConstraint));
110                 for(int j = 0; j < tgt->numConstraints; j++) {
111                         tgt->conList[j].size   = src->conList[j].size;
112                         tgt->conList[j].opCode = src->conList[j].opCode;
113                         tgt->conList[j].lbound = src->conList[j].lbound;
114                         tgt->conList[j].ubound = src->conList[j].ubound;
115                         tgt->conList[j].result = omxDuplicateMatrix(src->conList[j].result, tgt);
116                 }
117
118                 for(size_t j = 0; j < src->algebraList.size(); j++) {
119                         // TODO: Smarter inference for which algebras to duplicate
120                         tgt->algebraList.push_back(omxDuplicateMatrix(src->algebraList[j], tgt));
121                 }
122
123                 for(size_t j = 0; j < src->expectationList.size(); j++) {
124                         // TODO: Smarter inference for which expectations to duplicate
125                         tgt->expectationList.push_back(omxDuplicateExpectation(src->expectationList[j], tgt));
126                 }
127
128                 for(size_t j = 0; j < tgt->algebraList.size(); j++) {
129                         omxDuplicateAlgebra(tgt->algebraList[j], src->algebraList[j], tgt);
130                 }
131
132                 for(size_t j = 0; j < src->expectationList.size(); j++) {
133                         // TODO: Smarter inference for which expectations to duplicate
134                         omxCompleteExpectation(tgt->expectationList[j]);
135                 }
136
137                 tgt->childList                  = NULL;
138
139                 tgt->fitMatrix  = omxLookupDuplicateElement(tgt, src->fitMatrix);
140                 tgt->hessian                    = src->hessian;
141
142                 tgt->numFreeParams                      = src->numFreeParams;
143                 tgt->freeVarList                = new omxFreeVar[tgt->numFreeParams];
144                 for(int j = 0; j < tgt->numFreeParams; j++) {
145                         int numDeps                                                     = src->freeVarList[j].numDeps;
146
147                         tgt->freeVarList[j].lbound                      = src->freeVarList[j].lbound;
148                         tgt->freeVarList[j].ubound                      = src->freeVarList[j].ubound;
149                         tgt->freeVarList[j].locations                   = src->freeVarList[j].locations;
150                         tgt->freeVarList[j].numDeps                     = numDeps;
151                         
152                         tgt->freeVarList[j].deps                        = (int*) R_alloc(numDeps, sizeof(int));
153
154                         tgt->freeVarList[j].name                = src->freeVarList[j].name;
155
156                         for(int k = 0; k < numDeps; k++) {
157                                 tgt->freeVarList[j].deps[k] = src->freeVarList[j].deps[k];
158                         }
159                 }
160                 
161                 if (src->optimizerState) {
162                         tgt->optimizerState                                     = (omxOptimizerState*) R_alloc(1, sizeof(omxOptimizerState));
163                         tgt->optimizerState->currentParameter   = src->optimizerState->currentParameter;
164                         tgt->optimizerState->offset                             = src->optimizerState->offset;
165                         tgt->optimizerState->alpha                              = src->optimizerState->alpha;
166                 }
167                 
168                 tgt->optimalValues              = src->optimalValues;
169                 tgt->optimum                    = 9999999999;
170                                   
171                 tgt->majorIteration     = 0;
172                 tgt->minorIteration     = 0;
173                 tgt->startTime                  = src->startTime;
174                 tgt->endTime                    = 0;
175                 
176                 // TODO: adjust checkpointing based on parallelization method
177                 tgt->numCheckpoints     = 0;
178                 tgt->checkpointList     = NULL;
179                 tgt->chkptText1                 = NULL;
180                 tgt->chkptText2                 = NULL;
181                                   
182                 tgt->computeCount               = src->computeCount;
183                 tgt->currentRow                 = src->currentRow;
184
185                 tgt->statusCode                 = 0;
186                 strncpy(tgt->statusMsg, "", 1);
187         }
188
189         omxMatrix* omxLookupDuplicateElement(omxState* os, omxMatrix* element) {
190                 if(os == NULL || element == NULL) return NULL;
191
192                 if (element->hasMatrixNumber) {
193                         int matrixNumber = element->matrixNumber;
194                         if (matrixNumber >= 0) {
195                                 return(os->algebraList[matrixNumber]);
196                         } else {
197                                 return(os->matrixList[-matrixNumber - 1]);
198                         }
199                 }
200
201                 omxConstraint* parentConList = os->parentState->conList;
202
203                 for(int i = 0; i < os->numConstraints; i++) {
204                         if(parentConList[i].result == element) {
205                                 if(os->conList[i].result != NULL) {   // Not sure of proper failure behavior here.
206                 return(os->conList[i].result);
207                                 } else {
208                     omxRaiseError(os, -2, "Initialization Copy Error: Constraint required but not yet processed.");
209             }
210                         }
211                 }
212
213                 return NULL;
214         }
215         
216         omxExpectation* omxLookupDuplicateExpectation(omxState* os, omxExpectation* ox) {
217                 if(os == NULL || ox == NULL) return NULL;
218
219                 return(os->expectationList[ox->expNum]);
220         }
221
222         int omxCountLeafNodes(omxState *state) {
223                 int children = state->numChildren;
224                 if (children == 0) {
225                         return(1);
226                 } else {
227                         int sum = 0;
228                         for(int i = 0; i < children; i++) {
229                                 sum += omxCountLeafNodes(state->childList[i]);
230                         }
231                         return(sum);
232                 }
233         }
234
235         /* Traverse to the root of the state hierarchy,
236          * and then count the number of leaf nodes */
237         int omxTotalThreadCount(omxState *state) {
238
239                 while(state->parentState != NULL) {
240                         state = state->parentState;
241                 }
242         
243                 return(omxCountLeafNodes(state));
244         }
245
246         void omxFreeState(omxState *state) {
247                 int k;
248
249                 if (state->numChildren > 0) {
250                         for(k = 0; k < state->numChildren; k++) {
251                                 // Data are not modified and not copied. The same memory
252                                 // is shared across all instances of state. We only need
253                                 // to free the data once, so let the parent do it.
254                                 state->childList[k]->dataList.clear();
255
256                                 omxFreeState(state->childList[k]);
257                         }
258                         Free(state->childList);
259                         state->childList = NULL;
260                         state->numChildren = 0;
261                 }
262
263                 for(size_t ax = 0; ax < state->algebraList.size(); ax++) {
264                         if(OMX_DEBUG) { Rprintf("Freeing Algebra %d at 0x%x.\n", ax, state->algebraList[ax]); }
265                         omxFreeAllMatrixData(state->algebraList[ax]);
266                 }
267
268                 if(OMX_DEBUG) { Rprintf("Freeing %d Matrices.\n", state->matrixList.size());}
269                 for(size_t mk = 0; mk < state->matrixList.size(); mk++) {
270                         if(OMX_DEBUG) { Rprintf("Freeing Matrix %d at 0x%x.\n", mk, state->matrixList[mk]); }
271                         omxFreeAllMatrixData(state->matrixList[mk]);
272                 }
273                 
274                 if(OMX_DEBUG) { Rprintf("Freeing %d Model Expectations.\n", state->expectationList.size());}
275                 for(size_t ex = 0; ex < state->expectationList.size(); ex++) {
276                         if(OMX_DEBUG) { Rprintf("Freeing Expectation %d at 0x%x.\n", ex, state->expectationList[ex]); }
277                         omxFreeExpectationArgs(state->expectationList[ex]);
278                 }
279
280                 if(OMX_DEBUG) { Rprintf("Freeing %d Constraints.\n", state->numConstraints);}
281                 for(k = 0; k < state->numConstraints; k++) {
282                         if(OMX_DEBUG) { Rprintf("Freeing Constraint %d at 0x%x.\n", k, state->conList[k]); }
283                         omxFreeAllMatrixData(state->conList[k].result);
284                 }
285
286                 if(OMX_DEBUG) { Rprintf("Freeing %d Data Sets.\n", state->dataList.size());}
287                 for(size_t dx = 0; dx < state->dataList.size(); dx++) {
288                         if(OMX_DEBUG) { Rprintf("Freeing Data Set %d at 0x%x.\n", dx, state->dataList[dx]); }
289                         omxFreeData(state->dataList[dx]);
290                 }
291
292                 delete [] state->freeVarList;
293
294         if(OMX_DEBUG) {Rprintf("Freeing %d Children.\n", state->numChildren);}
295         for(k = 0; k < state->numChildren; k++) {
296                         if(OMX_DEBUG) { Rprintf("Freeing Child State %d at 0x%x.\n", k, state->childList[k]); }
297                         omxFreeState(state->childList[k]);            
298         }
299
300                 if(OMX_DEBUG) { Rprintf("Freeing %d Checkpoints.\n", state->numCheckpoints);}
301                 for(k = 0; k < state->numCheckpoints; k++) {
302                         if(OMX_DEBUG) { Rprintf("Freeing Data Set %d at 0x%x.\n", k, state->checkpointList[k]); }
303                         omxCheckpoint oC = state->checkpointList[k];
304                         switch(oC.type) {
305                                 case OMX_FILE_CHECKPOINT:
306                                         fclose(oC.file);
307                                         break;
308                                 case OMX_CONNECTION_CHECKPOINT: // NYI :::DEBUG:::
309                                         // Do nothing: this should be handled by R upon return.
310                                         break;
311                         }
312                         if(state->chkptText1 != NULL) {
313                                 Free(state->chkptText1);
314                         }
315                         if(state->chkptText2 != NULL) {
316                                 Free(state->chkptText2);
317                         }
318                         // Checkpoint list itself is freed by R.
319                 }
320
321                 delete state;
322
323                 if(OMX_DEBUG) { Rprintf("State Freed.\n");}
324         }
325
326         void omxSaveState(omxState *os, double* freeVals, double minimum) {
327                 if(os->optimalValues == NULL) {
328                         os->optimalValues = (double*) R_alloc(os->numFreeParams, sizeof(double));
329                 }
330
331                 for(int i = 0; i < os->numFreeParams; i++) {
332                         os->optimalValues[i] = freeVals[i];
333                 }
334                 os->optimum = minimum;
335                 os->optimumStatus = os->statusCode;
336                 strncpy(os->optimumMsg, os->statusMsg, 250);
337         }
338
339         void omxResetStatus(omxState *state) {
340                 int numChildren = state->numChildren;
341                 state->statusCode = 0;
342                 state->statusMsg[0] = '\0';
343                 for(int i = 0; i < numChildren; i++) {
344                         omxResetStatus(state->childList[i]);
345                 }
346         }
347
348 void omxRaiseErrorf(omxState *state, const char* errorMsg, ...)
349 {
350         va_list ap;
351         va_start(ap, errorMsg);
352         int fit = vsnprintf(state->statusMsg, MAX_STRING_LEN, errorMsg, ap);
353         va_end(ap);
354         if(OMX_DEBUG) {
355                 if (!(fit > -1 && fit < MAX_STRING_LEN)) {
356                         Rprintf("Error exceeded maximum length: %s\n", errorMsg);
357                 } else {
358                         Rprintf("Error raised: %s\n", state->statusMsg);
359                 }
360         }
361         state->statusCode = -1;  // this provides no additional information beyond errorMsg[0]!=0 TODO
362 }
363
364         void omxRaiseError(omxState *state, int errorCode, const char* errorMsg) {
365                 if(OMX_DEBUG && errorCode) { Rprintf("Error %d raised: %s\n", errorCode, errorMsg);}
366                 if(OMX_DEBUG && !errorCode) { Rprintf("Error status cleared."); }
367                 state->statusCode = errorCode;
368                 strncpy(state->statusMsg, errorMsg, 249);
369                 state->statusMsg[249] = '\0';
370                 if(state->computeCount <= 0 && errorCode < 0) {
371                         state->statusCode--;                    // Decrement status for init errors.
372                 }
373         }
374
375         void omxStateNextRow(omxState *state) {
376                 state->currentRow++;
377         };
378
379         void omxStateNextEvaluation(omxState *state) {
380                 state->currentRow = -1;
381                 state->computeCount++;
382         };
383
384         void omxWriteCheckpointHeader(omxState *os, omxCheckpoint* oC) {
385                 // FIXME: Is it faster to allocate this on the stack?
386                 os->chkptText1 = (char*) Calloc((24 + 15 * os->numFreeParams), char);
387                 os->chkptText2 = (char*) Calloc(1.0 + 15.0 * os->numFreeParams*
388                         (os->numFreeParams + 1.0) / 2.0, char);
389                 if (oC->type == OMX_FILE_CHECKPOINT) {
390                         fprintf(oC->file, "iterations\ttimestamp\tobjective\t");
391                         for(int j = 0; j < os->numFreeParams; j++) {
392                                 if(strcmp(os->freeVarList[j].name, CHAR(NA_STRING)) == 0) {
393                                         fprintf(oC->file, "%s", os->freeVarList[j].name);
394                                 } else {
395                                         fprintf(oC->file, "\"%s\"", os->freeVarList[j].name);
396                                 }
397                                 if (j != os->numFreeParams - 1) fprintf(oC->file, "\t");
398                         }
399                         fprintf(oC->file, "\n");
400                         fflush(oC->file);
401                 }
402         }
403  
404         void omxWriteCheckpointMessage(omxState *os, char *msg) {
405                 for(int i = 0; i < os->numCheckpoints; i++) {
406                         omxCheckpoint* oC = &(os->checkpointList[i]);
407                         if(os->chkptText1 == NULL) {    // First one: set up output
408                                 omxWriteCheckpointHeader(os, oC);
409                         }
410                         if (oC->type == OMX_FILE_CHECKPOINT) {
411                                 fprintf(oC->file, "%d \"%s\" NA ", os->majorIteration, msg);
412                                 for(int j = 0; j < os->numFreeParams; j++) {
413                                         fprintf(oC->file, "NA ");
414                                 }
415                                 fprintf(oC->file, "\n");
416                         }
417                 }
418         }
419
420         void omxSaveCheckpoint(omxState *os, double* x, double* f, int force) {
421                 time_t now = time(NULL);
422                 int soFar = now - os->startTime;                // Translated into minutes
423                 int n;
424                 for(int i = 0; i < os->numCheckpoints; i++) {
425                         n = 0;
426                         omxCheckpoint* oC = &(os->checkpointList[i]);
427                         // Check based on time            
428                         if((oC->time > 0 && (soFar - oC->lastCheckpoint) >= oC->time) || force) {
429                                 oC->lastCheckpoint = soFar;
430                                 n = 1;
431                         }
432                         // Or iterations
433                         if((oC->numIterations > 0 && (os->majorIteration - oC->lastCheckpoint) >= oC->numIterations) || force) {
434                                 oC->lastCheckpoint = os->majorIteration;
435                                 n = 1;
436                         }
437
438                         if(n) {         //In either case, save a checkpoint.
439                                 if(os->chkptText1 == NULL) {    // First one: set up output
440                                         omxWriteCheckpointHeader(os, oC);
441                                 }
442                                 char tempstring[25];
443                                 sprintf(tempstring, "%d", os->majorIteration);
444
445                                 if(strncmp(os->chkptText1, tempstring, strlen(tempstring))) {   // Returns zero if they're the same.
446                                         struct tm * nowTime = localtime(&now);                                          // So this only happens if the text is out of date.
447                                         strftime(tempstring, 25, "%b %d %Y %I:%M:%S %p", nowTime);
448                                         sprintf(os->chkptText1, "%d \"%s\" %9.5f", os->majorIteration, tempstring, f[0]);
449                                         for(int j = 0; j < os->numFreeParams; j++) {
450                                                 sprintf(tempstring, " %9.5f", x[j]);
451                                                 strncat(os->chkptText1, tempstring, 14);
452                                         }
453
454                                         double* hessian = os->hessian;
455                                         if(hessian != NULL) {
456                                                 for(int j = 0; j < os->numFreeParams; j++) {
457                                                         for(int k = 0; k <= j; k++) {
458                                                                 sprintf(tempstring, " %9.5f", hessian[j]);
459                                                                 strncat(os->chkptText2, tempstring, 14);
460                                                         }
461                                                 }
462                                         }
463                                 }
464
465                                 if(oC->type == OMX_FILE_CHECKPOINT) {
466                                         fprintf(oC->file, "%s", os->chkptText1);
467                                         if(oC->saveHessian)
468                                                 fprintf(oC->file, "%s", os->chkptText2);
469                                         fprintf(oC->file, "\n");
470                                         fflush(oC->file);
471                                 } else if(oC->type == OMX_CONNECTION_CHECKPOINT) {
472                                         warning("NYI: R_connections are not yet implemented.");
473                                         oC->numIterations = 0;
474                                         oC->time = 0;
475                                 }
476                         }
477                 }
478         }
479
480 void omxExamineFitOutput(omxState *state, omxMatrix *fitMatrix, int *mode)
481 {
482         if (!R_FINITE(fitMatrix->data[0])) {
483                 omxRaiseErrorf(state, "Fit function returned %g at iteration %d.%d",
484                                fitMatrix->data[0], state->majorIteration, state->minorIteration);
485                 *mode = -1;
486         }
487 }