Rewrite dependency tracking
[openmx:openmx.git] / src / omxState.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include <stdarg.h>
18 #include <errno.h>
19
20 #include "omxState.h"
21 #include "Compute.h"
22 #include "omxOpenmpWrap.h"
23
24 struct omxGlobal *Global = NULL;
25
26 FreeVarGroup *omxGlobal::findVarGroup(int id)
27 {
28         size_t numGroups = Global->freeGroup.size();
29         for (size_t vx=0; vx < numGroups; ++vx) {
30                 if (Global->freeGroup[vx]->id == id) return Global->freeGroup[vx];
31         }
32         return NULL;
33 }
34
35 FreeVarGroup *omxGlobal::findOrCreateVarGroup(int id)
36 {
37         FreeVarGroup *old = findVarGroup(id);
38         if (old) return old;
39
40         FreeVarGroup *fvg = new FreeVarGroup;
41         fvg->id = id;
42         Global->freeGroup.push_back(fvg);
43         return fvg;
44 }
45
46 void FreeVarGroup::cacheDependencies()
47 {
48         omxState *os = globalState;
49         size_t numMats = os->matrixList.size();
50         size_t numAlgs = os->algebraList.size();
51
52         dependencies.assign(numMats + numAlgs, false);
53
54         for (size_t vx = 0; vx < vars.size(); vx++) {
55                 omxFreeVar *freeVar = vars[vx];
56                 int *deps   = freeVar->deps;
57                 int numDeps = freeVar->numDeps;
58                 for (int index = 0; index < numDeps; index++) {
59                         dependencies[deps[index] + numMats] = true;
60                 }
61         }
62
63         //log();
64 }
65
66 void FreeVarGroup::markDirty(omxState *os)
67 {
68         size_t numMats = os->matrixList.size();
69         size_t numAlgs = os->algebraList.size();
70
71         for(size_t i = 0; i < numMats; i++) {
72                 if (dependencies[i]) {
73                         int offset = ~(i - numMats);
74                         omxMarkDirty(os->matrixList[offset]);
75                 }
76         }
77
78         for(size_t i = 0; i < numAlgs; i++) {
79                 if (dependencies[i + numMats]) {
80                         omxMarkDirty(os->algebraList[i]);
81                 }
82         }
83 }
84
85 void FreeVarGroup::log()
86 {
87         omxState *os = globalState;
88         size_t numMats = os->matrixList.size();
89         size_t numAlgs = os->algebraList.size();
90         std::string str;
91
92         str += string_snprintf("FreeVarGroup[%d] with %d variables:", id, vars.size());
93
94         for (size_t vx=0; vx < vars.size(); ++vx) {
95                 str += " ";
96                 str += vars[vx]->name;
97         }
98         str += "\nwill dirty:";
99
100         for(size_t i = 0; i < numMats; i++) {
101                 if (dependencies[i]) {
102                         int offset = ~(i - numMats);
103                         str += " ";
104                         str += os->matrixList[offset]->name;
105                 }
106         }
107
108         for(size_t i = 0; i < numAlgs; i++) {
109                 if (dependencies[i + numMats]) {
110                         str += " ";
111                         str += os->algebraList[i]->name;
112                 }
113         }
114         str += "\n";
115
116         mxLogBig(str);
117 }
118
119 /* Initialize and Destroy */
120         void omxInitState(omxState* state) {
121                 state->numConstraints = 0;
122                 state->childList = NULL;
123                 state->conList = NULL;
124
125                 state->majorIteration = 0;
126                 state->minorIteration = 0;
127                 state->startTime = 0;
128                 state->endTime = 0;
129                 state->numCheckpoints = 0;
130                 state->checkpointList = NULL;
131                 state->chkptText1 = NULL;
132                 state->chkptText2 = NULL;
133
134                 state->computeCount = 0;
135                 state->currentRow = -1;
136
137                 strncpy(state->statusMsg, "", 1);
138         }
139
140         void omxSetMajorIteration(omxState *state, int value) {
141                 state->majorIteration = value;
142                 if (!state->childList) return;
143                 for(int i = 0; i < Global->numChildren; i++) {
144                         omxSetMajorIteration(state->childList[i], value);
145                 }
146         }
147
148         void omxSetMinorIteration(omxState *state, int value) {
149                 state->minorIteration = value;
150                 if (!state->childList) return;
151                 for(int i = 0; i < Global->numChildren; i++) {
152                         omxSetMinorIteration(state->childList[i], value);
153                 }
154         }
155         
156         void omxDuplicateState(omxState* tgt, omxState* src) {
157                 tgt->dataList                   = src->dataList;
158                 
159                 for(size_t mx = 0; mx < src->matrixList.size(); mx++) {
160                         // TODO: Smarter inference for which matrices to duplicate
161                         tgt->matrixList.push_back(omxDuplicateMatrix(src->matrixList[mx], tgt));
162                 }
163
164                 tgt->numConstraints     = src->numConstraints;
165                 tgt->conList                    = (omxConstraint*) R_alloc(tgt->numConstraints, sizeof(omxConstraint));
166                 for(int j = 0; j < tgt->numConstraints; j++) {
167                         tgt->conList[j].size   = src->conList[j].size;
168                         tgt->conList[j].opCode = src->conList[j].opCode;
169                         tgt->conList[j].lbound = src->conList[j].lbound;
170                         tgt->conList[j].ubound = src->conList[j].ubound;
171                         tgt->conList[j].result = omxDuplicateMatrix(src->conList[j].result, tgt);
172                 }
173
174                 for(size_t j = 0; j < src->algebraList.size(); j++) {
175                         // TODO: Smarter inference for which algebras to duplicate
176                         tgt->algebraList.push_back(omxDuplicateMatrix(src->algebraList[j], tgt));
177                 }
178
179                 for(size_t j = 0; j < src->expectationList.size(); j++) {
180                         // TODO: Smarter inference for which expectations to duplicate
181                         tgt->expectationList.push_back(omxDuplicateExpectation(src->expectationList[j], tgt));
182                 }
183
184                 for(size_t j = 0; j < tgt->algebraList.size(); j++) {
185                         omxDuplicateAlgebra(tgt->algebraList[j], src->algebraList[j], tgt);
186                 }
187
188                 for(size_t j = 0; j < src->expectationList.size(); j++) {
189                         // TODO: Smarter inference for which expectations to duplicate
190                         omxCompleteExpectation(tgt->expectationList[j]);
191                 }
192
193                 tgt->childList                  = NULL;
194
195                 tgt->majorIteration     = 0;
196                 tgt->minorIteration     = 0;
197                 tgt->startTime                  = src->startTime;
198                 tgt->endTime                    = 0;
199                 
200                 // TODO: adjust checkpointing based on parallelization method
201                 tgt->numCheckpoints     = 0;
202                 tgt->checkpointList     = NULL;
203                 tgt->chkptText1                 = NULL;
204                 tgt->chkptText2                 = NULL;
205                                   
206                 tgt->computeCount               = src->computeCount;
207                 tgt->currentRow                 = src->currentRow;
208
209                 strncpy(tgt->statusMsg, "", 1);
210         }
211
212         omxMatrix* omxLookupDuplicateElement(omxState* os, omxMatrix* element) {
213                 if(os == NULL || element == NULL) return NULL;
214
215                 if (element->hasMatrixNumber) {
216                         int matrixNumber = element->matrixNumber;
217                         if (matrixNumber >= 0) {
218                                 return(os->algebraList[matrixNumber]);
219                         } else {
220                                 return(os->matrixList[-matrixNumber - 1]);
221                         }
222                 }
223
224                 omxConstraint* parentConList = globalState->conList;
225
226                 for(int i = 0; i < os->numConstraints; i++) {
227                         if(parentConList[i].result == element) {
228                                 if(os->conList[i].result != NULL) {   // Not sure of proper failure behavior here.
229                 return(os->conList[i].result);
230                                 } else {
231                     omxRaiseError(os, -2, "Initialization Copy Error: Constraint required but not yet processed.");
232             }
233                         }
234                 }
235
236                 return NULL;
237         }
238         
239 void omxFreeChildStates(omxState *state)
240 {
241         if (!state->childList || Global->numChildren == 0) return;
242
243         for(int k = 0; k < Global->numChildren; k++) {
244                 // Data are not modified and not copied. The same memory
245                 // is shared across all instances of state. We only need
246                 // to free the data once, so let the parent do it.
247                 state->childList[k]->dataList.clear();
248
249                 omxFreeState(state->childList[k]);
250         }
251         Free(state->childList);
252         state->childList = NULL;
253         Global->numChildren = 0;
254 }
255
256         void omxFreeState(omxState *state) {
257                 omxFreeChildStates(state);
258
259                 for(size_t ax = 0; ax < state->algebraList.size(); ax++) {
260                         if(OMX_DEBUG) { mxLog("Freeing Algebra %d at 0x%x.", ax, state->algebraList[ax]); }
261                         omxFreeAllMatrixData(state->algebraList[ax]);
262                 }
263
264                 if(OMX_DEBUG) { mxLog("Freeing %d Matrices.", state->matrixList.size());}
265                 for(size_t mk = 0; mk < state->matrixList.size(); mk++) {
266                         if(OMX_DEBUG) { mxLog("Freeing Matrix %d at 0x%x.", mk, state->matrixList[mk]); }
267                         omxFreeAllMatrixData(state->matrixList[mk]);
268                 }
269                 
270                 if(OMX_DEBUG) { mxLog("Freeing %d Model Expectations.", state->expectationList.size());}
271                 for(size_t ex = 0; ex < state->expectationList.size(); ex++) {
272                         if(OMX_DEBUG) { mxLog("Freeing Expectation %d at 0x%x.", ex, state->expectationList[ex]); }
273                         omxFreeExpectationArgs(state->expectationList[ex]);
274                 }
275
276                 if(OMX_DEBUG) { mxLog("Freeing %d Constraints.", state->numConstraints);}
277                 for(int k = 0; k < state->numConstraints; k++) {
278                         if(OMX_DEBUG) { mxLog("Freeing Constraint %d at 0x%x.", k, state->conList[k]); }
279                         omxFreeAllMatrixData(state->conList[k].result);
280                 }
281
282                 if(OMX_DEBUG) { mxLog("Freeing %d Data Sets.", state->dataList.size());}
283                 for(size_t dx = 0; dx < state->dataList.size(); dx++) {
284                         if(OMX_DEBUG) { mxLog("Freeing Data Set %d at 0x%x.", dx, state->dataList[dx]); }
285                         omxFreeData(state->dataList[dx]);
286                 }
287
288                 if(OMX_DEBUG) { mxLog("Freeing %d Checkpoints.", state->numCheckpoints);}
289                 for(int k = 0; k < state->numCheckpoints; k++) {
290                         if(OMX_DEBUG) { mxLog("Freeing Data Set %d at 0x%x.", k, state->checkpointList[k]); }
291                         omxCheckpoint oC = state->checkpointList[k];
292                         switch(oC.type) {
293                                 case OMX_FILE_CHECKPOINT:
294                                         fclose(oC.file);
295                                         break;
296                                 case OMX_CONNECTION_CHECKPOINT: // NYI :::DEBUG:::
297                                         // Do nothing: this should be handled by R upon return.
298                                         break;
299                         }
300                         if(state->chkptText1 != NULL) {
301                                 Free(state->chkptText1);
302                         }
303                         if(state->chkptText2 != NULL) {
304                                 Free(state->chkptText2);
305                         }
306                         // Checkpoint list itself is freed by R.
307                 }
308
309                 delete state;
310
311                 if(OMX_DEBUG) { mxLog("State Freed.");}
312         }
313
314 /*
315  * I don't understand why this results in double frees. TODO
316  *
317 FreeVarGroup::~FreeVarGroup()
318 {
319         for (size_t vx=0; vx < vars.size(); ++vx) {
320                 delete vars[vx];
321         }
322 }
323 */
324
325 omxGlobal::~omxGlobal()
326 {
327         for (size_t cx=0; cx < computeList.size(); ++cx) {
328                 delete computeList[cx];
329         }
330         for (size_t gx=0; gx < freeGroup.size(); ++gx) {
331                 delete freeGroup[gx];
332         }
333 }
334
335         void omxResetStatus(omxState *state) {
336                 int numChildren = Global->numChildren;
337                 state->statusMsg[0] = '\0';
338                 if (!state->childList) return;
339                 for(int i = 0; i < numChildren; i++) {
340                         omxResetStatus(state->childList[i]);
341                 }
342         }
343
344 std::string string_snprintf(const std::string fmt, ...)
345 {
346     int size = 100;
347     std::string str;
348     va_list ap;
349     while (1) {
350         str.resize(size);
351         va_start(ap, fmt);
352         int n = vsnprintf((char *)str.c_str(), size, fmt.c_str(), ap);
353         va_end(ap);
354         if (n > -1 && n < size) {
355             str.resize(n);
356             return str;
357         }
358         if (n > -1)
359             size = n + 1;
360         else
361             size *= 2;
362     }
363     return str;
364 }
365
366 void mxLogBig(const std::string str)   // thread-safe
367 {
368         ssize_t len = ssize_t(str.size());
369         ssize_t wrote = 0;
370 #pragma omp critical(stderr)
371         while (1) {
372                 ssize_t got = write(2, str.data() + wrote, len - wrote);
373                 if (got == EINTR) continue;
374                 if (got <= 0) error("mxLogBig failed with errno=%d", got);
375                 wrote += got;
376                 if (wrote == len) break;
377         }
378 }
379
380 void mxLog(const char* msg, ...)   // thread-safe
381 {
382         const int maxLen = 240;
383         char buf1[maxLen];
384         char buf2[maxLen];
385
386         va_list ap;
387         va_start(ap, msg);
388         vsnprintf(buf1, maxLen, msg, ap);
389         va_end(ap);
390
391         int len = snprintf(buf2, maxLen, "[%d] %s\n", omx_absolute_thread_num(), buf1);
392
393         ssize_t wrote = 0;
394 #pragma omp critical(stderr)
395         while (1) {
396                 ssize_t got = write(2, buf2 + wrote, len - wrote);
397                 if (got == EINTR) continue;
398                 if (got <= 0) error("mxLog failed with errno=%d", got);
399                 wrote += got;
400                 if (wrote == len) break;
401         }
402 }
403
404 void _omxRaiseError()
405 {
406         // keep for debugger breakpoints
407 }
408
409 void omxRaiseErrorf(omxState *state, const char* errorMsg, ...)
410 {
411         _omxRaiseError();
412         va_list ap;
413         va_start(ap, errorMsg);
414         int fit = vsnprintf(state->statusMsg, MAX_STRING_LEN, errorMsg, ap);
415         va_end(ap);
416         if(OMX_DEBUG) {
417                 if (!(fit > -1 && fit < MAX_STRING_LEN)) {
418                         mxLog("Error exceeded maximum length: %s", errorMsg);
419                 } else {
420                         mxLog("Error raised: %s", state->statusMsg);
421                 }
422         }
423 }
424
425         void omxRaiseError(omxState *state, int errorCode, const char* errorMsg) { // DEPRECATED
426                 _omxRaiseError();
427                 if(OMX_DEBUG && errorCode) { mxLog("Error %d raised: %s", errorCode, errorMsg);}
428                 if(OMX_DEBUG && !errorCode) { mxLog("Error status cleared."); }
429                 strncpy(state->statusMsg, errorMsg, 249);
430                 state->statusMsg[249] = '\0';
431         }
432
433         void omxStateNextRow(omxState *state) {
434                 state->currentRow++;
435         };
436
437         void omxStateNextEvaluation(omxState *state) {
438                 state->currentRow = -1;
439                 state->computeCount++;
440         };
441
442 static void omxWriteCheckpointHeader(omxState *os, omxCheckpoint* oC) {
443         // rewrite with std::string TODO
444         std::vector< omxFreeVar* > &vars = Global->freeGroup[0]->vars;
445         size_t numParam = vars.size();
446
447                 os->chkptText1 = (char*) Calloc((24 + 15 * numParam), char);
448                 os->chkptText2 = (char*) Calloc(1.0 + 15.0 * numParam*
449                         (numParam + 1.0) / 2.0, char);
450                 if (oC->type == OMX_FILE_CHECKPOINT) {
451                         fprintf(oC->file, "iterations\ttimestamp\tobjective\t");
452                         for(size_t j = 0; j < numParam; j++) {
453                                 if(strcmp(vars[j]->name, CHAR(NA_STRING)) == 0) {
454                                         fprintf(oC->file, "%s", vars[j]->name);
455                                 } else {
456                                         fprintf(oC->file, "\"%s\"", vars[j]->name);
457                                 }
458                                 if (j != numParam - 1) fprintf(oC->file, "\t");
459                         }
460                         fprintf(oC->file, "\n");
461                         fflush(oC->file);
462                 }
463         }
464  
465 void omxWriteCheckpointMessage(char *msg) {
466         std::vector< omxFreeVar* > &vars = Global->freeGroup[0]->vars;
467         size_t numParam = vars.size();
468
469                 omxState *os = globalState;
470                 for(int i = 0; i < os->numCheckpoints; i++) {
471                         omxCheckpoint* oC = &(os->checkpointList[i]);
472                         if(os->chkptText1 == NULL) {    // First one: set up output
473                                 omxWriteCheckpointHeader(os, oC);
474                         }
475                         if (oC->type == OMX_FILE_CHECKPOINT) {
476                                 fprintf(oC->file, "%d \"%s\" NA ", os->majorIteration, msg);
477                                 for(size_t j = 0; j < numParam; j++) {
478                                         fprintf(oC->file, "NA ");
479                                 }
480                                 fprintf(oC->file, "\n");
481                         }
482                 }
483         }
484
485 void omxSaveCheckpoint(double* x, double f, int force) {
486         // rewrite with std::string TODO
487         std::vector< omxFreeVar* > &vars = Global->freeGroup[0]->vars;
488         size_t numParam = vars.size();
489
490                 omxState *os = globalState;
491                 time_t now = time(NULL);
492                 int soFar = now - os->startTime;                // Translated into minutes
493                 int n;
494                 for(int i = 0; i < os->numCheckpoints; i++) {
495                         n = 0;
496                         omxCheckpoint* oC = &(os->checkpointList[i]);
497                         // Check based on time            
498                         if((oC->time > 0 && (soFar - oC->lastCheckpoint) >= oC->time) || force) {
499                                 oC->lastCheckpoint = soFar;
500                                 n = 1;
501                         }
502                         // Or iterations
503                         if((oC->numIterations > 0 && (os->majorIteration - oC->lastCheckpoint) >= oC->numIterations) || force) {
504                                 oC->lastCheckpoint = os->majorIteration;
505                                 n = 1;
506                         }
507
508                         if(n) {         //In either case, save a checkpoint.
509                                 if(os->chkptText1 == NULL) {    // First one: set up output
510                                         omxWriteCheckpointHeader(os, oC);
511                                 }
512                                 char tempstring[25];
513                                 sprintf(tempstring, "%d", os->majorIteration);
514
515                                 if(strncmp(os->chkptText1, tempstring, strlen(tempstring))) {   // Returns zero if they're the same.
516                                         struct tm * nowTime = localtime(&now);                                          // So this only happens if the text is out of date.
517                                         strftime(tempstring, 25, "%b %d %Y %I:%M:%S %p", nowTime);
518                                         sprintf(os->chkptText1, "%d \"%s\" %9.5f", os->majorIteration, tempstring, f);
519                                         for(size_t j = 0; j < numParam; j++) {
520                                                 sprintf(tempstring, " %9.5f", x[j]);
521                                                 strncat(os->chkptText1, tempstring, 14);
522                                         }
523                                 }
524
525                                 if(oC->type == OMX_FILE_CHECKPOINT) {
526                                         fprintf(oC->file, "%s", os->chkptText1);
527                                         if(oC->saveHessian)
528                                                 fprintf(oC->file, "%s", os->chkptText2);
529                                         fprintf(oC->file, "\n");
530                                         fflush(oC->file);
531                                 } else if(oC->type == OMX_CONNECTION_CHECKPOINT) {
532                                         warning("NYI: R_connections are not yet implemented.");
533                                         oC->numIterations = 0;
534                                         oC->time = 0;
535                                 }
536                         }
537                 }
538         }
539
540 void omxExamineFitOutput(omxState *state, omxMatrix *fitMatrix, int *mode)
541 {
542         if (!R_FINITE(fitMatrix->data[0])) {
543                 omxRaiseErrorf(state, "Fit function returned %g at iteration %d.%d",
544                                fitMatrix->data[0], state->majorIteration, state->minorIteration);
545                 *mode = -1;
546         }
547 }
548
549 omxFreeVarLocation *omxFreeVar::getLocation(int matrix)
550 {
551         for (size_t lx=0; lx < locations.size(); lx++) {
552                 omxFreeVarLocation *loc = &locations[lx];
553                 if (~loc->matrix == matrix) return loc;
554         }
555         return NULL;
556 }