Improvements to the checkpoint interface.
[openmx:openmx.git] / src / omxState.c
1 /*
2  *  Copyright 2007-2009 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 /***********************************************************
18 *
19 *  omxState.cc
20 *
21 *  Created: Timothy R. Brick    Date: 2009-06-05
22 *
23 *       omxStates carry the current optimization state
24 *
25 **********************************************************/
26
27 #include "omxState.h"
28
29 /* Initialize and Destroy */
30         void omxInitState(omxState* state) {
31                 state->numMats = 0;
32                 state->numAlgs = 0;
33                 state->numData = 0;
34                 state->matrixList = NULL;
35                 state->algebraList = NULL;
36                 state->dataList = NULL;
37                 state->objectiveMatrix = NULL;
38                 state->hessian = NULL;
39                 state->conList = NULL;
40                 state->freeVarList = NULL;
41                 state->optimizerState = NULL;
42                 state->optimalValues = NULL;
43                 state->optimum = 9999999999;
44
45                 state->majorIteration = 0;
46                 state->minorIteration = 0;
47                 state->startTime = 0;
48                 state->endTime = 0;
49                 state->numCheckpoints = 0;
50                 state->checkpointList = NULL;
51                 state->chkptText1 = NULL;
52                 state->chkptText2 = NULL;
53
54                 state->computeCount = -1;
55                 state->currentRow = -1;
56
57                 state->statusCode = 0;
58                 strncpy(state->statusMsg, "", 1);
59         }
60
61         void omxFillState(omxState* state, /*omxOptimizer *oo,*/ omxMatrix** matrixList,
62                                                 omxMatrix** algebraList, omxData** dataList, omxMatrix* objective) {
63                 error("NYI: Can't fill a state from outside yet. Besides, do you really need a single function to do this?");
64         }
65
66         void omxFreeState(omxState *oo) {
67                 int k;
68                 if(OMX_DEBUG) { Rprintf("Freeing %d Algebras.\n", oo->numAlgs);}
69                 for(k = 0; k < oo->numAlgs; k++) {
70                         if(OMX_DEBUG) { Rprintf("Freeing Algebra %d at 0x%x.\n", k, oo->algebraList[k]); }
71                         omxFreeAllMatrixData(oo->algebraList[k]);
72                 }
73
74                 if(OMX_DEBUG) { Rprintf("Freeing %d Matrices.\n", oo->numMats);}
75                 for(k = 0; k < oo->numMats; k++) {
76                         if(OMX_DEBUG) { Rprintf("Freeing Matrix %d at 0x%x.\n", k, oo->matrixList[k]); }
77                         omxFreeAllMatrixData(oo->matrixList[k]);
78                 }
79
80                 if(OMX_DEBUG) { Rprintf("Freeing %d Data Sets.\n", oo->numData);}
81                 for(k = 0; k < oo->numData; k++) {
82                         if(OMX_DEBUG) { Rprintf("Freeing Data Set %d at 0x%x.\n", k, oo->dataList[k]); }
83                         omxFreeData(oo->dataList[k]);
84                 }
85
86                 if(OMX_DEBUG) { Rprintf("Freeing %d Checkpoints.\n", oo->numCheckpoints);}
87                 for(k = 0; k < oo->numCheckpoints; k++) {
88                         if(OMX_DEBUG) { Rprintf("Freeing Data Set %d at 0x%x.\n", k, oo->checkpointList[k]); }
89                         omxCheckpoint oC = oo->checkpointList[k];
90                         switch(oC.type) {
91                                 case OMX_FILE_CHECKPOINT:
92                                         fclose(oC.file);
93                                         break;
94                                 case OMX_SOCKET_CHECKPOINT:     // NYI :::DEBUG:::
95                                         // TODO: Close socket
96                                         break;
97                                 case OMX_CONNECTION_CHECKPOINT: // NYI :::DEBUG:::
98                                         // Do nothing: this should be handled by R upon return.
99                                         break;
100                         }
101                         if(oo->chkptText1 != NULL) {
102                                 Free(oo->chkptText1);
103                         }
104                         if(oo->chkptText2 != NULL) {
105                                 Free(oo->chkptText2);
106                         }
107                         // Checkpoint list itself is freed by R.
108                 }
109
110                 if(OMX_DEBUG) { Rprintf("State Freed.\n");}
111         }
112
113         void omxSaveState(omxState *os, double* freeVals, double minimum) {
114                 if(os->optimalValues == NULL) {
115                         os->optimalValues = (double*) R_alloc(os->numFreeParams, sizeof(double));
116                 }
117
118                 for(int i = 0; i < os->numFreeParams; i++) {
119                         os->optimalValues[i] = freeVals[i];
120                 }
121                 os->optimum = minimum;
122                 os->optimumStatus = os->statusCode;
123                 strncpy(os->optimumMsg, os->statusMsg, 250);
124         }
125
126         void omxRaiseError(omxState *oo, int errorCode, char* errorMsg) {
127                 oo->statusCode = errorCode;
128                 strncpy(oo->statusMsg, errorMsg, 249);
129                 oo->statusMsg[249] = '\0';
130         }
131
132         void omxStateNextRow(omxState *oo) {
133                 oo->currentRow++;
134         };
135         void omxStateNextEvaluation(omxState *oo) {
136                 oo->currentRow = 0;
137                 oo->computeCount++;
138         };
139
140         void omxSaveCheckpoint(omxState *os, double* x, double* f) {
141                 time_t now = time(NULL);
142                 int soFar = now - os->startTime;                // Translated into minutes
143                 int n;
144                 for(int i = 0; i < os->numCheckpoints; i++) {
145                         n = 0;
146                         omxCheckpoint* oC = &(os->checkpointList[i]);
147                         // Check based on time
148                         if(oC->time > 0 && (soFar - oC->lastCheckpoint) >= oC->time) {
149                                 oC->lastCheckpoint = soFar;
150                                 n = 1;
151                         }
152                         // Or iterations
153                         if(oC->numIterations > 0 && (os->majorIteration - oC->lastCheckpoint) >= oC->numIterations) {
154                                 oC->lastCheckpoint = os->majorIteration;
155                                 n = 1;
156                         }
157
158                         if(n) {         //In either case, save a checkpoint.
159                                 if(os->chkptText1 == NULL) {    // First one: set up output
160                                         // FIXME: Is it faster to allocate this on the stack?
161                                         os->chkptText1 = (char*) Calloc((24+15*os->numFreeParams), char);
162                                         os->chkptText2 = (char*) Calloc(1.0+15.0*os->numFreeParams*
163                                                                                                                 (os->numFreeParams + 1.0)/2.0, char);
164                                         if (oC->type == OMX_FILE_CHECKPOINT) {
165                                                 fprintf(oC->file, "iterations\ttimestamp\tobjective\t");
166                                                 for(int j = 0; j < os->numFreeParams; j++) {
167                                                         if(strcmp(os->freeVarList[j].name, CHAR(NA_STRING)) == 0) {
168                                                                 fprintf(oC->file, "%s", os->freeVarList[j].name);
169                                                         } else {
170                                                                 fprintf(oC->file, "\"%s\"", os->freeVarList[j].name);
171                                                         }
172                                                         if (j != os->numFreeParams - 1) fprintf(oC->file, "\t");
173                                                 }
174                                                 fprintf(oC->file, "\n");
175                                                 fflush(oC->file);
176                                         }
177                                 }
178                                 char tempstring[25];
179                                 sprintf(tempstring, "%d", os->majorIteration);
180
181                                 if(strncmp(os->chkptText1, tempstring, strlen(tempstring))) {   // Returns zero if they're the same.
182                                         struct tm * nowTime = localtime(&now);                                          // So this only happens if the text is out of date.
183                                         strftime(tempstring, 25, "%b %d %Y %I:%M:%S %p", nowTime);
184                                         sprintf(os->chkptText1, "%d \"%s\" %9.5f", os->majorIteration, tempstring, f[0]);
185                                         for(int j = 0; j < os->numFreeParams; j++) {
186                                                 sprintf(tempstring, " %9.5f", x[j]);
187                                                 strncat(os->chkptText1, tempstring, 14);
188                                         }
189
190                                         double* hessian = os->hessian;
191                                         if(hessian != NULL) {
192                                                 for(int j = 0; j < os->numFreeParams; j++) {
193                                                         for(int k = 0; k <= j; k++) {
194                                                                 sprintf(tempstring, " %9.5f", hessian[j]);
195                                                                 strncat(os->chkptText2, tempstring, 14);
196                                                         }
197                                                 }
198                                         }
199                                 }
200
201                                 if(oC->type == OMX_FILE_CHECKPOINT) {
202                                         fprintf(oC->file, "%s", os->chkptText1);
203                                         if(oC->saveHessian)
204                                                 fprintf(oC->file, "%s", os->chkptText2);
205                                         fprintf(oC->file, "\n");
206                                         fflush(oC->file);
207                                 } else if(oC->type == OMX_SOCKET_CHECKPOINT) {
208                                         n = write(oC->socket, os->chkptText1, strlen(os->chkptText1));
209                                         if(n != strlen(os->chkptText1)) warning("Error writing checkpoint.");
210                                         if(oC->saveHessian) {
211                                                 n = write(oC->socket, os->chkptText2, strlen(os->chkptText2));
212                                                 if(n != strlen(os->chkptText1)) warning("Error writing checkpoint.");
213                                         }
214                                         n = write(oC->socket, "\n", 1);
215                                         if(n != 1) warning("Error writing checkpoint.");
216                                 } else if(oC->type == OMX_CONNECTION_CHECKPOINT) {
217                                         warning("NYI: R_connections are not yet implemented.");
218                                         oC->numIterations = 0;
219                                         oC->time = 0;
220                                 }
221                         }
222                 }
223         }