Switch over to new structured Compute system
[openmx:openmx.git] / src / omxState.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include <stdarg.h>
18 #include "omxState.h"
19 #include "Compute.h"
20
21 /* Initialize and Destroy */
22         void omxInitState(omxState* state, omxState *parentState) {
23                 state->ciMaxIterations = 5;
24                 state->numThreads = 1;
25                 state->numHessians = 0;
26
27                 state->numConstraints = 0;
28                 state->numFreeParams = 0;
29                 state->numChildren = 0;
30                 state->childList = NULL;
31                 state->parentState = parentState;
32                 state->conList = NULL;
33                 state->freeVarList = NULL;
34
35                 state->majorIteration = 0;
36                 state->minorIteration = 0;
37                 state->startTime = 0;
38                 state->endTime = 0;
39                 state->numCheckpoints = 0;
40                 state->checkpointList = NULL;
41                 state->chkptText1 = NULL;
42                 state->chkptText2 = NULL;
43
44                 state->computeCount = 0;
45                 state->currentRow = -1;
46
47                 strncpy(state->statusMsg, "", 1);
48         }
49
50         omxState* omxGetState(omxState* os, int stateNumber) {
51                 // TODO: Need to implement a smarter way to enumerate children
52                 if(stateNumber == 0) return os;
53                 if((stateNumber-1) < os->numChildren) {
54                         return(os->childList[stateNumber-1]);
55                 } else {
56                         error("Not implemented");
57                         // TODO: Account for unequal numbers of grandchild states
58                         int subState = (stateNumber - os->numChildren - 1);
59                         return omxGetState(os->childList[subState % os->numChildren], subState / os->numChildren);
60                 }
61         }
62
63         void omxSetMajorIteration(omxState *state, int value) {
64                 state->majorIteration = value;
65                 for(int i = 0; i < state->numChildren; i++) {
66                         omxSetMajorIteration(state->childList[i], value);
67                 }
68         }
69
70         void omxSetMinorIteration(omxState *state, int value) {
71                 state->minorIteration = value;
72                 for(int i = 0; i < state->numChildren; i++) {
73                         omxSetMinorIteration(state->childList[i], value);
74                 }
75         }
76         
77         void omxDuplicateState(omxState* tgt, omxState* src) {
78                 tgt->dataList                   = src->dataList;
79                 tgt->numChildren                = 0;
80                 
81                 // Duplicate matrices and algebras and build parentLists.
82                 tgt->parentState                = src;
83                                 
84                 for(size_t mx = 0; mx < src->matrixList.size(); mx++) {
85                         // TODO: Smarter inference for which matrices to duplicate
86                         tgt->matrixList.push_back(omxDuplicateMatrix(src->matrixList[mx], tgt));
87                 }
88
89                 tgt->numConstraints     = src->numConstraints;
90                 tgt->conList                    = (omxConstraint*) R_alloc(tgt->numConstraints, sizeof(omxConstraint));
91                 for(int j = 0; j < tgt->numConstraints; j++) {
92                         tgt->conList[j].size   = src->conList[j].size;
93                         tgt->conList[j].opCode = src->conList[j].opCode;
94                         tgt->conList[j].lbound = src->conList[j].lbound;
95                         tgt->conList[j].ubound = src->conList[j].ubound;
96                         tgt->conList[j].result = omxDuplicateMatrix(src->conList[j].result, tgt);
97                 }
98
99                 for(size_t j = 0; j < src->algebraList.size(); j++) {
100                         // TODO: Smarter inference for which algebras to duplicate
101                         tgt->algebraList.push_back(omxDuplicateMatrix(src->algebraList[j], tgt));
102                 }
103
104                 for(size_t j = 0; j < src->expectationList.size(); j++) {
105                         // TODO: Smarter inference for which expectations to duplicate
106                         tgt->expectationList.push_back(omxDuplicateExpectation(src->expectationList[j], tgt));
107                 }
108
109                 for(size_t j = 0; j < tgt->algebraList.size(); j++) {
110                         omxDuplicateAlgebra(tgt->algebraList[j], src->algebraList[j], tgt);
111                 }
112
113                 for(size_t j = 0; j < src->expectationList.size(); j++) {
114                         // TODO: Smarter inference for which expectations to duplicate
115                         omxCompleteExpectation(tgt->expectationList[j]);
116                 }
117
118                 tgt->childList                  = NULL;
119
120                 tgt->numFreeParams                      = src->numFreeParams;
121                 tgt->freeVarList                = new omxFreeVar[tgt->numFreeParams];
122                 for(int j = 0; j < tgt->numFreeParams; j++) {
123                         int numDeps                                                     = src->freeVarList[j].numDeps;
124
125                         tgt->freeVarList[j].lbound                      = src->freeVarList[j].lbound;
126                         tgt->freeVarList[j].ubound                      = src->freeVarList[j].ubound;
127                         tgt->freeVarList[j].locations                   = src->freeVarList[j].locations;
128                         tgt->freeVarList[j].numDeps                     = numDeps;
129                         
130                         tgt->freeVarList[j].deps                        = (int*) R_alloc(numDeps, sizeof(int));
131
132                         tgt->freeVarList[j].name                = src->freeVarList[j].name;
133
134                         for(int k = 0; k < numDeps; k++) {
135                                 tgt->freeVarList[j].deps[k] = src->freeVarList[j].deps[k];
136                         }
137                 }
138                 
139                 tgt->majorIteration     = 0;
140                 tgt->minorIteration     = 0;
141                 tgt->startTime                  = src->startTime;
142                 tgt->endTime                    = 0;
143                 
144                 // TODO: adjust checkpointing based on parallelization method
145                 tgt->numCheckpoints     = 0;
146                 tgt->checkpointList     = NULL;
147                 tgt->chkptText1                 = NULL;
148                 tgt->chkptText2                 = NULL;
149                                   
150                 tgt->computeCount               = src->computeCount;
151                 tgt->currentRow                 = src->currentRow;
152
153                 strncpy(tgt->statusMsg, "", 1);
154         }
155
156         omxMatrix* omxLookupDuplicateElement(omxState* os, omxMatrix* element) {
157                 if(os == NULL || element == NULL) return NULL;
158
159                 if (element->hasMatrixNumber) {
160                         int matrixNumber = element->matrixNumber;
161                         if (matrixNumber >= 0) {
162                                 return(os->algebraList[matrixNumber]);
163                         } else {
164                                 return(os->matrixList[-matrixNumber - 1]);
165                         }
166                 }
167
168                 omxConstraint* parentConList = os->parentState->conList;
169
170                 for(int i = 0; i < os->numConstraints; i++) {
171                         if(parentConList[i].result == element) {
172                                 if(os->conList[i].result != NULL) {   // Not sure of proper failure behavior here.
173                 return(os->conList[i].result);
174                                 } else {
175                     omxRaiseError(os, -2, "Initialization Copy Error: Constraint required but not yet processed.");
176             }
177                         }
178                 }
179
180                 return NULL;
181         }
182         
183         omxExpectation* omxLookupDuplicateExpectation(omxState* os, omxExpectation* ox) {
184                 if(os == NULL || ox == NULL) return NULL;
185
186                 return(os->expectationList[ox->expNum]);
187         }
188
189         void omxFreeState(omxState *state) {
190                 int k;
191
192                 if (state->numChildren > 0) {
193                         for(k = 0; k < state->numChildren; k++) {
194                                 // Data are not modified and not copied. The same memory
195                                 // is shared across all instances of state. We only need
196                                 // to free the data once, so let the parent do it.
197                                 state->childList[k]->dataList.clear();
198
199                                 omxFreeState(state->childList[k]);
200                         }
201                         Free(state->childList);
202                         state->childList = NULL;
203                         state->numChildren = 0;
204                 }
205
206                 for(size_t ax = 0; ax < state->algebraList.size(); ax++) {
207                         if(OMX_DEBUG) { Rprintf("Freeing Algebra %d at 0x%x.\n", ax, state->algebraList[ax]); }
208                         omxFreeAllMatrixData(state->algebraList[ax]);
209                 }
210
211                 if(OMX_DEBUG) { Rprintf("Freeing %d Matrices.\n", state->matrixList.size());}
212                 for(size_t mk = 0; mk < state->matrixList.size(); mk++) {
213                         if(OMX_DEBUG) { Rprintf("Freeing Matrix %d at 0x%x.\n", mk, state->matrixList[mk]); }
214                         omxFreeAllMatrixData(state->matrixList[mk]);
215                 }
216                 
217                 if(OMX_DEBUG) { Rprintf("Freeing %d Model Expectations.\n", state->expectationList.size());}
218                 for(size_t ex = 0; ex < state->expectationList.size(); ex++) {
219                         if(OMX_DEBUG) { Rprintf("Freeing Expectation %d at 0x%x.\n", ex, state->expectationList[ex]); }
220                         omxFreeExpectationArgs(state->expectationList[ex]);
221                 }
222
223                 if(OMX_DEBUG) { Rprintf("Freeing %d Constraints.\n", state->numConstraints);}
224                 for(k = 0; k < state->numConstraints; k++) {
225                         if(OMX_DEBUG) { Rprintf("Freeing Constraint %d at 0x%x.\n", k, state->conList[k]); }
226                         omxFreeAllMatrixData(state->conList[k].result);
227                 }
228
229                 if(OMX_DEBUG) { Rprintf("Freeing %d Data Sets.\n", state->dataList.size());}
230                 for(size_t dx = 0; dx < state->dataList.size(); dx++) {
231                         if(OMX_DEBUG) { Rprintf("Freeing Data Set %d at 0x%x.\n", dx, state->dataList[dx]); }
232                         omxFreeData(state->dataList[dx]);
233                 }
234
235                 delete [] state->freeVarList;
236
237         if(OMX_DEBUG) {Rprintf("Freeing %d Children.\n", state->numChildren);}
238         for(k = 0; k < state->numChildren; k++) {
239                         if(OMX_DEBUG) { Rprintf("Freeing Child State %d at 0x%x.\n", k, state->childList[k]); }
240                         omxFreeState(state->childList[k]);            
241         }
242
243                 if(OMX_DEBUG) { Rprintf("Freeing %d Checkpoints.\n", state->numCheckpoints);}
244                 for(k = 0; k < state->numCheckpoints; k++) {
245                         if(OMX_DEBUG) { Rprintf("Freeing Data Set %d at 0x%x.\n", k, state->checkpointList[k]); }
246                         omxCheckpoint oC = state->checkpointList[k];
247                         switch(oC.type) {
248                                 case OMX_FILE_CHECKPOINT:
249                                         fclose(oC.file);
250                                         break;
251                                 case OMX_CONNECTION_CHECKPOINT: // NYI :::DEBUG:::
252                                         // Do nothing: this should be handled by R upon return.
253                                         break;
254                         }
255                         if(state->chkptText1 != NULL) {
256                                 Free(state->chkptText1);
257                         }
258                         if(state->chkptText2 != NULL) {
259                                 Free(state->chkptText2);
260                         }
261                         // Checkpoint list itself is freed by R.
262                 }
263
264                 for (size_t ex = 0; ex < state->computeList.size(); ex++) {
265                         delete state->computeList[ex];
266                 }
267
268                 delete state;
269
270                 if(OMX_DEBUG) { Rprintf("State Freed.\n");}
271         }
272
273         void omxResetStatus(omxState *state) {
274                 int numChildren = state->numChildren;
275                 state->statusMsg[0] = '\0';
276                 for(int i = 0; i < numChildren; i++) {
277                         omxResetStatus(state->childList[i]);
278                 }
279         }
280
281 void omxRaiseErrorf(omxState *state, const char* errorMsg, ...)
282 {
283         va_list ap;
284         va_start(ap, errorMsg);
285         int fit = vsnprintf(state->statusMsg, MAX_STRING_LEN, errorMsg, ap);
286         va_end(ap);
287         if(OMX_DEBUG) {
288                 if (!(fit > -1 && fit < MAX_STRING_LEN)) {
289                         Rprintf("Error exceeded maximum length: %s\n", errorMsg);
290                 } else {
291                         Rprintf("Error raised: %s\n", state->statusMsg);
292                 }
293         }
294 }
295
296         void omxRaiseError(omxState *state, int errorCode, const char* errorMsg) { // DEPRECATED
297                 if(OMX_DEBUG && errorCode) { Rprintf("Error %d raised: %s\n", errorCode, errorMsg);}
298                 if(OMX_DEBUG && !errorCode) { Rprintf("Error status cleared."); }
299                 strncpy(state->statusMsg, errorMsg, 249);
300                 state->statusMsg[249] = '\0';
301         }
302
303         void omxStateNextRow(omxState *state) {
304                 state->currentRow++;
305         };
306
307         void omxStateNextEvaluation(omxState *state) {
308                 state->currentRow = -1;
309                 state->computeCount++;
310         };
311
312         void omxWriteCheckpointHeader(omxState *os, omxCheckpoint* oC) {
313                 // FIXME: Is it faster to allocate this on the stack?
314                 os->chkptText1 = (char*) Calloc((24 + 15 * os->numFreeParams), char);
315                 os->chkptText2 = (char*) Calloc(1.0 + 15.0 * os->numFreeParams*
316                         (os->numFreeParams + 1.0) / 2.0, char);
317                 if (oC->type == OMX_FILE_CHECKPOINT) {
318                         fprintf(oC->file, "iterations\ttimestamp\tobjective\t");
319                         for(int j = 0; j < os->numFreeParams; j++) {
320                                 if(strcmp(os->freeVarList[j].name, CHAR(NA_STRING)) == 0) {
321                                         fprintf(oC->file, "%s", os->freeVarList[j].name);
322                                 } else {
323                                         fprintf(oC->file, "\"%s\"", os->freeVarList[j].name);
324                                 }
325                                 if (j != os->numFreeParams - 1) fprintf(oC->file, "\t");
326                         }
327                         fprintf(oC->file, "\n");
328                         fflush(oC->file);
329                 }
330         }
331  
332         void omxWriteCheckpointMessage(omxState *os, char *msg) {
333                 for(int i = 0; i < os->numCheckpoints; i++) {
334                         omxCheckpoint* oC = &(os->checkpointList[i]);
335                         if(os->chkptText1 == NULL) {    // First one: set up output
336                                 omxWriteCheckpointHeader(os, oC);
337                         }
338                         if (oC->type == OMX_FILE_CHECKPOINT) {
339                                 fprintf(oC->file, "%d \"%s\" NA ", os->majorIteration, msg);
340                                 for(int j = 0; j < os->numFreeParams; j++) {
341                                         fprintf(oC->file, "NA ");
342                                 }
343                                 fprintf(oC->file, "\n");
344                         }
345                 }
346         }
347
348         void omxSaveCheckpoint(omxState *os, double* x, double* f, int force) {
349                 time_t now = time(NULL);
350                 int soFar = now - os->startTime;                // Translated into minutes
351                 int n;
352                 for(int i = 0; i < os->numCheckpoints; i++) {
353                         n = 0;
354                         omxCheckpoint* oC = &(os->checkpointList[i]);
355                         // Check based on time            
356                         if((oC->time > 0 && (soFar - oC->lastCheckpoint) >= oC->time) || force) {
357                                 oC->lastCheckpoint = soFar;
358                                 n = 1;
359                         }
360                         // Or iterations
361                         if((oC->numIterations > 0 && (os->majorIteration - oC->lastCheckpoint) >= oC->numIterations) || force) {
362                                 oC->lastCheckpoint = os->majorIteration;
363                                 n = 1;
364                         }
365
366                         if(n) {         //In either case, save a checkpoint.
367                                 if(os->chkptText1 == NULL) {    // First one: set up output
368                                         omxWriteCheckpointHeader(os, oC);
369                                 }
370                                 char tempstring[25];
371                                 sprintf(tempstring, "%d", os->majorIteration);
372
373                                 if(strncmp(os->chkptText1, tempstring, strlen(tempstring))) {   // Returns zero if they're the same.
374                                         struct tm * nowTime = localtime(&now);                                          // So this only happens if the text is out of date.
375                                         strftime(tempstring, 25, "%b %d %Y %I:%M:%S %p", nowTime);
376                                         sprintf(os->chkptText1, "%d \"%s\" %9.5f", os->majorIteration, tempstring, f[0]);
377                                         for(int j = 0; j < os->numFreeParams; j++) {
378                                                 sprintf(tempstring, " %9.5f", x[j]);
379                                                 strncat(os->chkptText1, tempstring, 14);
380                                         }
381                                 }
382
383                                 if(oC->type == OMX_FILE_CHECKPOINT) {
384                                         fprintf(oC->file, "%s", os->chkptText1);
385                                         if(oC->saveHessian)
386                                                 fprintf(oC->file, "%s", os->chkptText2);
387                                         fprintf(oC->file, "\n");
388                                         fflush(oC->file);
389                                 } else if(oC->type == OMX_CONNECTION_CHECKPOINT) {
390                                         warning("NYI: R_connections are not yet implemented.");
391                                         oC->numIterations = 0;
392                                         oC->time = 0;
393                                 }
394                         }
395                 }
396         }
397
398 void omxExamineFitOutput(omxState *state, omxMatrix *fitMatrix, int *mode)
399 {
400         if (!R_FINITE(fitMatrix->data[0])) {
401                 omxRaiseErrorf(state, "Fit function returned %g at iteration %d.%d",
402                                fitMatrix->data[0], state->majorIteration, state->minorIteration);
403                 *mode = -1;
404         }
405 }