Change how dirty matrices are indicated
[openmx:openmx.git] / src / omxState.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include <stdarg.h>
18 #include <errno.h>
19
20 #include "omxState.h"
21 #include "Compute.h"
22 #include "omxOpenmpWrap.h"
23
24 struct omxGlobal *Global = NULL;
25
26 FreeVarGroup *omxGlobal::findVarGroup(int id)
27 {
28         size_t numGroups = Global->freeGroup.size();
29         for (size_t vx=0; vx < numGroups; ++vx) {
30                 if (Global->freeGroup[vx]->id == id) return Global->freeGroup[vx];
31         }
32         return NULL;
33 }
34
35 FreeVarGroup *omxGlobal::findOrCreateVarGroup(int id)
36 {
37         FreeVarGroup *old = findVarGroup(id);
38         if (old) return old;
39
40         FreeVarGroup *fvg = new FreeVarGroup;
41         fvg->id = id;
42         Global->freeGroup.push_back(fvg);
43         return fvg;
44 }
45
46 void FreeVarGroup::cacheDependencies()
47 {
48         omxState *os = globalState;
49         size_t numMats = os->matrixList.size();
50         size_t numAlgs = os->algebraList.size();
51
52         dependencies.assign(numMats + numAlgs, false);
53         locations.assign(numMats, false);
54
55         for (size_t vx = 0; vx < vars.size(); vx++) {
56                 omxFreeVar *fv = vars[vx];
57                 int *deps   = fv->deps;
58                 int numDeps = fv->numDeps;
59                 for (int index = 0; index < numDeps; index++) {
60                         dependencies[deps[index] + numMats] = true;
61                 }
62                 for (size_t lx=0; lx < fv->locations.size(); ++lx) {
63                         locations[fv->locations[lx].matrix] = true;
64                 }
65         }
66
67         //log();
68 }
69
70 void FreeVarGroup::markDirty(omxState *os)
71 {
72         size_t numMats = os->matrixList.size();
73         size_t numAlgs = os->algebraList.size();
74
75         for(size_t i = 0; i < numMats; i++) {
76                 if (!locations[i]) continue;
77                 omxMarkClean(os->matrixList[i]);
78         }
79
80         for(size_t i = 0; i < numMats; i++) {
81                 if (dependencies[i]) {
82                         int offset = ~(i - numMats);
83                         omxMarkDirty(os->matrixList[offset]);
84                 }
85         }
86
87         for(size_t i = 0; i < numAlgs; i++) {
88                 if (dependencies[i + numMats]) {
89                         omxMarkDirty(os->algebraList[i]);
90                 }
91         }
92 }
93
94 void FreeVarGroup::log()
95 {
96         omxState *os = globalState;
97         size_t numMats = os->matrixList.size();
98         size_t numAlgs = os->algebraList.size();
99         std::string str;
100
101         str += string_snprintf("FreeVarGroup[%d] with %d variables:", id, vars.size());
102
103         for (size_t vx=0; vx < vars.size(); ++vx) {
104                 str += " ";
105                 str += vars[vx]->name;
106         }
107         str += "\nwill dirty:";
108
109         for(size_t i = 0; i < numMats; i++) {
110                 if (dependencies[i]) {
111                         int offset = ~(i - numMats);
112                         str += " ";
113                         str += os->matrixList[offset]->name;
114                 }
115         }
116
117         for(size_t i = 0; i < numAlgs; i++) {
118                 if (dependencies[i + numMats]) {
119                         str += " ";
120                         str += os->algebraList[i]->name;
121                 }
122         }
123         str += "\n";
124
125         mxLogBig(str);
126 }
127
128 /* Initialize and Destroy */
129         void omxInitState(omxState* state) {
130                 state->numConstraints = 0;
131                 state->childList = NULL;
132                 state->conList = NULL;
133
134                 state->majorIteration = 0;
135                 state->minorIteration = 0;
136                 state->startTime = 0;
137                 state->endTime = 0;
138                 state->numCheckpoints = 0;
139                 state->checkpointList = NULL;
140                 state->chkptText1 = NULL;
141                 state->chkptText2 = NULL;
142
143                 state->computeCount = 0;
144                 state->currentRow = -1;
145
146                 strncpy(state->statusMsg, "", 1);
147         }
148
149         void omxSetMajorIteration(omxState *state, int value) {
150                 state->majorIteration = value;
151                 if (!state->childList) return;
152                 for(int i = 0; i < Global->numChildren; i++) {
153                         omxSetMajorIteration(state->childList[i], value);
154                 }
155         }
156
157         void omxSetMinorIteration(omxState *state, int value) {
158                 state->minorIteration = value;
159                 if (!state->childList) return;
160                 for(int i = 0; i < Global->numChildren; i++) {
161                         omxSetMinorIteration(state->childList[i], value);
162                 }
163         }
164         
165         void omxDuplicateState(omxState* tgt, omxState* src) {
166                 tgt->dataList                   = src->dataList;
167                 
168                 for(size_t mx = 0; mx < src->matrixList.size(); mx++) {
169                         // TODO: Smarter inference for which matrices to duplicate
170                         tgt->matrixList.push_back(omxDuplicateMatrix(src->matrixList[mx], tgt));
171                 }
172
173                 tgt->numConstraints     = src->numConstraints;
174                 tgt->conList                    = (omxConstraint*) R_alloc(tgt->numConstraints, sizeof(omxConstraint));
175                 for(int j = 0; j < tgt->numConstraints; j++) {
176                         tgt->conList[j].size   = src->conList[j].size;
177                         tgt->conList[j].opCode = src->conList[j].opCode;
178                         tgt->conList[j].lbound = src->conList[j].lbound;
179                         tgt->conList[j].ubound = src->conList[j].ubound;
180                         tgt->conList[j].result = omxDuplicateMatrix(src->conList[j].result, tgt);
181                 }
182
183                 for(size_t j = 0; j < src->algebraList.size(); j++) {
184                         // TODO: Smarter inference for which algebras to duplicate
185                         tgt->algebraList.push_back(omxDuplicateMatrix(src->algebraList[j], tgt));
186                 }
187
188                 for(size_t j = 0; j < src->expectationList.size(); j++) {
189                         // TODO: Smarter inference for which expectations to duplicate
190                         tgt->expectationList.push_back(omxDuplicateExpectation(src->expectationList[j], tgt));
191                 }
192
193                 for(size_t j = 0; j < tgt->algebraList.size(); j++) {
194                         omxDuplicateAlgebra(tgt->algebraList[j], src->algebraList[j], tgt);
195                 }
196
197                 for(size_t j = 0; j < src->expectationList.size(); j++) {
198                         // TODO: Smarter inference for which expectations to duplicate
199                         omxCompleteExpectation(tgt->expectationList[j]);
200                 }
201
202                 tgt->childList                  = NULL;
203
204                 tgt->majorIteration     = 0;
205                 tgt->minorIteration     = 0;
206                 tgt->startTime                  = src->startTime;
207                 tgt->endTime                    = 0;
208                 
209                 // TODO: adjust checkpointing based on parallelization method
210                 tgt->numCheckpoints     = 0;
211                 tgt->checkpointList     = NULL;
212                 tgt->chkptText1                 = NULL;
213                 tgt->chkptText2                 = NULL;
214                                   
215                 tgt->computeCount               = src->computeCount;
216                 tgt->currentRow                 = src->currentRow;
217
218                 strncpy(tgt->statusMsg, "", 1);
219         }
220
221         omxMatrix* omxLookupDuplicateElement(omxState* os, omxMatrix* element) {
222                 if(os == NULL || element == NULL) return NULL;
223
224                 if (element->hasMatrixNumber) {
225                         int matrixNumber = element->matrixNumber;
226                         if (matrixNumber >= 0) {
227                                 return(os->algebraList[matrixNumber]);
228                         } else {
229                                 return(os->matrixList[-matrixNumber - 1]);
230                         }
231                 }
232
233                 omxConstraint* parentConList = globalState->conList;
234
235                 for(int i = 0; i < os->numConstraints; i++) {
236                         if(parentConList[i].result == element) {
237                                 if(os->conList[i].result != NULL) {   // Not sure of proper failure behavior here.
238                 return(os->conList[i].result);
239                                 } else {
240                     omxRaiseError(os, -2, "Initialization Copy Error: Constraint required but not yet processed.");
241             }
242                         }
243                 }
244
245                 return NULL;
246         }
247         
248 void omxFreeChildStates(omxState *state)
249 {
250         if (!state->childList || Global->numChildren == 0) return;
251
252         for(int k = 0; k < Global->numChildren; k++) {
253                 // Data are not modified and not copied. The same memory
254                 // is shared across all instances of state. We only need
255                 // to free the data once, so let the parent do it.
256                 state->childList[k]->dataList.clear();
257
258                 omxFreeState(state->childList[k]);
259         }
260         Free(state->childList);
261         state->childList = NULL;
262         Global->numChildren = 0;
263 }
264
265         void omxFreeState(omxState *state) {
266                 omxFreeChildStates(state);
267
268                 for(size_t ax = 0; ax < state->algebraList.size(); ax++) {
269                         if(OMX_DEBUG) { mxLog("Freeing Algebra %d at 0x%x.", ax, state->algebraList[ax]); }
270                         omxFreeAllMatrixData(state->algebraList[ax]);
271                 }
272
273                 if(OMX_DEBUG) { mxLog("Freeing %d Matrices.", state->matrixList.size());}
274                 for(size_t mk = 0; mk < state->matrixList.size(); mk++) {
275                         if(OMX_DEBUG) { mxLog("Freeing Matrix %d at 0x%x.", mk, state->matrixList[mk]); }
276                         omxFreeAllMatrixData(state->matrixList[mk]);
277                 }
278                 
279                 if(OMX_DEBUG) { mxLog("Freeing %d Model Expectations.", state->expectationList.size());}
280                 for(size_t ex = 0; ex < state->expectationList.size(); ex++) {
281                         if(OMX_DEBUG) { mxLog("Freeing Expectation %d at 0x%x.", ex, state->expectationList[ex]); }
282                         omxFreeExpectationArgs(state->expectationList[ex]);
283                 }
284
285                 if(OMX_DEBUG) { mxLog("Freeing %d Constraints.", state->numConstraints);}
286                 for(int k = 0; k < state->numConstraints; k++) {
287                         if(OMX_DEBUG) { mxLog("Freeing Constraint %d at 0x%x.", k, state->conList[k]); }
288                         omxFreeAllMatrixData(state->conList[k].result);
289                 }
290
291                 if(OMX_DEBUG) { mxLog("Freeing %d Data Sets.", state->dataList.size());}
292                 for(size_t dx = 0; dx < state->dataList.size(); dx++) {
293                         if(OMX_DEBUG) { mxLog("Freeing Data Set %d at 0x%x.", dx, state->dataList[dx]); }
294                         omxFreeData(state->dataList[dx]);
295                 }
296
297                 if(OMX_DEBUG) { mxLog("Freeing %d Checkpoints.", state->numCheckpoints);}
298                 for(int k = 0; k < state->numCheckpoints; k++) {
299                         if(OMX_DEBUG) { mxLog("Freeing Data Set %d at 0x%x.", k, state->checkpointList[k]); }
300                         omxCheckpoint oC = state->checkpointList[k];
301                         switch(oC.type) {
302                                 case OMX_FILE_CHECKPOINT:
303                                         fclose(oC.file);
304                                         break;
305                                 case OMX_CONNECTION_CHECKPOINT: // NYI :::DEBUG:::
306                                         // Do nothing: this should be handled by R upon return.
307                                         break;
308                         }
309                         if(state->chkptText1 != NULL) {
310                                 Free(state->chkptText1);
311                         }
312                         if(state->chkptText2 != NULL) {
313                                 Free(state->chkptText2);
314                         }
315                         // Checkpoint list itself is freed by R.
316                 }
317
318                 delete state;
319
320                 if(OMX_DEBUG) { mxLog("State Freed.");}
321         }
322
323 /*
324  * I don't understand why this results in double frees. TODO
325  *
326 FreeVarGroup::~FreeVarGroup()
327 {
328         for (size_t vx=0; vx < vars.size(); ++vx) {
329                 delete vars[vx];
330         }
331 }
332 */
333
334 omxGlobal::~omxGlobal()
335 {
336         for (size_t cx=0; cx < computeList.size(); ++cx) {
337                 delete computeList[cx];
338         }
339         for (size_t gx=0; gx < freeGroup.size(); ++gx) {
340                 delete freeGroup[gx];
341         }
342 }
343
344         void omxResetStatus(omxState *state) {
345                 int numChildren = Global->numChildren;
346                 state->statusMsg[0] = '\0';
347                 if (!state->childList) return;
348                 for(int i = 0; i < numChildren; i++) {
349                         omxResetStatus(state->childList[i]);
350                 }
351         }
352
353 std::string string_snprintf(const std::string fmt, ...)
354 {
355     int size = 100;
356     std::string str;
357     va_list ap;
358     while (1) {
359         str.resize(size);
360         va_start(ap, fmt);
361         int n = vsnprintf((char *)str.c_str(), size, fmt.c_str(), ap);
362         va_end(ap);
363         if (n > -1 && n < size) {
364             str.resize(n);
365             return str;
366         }
367         if (n > -1)
368             size = n + 1;
369         else
370             size *= 2;
371     }
372     return str;
373 }
374
375 void mxLogBig(const std::string str)   // thread-safe
376 {
377         ssize_t len = ssize_t(str.size());
378         ssize_t wrote = 0;
379 #pragma omp critical(stderr)
380         while (1) {
381                 ssize_t got = write(2, str.data() + wrote, len - wrote);
382                 if (got == EINTR) continue;
383                 if (got <= 0) error("mxLogBig failed with errno=%d", got);
384                 wrote += got;
385                 if (wrote == len) break;
386         }
387 }
388
389 void mxLog(const char* msg, ...)   // thread-safe
390 {
391         const int maxLen = 240;
392         char buf1[maxLen];
393         char buf2[maxLen];
394
395         va_list ap;
396         va_start(ap, msg);
397         vsnprintf(buf1, maxLen, msg, ap);
398         va_end(ap);
399
400         int len = snprintf(buf2, maxLen, "[%d] %s\n", omx_absolute_thread_num(), buf1);
401
402         ssize_t wrote = 0;
403 #pragma omp critical(stderr)
404         while (1) {
405                 ssize_t got = write(2, buf2 + wrote, len - wrote);
406                 if (got == EINTR) continue;
407                 if (got <= 0) error("mxLog failed with errno=%d", got);
408                 wrote += got;
409                 if (wrote == len) break;
410         }
411 }
412
413 void _omxRaiseError()
414 {
415         // keep for debugger breakpoints
416 }
417
418 void omxRaiseErrorf(omxState *state, const char* errorMsg, ...)
419 {
420         _omxRaiseError();
421         va_list ap;
422         va_start(ap, errorMsg);
423         int fit = vsnprintf(state->statusMsg, MAX_STRING_LEN, errorMsg, ap);
424         va_end(ap);
425         if(OMX_DEBUG) {
426                 if (!(fit > -1 && fit < MAX_STRING_LEN)) {
427                         mxLog("Error exceeded maximum length: %s", errorMsg);
428                 } else {
429                         mxLog("Error raised: %s", state->statusMsg);
430                 }
431         }
432 }
433
434         void omxRaiseError(omxState *state, int errorCode, const char* errorMsg) { // DEPRECATED
435                 _omxRaiseError();
436                 if(OMX_DEBUG && errorCode) { mxLog("Error %d raised: %s", errorCode, errorMsg);}
437                 if(OMX_DEBUG && !errorCode) { mxLog("Error status cleared."); }
438                 strncpy(state->statusMsg, errorMsg, 249);
439                 state->statusMsg[249] = '\0';
440         }
441
442         void omxStateNextRow(omxState *state) {
443                 state->currentRow++;
444         };
445
446         void omxStateNextEvaluation(omxState *state) {
447                 state->currentRow = -1;
448                 state->computeCount++;
449         };
450
451 static void omxWriteCheckpointHeader(omxState *os, omxCheckpoint* oC) {
452         // rewrite with std::string TODO
453         std::vector< omxFreeVar* > &vars = Global->freeGroup[0]->vars;
454         size_t numParam = vars.size();
455
456                 os->chkptText1 = (char*) Calloc((24 + 15 * numParam), char);
457                 os->chkptText2 = (char*) Calloc(1.0 + 15.0 * numParam*
458                         (numParam + 1.0) / 2.0, char);
459                 if (oC->type == OMX_FILE_CHECKPOINT) {
460                         fprintf(oC->file, "iterations\ttimestamp\tobjective\t");
461                         for(size_t j = 0; j < numParam; j++) {
462                                 if(strcmp(vars[j]->name, CHAR(NA_STRING)) == 0) {
463                                         fprintf(oC->file, "%s", vars[j]->name);
464                                 } else {
465                                         fprintf(oC->file, "\"%s\"", vars[j]->name);
466                                 }
467                                 if (j != numParam - 1) fprintf(oC->file, "\t");
468                         }
469                         fprintf(oC->file, "\n");
470                         fflush(oC->file);
471                 }
472         }
473  
474 void omxWriteCheckpointMessage(char *msg) {
475         std::vector< omxFreeVar* > &vars = Global->freeGroup[0]->vars;
476         size_t numParam = vars.size();
477
478                 omxState *os = globalState;
479                 for(int i = 0; i < os->numCheckpoints; i++) {
480                         omxCheckpoint* oC = &(os->checkpointList[i]);
481                         if(os->chkptText1 == NULL) {    // First one: set up output
482                                 omxWriteCheckpointHeader(os, oC);
483                         }
484                         if (oC->type == OMX_FILE_CHECKPOINT) {
485                                 fprintf(oC->file, "%d \"%s\" NA ", os->majorIteration, msg);
486                                 for(size_t j = 0; j < numParam; j++) {
487                                         fprintf(oC->file, "NA ");
488                                 }
489                                 fprintf(oC->file, "\n");
490                         }
491                 }
492         }
493
494 void omxSaveCheckpoint(double* x, double f, int force) {
495         // rewrite with std::string TODO
496         std::vector< omxFreeVar* > &vars = Global->freeGroup[0]->vars;
497         size_t numParam = vars.size();
498
499                 omxState *os = globalState;
500                 time_t now = time(NULL);
501                 int soFar = now - os->startTime;                // Translated into minutes
502                 int n;
503                 for(int i = 0; i < os->numCheckpoints; i++) {
504                         n = 0;
505                         omxCheckpoint* oC = &(os->checkpointList[i]);
506                         // Check based on time            
507                         if((oC->time > 0 && (soFar - oC->lastCheckpoint) >= oC->time) || force) {
508                                 oC->lastCheckpoint = soFar;
509                                 n = 1;
510                         }
511                         // Or iterations
512                         if((oC->numIterations > 0 && (os->majorIteration - oC->lastCheckpoint) >= oC->numIterations) || force) {
513                                 oC->lastCheckpoint = os->majorIteration;
514                                 n = 1;
515                         }
516
517                         if(n) {         //In either case, save a checkpoint.
518                                 if(os->chkptText1 == NULL) {    // First one: set up output
519                                         omxWriteCheckpointHeader(os, oC);
520                                 }
521                                 char tempstring[25];
522                                 sprintf(tempstring, "%d", os->majorIteration);
523
524                                 if(strncmp(os->chkptText1, tempstring, strlen(tempstring))) {   // Returns zero if they're the same.
525                                         struct tm * nowTime = localtime(&now);                                          // So this only happens if the text is out of date.
526                                         strftime(tempstring, 25, "%b %d %Y %I:%M:%S %p", nowTime);
527                                         sprintf(os->chkptText1, "%d \"%s\" %9.5f", os->majorIteration, tempstring, f);
528                                         for(size_t j = 0; j < numParam; j++) {
529                                                 sprintf(tempstring, " %9.5f", x[j]);
530                                                 strncat(os->chkptText1, tempstring, 14);
531                                         }
532                                 }
533
534                                 if(oC->type == OMX_FILE_CHECKPOINT) {
535                                         fprintf(oC->file, "%s", os->chkptText1);
536                                         if(oC->saveHessian)
537                                                 fprintf(oC->file, "%s", os->chkptText2);
538                                         fprintf(oC->file, "\n");
539                                         fflush(oC->file);
540                                 } else if(oC->type == OMX_CONNECTION_CHECKPOINT) {
541                                         warning("NYI: R_connections are not yet implemented.");
542                                         oC->numIterations = 0;
543                                         oC->time = 0;
544                                 }
545                         }
546                 }
547         }
548
549 void omxExamineFitOutput(omxState *state, omxMatrix *fitMatrix, int *mode)
550 {
551         if (!R_FINITE(fitMatrix->data[0])) {
552                 omxRaiseErrorf(state, "Fit function returned %g at iteration %d.%d",
553                                fitMatrix->data[0], state->majorIteration, state->minorIteration);
554                 *mode = -1;
555         }
556 }
557
558 omxFreeVarLocation *omxFreeVar::getLocation(int matrix)
559 {
560         for (size_t lx=0; lx < locations.size(); lx++) {
561                 omxFreeVarLocation *loc = &locations[lx];
562                 if (~loc->matrix == matrix) return loc;
563         }
564         return NULL;
565 }