Add ComputeAssign
[openmx:openmx.git] / src / omxState.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include <stdarg.h>
18 #include <errno.h>
19
20 #include "omxState.h"
21 #include "Compute.h"
22 #include "omxOpenmpWrap.h"
23
24 struct omxGlobal *Global = NULL;
25
26 /* Initialize and Destroy */
27         void omxInitState(omxState* state) {
28                 state->numConstraints = 0;
29                 state->childList = NULL;
30                 state->conList = NULL;
31
32                 state->majorIteration = 0;
33                 state->minorIteration = 0;
34                 state->startTime = 0;
35                 state->endTime = 0;
36                 state->numCheckpoints = 0;
37                 state->checkpointList = NULL;
38                 state->chkptText1 = NULL;
39                 state->chkptText2 = NULL;
40
41                 state->computeCount = 0;
42                 state->currentRow = -1;
43
44                 strncpy(state->statusMsg, "", 1);
45         }
46
47         void omxSetMajorIteration(omxState *state, int value) {
48                 state->majorIteration = value;
49                 if (!state->childList) return;
50                 for(int i = 0; i < Global->numChildren; i++) {
51                         omxSetMajorIteration(state->childList[i], value);
52                 }
53         }
54
55         void omxSetMinorIteration(omxState *state, int value) {
56                 state->minorIteration = value;
57                 if (!state->childList) return;
58                 for(int i = 0; i < Global->numChildren; i++) {
59                         omxSetMinorIteration(state->childList[i], value);
60                 }
61         }
62         
63         void omxDuplicateState(omxState* tgt, omxState* src) {
64                 tgt->dataList                   = src->dataList;
65                 
66                 for(size_t mx = 0; mx < src->matrixList.size(); mx++) {
67                         // TODO: Smarter inference for which matrices to duplicate
68                         tgt->matrixList.push_back(omxDuplicateMatrix(src->matrixList[mx], tgt));
69                 }
70
71                 tgt->numConstraints     = src->numConstraints;
72                 tgt->conList                    = (omxConstraint*) R_alloc(tgt->numConstraints, sizeof(omxConstraint));
73                 for(int j = 0; j < tgt->numConstraints; j++) {
74                         tgt->conList[j].size   = src->conList[j].size;
75                         tgt->conList[j].opCode = src->conList[j].opCode;
76                         tgt->conList[j].lbound = src->conList[j].lbound;
77                         tgt->conList[j].ubound = src->conList[j].ubound;
78                         tgt->conList[j].result = omxDuplicateMatrix(src->conList[j].result, tgt);
79                 }
80
81                 for(size_t j = 0; j < src->algebraList.size(); j++) {
82                         // TODO: Smarter inference for which algebras to duplicate
83                         tgt->algebraList.push_back(omxDuplicateMatrix(src->algebraList[j], tgt));
84                 }
85
86                 for(size_t j = 0; j < src->expectationList.size(); j++) {
87                         // TODO: Smarter inference for which expectations to duplicate
88                         tgt->expectationList.push_back(omxDuplicateExpectation(src->expectationList[j], tgt));
89                 }
90
91                 for(size_t j = 0; j < tgt->algebraList.size(); j++) {
92                         omxDuplicateAlgebra(tgt->algebraList[j], src->algebraList[j], tgt);
93                 }
94
95                 for(size_t j = 0; j < src->expectationList.size(); j++) {
96                         // TODO: Smarter inference for which expectations to duplicate
97                         omxCompleteExpectation(tgt->expectationList[j]);
98                 }
99
100                 tgt->childList                  = NULL;
101
102                 tgt->majorIteration     = 0;
103                 tgt->minorIteration     = 0;
104                 tgt->startTime                  = src->startTime;
105                 tgt->endTime                    = 0;
106                 
107                 // TODO: adjust checkpointing based on parallelization method
108                 tgt->numCheckpoints     = 0;
109                 tgt->checkpointList     = NULL;
110                 tgt->chkptText1                 = NULL;
111                 tgt->chkptText2                 = NULL;
112                                   
113                 tgt->computeCount               = src->computeCount;
114                 tgt->currentRow                 = src->currentRow;
115
116                 strncpy(tgt->statusMsg, "", 1);
117         }
118
119         omxMatrix* omxLookupDuplicateElement(omxState* os, omxMatrix* element) {
120                 if(os == NULL || element == NULL) return NULL;
121
122                 if (element->hasMatrixNumber) {
123                         int matrixNumber = element->matrixNumber;
124                         if (matrixNumber >= 0) {
125                                 return(os->algebraList[matrixNumber]);
126                         } else {
127                                 return(os->matrixList[-matrixNumber - 1]);
128                         }
129                 }
130
131                 omxConstraint* parentConList = globalState->conList;
132
133                 for(int i = 0; i < os->numConstraints; i++) {
134                         if(parentConList[i].result == element) {
135                                 if(os->conList[i].result != NULL) {   // Not sure of proper failure behavior here.
136                 return(os->conList[i].result);
137                                 } else {
138                     omxRaiseError(os, -2, "Initialization Copy Error: Constraint required but not yet processed.");
139             }
140                         }
141                 }
142
143                 return NULL;
144         }
145         
146 void omxFreeChildStates(omxState *state)
147 {
148         if (!state->childList || Global->numChildren == 0) return;
149
150         for(int k = 0; k < Global->numChildren; k++) {
151                 // Data are not modified and not copied. The same memory
152                 // is shared across all instances of state. We only need
153                 // to free the data once, so let the parent do it.
154                 state->childList[k]->dataList.clear();
155
156                 omxFreeState(state->childList[k]);
157         }
158         Free(state->childList);
159         state->childList = NULL;
160         Global->numChildren = 0;
161 }
162
163         void omxFreeState(omxState *state) {
164                 omxFreeChildStates(state);
165
166                 for(size_t ax = 0; ax < state->algebraList.size(); ax++) {
167                         if(OMX_DEBUG) { mxLog("Freeing Algebra %d at 0x%x.", ax, state->algebraList[ax]); }
168                         omxFreeAllMatrixData(state->algebraList[ax]);
169                 }
170
171                 if(OMX_DEBUG) { mxLog("Freeing %d Matrices.", state->matrixList.size());}
172                 for(size_t mk = 0; mk < state->matrixList.size(); mk++) {
173                         if(OMX_DEBUG) { mxLog("Freeing Matrix %d at 0x%x.", mk, state->matrixList[mk]); }
174                         omxFreeAllMatrixData(state->matrixList[mk]);
175                 }
176                 
177                 if(OMX_DEBUG) { mxLog("Freeing %d Model Expectations.", state->expectationList.size());}
178                 for(size_t ex = 0; ex < state->expectationList.size(); ex++) {
179                         if(OMX_DEBUG) { mxLog("Freeing Expectation %d at 0x%x.", ex, state->expectationList[ex]); }
180                         omxFreeExpectationArgs(state->expectationList[ex]);
181                 }
182
183                 if(OMX_DEBUG) { mxLog("Freeing %d Constraints.", state->numConstraints);}
184                 for(int k = 0; k < state->numConstraints; k++) {
185                         if(OMX_DEBUG) { mxLog("Freeing Constraint %d at 0x%x.", k, state->conList[k]); }
186                         omxFreeAllMatrixData(state->conList[k].result);
187                 }
188
189                 if(OMX_DEBUG) { mxLog("Freeing %d Data Sets.", state->dataList.size());}
190                 for(size_t dx = 0; dx < state->dataList.size(); dx++) {
191                         if(OMX_DEBUG) { mxLog("Freeing Data Set %d at 0x%x.", dx, state->dataList[dx]); }
192                         omxFreeData(state->dataList[dx]);
193                 }
194
195                 if(OMX_DEBUG) { mxLog("Freeing %d Checkpoints.", state->numCheckpoints);}
196                 for(int k = 0; k < state->numCheckpoints; k++) {
197                         if(OMX_DEBUG) { mxLog("Freeing Data Set %d at 0x%x.", k, state->checkpointList[k]); }
198                         omxCheckpoint oC = state->checkpointList[k];
199                         switch(oC.type) {
200                                 case OMX_FILE_CHECKPOINT:
201                                         fclose(oC.file);
202                                         break;
203                                 case OMX_CONNECTION_CHECKPOINT: // NYI :::DEBUG:::
204                                         // Do nothing: this should be handled by R upon return.
205                                         break;
206                         }
207                         if(state->chkptText1 != NULL) {
208                                 Free(state->chkptText1);
209                         }
210                         if(state->chkptText2 != NULL) {
211                                 Free(state->chkptText2);
212                         }
213                         // Checkpoint list itself is freed by R.
214                 }
215
216                 delete state;
217
218                 if(OMX_DEBUG) { mxLog("State Freed.");}
219         }
220
221 FreeVarGroup::~FreeVarGroup()
222 {
223         for (size_t vx=0; vx < vars.size(); ++vx) {
224                 delete vars[vx];
225         }
226 }
227
228 omxGlobal::~omxGlobal()
229 {
230         for (size_t cx=0; cx < computeList.size(); ++cx) {
231                 delete computeList[cx];
232         }
233         for (size_t gx=0; gx < freeGroup.size(); ++gx) {
234                 delete freeGroup[gx];
235         }
236 }
237
238         void omxResetStatus(omxState *state) {
239                 int numChildren = Global->numChildren;
240                 state->statusMsg[0] = '\0';
241                 if (!state->childList) return;
242                 for(int i = 0; i < numChildren; i++) {
243                         omxResetStatus(state->childList[i]);
244                 }
245         }
246
247 std::string string_snprintf(const std::string fmt, ...)
248 {
249     int size = 100;
250     std::string str;
251     va_list ap;
252     while (1) {
253         str.resize(size);
254         va_start(ap, fmt);
255         int n = vsnprintf((char *)str.c_str(), size, fmt.c_str(), ap);
256         va_end(ap);
257         if (n > -1 && n < size) {
258             str.resize(n);
259             return str;
260         }
261         if (n > -1)
262             size = n + 1;
263         else
264             size *= 2;
265     }
266     return str;
267 }
268
269 void mxLogBig(const std::string str)   // thread-safe
270 {
271         ssize_t len = ssize_t(str.size());
272         ssize_t wrote = 0;
273 #pragma omp critical(stderr)
274         while (1) {
275                 ssize_t got = write(2, str.data() + wrote, len - wrote);
276                 if (got == EINTR) continue;
277                 if (got <= 0) error("mxLogBig failed with errno=%d", got);
278                 wrote += got;
279                 if (wrote == len) break;
280         }
281 }
282
283 void mxLog(const char* msg, ...)   // thread-safe
284 {
285         const int maxLen = 240;
286         char buf1[maxLen];
287         char buf2[maxLen];
288
289         va_list ap;
290         va_start(ap, msg);
291         vsnprintf(buf1, maxLen, msg, ap);
292         va_end(ap);
293
294         int len = snprintf(buf2, maxLen, "[%d] %s\n", omx_absolute_thread_num(), buf1);
295
296         ssize_t wrote = 0;
297 #pragma omp critical(stderr)
298         while (1) {
299                 ssize_t got = write(2, buf2 + wrote, len - wrote);
300                 if (got == EINTR) continue;
301                 if (got <= 0) error("mxLog failed with errno=%d", got);
302                 wrote += got;
303                 if (wrote == len) break;
304         }
305 }
306
307 void _omxRaiseError()
308 {
309         // keep for debugger breakpoints
310 }
311
312 void omxRaiseErrorf(omxState *state, const char* errorMsg, ...)
313 {
314         _omxRaiseError();
315         va_list ap;
316         va_start(ap, errorMsg);
317         int fit = vsnprintf(state->statusMsg, MAX_STRING_LEN, errorMsg, ap);
318         va_end(ap);
319         if(OMX_DEBUG) {
320                 if (!(fit > -1 && fit < MAX_STRING_LEN)) {
321                         mxLog("Error exceeded maximum length: %s", errorMsg);
322                 } else {
323                         mxLog("Error raised: %s", state->statusMsg);
324                 }
325         }
326 }
327
328         void omxRaiseError(omxState *state, int errorCode, const char* errorMsg) { // DEPRECATED
329                 _omxRaiseError();
330                 if(OMX_DEBUG && errorCode) { mxLog("Error %d raised: %s", errorCode, errorMsg);}
331                 if(OMX_DEBUG && !errorCode) { mxLog("Error status cleared."); }
332                 strncpy(state->statusMsg, errorMsg, 249);
333                 state->statusMsg[249] = '\0';
334         }
335
336         void omxStateNextRow(omxState *state) {
337                 state->currentRow++;
338         };
339
340         void omxStateNextEvaluation(omxState *state) {
341                 state->currentRow = -1;
342                 state->computeCount++;
343         };
344
345 static void omxWriteCheckpointHeader(omxState *os, omxCheckpoint* oC) {
346         // rewrite with std::string TODO
347         std::vector< omxFreeVar* > &vars = Global->freeGroup[0]->vars;
348         size_t numParam = vars.size();
349
350                 os->chkptText1 = (char*) Calloc((24 + 15 * numParam), char);
351                 os->chkptText2 = (char*) Calloc(1.0 + 15.0 * numParam*
352                         (numParam + 1.0) / 2.0, char);
353                 if (oC->type == OMX_FILE_CHECKPOINT) {
354                         fprintf(oC->file, "iterations\ttimestamp\tobjective\t");
355                         for(size_t j = 0; j < numParam; j++) {
356                                 if(strcmp(vars[j]->name, CHAR(NA_STRING)) == 0) {
357                                         fprintf(oC->file, "%s", vars[j]->name);
358                                 } else {
359                                         fprintf(oC->file, "\"%s\"", vars[j]->name);
360                                 }
361                                 if (j != numParam - 1) fprintf(oC->file, "\t");
362                         }
363                         fprintf(oC->file, "\n");
364                         fflush(oC->file);
365                 }
366         }
367  
368 void omxWriteCheckpointMessage(char *msg) {
369         std::vector< omxFreeVar* > &vars = Global->freeGroup[0]->vars;
370         size_t numParam = vars.size();
371
372                 omxState *os = globalState;
373                 for(int i = 0; i < os->numCheckpoints; i++) {
374                         omxCheckpoint* oC = &(os->checkpointList[i]);
375                         if(os->chkptText1 == NULL) {    // First one: set up output
376                                 omxWriteCheckpointHeader(os, oC);
377                         }
378                         if (oC->type == OMX_FILE_CHECKPOINT) {
379                                 fprintf(oC->file, "%d \"%s\" NA ", os->majorIteration, msg);
380                                 for(size_t j = 0; j < numParam; j++) {
381                                         fprintf(oC->file, "NA ");
382                                 }
383                                 fprintf(oC->file, "\n");
384                         }
385                 }
386         }
387
388 void omxSaveCheckpoint(double* x, double f, int force) {
389         // rewrite with std::string TODO
390         std::vector< omxFreeVar* > &vars = Global->freeGroup[0]->vars;
391         size_t numParam = vars.size();
392
393                 omxState *os = globalState;
394                 time_t now = time(NULL);
395                 int soFar = now - os->startTime;                // Translated into minutes
396                 int n;
397                 for(int i = 0; i < os->numCheckpoints; i++) {
398                         n = 0;
399                         omxCheckpoint* oC = &(os->checkpointList[i]);
400                         // Check based on time            
401                         if((oC->time > 0 && (soFar - oC->lastCheckpoint) >= oC->time) || force) {
402                                 oC->lastCheckpoint = soFar;
403                                 n = 1;
404                         }
405                         // Or iterations
406                         if((oC->numIterations > 0 && (os->majorIteration - oC->lastCheckpoint) >= oC->numIterations) || force) {
407                                 oC->lastCheckpoint = os->majorIteration;
408                                 n = 1;
409                         }
410
411                         if(n) {         //In either case, save a checkpoint.
412                                 if(os->chkptText1 == NULL) {    // First one: set up output
413                                         omxWriteCheckpointHeader(os, oC);
414                                 }
415                                 char tempstring[25];
416                                 sprintf(tempstring, "%d", os->majorIteration);
417
418                                 if(strncmp(os->chkptText1, tempstring, strlen(tempstring))) {   // Returns zero if they're the same.
419                                         struct tm * nowTime = localtime(&now);                                          // So this only happens if the text is out of date.
420                                         strftime(tempstring, 25, "%b %d %Y %I:%M:%S %p", nowTime);
421                                         sprintf(os->chkptText1, "%d \"%s\" %9.5f", os->majorIteration, tempstring, f);
422                                         for(size_t j = 0; j < numParam; j++) {
423                                                 sprintf(tempstring, " %9.5f", x[j]);
424                                                 strncat(os->chkptText1, tempstring, 14);
425                                         }
426                                 }
427
428                                 if(oC->type == OMX_FILE_CHECKPOINT) {
429                                         fprintf(oC->file, "%s", os->chkptText1);
430                                         if(oC->saveHessian)
431                                                 fprintf(oC->file, "%s", os->chkptText2);
432                                         fprintf(oC->file, "\n");
433                                         fflush(oC->file);
434                                 } else if(oC->type == OMX_CONNECTION_CHECKPOINT) {
435                                         warning("NYI: R_connections are not yet implemented.");
436                                         oC->numIterations = 0;
437                                         oC->time = 0;
438                                 }
439                         }
440                 }
441         }
442
443 void omxExamineFitOutput(omxState *state, omxMatrix *fitMatrix, int *mode)
444 {
445         if (!R_FINITE(fitMatrix->data[0])) {
446                 omxRaiseErrorf(state, "Fit function returned %g at iteration %d.%d",
447                                fitMatrix->data[0], state->majorIteration, state->minorIteration);
448                 *mode = -1;
449         }
450 }
451
452 omxFreeVarLocation *omxFreeVar::getLocation(int matrix)
453 {
454         for (size_t lx=0; lx < locations.size(); lx++) {
455                 omxFreeVarLocation *loc = &locations[lx];
456                 if (~loc->matrix == matrix) return loc;
457         }
458         return NULL;
459 }