Attempting to propagate major and minor iteration numbers.
[openmx:openmx.git] / src / omxState.c
1 /*
2  *  Copyright 2007-2009 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 /***********************************************************
18 *
19 *  omxState.cc
20 *
21 *  Created: Timothy R. Brick    Date: 2009-06-05
22 *
23 *       omxStates carry the current optimization state
24 *
25 **********************************************************/
26
27 #include "omxState.h"
28
29 /* Initialize and Destroy */
30         void omxInitState(omxState* state, omxState *parentState, int numThreads) {
31                 int i;
32                 state->numMats = 0;
33                 state->numAlgs = 0;
34                 state->numData = 0;
35                 state->numFreeParams = 0;
36                 state->numDynamic = 0;
37                 state->maxDynamic = 0;
38                 if (numThreads > 1) {
39                         state->numChildren = numThreads;
40                         state->childList = (omxState**) Calloc(numThreads, omxState*);
41                         for(i = 0; i < numThreads; i++) {
42                                 state->childList[i] = (omxState*) R_alloc(1, sizeof(omxState));
43                                 omxInitState(state->childList[i], state, 1);
44                         }
45                 } else {
46                 state->numChildren = 0;
47                         state->childList = NULL;
48                 }
49                 state->matrixList = NULL;
50                 state->algebraList = NULL;
51                 state->dynamicList = NULL;
52         state->parentState = parentState;
53         state->parentMatrix = NULL;
54         state->parentAlgebra = NULL;
55                 state->parentConList= NULL;
56                 state->dataList = NULL;
57                 state->objectiveMatrix = NULL;
58                 state->hessian = NULL;
59                 state->conList = NULL;
60                 state->freeVarList = NULL;
61                 state->optimizerState = NULL;
62                 state->optimalValues = NULL;
63                 state->optimum = 9999999999;
64
65                 state->majorIteration = 0;
66                 state->minorIteration = 0;
67                 state->startTime = 0;
68                 state->endTime = 0;
69                 state->numCheckpoints = 0;
70                 state->checkpointList = NULL;
71                 state->chkptText1 = NULL;
72                 state->chkptText2 = NULL;
73
74                 state->computeCount = -1;
75                 state->currentRow = -1;
76
77                 state->statusCode = 0;
78                 strncpy(state->statusMsg, "", 1);
79         }
80
81         void omxFillState(omxState* state, /*omxOptimizer *oo,*/ omxMatrix** matrixList,
82                                                 omxMatrix** algebraList, omxData** dataList, omxMatrix* objective) {
83                 error("NYI: Can't fill a state from outside yet. Besides, do you really need a single function to do this?");
84         }
85         
86         omxState* omxGetState(omxState* os, int stateNumber) {
87                 // TODO: Need to implement a smarter way to enumerate children
88                 if(stateNumber == 0) return os;
89                 if((stateNumber-1) < os->numChildren) {
90                         return(os->childList[stateNumber-1]);
91                 } else {
92                         // TODO: Account for unequal numbers of grandchild states
93                         int subState = (stateNumber - os->numChildren - 1);
94                         return omxGetState(os->childList[subState % os->numChildren], subState / os->numChildren);
95                 }
96         }
97
98         void omxUpdateState(omxState* tgt, omxState* src) {
99                 tgt->computeCount               = src->computeCount;
100                 tgt->currentRow                 = src->currentRow;
101                 tgt->optimalValues              = src->optimalValues;
102                 tgt->majorIteration     = src->majorIteration;
103                 tgt->minorIteration     = src->minorIteration;
104
105                 for(int i = 0; i < src->numMats; i++) {
106                         omxCopyMatrix(tgt->matrixList[i], src->matrixList[i]);
107                 }
108                 for(int i = 0; i < src->numAlgs; i++) {
109                         omxUpdateAlgebra(tgt->algebraList[i], src->algebraList[i]);
110                 }
111                 for(int i = 0; i < src->numDynamic; i++) {
112                         omxUpdateAlgebra(tgt->dynamicList[i], src->dynamicList[i]);
113                 }
114         }
115
116     void omxSetMajorIteration(omxState *state, int value) {
117                 state->majorIteration = value;
118                 for(int i = 0; i < state->numChildren; i++) {
119                         omxSetMajorIteration(state->childList[i], value);
120                 }
121         }
122
123     void omxSetMinorIteration(omxState *state, int value) {
124                 state->minorIteration = value;
125                 for(int i = 0; i < state->numChildren; i++) {
126                         omxSetMinorIteration(state->childList[i], value);
127                 }
128         }
129
130         void omxAddDynamicMatrix(omxState* state, omxMatrix* matrix) {
131                 if (state->dynamicList == NULL) {
132                         state->dynamicList = (omxMatrix**) malloc(16 * sizeof(omxMatrix*));
133                         state->maxDynamic = 16;
134                 }
135                 if (state->numDynamic == state->maxDynamic) {
136                         state->dynamicList = realloc(state->dynamicList, state->maxDynamic * 2 * sizeof(omxMatrix*));
137                         state->maxDynamic = state->maxDynamic * 2;
138                 }
139                 matrix->matrixNumber = state->numDynamic;
140                 state->dynamicList[state->numDynamic] = matrix;
141                 state->numDynamic = state->numDynamic + 1;              
142         }
143         
144         void omxDuplicateState(omxState* tgt, omxState* src, unsigned short fullCopy) {
145                 tgt->numMats                    = src->numMats;
146                 tgt->numAlgs                    = src->numAlgs;
147                 tgt->numData                    = src->numData;
148                 tgt->dataList                   = src->dataList;
149                 tgt->numChildren                = 0;
150                 
151                 // Duplicate matrices and algebras and build parentLists.
152                 tgt->parentState                = src;
153                 tgt->parentMatrix               = src->matrixList;
154                 tgt->parentAlgebra              = src->algebraList;
155                 tgt->matrixList                 = (omxMatrix**) R_alloc(tgt->numMats, sizeof(omxMatrix*));
156                 for(int j = 0; j < tgt->numMats; j++) {
157                         // TODO: Smarter inference for which matrices to duplicate
158                         tgt->matrixList[j] = omxDuplicateMatrix(src->matrixList[j], tgt, fullCopy);
159                 }
160                                 
161                 tgt->parentConList              = src->conList;
162                 tgt->numConstraints     = src->numConstraints;
163                 tgt->conList                    = (omxConstraint*) R_alloc(tgt->numConstraints, sizeof(omxConstraint));
164                 for(int j = 0; j < tgt->numConstraints; j++) {
165                         tgt->conList[j].size   = src->conList[j].size;
166                         tgt->conList[j].opCode = src->conList[j].opCode;
167                         tgt->conList[j].lbound = src->conList[j].lbound;
168                         tgt->conList[j].ubound = src->conList[j].ubound;
169                         tgt->conList[j].result = omxDuplicateMatrix(src->conList[j].result, tgt, fullCopy);
170                 }
171
172                 tgt->algebraList                = (omxMatrix**) R_alloc(tgt->numAlgs, sizeof(omxMatrix*));
173
174                 for(int j = 0; j < tgt->numAlgs; j++) {
175                         // TODO: Smarter inference for which algebras to duplicate
176                         tgt->algebraList[j] = omxDuplicateMatrix(src->algebraList[j], tgt, fullCopy);
177                 }
178
179                 for(int j = 0; j < tgt->numAlgs; j++) {
180                         omxDuplicateAlgebra(tgt->algebraList[j], src->algebraList[j], tgt, fullCopy);
181                 }
182
183                 
184                 tgt->childList                  = NULL;
185
186                 tgt->objectiveMatrix    = omxLookupDuplicateElement(tgt, src->objectiveMatrix);
187                 tgt->hessian                    = src->hessian;
188
189                 tgt->numFreeParams                      = src->numFreeParams;
190                 tgt->freeVarList                = (omxFreeVar*) R_alloc(tgt->numFreeParams, sizeof(omxFreeVar));
191                 for(int j = 0; j < tgt->numFreeParams; j++) {
192                         tgt->freeVarList[j].lbound                      = src->freeVarList[j].lbound;
193                         tgt->freeVarList[j].ubound                      = src->freeVarList[j].ubound;
194                         tgt->freeVarList[j].numLocations        = src->freeVarList[j].numLocations;
195                         
196                         int nLocs                                                       = tgt->freeVarList[j].numLocations;
197                         tgt->freeVarList[j].matrices            = (int*) R_alloc(nLocs, sizeof(int));
198                         tgt->freeVarList[j].row                         = (int*) R_alloc(nLocs, sizeof(int));
199                         tgt->freeVarList[j].col                         = (int*) R_alloc(nLocs, sizeof(int));
200
201                         for(int k = 0; k < nLocs; k++) {
202                                 int theMat                                              = src->freeVarList[j].matrices[k];
203                                 int theRow                                              = src->freeVarList[j].row[k];
204                                 int theCol                                              = src->freeVarList[j].col[k];
205
206                                 tgt->freeVarList[j].matrices[k] = theMat;
207                                 tgt->freeVarList[j].row[k]              = theRow;
208                                 tgt->freeVarList[j].col[k]              = theCol;
209                                                                 
210                                 tgt->freeVarList[j].name                = src->freeVarList[j].name;
211                         }
212                 }
213                 
214                 if (src->optimizerState) {
215                         tgt->optimizerState                                     = (omxOptimizerState*) R_alloc(1, sizeof(omxOptimizerState));
216                         tgt->optimizerState->currentParameter   = src->optimizerState->currentParameter;
217                         tgt->optimizerState->offset                             = src->optimizerState->offset;
218                         tgt->optimizerState->alpha                              = src->optimizerState->alpha;
219                 }
220                 
221                 tgt->optimalValues              = src->optimalValues;
222                 tgt->optimum                    = 9999999999;
223                                   
224                 tgt->majorIteration     = 0;
225                 tgt->minorIteration     = 0;
226                 tgt->startTime                  = src->startTime;
227                 tgt->endTime                    = 0;
228                 
229                 // TODO: adjust checkpointing based on parallelization method
230                 tgt->numCheckpoints     = 0;
231                 tgt->checkpointList     = NULL;
232                 tgt->chkptText1                 = NULL;
233                 tgt->chkptText2                 = NULL;
234                                   
235                 tgt->computeCount               = src->computeCount;
236                 tgt->currentRow                 = src->currentRow;
237
238                 tgt->statusCode                 = 0;
239                 strncpy(tgt->statusMsg, "", 1);
240         }
241
242     omxMatrix* omxLookupDuplicateElement(omxState* os, omxMatrix* element) {
243         if(os == NULL || element == NULL) return NULL;
244
245                 if (element->hasMatrixNumber) {
246                         int matrixNumber = element->matrixNumber;
247                         if (matrixNumber >= 0) {
248                                 return(os->algebraList[matrixNumber]);
249                         } else {
250                                 return(os->matrixList[-matrixNumber - 1]);
251                         }
252                 }
253
254         for(int i = 0; i < os->numConstraints; i++) {
255             if(os->parentConList[i].result == element) {
256                                 if(os->conList[i].result != NULL)   // Not sure of proper failure behavior here.
257                     return(os->conList[i].result);
258                 else
259                     omxRaiseError(os, -2, "Initialization Copy Error: Constraint required but not yet processed.");
260             }
261         }
262
263         return NULL;
264     }
265
266         void omxFreeState(omxState *state) {
267                 int k;
268
269                 if (state->numChildren > 0) {
270                         for(k = 0; k < state->numChildren; k++) {
271                                 omxFreeState(state->childList[k]);
272                         }
273                         Free(state->childList);
274                         state->childList = NULL;
275                         state->numChildren = 0;
276                 }
277
278                 if(OMX_DEBUG) { Rprintf("Freeing %d Algebras.\n", state->numAlgs);}
279                 for(k = 0; k < state->numAlgs; k++) {
280                         if(OMX_DEBUG) { Rprintf("Freeing Algebra %d at 0x%x.\n", k, state->algebraList[k]); }
281                         omxFreeAllMatrixData(state->algebraList[k]);
282                 }
283
284                 if(OMX_DEBUG) { Rprintf("Freeing %d Matrices.\n", state->numMats);}
285                 for(k = 0; k < state->numMats; k++) {
286                         if(OMX_DEBUG) { Rprintf("Freeing Matrix %d at 0x%x.\n", k, state->matrixList[k]); }
287                         omxFreeAllMatrixData(state->matrixList[k]);
288                 }
289
290                 if(OMX_DEBUG) { Rprintf("Freeing %d Data Sets.\n", state->numData);}
291                 for(k = 0; k < state->numData; k++) {
292                         if(OMX_DEBUG) { Rprintf("Freeing Data Set %d at 0x%x.\n", k, state->dataList[k]); }
293                         omxFreeData(state->dataList[k]);
294                 }
295
296         if(OMX_DEBUG) {Rprintf("Freeing %d Children.\n", state->numChildren);}
297         for(k = 0; k < state->numChildren; k++) {
298                         if(OMX_DEBUG) { Rprintf("Freeing Child State %d at 0x%x.\n", k, state->childList[k]); }
299                         omxFreeState(state->childList[k]);            
300         }
301
302                 if(OMX_DEBUG) { Rprintf("Freeing %d Checkpoints.\n", state->numCheckpoints);}
303                 for(k = 0; k < state->numCheckpoints; k++) {
304                         if(OMX_DEBUG) { Rprintf("Freeing Data Set %d at 0x%x.\n", k, state->checkpointList[k]); }
305                         omxCheckpoint oC = state->checkpointList[k];
306                         switch(oC.type) {
307                                 case OMX_FILE_CHECKPOINT:
308                                         fclose(oC.file);
309                                         break;
310                                 case OMX_SOCKET_CHECKPOINT:     // NYI :::DEBUG:::
311                                         // TODO: Close socket
312                                         break;
313                                 case OMX_CONNECTION_CHECKPOINT: // NYI :::DEBUG:::
314                                         // Do nothing: this should be handled by R upon return.
315                                         break;
316                         }
317                         if(state->chkptText1 != NULL) {
318                                 Free(state->chkptText1);
319                         }
320                         if(state->chkptText2 != NULL) {
321                                 Free(state->chkptText2);
322                         }
323                         // Checkpoint list itself is freed by R.
324                 }
325
326                 if(state->dynamicList != NULL) free(state->dynamicList);
327
328                 if(OMX_DEBUG) { Rprintf("State Freed.\n");}
329         }
330
331         void omxSaveState(omxState *os, double* freeVals, double minimum) {
332                 if(os->optimalValues == NULL) {
333                         os->optimalValues = (double*) R_alloc(os->numFreeParams, sizeof(double));
334                 }
335
336                 for(int i = 0; i < os->numFreeParams; i++) {
337                         os->optimalValues[i] = freeVals[i];
338                 }
339                 os->optimum = minimum;
340                 os->optimumStatus = os->statusCode;
341                 strncpy(os->optimumMsg, os->statusMsg, 250);
342         }
343
344         void omxRaiseError(omxState *state, int errorCode, char* errorMsg) {
345                 if(OMX_DEBUG && errorCode) { Rprintf("Error %d raised: %s", errorCode, errorMsg);}
346                 if(OMX_DEBUG && !errorCode) { Rprintf("Error status cleared."); }
347                 state->statusCode = errorCode;
348                 strncpy(state->statusMsg, errorMsg, 249);
349                 state->statusMsg[249] = '\0';
350                 if(state->computeCount <= 0 && errorCode < 0) {
351                         state->statusCode--;                    // Decrement status for init errors.
352                 }
353         }
354
355         void omxStateNextRow(omxState *state) {
356                 state->currentRow++;
357         };
358         void omxStateNextEvaluation(omxState *state) {
359                 state->currentRow = 0;
360                 state->computeCount++;
361         };
362
363         void omxSaveCheckpoint(omxState *os, double* x, double* f) {
364                 time_t now = time(NULL);
365                 int soFar = now - os->startTime;                // Translated into minutes
366                 int n;
367                 for(int i = 0; i < os->numCheckpoints; i++) {
368                         n = 0;
369                         omxCheckpoint* oC = &(os->checkpointList[i]);
370                         // Check based on time
371                         if(oC->time > 0 && (soFar - oC->lastCheckpoint) >= oC->time) {
372                                 oC->lastCheckpoint = soFar;
373                                 n = 1;
374                         }
375                         // Or iterations
376                         if(oC->numIterations > 0 && (os->majorIteration - oC->lastCheckpoint) >= oC->numIterations) {
377                                 oC->lastCheckpoint = os->majorIteration;
378                                 n = 1;
379                         }
380
381                         if(n) {         //In either case, save a checkpoint.
382                                 if(os->chkptText1 == NULL) {    // First one: set up output
383                                         // FIXME: Is it faster to allocate this on the stack?
384                                         os->chkptText1 = (char*) Calloc((24+15*os->numFreeParams), char);
385                                         os->chkptText2 = (char*) Calloc(1.0+15.0*os->numFreeParams*
386                                                                                                                 (os->numFreeParams + 1.0)/2.0, char);
387                                         if (oC->type == OMX_FILE_CHECKPOINT) {
388                                                 fprintf(oC->file, "iterations\ttimestamp\tobjective\t");
389                                                 for(int j = 0; j < os->numFreeParams; j++) {
390                                                         if(strcmp(os->freeVarList[j].name, CHAR(NA_STRING)) == 0) {
391                                                                 fprintf(oC->file, "%s", os->freeVarList[j].name);
392                                                         } else {
393                                                                 fprintf(oC->file, "\"%s\"", os->freeVarList[j].name);
394                                                         }
395                                                         if (j != os->numFreeParams - 1) fprintf(oC->file, "\t");
396                                                 }
397                                                 fprintf(oC->file, "\n");
398                                                 fflush(oC->file);
399                                         }
400                                 }
401                                 char tempstring[25];
402                                 sprintf(tempstring, "%d", os->majorIteration);
403
404                                 if(strncmp(os->chkptText1, tempstring, strlen(tempstring))) {   // Returns zero if they're the same.
405                                         struct tm * nowTime = localtime(&now);                                          // So this only happens if the text is out of date.
406                                         strftime(tempstring, 25, "%b %d %Y %I:%M:%S %p", nowTime);
407                                         sprintf(os->chkptText1, "%d \"%s\" %9.5f", os->majorIteration, tempstring, f[0]);
408                                         for(int j = 0; j < os->numFreeParams; j++) {
409                                                 sprintf(tempstring, " %9.5f", x[j]);
410                                                 strncat(os->chkptText1, tempstring, 14);
411                                         }
412
413                                         double* hessian = os->hessian;
414                                         if(hessian != NULL) {
415                                                 for(int j = 0; j < os->numFreeParams; j++) {
416                                                         for(int k = 0; k <= j; k++) {
417                                                                 sprintf(tempstring, " %9.5f", hessian[j]);
418                                                                 strncat(os->chkptText2, tempstring, 14);
419                                                         }
420                                                 }
421                                         }
422                                 }
423
424                                 if(oC->type == OMX_FILE_CHECKPOINT) {
425                                         fprintf(oC->file, "%s", os->chkptText1);
426                                         if(oC->saveHessian)
427                                                 fprintf(oC->file, "%s", os->chkptText2);
428                                         fprintf(oC->file, "\n");
429                                         fflush(oC->file);
430                                 } else if(oC->type == OMX_SOCKET_CHECKPOINT) {
431                                         n = write(oC->socket, os->chkptText1, strlen(os->chkptText1));
432                                         if(n != strlen(os->chkptText1)) warning("Error writing checkpoint.");
433                                         if(oC->saveHessian) {
434                                                 n = write(oC->socket, os->chkptText2, strlen(os->chkptText2));
435                                                 if(n != strlen(os->chkptText1)) warning("Error writing checkpoint.");
436                                         }
437                                         n = write(oC->socket, "\n", 1);
438                                         if(n != 1) warning("Error writing checkpoint.");
439                                 } else if(oC->type == OMX_CONNECTION_CHECKPOINT) {
440                                         warning("NYI: R_connections are not yet implemented.");
441                                         oC->numIterations = 0;
442                                         oC->time = 0;
443                                 }
444                         }
445                 }
446         }