Fix a bug I introduced while refactoring omxLookupDuplicateElement()
[openmx:openmx.git] / src / omxState.c
1 /*
2  *  Copyright 2007-2009 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 /***********************************************************
18 *
19 *  omxState.cc
20 *
21 *  Created: Timothy R. Brick    Date: 2009-06-05
22 *
23 *       omxStates carry the current optimization state
24 *
25 **********************************************************/
26
27 #include "omxState.h"
28
29 /* Initialize and Destroy */
30         void omxInitState(omxState* state, int numThreads) {
31                 int i;
32                 state->numMats = 0;
33                 state->numAlgs = 0;
34                 state->numData = 0;
35                 state->numFreeParams = 0;
36                 if (numThreads > 1) {
37                         state->numChildren = numThreads;
38                         state->childList = (omxState**) Calloc(numThreads, omxState*);
39                         for(i = 0; i < numThreads; i++) {
40                                 state->childList[i] = (omxState*) R_alloc(1, sizeof(omxState));
41                                 omxInitState(state->childList[i], 1);
42                         }
43                 } else {
44                 state->numChildren = 0;
45                         state->childList = NULL;
46                 }
47                 state->matrixList = NULL;
48                 state->algebraList = NULL;
49         state->parentState = NULL;
50         state->parentMatrix = NULL;
51         state->parentAlgebra = NULL;
52                 state->parentConList= NULL;
53                 state->dataList = NULL;
54                 state->objectiveMatrix = NULL;
55                 state->hessian = NULL;
56                 state->conList = NULL;
57                 state->freeVarList = NULL;
58                 state->optimizerState = NULL;
59                 state->optimalValues = NULL;
60                 state->optimum = 9999999999;
61
62                 state->majorIteration = 0;
63                 state->minorIteration = 0;
64                 state->startTime = 0;
65                 state->endTime = 0;
66                 state->numCheckpoints = 0;
67                 state->checkpointList = NULL;
68                 state->chkptText1 = NULL;
69                 state->chkptText2 = NULL;
70
71                 state->computeCount = -1;
72                 state->currentRow = -1;
73
74                 state->statusCode = 0;
75                 strncpy(state->statusMsg, "", 1);
76         }
77
78         void omxFillState(omxState* state, /*omxOptimizer *oo,*/ omxMatrix** matrixList,
79                                                 omxMatrix** algebraList, omxData** dataList, omxMatrix* objective) {
80                 error("NYI: Can't fill a state from outside yet. Besides, do you really need a single function to do this?");
81         }
82         
83         omxState* omxGetState(omxState* os, int stateNumber) {
84                 // TODO: Need to implement a smarter way to enumerate children
85                 if(stateNumber == 0) return os;
86                 if((stateNumber-1) < os->numChildren) {
87                         return(os->childList[stateNumber-1]);
88                 } else {
89                         // TODO: Account for unequal numbers of grandchild states
90                         int subState = (stateNumber - os->numChildren - 1);
91                         return omxGetState(os->childList[subState % os->numChildren], subState / os->numChildren);
92                 }
93         }
94         
95         void omxDuplicateState(omxState* tgt, omxState* src, unsigned short fullCopy) {
96                 tgt->numMats                    = src->numMats;
97                 tgt->numAlgs                    = src->numAlgs;
98                 tgt->numData                    = src->numData;
99                 tgt->dataList                   = src->dataList;
100                 tgt->numChildren                = 0;
101                 
102                 // Duplicate matrices and algebras and build parentLists.
103                 tgt->parentState                = src;
104                 tgt->parentMatrix               = src->matrixList;
105                 tgt->parentAlgebra              = src->algebraList;
106                 tgt->matrixList                 = (omxMatrix**) R_alloc(tgt->numMats, sizeof(omxMatrix*));
107                 for(int j = 0; j < tgt->numMats; j++) {
108                         // TODO: Smarter inference for which matrices to duplicate
109                         tgt->matrixList[j] = omxDuplicateMatrix(NULL, src->matrixList[j], tgt, fullCopy);
110                 }
111                                 
112                 tgt->parentConList              = src->conList;
113                 tgt->numConstraints     = src->numConstraints;
114                 tgt->conList                    = (omxConstraint*) R_alloc(tgt->numConstraints, sizeof(omxConstraint));
115                 for(int j = 0; j < tgt->numConstraints; j++) {
116                         tgt->conList[j].size   = src->conList[j].size;
117                         tgt->conList[j].opCode = src->conList[j].opCode;
118                         tgt->conList[j].lbound = src->conList[j].lbound;
119                         tgt->conList[j].ubound = src->conList[j].ubound;
120                         tgt->conList[j].result = omxDuplicateMatrix(NULL, src->conList[j].result, tgt, fullCopy);
121                 }
122
123                 tgt->algebraList                = (omxMatrix**) R_alloc(tgt->numAlgs, sizeof(omxMatrix*));
124                 for(int j = 0; j < tgt->numAlgs; j++) {
125                         // TODO: Smarter inference for which algebras to duplicate
126                         tgt->algebraList[j] = omxDuplicateMatrix(NULL, src->algebraList[j], tgt, fullCopy);
127                 }
128
129                 for(int j = 0; j < tgt->numAlgs; j++) {
130                         omxMatrix *srcMatrix = src->algebraList[j];
131                         omxMatrix *tgtMatrix = tgt->algebraList[j];
132                         if(srcMatrix->algebra != NULL) {
133                                 omxDuplicateAlgebra(tgtMatrix->algebra, srcMatrix->algebra, tgt, fullCopy);
134                     }
135                 }
136
137                 
138                 tgt->childList                  = NULL;
139
140                 tgt->objectiveMatrix    = omxLookupDuplicateElement(tgt, src->objectiveMatrix);
141                 tgt->hessian                    = src->hessian;
142
143                 tgt->numFreeParams                      = src->numFreeParams;
144                 tgt->freeVarList                = (omxFreeVar*) R_alloc(tgt->numFreeParams, sizeof(omxFreeVar));
145                 for(int j = 0; j < tgt->numFreeParams; j++) {
146                         tgt->freeVarList[j].lbound                      = src->freeVarList[j].lbound;
147                         tgt->freeVarList[j].ubound                      = src->freeVarList[j].ubound;
148                         tgt->freeVarList[j].numLocations        = src->freeVarList[j].numLocations;
149                         
150                         int nLocs                                                       = tgt->freeVarList[j].numLocations;
151                         tgt->freeVarList[j].matrices            = (int*) R_alloc(nLocs, sizeof(int));
152                         tgt->freeVarList[j].row                         = (int*) R_alloc(nLocs, sizeof(int));
153                         tgt->freeVarList[j].col                         = (int*) R_alloc(nLocs, sizeof(int));
154                         tgt->freeVarList[j].location            = (double**) R_alloc(nLocs, sizeof(double*));
155
156                         for(int k = 0; k < nLocs; k++) {
157                                 int theMat                                              = src->freeVarList[j].matrices[k];
158                                 int theRow                                              = src->freeVarList[j].row[k];
159                                 int theCol                                              = src->freeVarList[j].col[k];
160
161                                 tgt->freeVarList[j].matrices[k] = theMat;
162                                 tgt->freeVarList[j].row[k]              = theRow;
163                                 tgt->freeVarList[j].col[k]              = theCol;
164                                 
165                                 tgt->freeVarList[j].location[k] = omxLocationOfMatrixElement(tgt->matrixList[theMat], theRow, theCol);
166                                 
167                                 tgt->freeVarList[j].name                = src->freeVarList[j].name;
168                         }
169                 }
170                 
171                 if (src->optimizerState) {
172                         tgt->optimizerState                                     = (omxOptimizerState*) R_alloc(1, sizeof(omxOptimizerState));
173                         tgt->optimizerState->currentParameter   = src->optimizerState->currentParameter;
174                         tgt->optimizerState->offset                             = src->optimizerState->offset;
175                         tgt->optimizerState->alpha                              = src->optimizerState->alpha;
176                 }
177                 
178                 tgt->optimalValues              = src->optimalValues;
179                 tgt->optimum                    = 9999999999;
180                                   
181                 tgt->majorIteration     = 0;
182                 tgt->minorIteration     = 0;
183                 tgt->startTime                  = src->startTime;
184                 tgt->endTime                    = 0;
185                 
186                 // TODO: adjust checkpointing based on parallelization method
187                 tgt->numCheckpoints     = src->numCheckpoints;
188                 tgt->checkpointList     = src->checkpointList;
189                 tgt->chkptText1                 = src->chkptText1;
190                 tgt->chkptText2                 = src->chkptText2;
191                                   
192                 tgt->computeCount               = src->computeCount;
193                 tgt->currentRow                 = src->currentRow;
194
195                 tgt->statusCode                 = 0;
196                 strncpy(tgt->statusMsg, "", 1);
197         }
198
199     omxMatrix* omxLookupDuplicateElement(omxState* os, omxMatrix* element) {
200         if(os == NULL || element == NULL) return NULL;
201         if(os->parentState == NULL) return element; // FIXME: Not sure if this is the correct behavior.
202
203                 if (element->hasMatrixNumber) {
204                         int matrixNumber = element->matrixNumber;
205                         if (matrixNumber >= 0) {
206                                 return(os->algebraList[matrixNumber]);
207                         } else {
208                                 return(os->matrixList[-matrixNumber - 1]);
209                         }
210                 }
211
212         for(int i = 0; i < os->numConstraints; i++) {
213             if(os->parentConList[i].result == element) {
214                                 if(os->conList[i].result != NULL)   // Not sure of proper failure behavior here.
215                     return(os->conList[i].result);
216                 else
217                     omxRaiseError(os, -2, "Initialization Copy Error: Constraint required but not yet processed.");
218             }
219         }
220
221         return NULL;
222     }
223
224         void omxFreeState(omxState *oo) {
225                 int k;
226
227                 if (oo->numChildren > 0) {
228                         for(k = 0; k < oo->numChildren; k++) {
229                                 omxFreeState(oo->childList[k]);
230                         }
231                         Free(oo->childList);
232                         oo->childList = NULL;
233                         oo->numChildren = 0;
234                 }
235
236                 if(OMX_DEBUG) { Rprintf("Freeing %d Algebras.\n", oo->numAlgs);}
237                 for(k = 0; k < oo->numAlgs; k++) {
238                         if(OMX_DEBUG) { Rprintf("Freeing Algebra %d at 0x%x.\n", k, oo->algebraList[k]); }
239                         omxFreeAllMatrixData(oo->algebraList[k]);
240                 }
241
242                 if(OMX_DEBUG) { Rprintf("Freeing %d Matrices.\n", oo->numMats);}
243                 for(k = 0; k < oo->numMats; k++) {
244                         if(OMX_DEBUG) { Rprintf("Freeing Matrix %d at 0x%x.\n", k, oo->matrixList[k]); }
245                         omxFreeAllMatrixData(oo->matrixList[k]);
246                 }
247
248                 if(OMX_DEBUG) { Rprintf("Freeing %d Data Sets.\n", oo->numData);}
249                 for(k = 0; k < oo->numData; k++) {
250                         if(OMX_DEBUG) { Rprintf("Freeing Data Set %d at 0x%x.\n", k, oo->dataList[k]); }
251                         omxFreeData(oo->dataList[k]);
252                 }
253
254         if(OMX_DEBUG) {Rprintf("Freeing %d Children.\n", oo->numChildren);}
255         for(k = 0; k < oo->numChildren; k++) {
256                         if(OMX_DEBUG) { Rprintf("Freeing Child State %d at 0x%x.\n", k, oo->childList[k]); }
257                         omxFreeState(oo->childList[k]);            
258         }
259
260                 if(OMX_DEBUG) { Rprintf("Freeing %d Checkpoints.\n", oo->numCheckpoints);}
261                 for(k = 0; k < oo->numCheckpoints; k++) {
262                         if(OMX_DEBUG) { Rprintf("Freeing Data Set %d at 0x%x.\n", k, oo->checkpointList[k]); }
263                         omxCheckpoint oC = oo->checkpointList[k];
264                         switch(oC.type) {
265                                 case OMX_FILE_CHECKPOINT:
266                                         fclose(oC.file);
267                                         break;
268                                 case OMX_SOCKET_CHECKPOINT:     // NYI :::DEBUG:::
269                                         // TODO: Close socket
270                                         break;
271                                 case OMX_CONNECTION_CHECKPOINT: // NYI :::DEBUG:::
272                                         // Do nothing: this should be handled by R upon return.
273                                         break;
274                         }
275                         if(oo->chkptText1 != NULL) {
276                                 Free(oo->chkptText1);
277                         }
278                         if(oo->chkptText2 != NULL) {
279                                 Free(oo->chkptText2);
280                         }
281                         // Checkpoint list itself is freed by R.
282                 }
283
284                 if(OMX_DEBUG) { Rprintf("State Freed.\n");}
285         }
286
287         void omxSaveState(omxState *os, double* freeVals, double minimum) {
288                 if(os->optimalValues == NULL) {
289                         os->optimalValues = (double*) R_alloc(os->numFreeParams, sizeof(double));
290                 }
291
292                 for(int i = 0; i < os->numFreeParams; i++) {
293                         os->optimalValues[i] = freeVals[i];
294                 }
295                 os->optimum = minimum;
296                 os->optimumStatus = os->statusCode;
297                 strncpy(os->optimumMsg, os->statusMsg, 250);
298         }
299
300         void omxRaiseError(omxState *oo, int errorCode, char* errorMsg) {
301                 if(OMX_DEBUG && errorCode) { Rprintf("Error %d raised: %s", errorCode, errorMsg);}
302                 if(OMX_DEBUG && !errorCode) { Rprintf("Error status cleared."); }
303                 oo->statusCode = errorCode;
304                 strncpy(oo->statusMsg, errorMsg, 249);
305                 oo->statusMsg[249] = '\0';
306                 if(oo->computeCount <= 0 && errorCode < 0) {
307                         oo->statusCode--;                       // Decrement status for init errors.
308                 }
309         }
310
311         void omxStateNextRow(omxState *oo) {
312                 oo->currentRow++;
313         };
314         void omxStateNextEvaluation(omxState *oo) {
315                 oo->currentRow = 0;
316                 oo->computeCount++;
317         };
318
319         void omxSaveCheckpoint(omxState *os, double* x, double* f) {
320                 time_t now = time(NULL);
321                 int soFar = now - os->startTime;                // Translated into minutes
322                 int n;
323                 for(int i = 0; i < os->numCheckpoints; i++) {
324                         n = 0;
325                         omxCheckpoint* oC = &(os->checkpointList[i]);
326                         // Check based on time
327                         if(oC->time > 0 && (soFar - oC->lastCheckpoint) >= oC->time) {
328                                 oC->lastCheckpoint = soFar;
329                                 n = 1;
330                         }
331                         // Or iterations
332                         if(oC->numIterations > 0 && (os->majorIteration - oC->lastCheckpoint) >= oC->numIterations) {
333                                 oC->lastCheckpoint = os->majorIteration;
334                                 n = 1;
335                         }
336
337                         if(n) {         //In either case, save a checkpoint.
338                                 if(os->chkptText1 == NULL) {    // First one: set up output
339                                         // FIXME: Is it faster to allocate this on the stack?
340                                         os->chkptText1 = (char*) Calloc((24+15*os->numFreeParams), char);
341                                         os->chkptText2 = (char*) Calloc(1.0+15.0*os->numFreeParams*
342                                                                                                                 (os->numFreeParams + 1.0)/2.0, char);
343                                         if (oC->type == OMX_FILE_CHECKPOINT) {
344                                                 fprintf(oC->file, "iterations\ttimestamp\tobjective\t");
345                                                 for(int j = 0; j < os->numFreeParams; j++) {
346                                                         if(strcmp(os->freeVarList[j].name, CHAR(NA_STRING)) == 0) {
347                                                                 fprintf(oC->file, "%s", os->freeVarList[j].name);
348                                                         } else {
349                                                                 fprintf(oC->file, "\"%s\"", os->freeVarList[j].name);
350                                                         }
351                                                         if (j != os->numFreeParams - 1) fprintf(oC->file, "\t");
352                                                 }
353                                                 fprintf(oC->file, "\n");
354                                                 fflush(oC->file);
355                                         }
356                                 }
357                                 char tempstring[25];
358                                 sprintf(tempstring, "%d", os->majorIteration);
359
360                                 if(strncmp(os->chkptText1, tempstring, strlen(tempstring))) {   // Returns zero if they're the same.
361                                         struct tm * nowTime = localtime(&now);                                          // So this only happens if the text is out of date.
362                                         strftime(tempstring, 25, "%b %d %Y %I:%M:%S %p", nowTime);
363                                         sprintf(os->chkptText1, "%d \"%s\" %9.5f", os->majorIteration, tempstring, f[0]);
364                                         for(int j = 0; j < os->numFreeParams; j++) {
365                                                 sprintf(tempstring, " %9.5f", x[j]);
366                                                 strncat(os->chkptText1, tempstring, 14);
367                                         }
368
369                                         double* hessian = os->hessian;
370                                         if(hessian != NULL) {
371                                                 for(int j = 0; j < os->numFreeParams; j++) {
372                                                         for(int k = 0; k <= j; k++) {
373                                                                 sprintf(tempstring, " %9.5f", hessian[j]);
374                                                                 strncat(os->chkptText2, tempstring, 14);
375                                                         }
376                                                 }
377                                         }
378                                 }
379
380                                 if(oC->type == OMX_FILE_CHECKPOINT) {
381                                         fprintf(oC->file, "%s", os->chkptText1);
382                                         if(oC->saveHessian)
383                                                 fprintf(oC->file, "%s", os->chkptText2);
384                                         fprintf(oC->file, "\n");
385                                         fflush(oC->file);
386                                 } else if(oC->type == OMX_SOCKET_CHECKPOINT) {
387                                         n = write(oC->socket, os->chkptText1, strlen(os->chkptText1));
388                                         if(n != strlen(os->chkptText1)) warning("Error writing checkpoint.");
389                                         if(oC->saveHessian) {
390                                                 n = write(oC->socket, os->chkptText2, strlen(os->chkptText2));
391                                                 if(n != strlen(os->chkptText1)) warning("Error writing checkpoint.");
392                                         }
393                                         n = write(oC->socket, "\n", 1);
394                                         if(n != 1) warning("Error writing checkpoint.");
395                                 } else if(oC->type == OMX_CONNECTION_CHECKPOINT) {
396                                         warning("NYI: R_connections are not yet implemented.");
397                                         oC->numIterations = 0;
398                                         oC->time = 0;
399                                 }
400                         }
401                 }
402         }