Split omxState into truly global stuff and per-thread stuff
[openmx:openmx.git] / src / omxState.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include <stdarg.h>
18 #include <errno.h>
19
20 #include "omxState.h"
21 #include "Compute.h"
22 #include "omxOpenmpWrap.h"
23
24 struct omxGlobal Global;
25
26 /* Initialize and Destroy */
27         void omxInitState(omxState* state, omxState *parentState) {
28                 state->numConstraints = 0;
29                 state->numFreeParams = 0;
30                 state->childList = NULL;
31                 state->parentState = parentState;
32                 state->conList = NULL;
33                 state->freeVarList = NULL;
34
35                 state->majorIteration = 0;
36                 state->minorIteration = 0;
37                 state->startTime = 0;
38                 state->endTime = 0;
39                 state->numCheckpoints = 0;
40                 state->checkpointList = NULL;
41                 state->chkptText1 = NULL;
42                 state->chkptText2 = NULL;
43
44                 state->computeCount = 0;
45                 state->currentRow = -1;
46
47                 strncpy(state->statusMsg, "", 1);
48         }
49
50         void omxSetMajorIteration(omxState *state, int value) {
51                 state->majorIteration = value;
52                 if (!state->childList) return;
53                 for(int i = 0; i < Global.numChildren; i++) {
54                         omxSetMajorIteration(state->childList[i], value);
55                 }
56         }
57
58         void omxSetMinorIteration(omxState *state, int value) {
59                 state->minorIteration = value;
60                 if (!state->childList) return;
61                 for(int i = 0; i < Global.numChildren; i++) {
62                         omxSetMinorIteration(state->childList[i], value);
63                 }
64         }
65         
66         void omxDuplicateState(omxState* tgt, omxState* src) {
67                 tgt->dataList                   = src->dataList;
68                 
69                 // Duplicate matrices and algebras and build parentLists.
70                 tgt->parentState                = src;
71                                 
72                 for(size_t mx = 0; mx < src->matrixList.size(); mx++) {
73                         // TODO: Smarter inference for which matrices to duplicate
74                         tgt->matrixList.push_back(omxDuplicateMatrix(src->matrixList[mx], tgt));
75                 }
76
77                 tgt->numConstraints     = src->numConstraints;
78                 tgt->conList                    = (omxConstraint*) R_alloc(tgt->numConstraints, sizeof(omxConstraint));
79                 for(int j = 0; j < tgt->numConstraints; j++) {
80                         tgt->conList[j].size   = src->conList[j].size;
81                         tgt->conList[j].opCode = src->conList[j].opCode;
82                         tgt->conList[j].lbound = src->conList[j].lbound;
83                         tgt->conList[j].ubound = src->conList[j].ubound;
84                         tgt->conList[j].result = omxDuplicateMatrix(src->conList[j].result, tgt);
85                 }
86
87                 for(size_t j = 0; j < src->algebraList.size(); j++) {
88                         // TODO: Smarter inference for which algebras to duplicate
89                         tgt->algebraList.push_back(omxDuplicateMatrix(src->algebraList[j], tgt));
90                 }
91
92                 for(size_t j = 0; j < src->expectationList.size(); j++) {
93                         // TODO: Smarter inference for which expectations to duplicate
94                         tgt->expectationList.push_back(omxDuplicateExpectation(src->expectationList[j], tgt));
95                 }
96
97                 for(size_t j = 0; j < tgt->algebraList.size(); j++) {
98                         omxDuplicateAlgebra(tgt->algebraList[j], src->algebraList[j], tgt);
99                 }
100
101                 for(size_t j = 0; j < src->expectationList.size(); j++) {
102                         // TODO: Smarter inference for which expectations to duplicate
103                         omxCompleteExpectation(tgt->expectationList[j]);
104                 }
105
106                 tgt->childList                  = NULL;
107
108                 tgt->numFreeParams                      = src->numFreeParams;
109                 tgt->freeVarList                = new omxFreeVar[tgt->numFreeParams];
110                 for(int j = 0; j < tgt->numFreeParams; j++) {
111                         int numDeps                                                     = src->freeVarList[j].numDeps;
112
113                         tgt->freeVarList[j].lbound                      = src->freeVarList[j].lbound;
114                         tgt->freeVarList[j].ubound                      = src->freeVarList[j].ubound;
115                         tgt->freeVarList[j].locations                   = src->freeVarList[j].locations;
116                         tgt->freeVarList[j].numDeps                     = numDeps;
117                         
118                         tgt->freeVarList[j].deps                        = (int*) R_alloc(numDeps, sizeof(int));
119
120                         tgt->freeVarList[j].name                = src->freeVarList[j].name;
121
122                         for(int k = 0; k < numDeps; k++) {
123                                 tgt->freeVarList[j].deps[k] = src->freeVarList[j].deps[k];
124                         }
125                 }
126                 
127                 tgt->majorIteration     = 0;
128                 tgt->minorIteration     = 0;
129                 tgt->startTime                  = src->startTime;
130                 tgt->endTime                    = 0;
131                 
132                 // TODO: adjust checkpointing based on parallelization method
133                 tgt->numCheckpoints     = 0;
134                 tgt->checkpointList     = NULL;
135                 tgt->chkptText1                 = NULL;
136                 tgt->chkptText2                 = NULL;
137                                   
138                 tgt->computeCount               = src->computeCount;
139                 tgt->currentRow                 = src->currentRow;
140
141                 strncpy(tgt->statusMsg, "", 1);
142         }
143
144         omxMatrix* omxLookupDuplicateElement(omxState* os, omxMatrix* element) {
145                 if(os == NULL || element == NULL) return NULL;
146
147                 if (element->hasMatrixNumber) {
148                         int matrixNumber = element->matrixNumber;
149                         if (matrixNumber >= 0) {
150                                 return(os->algebraList[matrixNumber]);
151                         } else {
152                                 return(os->matrixList[-matrixNumber - 1]);
153                         }
154                 }
155
156                 omxConstraint* parentConList = os->parentState->conList;
157
158                 for(int i = 0; i < os->numConstraints; i++) {
159                         if(parentConList[i].result == element) {
160                                 if(os->conList[i].result != NULL) {   // Not sure of proper failure behavior here.
161                 return(os->conList[i].result);
162                                 } else {
163                     omxRaiseError(os, -2, "Initialization Copy Error: Constraint required but not yet processed.");
164             }
165                         }
166                 }
167
168                 return NULL;
169         }
170         
171 void omxFreeChildStates(omxState *state)
172 {
173         if (!state->childList || Global.numChildren == 0) return;
174
175         for(int k = 0; k < Global.numChildren; k++) {
176                 // Data are not modified and not copied. The same memory
177                 // is shared across all instances of state. We only need
178                 // to free the data once, so let the parent do it.
179                 state->childList[k]->dataList.clear();
180
181                 omxFreeState(state->childList[k]);
182         }
183         Free(state->childList);
184         state->childList = NULL;
185         Global.numChildren = 0;
186 }
187
188         void omxFreeState(omxState *state) {
189                 omxFreeChildStates(state);
190
191                 for(size_t ax = 0; ax < state->algebraList.size(); ax++) {
192                         if(OMX_DEBUG) { mxLog("Freeing Algebra %d at 0x%x.", ax, state->algebraList[ax]); }
193                         omxFreeAllMatrixData(state->algebraList[ax]);
194                 }
195
196                 if(OMX_DEBUG) { mxLog("Freeing %d Matrices.", state->matrixList.size());}
197                 for(size_t mk = 0; mk < state->matrixList.size(); mk++) {
198                         if(OMX_DEBUG) { mxLog("Freeing Matrix %d at 0x%x.", mk, state->matrixList[mk]); }
199                         omxFreeAllMatrixData(state->matrixList[mk]);
200                 }
201                 
202                 if(OMX_DEBUG) { mxLog("Freeing %d Model Expectations.", state->expectationList.size());}
203                 for(size_t ex = 0; ex < state->expectationList.size(); ex++) {
204                         if(OMX_DEBUG) { mxLog("Freeing Expectation %d at 0x%x.", ex, state->expectationList[ex]); }
205                         omxFreeExpectationArgs(state->expectationList[ex]);
206                 }
207
208                 if(OMX_DEBUG) { mxLog("Freeing %d Constraints.", state->numConstraints);}
209                 for(int k = 0; k < state->numConstraints; k++) {
210                         if(OMX_DEBUG) { mxLog("Freeing Constraint %d at 0x%x.", k, state->conList[k]); }
211                         omxFreeAllMatrixData(state->conList[k].result);
212                 }
213
214                 if(OMX_DEBUG) { mxLog("Freeing %d Data Sets.", state->dataList.size());}
215                 for(size_t dx = 0; dx < state->dataList.size(); dx++) {
216                         if(OMX_DEBUG) { mxLog("Freeing Data Set %d at 0x%x.", dx, state->dataList[dx]); }
217                         omxFreeData(state->dataList[dx]);
218                 }
219
220                 delete [] state->freeVarList;
221
222                 if(OMX_DEBUG) { mxLog("Freeing %d Checkpoints.", state->numCheckpoints);}
223                 for(int k = 0; k < state->numCheckpoints; k++) {
224                         if(OMX_DEBUG) { mxLog("Freeing Data Set %d at 0x%x.", k, state->checkpointList[k]); }
225                         omxCheckpoint oC = state->checkpointList[k];
226                         switch(oC.type) {
227                                 case OMX_FILE_CHECKPOINT:
228                                         fclose(oC.file);
229                                         break;
230                                 case OMX_CONNECTION_CHECKPOINT: // NYI :::DEBUG:::
231                                         // Do nothing: this should be handled by R upon return.
232                                         break;
233                         }
234                         if(state->chkptText1 != NULL) {
235                                 Free(state->chkptText1);
236                         }
237                         if(state->chkptText2 != NULL) {
238                                 Free(state->chkptText2);
239                         }
240                         // Checkpoint list itself is freed by R.
241                 }
242
243                 for (size_t ex = 0; ex < state->computeList.size(); ex++) {
244                         delete state->computeList[ex];
245                 }
246
247                 delete state;
248
249                 if(OMX_DEBUG) { mxLog("State Freed.");}
250         }
251
252         void omxResetStatus(omxState *state) {
253                 int numChildren = Global.numChildren;
254                 state->statusMsg[0] = '\0';
255                 if (!state->childList) return;
256                 for(int i = 0; i < numChildren; i++) {
257                         omxResetStatus(state->childList[i]);
258                 }
259         }
260
261 std::string string_snprintf(const std::string fmt, ...)
262 {
263     int size = 100;
264     std::string str;
265     va_list ap;
266     while (1) {
267         str.resize(size);
268         va_start(ap, fmt);
269         int n = vsnprintf((char *)str.c_str(), size, fmt.c_str(), ap);
270         va_end(ap);
271         if (n > -1 && n < size) {
272             str.resize(n);
273             return str;
274         }
275         if (n > -1)
276             size = n + 1;
277         else
278             size *= 2;
279     }
280     return str;
281 }
282
283 void mxLogBig(const std::string str)   // thread-safe
284 {
285         ssize_t len = ssize_t(str.size());
286         ssize_t wrote = 0;
287 #pragma omp critical(stderr)
288         while (1) {
289                 ssize_t got = write(2, str.data() + wrote, len - wrote);
290                 if (got == EINTR) continue;
291                 if (got <= 0) error("mxLogBig failed with errno=%d", got);
292                 wrote += got;
293                 if (wrote == len) break;
294         }
295 }
296
297 void mxLog(const char* msg, ...)   // thread-safe
298 {
299         const int maxLen = 240;
300         char buf1[maxLen];
301         char buf2[maxLen];
302
303         va_list ap;
304         va_start(ap, msg);
305         vsnprintf(buf1, maxLen, msg, ap);
306         va_end(ap);
307
308         int len = snprintf(buf2, maxLen, "[%d] %s\n", omx_absolute_thread_num(), buf1);
309
310         ssize_t wrote = 0;
311 #pragma omp critical(stderr)
312         while (1) {
313                 ssize_t got = write(2, buf2 + wrote, len - wrote);
314                 if (got == EINTR) continue;
315                 if (got <= 0) error("mxLog failed with errno=%d", got);
316                 wrote += got;
317                 if (wrote == len) break;
318         }
319 }
320
321 void omxRaiseErrorf(omxState *state, const char* errorMsg, ...)
322 {
323         va_list ap;
324         va_start(ap, errorMsg);
325         int fit = vsnprintf(state->statusMsg, MAX_STRING_LEN, errorMsg, ap);
326         va_end(ap);
327         if(OMX_DEBUG) {
328                 if (!(fit > -1 && fit < MAX_STRING_LEN)) {
329                         mxLog("Error exceeded maximum length: %s", errorMsg);
330                 } else {
331                         mxLog("Error raised: %s", state->statusMsg);
332                 }
333         }
334 }
335
336         void omxRaiseError(omxState *state, int errorCode, const char* errorMsg) { // DEPRECATED
337                 if(OMX_DEBUG && errorCode) { mxLog("Error %d raised: %s", errorCode, errorMsg);}
338                 if(OMX_DEBUG && !errorCode) { mxLog("Error status cleared."); }
339                 strncpy(state->statusMsg, errorMsg, 249);
340                 state->statusMsg[249] = '\0';
341         }
342
343         void omxStateNextRow(omxState *state) {
344                 state->currentRow++;
345         };
346
347         void omxStateNextEvaluation(omxState *state) {
348                 state->currentRow = -1;
349                 state->computeCount++;
350         };
351
352         void omxWriteCheckpointHeader(omxState *os, omxCheckpoint* oC) {
353                 // FIXME: Is it faster to allocate this on the stack?
354                 os->chkptText1 = (char*) Calloc((24 + 15 * os->numFreeParams), char);
355                 os->chkptText2 = (char*) Calloc(1.0 + 15.0 * os->numFreeParams*
356                         (os->numFreeParams + 1.0) / 2.0, char);
357                 if (oC->type == OMX_FILE_CHECKPOINT) {
358                         fprintf(oC->file, "iterations\ttimestamp\tobjective\t");
359                         for(int j = 0; j < os->numFreeParams; j++) {
360                                 if(strcmp(os->freeVarList[j].name, CHAR(NA_STRING)) == 0) {
361                                         fprintf(oC->file, "%s", os->freeVarList[j].name);
362                                 } else {
363                                         fprintf(oC->file, "\"%s\"", os->freeVarList[j].name);
364                                 }
365                                 if (j != os->numFreeParams - 1) fprintf(oC->file, "\t");
366                         }
367                         fprintf(oC->file, "\n");
368                         fflush(oC->file);
369                 }
370         }
371  
372         void omxWriteCheckpointMessage(omxState *os, char *msg) {
373                 for(int i = 0; i < os->numCheckpoints; i++) {
374                         omxCheckpoint* oC = &(os->checkpointList[i]);
375                         if(os->chkptText1 == NULL) {    // First one: set up output
376                                 omxWriteCheckpointHeader(os, oC);
377                         }
378                         if (oC->type == OMX_FILE_CHECKPOINT) {
379                                 fprintf(oC->file, "%d \"%s\" NA ", os->majorIteration, msg);
380                                 for(int j = 0; j < os->numFreeParams; j++) {
381                                         fprintf(oC->file, "NA ");
382                                 }
383                                 fprintf(oC->file, "\n");
384                         }
385                 }
386         }
387
388         void omxSaveCheckpoint(omxState *os, double* x, double* f, int force) {
389                 time_t now = time(NULL);
390                 int soFar = now - os->startTime;                // Translated into minutes
391                 int n;
392                 for(int i = 0; i < os->numCheckpoints; i++) {
393                         n = 0;
394                         omxCheckpoint* oC = &(os->checkpointList[i]);
395                         // Check based on time            
396                         if((oC->time > 0 && (soFar - oC->lastCheckpoint) >= oC->time) || force) {
397                                 oC->lastCheckpoint = soFar;
398                                 n = 1;
399                         }
400                         // Or iterations
401                         if((oC->numIterations > 0 && (os->majorIteration - oC->lastCheckpoint) >= oC->numIterations) || force) {
402                                 oC->lastCheckpoint = os->majorIteration;
403                                 n = 1;
404                         }
405
406                         if(n) {         //In either case, save a checkpoint.
407                                 if(os->chkptText1 == NULL) {    // First one: set up output
408                                         omxWriteCheckpointHeader(os, oC);
409                                 }
410                                 char tempstring[25];
411                                 sprintf(tempstring, "%d", os->majorIteration);
412
413                                 if(strncmp(os->chkptText1, tempstring, strlen(tempstring))) {   // Returns zero if they're the same.
414                                         struct tm * nowTime = localtime(&now);                                          // So this only happens if the text is out of date.
415                                         strftime(tempstring, 25, "%b %d %Y %I:%M:%S %p", nowTime);
416                                         sprintf(os->chkptText1, "%d \"%s\" %9.5f", os->majorIteration, tempstring, f[0]);
417                                         for(int j = 0; j < os->numFreeParams; j++) {
418                                                 sprintf(tempstring, " %9.5f", x[j]);
419                                                 strncat(os->chkptText1, tempstring, 14);
420                                         }
421                                 }
422
423                                 if(oC->type == OMX_FILE_CHECKPOINT) {
424                                         fprintf(oC->file, "%s", os->chkptText1);
425                                         if(oC->saveHessian)
426                                                 fprintf(oC->file, "%s", os->chkptText2);
427                                         fprintf(oC->file, "\n");
428                                         fflush(oC->file);
429                                 } else if(oC->type == OMX_CONNECTION_CHECKPOINT) {
430                                         warning("NYI: R_connections are not yet implemented.");
431                                         oC->numIterations = 0;
432                                         oC->time = 0;
433                                 }
434                         }
435                 }
436         }
437
438 void omxExamineFitOutput(omxState *state, omxMatrix *fitMatrix, int *mode)
439 {
440         if (!R_FINITE(fitMatrix->data[0])) {
441                 omxRaiseErrorf(state, "Fit function returned %g at iteration %d.%d",
442                                fitMatrix->data[0], state->majorIteration, state->minorIteration);
443                 *mode = -1;
444         }
445 }