Add option to checkpoint every evaluation
[openmx:openmx.git] / src / omxState.cpp
1 /*
2  *  Copyright 2007-2014 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include <stdarg.h>
18 #include <errno.h>
19
20 #include "omxState.h"
21 #include "Compute.h"
22 #include "omxOpenmpWrap.h"
23
24 struct omxGlobal *Global = NULL;
25
26 FreeVarGroup *omxGlobal::findVarGroup(int id)
27 {
28         size_t numGroups = Global->freeGroup.size();
29         for (size_t vx=0; vx < numGroups; ++vx) {
30                 std::vector<int> &ids = Global->freeGroup[vx]->id;
31                 for (size_t ix=0; ix < ids.size(); ++ix) {
32                         if (ids[ix] == id) return Global->freeGroup[vx];
33                 }
34         }
35         return NULL;
36 }
37
38 FreeVarGroup *omxGlobal::findOrCreateVarGroup(int id)
39 {
40         FreeVarGroup *old = findVarGroup(id);
41         if (old) return old;
42
43         FreeVarGroup *fvg = new FreeVarGroup;
44         fvg->id.push_back(id);
45         Global->freeGroup.push_back(fvg);
46         return fvg;
47 }
48
49 bool FreeVarGroup::hasSameVars(FreeVarGroup *g2)
50 {
51         if (vars.size() != g2->vars.size()) return false;
52
53         for (size_t vx=0; vx < vars.size(); ++vx) {
54                 if (vars[vx] != g2->vars[vx]) return false;
55         }
56         return true;
57 }
58
59 int FreeVarGroup::lookupVar(const char *name)
60 {
61         for (size_t vx=0; vx < vars.size(); ++vx) {
62                 if (strcmp(name, vars[vx]->name) == 0) return vx;
63         }
64         return -1;
65 }
66
67 void FreeVarGroup::cacheDependencies()
68 {
69         omxState *os = globalState;
70         size_t numMats = os->matrixList.size();
71         size_t numAlgs = os->algebraList.size();
72
73         dependencies.assign(numMats + numAlgs, false);
74         locations.assign(numMats, false);
75
76         for (size_t vx = 0; vx < vars.size(); vx++) {
77                 omxFreeVar *fv = vars[vx];
78                 int *deps   = fv->deps;
79                 int numDeps = fv->numDeps;
80                 for (int index = 0; index < numDeps; index++) {
81                         dependencies[deps[index] + numMats] = true;
82                 }
83                 for (size_t lx=0; lx < fv->locations.size(); ++lx) {
84                         locations[fv->locations[lx].matrix] = true;
85                 }
86         }
87
88         // Everything is set up. This is a good place to log.
89         if (OMX_DEBUG) { log(); }
90 }
91
92 void FreeVarGroup::markDirty(omxState *os)
93 {
94         size_t numMats = os->matrixList.size();
95         size_t numAlgs = os->algebraList.size();
96
97         for(size_t i = 0; i < numMats; i++) {
98                 if (!locations[i]) continue;
99                 omxMarkClean(os->matrixList[i]);
100         }
101
102         for(size_t i = 0; i < numMats; i++) {
103                 if (dependencies[i]) {
104                         int offset = ~(i - numMats);
105                         omxMarkDirty(os->matrixList[offset]);
106                 }
107         }
108
109         for(size_t i = 0; i < numAlgs; i++) {
110                 if (dependencies[i + numMats]) {
111                         omxMarkDirty(os->algebraList[i]);
112                 }
113         }
114 }
115
116 void FreeVarGroup::log()
117 {
118         omxState *os = globalState;
119         size_t numMats = os->matrixList.size();
120         size_t numAlgs = os->algebraList.size();
121         std::string str;
122
123         str += string_snprintf("FreeVarGroup(id=%d", id[0]);
124         for (size_t ix=1; ix < id.size(); ++ix) {
125                 str += string_snprintf(",%d", id[ix]);
126         }
127         str += string_snprintf(") with %d variables:", (int) vars.size());
128
129         for (size_t vx=0; vx < vars.size(); ++vx) {
130                 str += " ";
131                 str += vars[vx]->name;
132         }
133         if (vars.size()) str += "\nwill dirty:";
134
135         for(size_t i = 0; i < numMats; i++) {
136                 if (dependencies[i]) {
137                         int offset = ~(i - numMats);
138                         str += " ";
139                         str += os->matrixList[offset]->name;
140                 }
141         }
142
143         for(size_t i = 0; i < numAlgs; i++) {
144                 if (dependencies[i + numMats]) {
145                         str += " ";
146                         str += os->algebraList[i]->name;
147                 }
148         }
149         str += "\n";
150
151         mxLogBig(str);
152 }
153
154 omxGlobal::omxGlobal()
155 {
156         ciMaxIterations = 5;
157         numThreads = 1;
158         analyticGradients = 0;
159         numChildren = 0;
160         llScale = -2.0;
161         computeCount = 0;
162 }
163
164 void omxGlobal::deduplicateVarGroups()
165 {
166         for (size_t g1=0; g1 < freeGroup.size(); ++g1) {
167                 for (size_t g2=freeGroup.size()-1; g2 > g1; --g2) {
168                         if (freeGroup[g1]->hasSameVars(freeGroup[g2])) {
169                                 freeGroup[g1]->id.insert(freeGroup[g1]->id.end(),
170                                                          freeGroup[g2]->id.begin(), freeGroup[g2]->id.end());
171                                 delete freeGroup[g2];
172                                 freeGroup.erase(freeGroup.begin() + g2);
173                         }
174                 }
175         }
176 }
177
178 /* Initialize and Destroy */
179         void omxInitState(omxState* state) {
180                 state->stale = FALSE;
181                 state->numConstraints = 0;
182                 state->conList = NULL;
183                 state->currentRow = -1;
184         }
185
186         void omxDuplicateState(omxState* tgt, omxState* src) {
187                 tgt->dataList                   = src->dataList;
188                 
189                 for(size_t mx = 0; mx < src->matrixList.size(); mx++) {
190                         // TODO: Smarter inference for which matrices to duplicate
191                         tgt->matrixList.push_back(omxDuplicateMatrix(src->matrixList[mx], tgt));
192                 }
193
194                 tgt->numConstraints     = src->numConstraints;
195                 tgt->conList                    = (omxConstraint*) R_alloc(tgt->numConstraints, sizeof(omxConstraint));
196                 for(int j = 0; j < tgt->numConstraints; j++) {
197                         tgt->conList[j].size   = src->conList[j].size;
198                         tgt->conList[j].opCode = src->conList[j].opCode;
199                         tgt->conList[j].lbound = src->conList[j].lbound;
200                         tgt->conList[j].ubound = src->conList[j].ubound;
201                         tgt->conList[j].result = omxDuplicateMatrix(src->conList[j].result, tgt);
202                 }
203
204                 for(size_t j = 0; j < src->algebraList.size(); j++) {
205                         // TODO: Smarter inference for which algebras to duplicate
206                         tgt->algebraList.push_back(omxDuplicateMatrix(src->algebraList[j], tgt));
207                 }
208
209                 for(size_t j = 0; j < src->expectationList.size(); j++) {
210                         // TODO: Smarter inference for which expectations to duplicate
211                         tgt->expectationList.push_back(omxDuplicateExpectation(src->expectationList[j], tgt));
212                 }
213
214                 for(size_t j = 0; j < tgt->algebraList.size(); j++) {
215                         omxDuplicateAlgebra(tgt->algebraList[j], src->algebraList[j], tgt);
216                 }
217
218                 for(size_t j = 0; j < src->expectationList.size(); j++) {
219                         // TODO: Smarter inference for which expectations to duplicate
220                         omxCompleteExpectation(tgt->expectationList[j]);
221                 }
222
223                 tgt->currentRow                 = src->currentRow;
224         }
225
226         omxMatrix* omxLookupDuplicateElement(omxState* os, omxMatrix* element) {
227                 if(os == NULL || element == NULL) return NULL;
228
229                 if (element->hasMatrixNumber) {
230                         int matrixNumber = element->matrixNumber;
231                         if (matrixNumber >= 0) {
232                                 return(os->algebraList[matrixNumber]);
233                         } else {
234                                 return(os->matrixList[-matrixNumber - 1]);
235                         }
236                 }
237
238                 omxConstraint* parentConList = globalState->conList;
239
240                 for(int i = 0; i < os->numConstraints; i++) {
241                         if(parentConList[i].result == element) {
242                                 if(os->conList[i].result != NULL) {   // Not sure of proper failure behavior here.
243                 return(os->conList[i].result);
244                                 } else {
245                     omxRaiseError("Initialization Copy Error: Constraint required but not yet processed.");
246             }
247                         }
248                 }
249
250                 return NULL;
251         }
252         
253 void omxFreeChildStates(omxState *state)
254 {
255         if (state->childList.size() == 0) return;
256
257         for(int k = 0; k < Global->numChildren; k++) {
258                 // Data are not modified and not copied. The same memory
259                 // is shared across all instances of state. We only need
260                 // to free the data once, so let the parent do it.
261                 state->childList[k]->dataList.clear();
262
263                 omxFreeState(state->childList[k]);
264         }
265         state->childList.clear();
266         Global->numChildren = 0;
267 }
268
269         void omxFreeState(omxState *state) {
270                 omxFreeChildStates(state);
271
272                 if(OMX_DEBUG) { mxLog("Freeing %d Constraints.", (int) state->numConstraints);}
273                 for(int k = 0; k < state->numConstraints; k++) {
274                         omxFreeMatrix(state->conList[k].result);
275                 }
276
277                 for(size_t ax = 0; ax < state->algebraList.size(); ax++) {
278                         // free argument tree
279                         omxFreeMatrix(state->algebraList[ax]);
280                 }
281
282                 for(size_t ax = 0; ax < state->algebraList.size(); ax++) {
283                         state->algebraList[ax]->hasMatrixNumber = false;
284                         omxFreeMatrix(state->algebraList[ax]);
285                 }
286
287                 if(OMX_DEBUG) { mxLog("Freeing %d Matrices.", (int) state->matrixList.size());}
288                 for(size_t mk = 0; mk < state->matrixList.size(); mk++) {
289                         state->matrixList[mk]->hasMatrixNumber = false;
290                         omxFreeMatrix(state->matrixList[mk]);
291                 }
292                 
293                 if(OMX_DEBUG) { mxLog("Freeing %d Model Expectations.", (int) state->expectationList.size());}
294                 for(size_t ex = 0; ex < state->expectationList.size(); ex++) {
295                         omxFreeExpectationArgs(state->expectationList[ex]);
296                 }
297
298                 if(OMX_DEBUG) { mxLog("Freeing %d Data Sets.", (int) state->dataList.size());}
299                 for(size_t dx = 0; dx < state->dataList.size(); dx++) {
300                         omxFreeData(state->dataList[dx]);
301                 }
302
303                 delete state;
304
305                 if(OMX_DEBUG) { mxLog("State Freed.");}
306         }
307
308 omxGlobal::~omxGlobal()
309 {
310         for (size_t cx=0; cx < computeList.size(); ++cx) {
311                 delete computeList[cx];
312         }
313         for (size_t cx=0; cx < algebraList.size(); ++cx) {
314                 delete algebraList[cx];
315         }
316         for (size_t cx=0; cx < checkpointList.size(); ++cx) {
317                 delete checkpointList[cx];
318         }
319         if (freeGroup.size()) {
320                 std::vector< omxFreeVar* > &vars = freeGroup[0]->vars;  // has all vars
321                 for (size_t vx=0; vx < vars.size(); ++vx) {
322                         delete vars[vx];
323                 }
324         }
325         for (size_t gx=0; gx < freeGroup.size(); ++gx) {
326                 delete freeGroup[gx];
327         }
328 }
329
330 std::string string_vsnprintf(const char *fmt, va_list orig_ap)
331 {
332     int size = 100;
333     std::string str;
334     while (1) {
335         str.resize(size);
336         va_list ap;
337         va_copy(ap, orig_ap);
338         int n = vsnprintf((char *)str.c_str(), size, fmt, ap);
339         va_end(ap);
340         if (n > -1 && n < size) {
341             str.resize(n);
342             return str;
343         }
344         if (n > -1)
345             size = n + 1;
346         else
347             size *= 2;
348     }
349     return str;
350 }
351
352 std::string string_snprintf(const char *fmt, ...)
353 {
354         va_list ap;
355         va_start(ap, fmt);
356         std::string str = string_vsnprintf(fmt, ap);
357         va_end(ap);
358         return str;
359 }
360
361 void mxLogBig(const std::string str)   // thread-safe
362 {
363         ssize_t len = ssize_t(str.size());
364         ssize_t wrote = 0;
365         int maxRetries = 20;
366         ssize_t got;
367 #pragma omp critical(stderp)
368         {
369                 while (--maxRetries > 0) {
370                         got = write(2, str.data() + wrote, len - wrote);
371                         if (got == -EINTR) continue;
372                         if (got <= 0) break;
373                         wrote += got;
374                         if (wrote == len) break;
375                 }
376         }
377         if (got <= 0) Rf_error("mxLogBig failed with errno=%d", got);
378
379 }
380
381 void mxLog(const char* msg, ...)   // thread-safe
382 {
383         const int maxLen = 240;
384         char buf1[maxLen];
385         char buf2[maxLen];
386
387         va_list ap;
388         va_start(ap, msg);
389         vsnprintf(buf1, maxLen, msg, ap);
390         va_end(ap);
391
392         int len = snprintf(buf2, maxLen, "[%d] %s\n", omx_absolute_thread_num(), buf1);
393
394         int maxRetries = 20;
395         ssize_t wrote = 0;
396         ssize_t got;
397 #pragma omp critical(stderp)
398         {
399                 while (--maxRetries > 0) {
400                         got = write(2, buf2 + wrote, len - wrote);
401                         if (got == -EINTR) continue;
402                         if (got <= 0) break;
403                         wrote += got;
404                         if (wrote == len) break;
405                 }
406         }
407         if (got <= 0) Rf_error("mxLog failed with errno=%d", got);
408 }
409
410 void _omxRaiseError()
411 {
412         // keep for debugger breakpoints
413 }
414
415 void omxRaiseErrorf(const char* msg, ...)
416 {
417         va_list ap;
418         va_start(ap, msg);
419         std::string str = string_vsnprintf(msg, ap);
420         va_end(ap);
421         _omxRaiseError();
422
423         if(OMX_DEBUG) {
424                 mxLog("Error raised: %s", str.c_str());
425         }
426
427         bool overflow = false;
428 #pragma omp critical(bads)
429         {
430                 if (Global->bads.size() > 100) {
431                         overflow = true;
432                 } else {
433                         Global->bads.push_back(str);
434                 }
435         }
436
437         // mxLog takes a lock too, so call it outside of critical section
438         if (overflow) mxLog("Too many errors: %s", str.c_str());
439 }
440
441 const char *omxGlobal::getBads()
442 {
443         if (bads.size() == 0) return NULL;
444
445         std::string str;
446         for (size_t mx=0; mx < bads.size(); ++mx) {
447                 if (bads.size() > 1) str += string_snprintf("%d:", (int)mx+1);
448                 str += bads[mx];
449                 if (str.size() > (1<<14)) break;
450                 if (mx < bads.size() - 1) str += "\n";
451         }
452
453         size_t sz = str.size();
454         char *mem = R_alloc(sz+1, 1);  // use R's memory
455         memcpy(mem, str.c_str(), sz);
456         mem[sz] = 0;
457         return mem;
458 }
459
460 void omxRaiseError(const char* msg) { // DEPRECATED
461         omxRaiseErrorf("%s", msg);
462 }
463
464         void omxStateNextRow(omxState *state) {
465                 state->currentRow++;
466         };
467
468 void omxGlobal::checkpointMessage(FitContext *fc, double *est, const char *fmt, ...)
469 {
470         va_list ap;
471         va_start(ap, fmt);
472         std::string str = string_vsnprintf(fmt, ap);
473         va_end(ap);
474
475         for(size_t i = 0; i < checkpointList.size(); i++) {
476                 checkpointList[i]->message(fc, est, str.c_str());
477         }
478 }
479
480 void omxGlobal::checkpointPrefit(FitContext *fc, double *est, bool force)
481 {
482         for(size_t i = 0; i < checkpointList.size(); i++) {
483                 checkpointList[i]->prefit(fc, est, force);
484         }
485 }
486
487 void omxGlobal::checkpointPostfit(FitContext *fc)
488 {
489         for(size_t i = 0; i < checkpointList.size(); i++) {
490                 checkpointList[i]->postfit(fc);
491         }
492 }
493
494 omxCheckpoint::omxCheckpoint() : wroteHeader(false), lastCheckpoint(0), lastIterations(0),
495                                  lastEvaluation(0), fitPending(false),
496                                  timePerCheckpoint(0), iterPerCheckpoint(0), evalsPerCheckpoint(0), file(NULL)
497 {}
498
499 omxCheckpoint::~omxCheckpoint()
500 {
501         if (file) fclose(file);
502 }
503
504 /* We need to re-design checkpointing when it is possible to run
505    more than 1 optimization in parallel. */
506 void omxCheckpoint::omxWriteCheckpointHeader()
507 {
508         if (wroteHeader) return;
509         std::vector< omxFreeVar* > &vars = Global->freeGroup[0]->vars;
510         size_t numParam = vars.size();
511
512         // New columns should use the OpenMx prefit to avoid clashing with
513         // free parameter names.
514         fprintf(file, "OpenMxContext\tOpenMxNumFree\tOpenMxEvals\titerations\ttimestamp");
515         for(size_t j = 0; j < numParam; j++) {
516                 fprintf(file, "\t\"%s\"", vars[j]->name);
517         }
518         fprintf(file, "\tobjective\n");
519         fflush(file);
520         wroteHeader = true;
521 }
522  
523 void omxCheckpoint::message(FitContext *fc, double *est, const char *msg)
524 {
525         _prefit(fc, est, true, msg);
526         postfit(fc);
527 }
528
529 void omxCheckpoint::_prefit(FitContext *fc, double *est, bool force, const char *context)
530 {
531         const int timeBufSize = 32;
532         char timeBuf[timeBufSize];
533         time_t now = time(NULL); // avoid checking unless we need it
534
535         bool doit = force;
536         if ((timePerCheckpoint && timePerCheckpoint <= now - lastCheckpoint) ||
537             (iterPerCheckpoint && iterPerCheckpoint <= fc->iterations - lastIterations) ||
538             (evalsPerCheckpoint && evalsPerCheckpoint <= Global->computeCount - lastEvaluation)) {
539                 doit = true;
540         }
541         if (!doit) return;
542
543         omxWriteCheckpointHeader();
544
545         std::vector< omxFreeVar* > &vars = fc->varGroup->vars;
546         struct tm *nowTime = localtime(&now);
547         strftime(timeBuf, timeBufSize, "%b %d %Y %I:%M:%S %p", nowTime);
548         fprintf(file, "%s\t%d\t%d\t%d\t%s", context, int(vars.size()), lastEvaluation, lastIterations, timeBuf);
549
550         size_t lx=0;
551         size_t numParam = Global->freeGroup[0]->vars.size();
552         for (size_t px=0; px < numParam; ++px) {
553                 if (lx < vars.size() && vars[lx]->id == (int)px) {
554                         fprintf(file, "\t%.10g", est[lx]);
555                         ++lx;
556                 } else {
557                         fprintf(file, "\tNA");
558                 }
559         }
560         fflush(file);
561         if (fitPending) Rf_error("Checkpoint not reentrant");
562         fitPending = true;
563         lastCheckpoint = now;
564         lastIterations = fc->iterations;
565         lastEvaluation = Global->computeCount;
566 }
567
568 void omxCheckpoint::prefit(FitContext *fc, double *est, bool force)
569 {
570         _prefit(fc, est, force, "opt");
571 }
572
573 void omxCheckpoint::postfit(FitContext *fc)
574 {
575         if (!fitPending) return;
576         fprintf(file, "\t%.10g\n", fc->fit);
577         fflush(file);
578         fitPending = false;
579 }
580
581 omxFreeVarLocation *omxFreeVar::getLocation(int matrix)
582 {
583         for (size_t lx=0; lx < locations.size(); lx++) {
584                 omxFreeVarLocation *loc = &locations[lx];
585                 if (~loc->matrix == matrix) return loc;
586         }
587         return NULL;
588 }