Don't hide errors
[openmx:openmx.git] / src / omxState.cpp
1 /*
2  *  Copyright 2007-2014 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include <stdarg.h>
18 #include <errno.h>
19
20 #include "omxState.h"
21 #include "Compute.h"
22 #include "omxOpenmpWrap.h"
23
24 struct omxGlobal *Global = NULL;
25
26 FreeVarGroup *omxGlobal::findVarGroup(int id)
27 {
28         size_t numGroups = Global->freeGroup.size();
29         for (size_t vx=0; vx < numGroups; ++vx) {
30                 std::vector<int> &ids = Global->freeGroup[vx]->id;
31                 for (size_t ix=0; ix < ids.size(); ++ix) {
32                         if (ids[ix] == id) return Global->freeGroup[vx];
33                 }
34         }
35         return NULL;
36 }
37
38 FreeVarGroup *omxGlobal::findOrCreateVarGroup(int id)
39 {
40         FreeVarGroup *old = findVarGroup(id);
41         if (old) return old;
42
43         FreeVarGroup *fvg = new FreeVarGroup;
44         fvg->id.push_back(id);
45         Global->freeGroup.push_back(fvg);
46         return fvg;
47 }
48
49 bool FreeVarGroup::hasSameVars(FreeVarGroup *g2)
50 {
51         if (vars.size() != g2->vars.size()) return false;
52
53         for (size_t vx=0; vx < vars.size(); ++vx) {
54                 if (vars[vx] != g2->vars[vx]) return false;
55         }
56         return true;
57 }
58
59 int FreeVarGroup::lookupVar(const char *name)
60 {
61         for (size_t vx=0; vx < vars.size(); ++vx) {
62                 if (strcmp(name, vars[vx]->name) == 0) return vx;
63         }
64         return -1;
65 }
66
67 void FreeVarGroup::cacheDependencies()
68 {
69         omxState *os = globalState;
70         size_t numMats = os->matrixList.size();
71         size_t numAlgs = os->algebraList.size();
72
73         dependencies.assign(numMats + numAlgs, false);
74         locations.assign(numMats, false);
75
76         for (size_t vx = 0; vx < vars.size(); vx++) {
77                 omxFreeVar *fv = vars[vx];
78                 int *deps   = fv->deps;
79                 int numDeps = fv->numDeps;
80                 for (int index = 0; index < numDeps; index++) {
81                         dependencies[deps[index] + numMats] = true;
82                 }
83                 for (size_t lx=0; lx < fv->locations.size(); ++lx) {
84                         locations[fv->locations[lx].matrix] = true;
85                 }
86         }
87
88         // Everything is set up. This is a good place to log.
89         if (OMX_DEBUG) { log(); }
90 }
91
92 void FreeVarGroup::markDirty(omxState *os)
93 {
94         size_t numMats = os->matrixList.size();
95         size_t numAlgs = os->algebraList.size();
96
97         for(size_t i = 0; i < numMats; i++) {
98                 if (!locations[i]) continue;
99                 omxMarkClean(os->matrixList[i]);
100         }
101
102         for(size_t i = 0; i < numMats; i++) {
103                 if (dependencies[i]) {
104                         int offset = ~(i - numMats);
105                         omxMarkDirty(os->matrixList[offset]);
106                 }
107         }
108
109         for(size_t i = 0; i < numAlgs; i++) {
110                 if (dependencies[i + numMats]) {
111                         omxMarkDirty(os->algebraList[i]);
112                 }
113         }
114 }
115
116 void FreeVarGroup::log()
117 {
118         omxState *os = globalState;
119         size_t numMats = os->matrixList.size();
120         size_t numAlgs = os->algebraList.size();
121         std::string str;
122
123         str += string_snprintf("FreeVarGroup(id=%d", id[0]);
124         for (size_t ix=1; ix < id.size(); ++ix) {
125                 str += string_snprintf(",%d", id[ix]);
126         }
127         str += string_snprintf(") with %d variables:", (int) vars.size());
128
129         for (size_t vx=0; vx < vars.size(); ++vx) {
130                 str += " ";
131                 str += vars[vx]->name;
132         }
133         if (vars.size()) str += "\nwill dirty:";
134
135         for(size_t i = 0; i < numMats; i++) {
136                 if (dependencies[i]) {
137                         int offset = ~(i - numMats);
138                         str += " ";
139                         str += os->matrixList[offset]->name;
140                 }
141         }
142
143         for(size_t i = 0; i < numAlgs; i++) {
144                 if (dependencies[i + numMats]) {
145                         str += " ";
146                         str += os->algebraList[i]->name;
147                 }
148         }
149         str += "\n";
150
151         mxLogBig(str);
152 }
153
154 omxGlobal::omxGlobal()
155 {
156         ciMaxIterations = 5;
157         numThreads = 1;
158         analyticGradients = 0;
159         numChildren = 0;
160         llScale = -2.0;
161 }
162
163 void omxGlobal::deduplicateVarGroups()
164 {
165         for (size_t g1=0; g1 < freeGroup.size(); ++g1) {
166                 for (size_t g2=freeGroup.size()-1; g2 > g1; --g2) {
167                         if (freeGroup[g1]->hasSameVars(freeGroup[g2])) {
168                                 freeGroup[g1]->id.insert(freeGroup[g1]->id.end(),
169                                                          freeGroup[g2]->id.begin(), freeGroup[g2]->id.end());
170                                 delete freeGroup[g2];
171                                 freeGroup.erase(freeGroup.begin() + g2);
172                         }
173                 }
174         }
175 }
176
177 /* Initialize and Destroy */
178         void omxInitState(omxState* state) {
179                 state->stale = FALSE;
180                 state->numConstraints = 0;
181                 state->conList = NULL;
182
183                 state->majorIteration = 0;
184                 state->minorIteration = 0;
185                 state->startTime = 0;
186                 state->endTime = 0;
187                 state->numCheckpoints = 0;
188                 state->checkpointList = NULL;
189                 state->chkptText1 = NULL;
190                 state->chkptText2 = NULL;
191
192                 state->computeCount = 0;
193                 state->currentRow = -1;
194         }
195
196         void omxSetMajorIteration(omxState *state, int value) {
197                 state->majorIteration = value;
198                 if (!state->childList.size()) return;
199                 for(int i = 0; i < Global->numChildren; i++) {
200                         omxSetMajorIteration(state->childList[i], value);
201                 }
202         }
203
204         void omxSetMinorIteration(omxState *state, int value) {
205                 state->minorIteration = value;
206                 if (!state->childList.size()) return;
207                 for(int i = 0; i < Global->numChildren; i++) {
208                         omxSetMinorIteration(state->childList[i], value);
209                 }
210         }
211         
212         void omxDuplicateState(omxState* tgt, omxState* src) {
213                 tgt->dataList                   = src->dataList;
214                 
215                 for(size_t mx = 0; mx < src->matrixList.size(); mx++) {
216                         // TODO: Smarter inference for which matrices to duplicate
217                         tgt->matrixList.push_back(omxDuplicateMatrix(src->matrixList[mx], tgt));
218                 }
219
220                 tgt->numConstraints     = src->numConstraints;
221                 tgt->conList                    = (omxConstraint*) R_alloc(tgt->numConstraints, sizeof(omxConstraint));
222                 for(int j = 0; j < tgt->numConstraints; j++) {
223                         tgt->conList[j].size   = src->conList[j].size;
224                         tgt->conList[j].opCode = src->conList[j].opCode;
225                         tgt->conList[j].lbound = src->conList[j].lbound;
226                         tgt->conList[j].ubound = src->conList[j].ubound;
227                         tgt->conList[j].result = omxDuplicateMatrix(src->conList[j].result, tgt);
228                 }
229
230                 for(size_t j = 0; j < src->algebraList.size(); j++) {
231                         // TODO: Smarter inference for which algebras to duplicate
232                         tgt->algebraList.push_back(omxDuplicateMatrix(src->algebraList[j], tgt));
233                 }
234
235                 for(size_t j = 0; j < src->expectationList.size(); j++) {
236                         // TODO: Smarter inference for which expectations to duplicate
237                         tgt->expectationList.push_back(omxDuplicateExpectation(src->expectationList[j], tgt));
238                 }
239
240                 for(size_t j = 0; j < tgt->algebraList.size(); j++) {
241                         omxDuplicateAlgebra(tgt->algebraList[j], src->algebraList[j], tgt);
242                 }
243
244                 for(size_t j = 0; j < src->expectationList.size(); j++) {
245                         // TODO: Smarter inference for which expectations to duplicate
246                         omxCompleteExpectation(tgt->expectationList[j]);
247                 }
248
249                 tgt->majorIteration     = 0;
250                 tgt->minorIteration     = 0;
251                 tgt->startTime                  = src->startTime;
252                 tgt->endTime                    = 0;
253                 
254                 // TODO: adjust checkpointing based on parallelization method
255                 tgt->numCheckpoints     = 0;
256                 tgt->checkpointList     = NULL;
257                 tgt->chkptText1                 = NULL;
258                 tgt->chkptText2                 = NULL;
259                                   
260                 tgt->computeCount               = src->computeCount;
261                 tgt->currentRow                 = src->currentRow;
262         }
263
264         omxMatrix* omxLookupDuplicateElement(omxState* os, omxMatrix* element) {
265                 if(os == NULL || element == NULL) return NULL;
266
267                 if (element->hasMatrixNumber) {
268                         int matrixNumber = element->matrixNumber;
269                         if (matrixNumber >= 0) {
270                                 return(os->algebraList[matrixNumber]);
271                         } else {
272                                 return(os->matrixList[-matrixNumber - 1]);
273                         }
274                 }
275
276                 omxConstraint* parentConList = globalState->conList;
277
278                 for(int i = 0; i < os->numConstraints; i++) {
279                         if(parentConList[i].result == element) {
280                                 if(os->conList[i].result != NULL) {   // Not sure of proper failure behavior here.
281                 return(os->conList[i].result);
282                                 } else {
283                     omxRaiseError(os, -2, "Initialization Copy Error: Constraint required but not yet processed.");
284             }
285                         }
286                 }
287
288                 return NULL;
289         }
290         
291 void omxFreeChildStates(omxState *state)
292 {
293         if (state->childList.size() == 0) return;
294
295         for(int k = 0; k < Global->numChildren; k++) {
296                 // Data are not modified and not copied. The same memory
297                 // is shared across all instances of state. We only need
298                 // to free the data once, so let the parent do it.
299                 state->childList[k]->dataList.clear();
300
301                 omxFreeState(state->childList[k]);
302         }
303         state->childList.clear();
304         Global->numChildren = 0;
305 }
306
307         void omxFreeState(omxState *state) {
308                 omxFreeChildStates(state);
309
310                 if(OMX_DEBUG) { mxLog("Freeing %d Constraints.", (int) state->numConstraints);}
311                 for(int k = 0; k < state->numConstraints; k++) {
312                         omxFreeMatrix(state->conList[k].result);
313                 }
314
315                 for(size_t ax = 0; ax < state->algebraList.size(); ax++) {
316                         // free argument tree
317                         omxFreeMatrix(state->algebraList[ax]);
318                 }
319
320                 for(size_t ax = 0; ax < state->algebraList.size(); ax++) {
321                         state->algebraList[ax]->hasMatrixNumber = false;
322                         omxFreeMatrix(state->algebraList[ax]);
323                 }
324
325                 if(OMX_DEBUG) { mxLog("Freeing %d Matrices.", (int) state->matrixList.size());}
326                 for(size_t mk = 0; mk < state->matrixList.size(); mk++) {
327                         state->matrixList[mk]->hasMatrixNumber = false;
328                         omxFreeMatrix(state->matrixList[mk]);
329                 }
330                 
331                 if(OMX_DEBUG) { mxLog("Freeing %d Model Expectations.", (int) state->expectationList.size());}
332                 for(size_t ex = 0; ex < state->expectationList.size(); ex++) {
333                         omxFreeExpectationArgs(state->expectationList[ex]);
334                 }
335
336                 if(OMX_DEBUG) { mxLog("Freeing %d Data Sets.", (int) state->dataList.size());}
337                 for(size_t dx = 0; dx < state->dataList.size(); dx++) {
338                         omxFreeData(state->dataList[dx]);
339                 }
340
341                 if(OMX_DEBUG) { mxLog("Freeing %d Checkpoints.", state->numCheckpoints);}
342                 for(int k = 0; k < state->numCheckpoints; k++) {
343                         omxCheckpoint oC = state->checkpointList[k];
344                         switch(oC.type) {
345                                 case OMX_FILE_CHECKPOINT:
346                                         fclose(oC.file);
347                                         break;
348                                 case OMX_CONNECTION_CHECKPOINT: // NYI :::DEBUG:::
349                                         // Do nothing: this should be handled by R upon return.
350                                         break;
351                         }
352                         if(state->chkptText1 != NULL) {
353                                 Free(state->chkptText1);
354                         }
355                         if(state->chkptText2 != NULL) {
356                                 Free(state->chkptText2);
357                         }
358                         // Checkpoint list itself is freed by R.
359                 }
360
361                 delete state;
362
363                 if(OMX_DEBUG) { mxLog("State Freed.");}
364         }
365
366 omxGlobal::~omxGlobal()
367 {
368         for (size_t cx=0; cx < computeList.size(); ++cx) {
369                 delete computeList[cx];
370         }
371         for (size_t cx=0; cx < algebraList.size(); ++cx) {
372                 delete algebraList[cx];
373         }
374         if (freeGroup.size()) {
375                 std::vector< omxFreeVar* > &vars = freeGroup[0]->vars;  // has all vars
376                 for (size_t vx=0; vx < vars.size(); ++vx) {
377                         delete vars[vx];
378                 }
379         }
380         for (size_t gx=0; gx < freeGroup.size(); ++gx) {
381                 delete freeGroup[gx];
382         }
383 }
384
385 std::string string_vsnprintf(const char *fmt, va_list orig_ap)
386 {
387     int size = 100;
388     std::string str;
389     while (1) {
390         str.resize(size);
391         va_list ap;
392         va_copy(ap, orig_ap);
393         int n = vsnprintf((char *)str.c_str(), size, fmt, ap);
394         va_end(ap);
395         if (n > -1 && n < size) {
396             str.resize(n);
397             return str;
398         }
399         if (n > -1)
400             size = n + 1;
401         else
402             size *= 2;
403     }
404     return str;
405 }
406
407 std::string string_snprintf(const char *fmt, ...)
408 {
409         va_list ap;
410         va_start(ap, fmt);
411         std::string str = string_vsnprintf(fmt, ap);
412         va_end(ap);
413         return str;
414 }
415
416 void mxLogBig(const std::string str)   // thread-safe
417 {
418         ssize_t len = ssize_t(str.size());
419         ssize_t wrote = 0;
420         int maxRetries = 20;
421         ssize_t got;
422 #pragma omp critical(stderp)
423         {
424                 while (--maxRetries > 0) {
425                         got = write(2, str.data() + wrote, len - wrote);
426                         if (got == -EINTR) continue;
427                         if (got <= 0) break;
428                         wrote += got;
429                         if (wrote == len) break;
430                 }
431         }
432         if (got <= 0) Rf_error("mxLogBig failed with errno=%d", got);
433
434 }
435
436 void mxLog(const char* msg, ...)   // thread-safe
437 {
438         const int maxLen = 240;
439         char buf1[maxLen];
440         char buf2[maxLen];
441
442         va_list ap;
443         va_start(ap, msg);
444         vsnprintf(buf1, maxLen, msg, ap);
445         va_end(ap);
446
447         int len = snprintf(buf2, maxLen, "[%d] %s\n", omx_absolute_thread_num(), buf1);
448
449         int maxRetries = 20;
450         ssize_t wrote = 0;
451         ssize_t got;
452 #pragma omp critical(stderp)
453         {
454                 while (--maxRetries > 0) {
455                         got = write(2, buf2 + wrote, len - wrote);
456                         if (got == -EINTR) continue;
457                         if (got <= 0) break;
458                         wrote += got;
459                         if (wrote == len) break;
460                 }
461         }
462         if (got <= 0) Rf_error("mxLog failed with errno=%d", got);
463 }
464
465 void _omxRaiseError()
466 {
467         // keep for debugger breakpoints
468 }
469
470 void omxRaiseErrorf(omxState *, const char* msg, ...)
471 {
472         va_list ap;
473         va_start(ap, msg);
474         std::string str = string_vsnprintf(msg, ap);
475         va_end(ap);
476         _omxRaiseError();
477
478         if(OMX_DEBUG) {
479                 mxLog("Error raised: %s", str.c_str());
480         }
481
482         bool overflow = false;
483 #pragma omp critical(bads)
484         {
485                 if (Global->bads.size() > 100) {
486                         overflow = true;
487                 } else {
488                         Global->bads.push_back(str);
489                 }
490         }
491
492         // mxLog takes a lock too, so call it outside of critical section
493         if (overflow) mxLog("Too many errors: %s", str.c_str());
494 }
495
496 const char *omxGlobal::getBads()
497 {
498         if (bads.size() == 0) return NULL;
499
500         std::string str;
501         for (size_t mx=0; mx < bads.size(); ++mx) {
502                 if (bads.size() > 1) str += string_snprintf("%d:", (int)mx+1);
503                 str += bads[mx];
504                 if (str.size() > (1<<14)) break;
505                 if (mx < bads.size() - 1) str += "\n";
506         }
507
508         size_t sz = str.size();
509         char *mem = R_alloc(sz+1, 1);  // use R's memory
510         memcpy(mem, str.c_str(), sz);
511         mem[sz] = 0;
512         return mem;
513 }
514
515 void omxRaiseError(omxState *, int, const char* msg) { // DEPRECATED
516         omxRaiseErrorf(NULL, "%s", msg);
517 }
518
519         void omxStateNextRow(omxState *state) {
520                 state->currentRow++;
521         };
522
523         void omxStateNextEvaluation(omxState *state) {
524                 state->currentRow = -1;
525                 state->computeCount++;
526         };
527
528 static void omxWriteCheckpointHeader(omxState *os, omxCheckpoint* oC) {
529         // rewrite with std::string TODO
530         std::vector< omxFreeVar* > &vars = Global->freeGroup[0]->vars;
531         size_t numParam = vars.size();
532
533                 os->chkptText1 = (char*) Calloc((24 + 15 * numParam), char);
534                 os->chkptText2 = (char*) Calloc(1.0 + 15.0 * numParam*
535                         (numParam + 1.0) / 2.0, char);
536                 if (oC->type == OMX_FILE_CHECKPOINT) {
537                         fprintf(oC->file, "iterations\ttimestamp\tobjective\t");
538                         for(size_t j = 0; j < numParam; j++) {
539                                 if(strcmp(vars[j]->name, CHAR(NA_STRING)) == 0) {
540                                         fprintf(oC->file, "%s", vars[j]->name);
541                                 } else {
542                                         fprintf(oC->file, "\"%s\"", vars[j]->name);
543                                 }
544                                 if (j != numParam - 1) fprintf(oC->file, "\t");
545                         }
546                         fprintf(oC->file, "\n");
547                         fflush(oC->file);
548                 }
549         }
550  
551 void omxWriteCheckpointMessage(char *msg) {
552         std::vector< omxFreeVar* > &vars = Global->freeGroup[0]->vars;
553         size_t numParam = vars.size();
554
555                 omxState *os = globalState;
556                 for(int i = 0; i < os->numCheckpoints; i++) {
557                         omxCheckpoint* oC = &(os->checkpointList[i]);
558                         if(os->chkptText1 == NULL) {    // First one: set up output
559                                 omxWriteCheckpointHeader(os, oC);
560                         }
561                         if (oC->type == OMX_FILE_CHECKPOINT) {
562                                 fprintf(oC->file, "%d \"%s\" NA ", os->majorIteration, msg);
563                                 for(size_t j = 0; j < numParam; j++) {
564                                         fprintf(oC->file, "NA ");
565                                 }
566                                 fprintf(oC->file, "\n");
567                         }
568                 }
569         }
570
571 /*
572  * Next time we rewrite this code, Mike Neale suggested that the
573  * checkpoint be taken before evaluating the fit function and then
574  * again after obtaining the fit statistic. This will help with
575  * debugging fit functions that fail to evaluate in some spots.
576  */
577 void omxSaveCheckpoint(double* x, double f, int force) {
578         // rewrite with std::string TODO
579         std::vector< omxFreeVar* > &vars = Global->freeGroup[0]->vars;
580         size_t numParam = vars.size();
581
582                 omxState *os = globalState;
583                 time_t now = time(NULL);
584                 int soFar = now - os->startTime;                // Translated into minutes
585                 int n;
586                 for(int i = 0; i < os->numCheckpoints; i++) {
587                         n = 0;
588                         omxCheckpoint* oC = &(os->checkpointList[i]);
589                         // Check based on time            
590                         if((oC->time > 0 && (soFar - oC->lastCheckpoint) >= oC->time) || force) {
591                                 oC->lastCheckpoint = soFar;
592                                 n = 1;
593                         }
594                         // Or iterations
595                         if((oC->numIterations > 0 && (os->majorIteration - oC->lastCheckpoint) >= oC->numIterations) || force) {
596                                 oC->lastCheckpoint = os->majorIteration;
597                                 n = 1;
598                         }
599
600                         if(n) {         //In either case, save a checkpoint.
601                                 if(os->chkptText1 == NULL) {    // First one: set up output
602                                         omxWriteCheckpointHeader(os, oC);
603                                 }
604                                 char tempstring[25];
605                                 sprintf(tempstring, "%d", os->majorIteration);
606
607                                 if(strncmp(os->chkptText1, tempstring, strlen(tempstring))) {   // Returns zero if they're the same.
608                                         struct tm * nowTime = localtime(&now);                                          // So this only happens if the text is out of date.
609                                         strftime(tempstring, 25, "%b %d %Y %I:%M:%S %p", nowTime);
610                                         sprintf(os->chkptText1, "%d \"%s\" %9.5f", os->majorIteration, tempstring, f);
611                                         for(size_t j = 0; j < numParam; j++) {
612                                                 sprintf(tempstring, " %9.5f", x[j]);
613                                                 strncat(os->chkptText1, tempstring, 14);
614                                         }
615                                 }
616
617                                 if(oC->type == OMX_FILE_CHECKPOINT) {
618                                         fprintf(oC->file, "%s", os->chkptText1);
619                                         if(oC->saveHessian)
620                                                 fprintf(oC->file, "%s", os->chkptText2);
621                                         fprintf(oC->file, "\n");
622                                         fflush(oC->file);
623                                 } else if(oC->type == OMX_CONNECTION_CHECKPOINT) {
624                                         Rf_warning("NYI: R_connections are not yet implemented.");
625                                         oC->numIterations = 0;
626                                         oC->time = 0;
627                                 }
628                         }
629                 }
630         }
631
632 omxFreeVarLocation *omxFreeVar::getLocation(int matrix)
633 {
634         for (size_t lx=0; lx < locations.size(); lx++) {
635                 omxFreeVarLocation *loc = &locations[lx];
636                 if (~loc->matrix == matrix) return loc;
637         }
638         return NULL;
639 }