Initialize fit functions after expectations
[openmx:openmx.git] / src / npsolWrap.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include <R.h>
18 #include <Rinternals.h>
19 #include <Rdefines.h>
20 #include <R_ext/Rdynload.h>
21 #include <R_ext/BLAS.h>
22 #include <R_ext/Lapack.h>
23 #include "omxDefines.h"
24 #include "npsolWrap.h"
25 #include "omxOpenmpWrap.h"
26
27 #include <stdio.h>
28 #include <sys/types.h>
29 #include <errno.h>
30 #include "omxState.h"
31 #include "omxGlobalState.h"
32 #include "omxMatrix.h"
33 #include "omxAlgebra.h"
34 #include "omxFitFunction.h"
35 #include "omxExpectation.h"
36 #include "omxNPSOLSpecific.h"
37 #include "omxImportFrontendState.h"
38 #include "omxExportBackendState.h"
39 #include "omxHessianCalculation.h"
40 #include "omxOptimizer.h"
41
42 omp_lock_t GlobalRLock;
43
44 static R_CallMethodDef callMethods[] = {
45         {"omxBackend", (DL_FUNC) omxBackend, 12},
46         {"omxCallAlgebra", (DL_FUNC) omxCallAlgebra, 3},
47         {"findIdenticalRowsData", (DL_FUNC) findIdenticalRowsData, 5},
48         {NULL, NULL, 0}
49 };
50
51 #ifdef  __cplusplus
52 extern "C" {
53 #endif
54
55 void R_init_OpenMx(DllInfo *info) {
56         R_registerRoutines(info, NULL, callMethods, NULL, NULL);
57
58         omx_omp_init_lock(&GlobalRLock);
59
60         // There is no code that will change behavior whether openmp
61         // is set for nested or not. I'm just keeping this in case it
62         // makes a difference with older versions of openmp. 2012-12-24 JNP
63 #if defined(_OPENMP) && _OPENMP <= 200505
64         omp_set_nested(0);
65 #endif
66 }
67
68 void R_unload_OpenMx(DllInfo *info) {
69         omx_omp_destroy_lock(&GlobalRLock);
70 }
71
72 #ifdef  __cplusplus
73 }
74 #endif
75
76 void string_to_try_error( const std::string& str )
77 {
78         error("%s", str.c_str());
79 }
80
81 void exception_to_try_error( const std::exception& ex )
82 {
83         string_to_try_error(ex.what());
84 }
85
86 /* Main functions */
87 SEXP omxCallAlgebra2(SEXP matList, SEXP algNum, SEXP options) {
88
89         omxManageProtectInsanity protectManager;
90
91         if(OMX_DEBUG) { Rprintf("-----------------------------------------------------------------------\n");}
92         if(OMX_DEBUG) { Rprintf("Explicit call to algebra %d.\n", INTEGER(algNum));}
93
94         int j,k,l;
95         omxMatrix* algebra;
96         int algebraNum = INTEGER(algNum)[0];
97         SEXP ans, nextMat;
98         char output[250];
99         int errOut = 0;
100
101         /* Create new omxState for current state storage and initialize it. */
102         
103         globalState = new omxState;
104         omxInitState(globalState, NULL);
105         if(OMX_DEBUG) { Rprintf("Created state object at 0x%x.\n", globalState);}
106
107         /* Retrieve All Matrices From the MatList */
108
109         if(OMX_DEBUG) { Rprintf("Processing %d matrix(ces).\n", length(matList));}
110
111         omxMatrix *args[length(matList)];
112         for(k = 0; k < length(matList); k++) {
113                 PROTECT(nextMat = VECTOR_ELT(matList, k));      // This is the matrix + populations
114                 args[k] = omxNewMatrixFromRPrimitive(nextMat, globalState, 1, - k - 1);
115                 globalState->matrixList.push_back(args[k]);
116                 if(OMX_DEBUG) {
117                         Rprintf("Matrix initialized at 0x%0xd = (%d x %d).\n",
118                                 globalState->matrixList[k], globalState->matrixList[k]->rows, globalState->matrixList[k]->cols);
119                 }
120         }
121
122         algebra = omxNewAlgebraFromOperatorAndArgs(algebraNum, args, length(matList), globalState);
123
124         if(algebra==NULL) {
125                 error(globalState->statusMsg);
126         }
127
128         if(OMX_DEBUG) {Rprintf("Completed Algebras and Matrices.  Beginning Initial Compute.\n");}
129         omxStateNextEvaluation(globalState);
130
131         omxRecompute(algebra);
132
133         PROTECT(ans = allocMatrix(REALSXP, algebra->rows, algebra->cols));
134         for(l = 0; l < algebra->rows; l++)
135                 for(j = 0; j < algebra->cols; j++)
136                         REAL(ans)[j * algebra->rows + l] =
137                                 omxMatrixElement(algebra, l, j);
138
139         if(OMX_DEBUG) { Rprintf("All Algebras complete.\n"); }
140
141         if(globalState->statusCode != 0) {
142                 errOut = globalState->statusCode;
143                 strncpy(output, globalState->statusMsg, 250);
144         }
145
146         omxFreeAllMatrixData(algebra);
147         omxFreeState(globalState);
148
149         if(errOut != 0) {
150                 error(output);
151         }
152
153         return ans;
154 }
155
156 SEXP omxCallAlgebra(SEXP matList, SEXP algNum, SEXP options)
157 {
158         try {
159                 return omxCallAlgebra2(matList, algNum, options);
160         } catch( std::exception& __ex__ ) {
161                 exception_to_try_error( __ex__ );
162         } catch(...) {
163                 string_to_try_error( "c++ exception (unknown reason)" );
164         }
165 }
166
167 SEXP omxBackend2(SEXP fitfunction, SEXP startVals, SEXP constraints,
168         SEXP matList, SEXP varList, SEXP algList, SEXP expectList,
169         SEXP data, SEXP intervalList, SEXP checkpointList, SEXP options, SEXP state) {
170
171         /* Helpful variables */
172
173         int errOut = 0;                 // Error state: Clear
174
175         SEXP nextLoc;
176
177         int calculateStdErrors = FALSE;
178         int numHessians = 0;
179         int ciMaxIterations = 5;
180         int disableOptimizer = 0;
181         int numThreads = 1;
182         int analyticGradients = 0;
183
184         /* Sanity Check and Parse Inputs */
185         /* TODO: Need to find a way to account for nullness in these.  For now, all checking is done on the front-end. */
186 //      if(!isVector(startVals)) error ("startVals must be a vector");
187 //      if(!isVector(matList)) error ("matList must be a list");
188 //      if(!isVector(algList)) error ("algList must be a list");
189
190         omxManageProtectInsanity protectManager;
191
192         /*      Set NPSOL options */
193         omxSetNPSOLOpts(options, &numHessians, &calculateStdErrors, 
194                 &ciMaxIterations, &disableOptimizer, &numThreads, 
195                 &analyticGradients, length(startVals));
196
197         /* Create new omxState for current state storage and initialize it. */
198         globalState = new omxState;
199         omxInitState(globalState, NULL);
200         globalState->numThreads = numThreads;
201         globalState->numFreeParams = length(startVals);
202         globalState->analyticGradients = analyticGradients;
203         if(OMX_DEBUG) { Rprintf("Created state object at 0x%x.\n", globalState);}
204
205         /* Retrieve Data Objects */
206         if(!errOut) errOut = omxProcessMxDataEntities(data);
207     
208         /* Retrieve All Matrices From the MatList */
209         if(!errOut) omxProcessMxMatrixEntities(matList);
210
211         globalState->numAlgs = length(algList);
212         
213         if (length(startVals) != length(varList)) error("varList and startVals must be the same length");
214
215         /* Process Free Var List */
216         omxProcessFreeVarList(varList);
217
218         if(!errOut) {
219                 omxProcessMxExpectationEntities(expectList);
220                 errOut = globalState->statusMsg[0];
221         }
222
223         if(!errOut) {
224                 omxProcessMxAlgebraEntities(algList);
225                 errOut = globalState->statusMsg[0];
226         }
227
228         if(!errOut) {
229                 omxCompleteMxExpectationEntities();
230                 errOut = globalState->statusMsg[0];
231         }
232
233         if(!errOut) {
234                 omxProcessMxFitFunction(algList);
235                 errOut = globalState->statusMsg[0];
236         }
237
238         if(!errOut) {
239                 // This is the chance to check for matrix
240                 // conformability, etc.  Any errors encountered should
241                 // be reported using R's error() function, not
242                 // omxRaiseErrorf.
243
244                 omxInitialMatrixAlgebraCompute();
245                 omxResetStatus(globalState);
246         }
247
248         if(!errOut && !isNull(fitfunction)) {
249                 if(OMX_DEBUG) { Rprintf("Processing fit function.\n"); }
250                 globalState->fitMatrix = omxMatrixLookupFromState1(fitfunction, globalState);
251                 errOut = globalState->statusMsg[0];
252         }
253         
254         // TODO: Make calculateHessians an option instead.
255
256         if(errOut) error(globalState->statusMsg);
257
258         /* Process Matrix and Algebra Population Function */
259         /*
260           Each matrix is a list containing a matrix and the other matrices/algebras that are
261           populated into it at each iteration.  The first element is already processed, above.
262           The rest of the list will be processed here.
263         */
264         for(int j = 0; j < length(matList); j++) {
265                 PROTECT(nextLoc = VECTOR_ELT(matList, j));              // This is the matrix + populations
266                 omxProcessMatrixPopulationList(globalState->matrixList[j], nextLoc);
267         }
268
269         /* Processing Constraints */
270         omxProcessConstraints(constraints);
271
272         /* Process Confidence Interval List */
273         omxProcessConfidenceIntervals(intervalList);
274
275         /* Process Checkpoint List */
276         omxProcessCheckpointOptions(checkpointList);
277
278         // Probably, this is always the same for all children and
279         // doesn't need to be copied to child states.
280         cacheFreeVarDependencies(globalState);
281
282         omxFitFunctionCreateChildren(globalState, numThreads);
283
284         int n = globalState->numFreeParams;
285
286         SEXP minimum, estimate, gradient, hessian;
287         PROTECT(minimum = NEW_NUMERIC(1));
288         PROTECT(estimate = allocVector(REALSXP, n));
289         PROTECT(gradient = allocVector(REALSXP, n));
290         PROTECT(hessian = allocMatrix(REALSXP, n, n));
291
292         if (n>0) { memcpy(REAL(estimate), REAL(startVals), sizeof(double)*n); }
293         
294         omxInvokeNPSOL(REAL(minimum), REAL(estimate), REAL(gradient), REAL(hessian), disableOptimizer);
295
296         SEXP code, status, statusMsg, iterations;
297         SEXP evaluations, ans=NULL, names=NULL, algebras, matrices, expectations, optimizer;
298         SEXP intervals, NAmat, intervalCodes, calculatedHessian, stdErrors;
299
300         int numReturns = 14;
301
302         PROTECT(code = NEW_NUMERIC(1));
303         PROTECT(status = allocVector(VECSXP, 3));
304         PROTECT(iterations = NEW_NUMERIC(1));
305         PROTECT(evaluations = NEW_NUMERIC(2));
306         PROTECT(matrices = NEW_LIST(globalState->matrixList.size()));
307         PROTECT(algebras = NEW_LIST(globalState->numAlgs));
308         PROTECT(expectations = NEW_LIST(globalState->numExpects));
309
310         PROTECT(optimizer = allocVector(VECSXP, 2));
311         PROTECT(calculatedHessian = allocMatrix(REALSXP, n, n));
312         PROTECT(stdErrors = allocMatrix(REALSXP, n, 1)); // for optimizer
313         PROTECT(names = allocVector(STRSXP, 2)); // for optimizer
314         PROTECT(intervals = allocMatrix(REALSXP, globalState->numIntervals, 2)); // for optimizer
315         PROTECT(intervalCodes = allocMatrix(INTSXP, globalState->numIntervals, 2)); // for optimizer
316         PROTECT(NAmat = allocMatrix(REALSXP, 1, 1)); // In case of missingness
317         REAL(NAmat)[0] = R_NaReal;
318
319         omxSaveState(globalState, REAL(estimate), REAL(minimum)[0]);
320
321         /* Fill in details from the optimizer */
322         SET_VECTOR_ELT(optimizer, 0, gradient);
323         SET_VECTOR_ELT(optimizer, 1, hessian);
324
325         SET_STRING_ELT(names, 0, mkChar("minimum"));
326         SET_STRING_ELT(names, 1, mkChar("estimate"));
327         namesgets(optimizer, names);
328
329         REAL(code)[0] = globalState->inform;
330         REAL(iterations)[0] = globalState->iter;
331         REAL(evaluations)[0] = globalState->computeCount;
332
333         /* Fill Status code. */
334         SET_VECTOR_ELT(status, 0, code);
335         PROTECT(code = NEW_NUMERIC(1));
336         REAL(code)[0] = globalState->statusCode;
337         SET_VECTOR_ELT(status, 1, code);
338         PROTECT(statusMsg = allocVector(STRSXP, 1));
339         SET_STRING_ELT(statusMsg, 0, mkChar(globalState->statusMsg));
340         SET_VECTOR_ELT(status, 2, statusMsg);
341
342         if(numHessians && globalState->fitMatrix != NULL && globalState->optimumStatus >= 0) {          // No hessians or standard errors if the optimum is invalid
343                 if(globalState->numConstraints == 0) {
344                         if(OMX_DEBUG) { Rprintf("Calculating Hessian for Fit Function.\n");}
345                         int gotHessians = omxEstimateHessian(numHessians, .0001, 4, globalState);
346                         if(gotHessians) {
347                                 if(calculateStdErrors) {
348                                         for(int j = 0; j < numHessians; j++) {          //TODO: Fix Hessian calculation to allow more if requested
349                                                 if(OMX_DEBUG) { Rprintf("Calculating Standard Errors for Fit Function.\n");}
350                                                 omxFitFunction* oo = globalState->fitMatrix->fitFunction;
351                                                 if(oo->getStandardErrorFun != NULL) {
352                                                         oo->getStandardErrorFun(oo);
353                                                 } else {
354                                                         omxCalculateStdErrorFromHessian(2.0, oo);
355                                                 }
356                                         }
357                                 }
358                         } else {
359                                 numHessians = 0;
360                         }
361                 } else {
362                         numHessians = 0;
363                 }
364         } else {
365                 numHessians = 0;
366         }
367
368         /* Likelihood-based Confidence Interval Calculation */
369         if(globalState->numIntervals) {
370                 omxNPSOLConfidenceIntervals(REAL(minimum), REAL(estimate), REAL(gradient), REAL(hessian), ciMaxIterations);
371         }  
372
373         handleFreeVarList(globalState, globalState->optimalValues, n);  // Restore to optima for final compute
374         if(!errOut) omxFinalAlgebraCalculation(globalState, matrices, algebras, expectations); 
375
376         omxPopulateFitFunction(globalState, numReturns, &ans, &names);
377
378         if(numHessians) {
379                 omxPopulateHessians(numHessians, globalState->fitMatrix, 
380                         calculatedHessian, stdErrors, calculateStdErrors, n);
381         }
382
383         if(globalState->numIntervals) { // Populate CIs
384                 omxPopulateConfidenceIntervals(globalState, intervals, intervalCodes);
385         }
386         
387         REAL(evaluations)[1] = globalState->computeCount;
388
389         int nextEl = 0;
390
391         SET_STRING_ELT(names, nextEl++, mkChar("minimum"));
392         SET_STRING_ELT(names, nextEl++, mkChar("estimate"));
393         SET_STRING_ELT(names, nextEl++, mkChar("gradient"));
394         SET_STRING_ELT(names, nextEl++, mkChar("hessianCholesky"));
395         SET_STRING_ELT(names, nextEl++, mkChar("status"));
396         SET_STRING_ELT(names, nextEl++, mkChar("iterations"));
397         SET_STRING_ELT(names, nextEl++, mkChar("evaluations"));
398         SET_STRING_ELT(names, nextEl++, mkChar("matrices"));
399         SET_STRING_ELT(names, nextEl++, mkChar("algebras"));
400         SET_STRING_ELT(names, nextEl++, mkChar("expectations"));
401         SET_STRING_ELT(names, nextEl++, mkChar("confidenceIntervals"));
402         SET_STRING_ELT(names, nextEl++, mkChar("confidenceIntervalCodes"));
403         SET_STRING_ELT(names, nextEl++, mkChar("calculatedHessian"));
404         SET_STRING_ELT(names, nextEl++, mkChar("standardErrors"));
405
406         nextEl = 0;
407
408         SET_VECTOR_ELT(ans, nextEl++, minimum);
409         SET_VECTOR_ELT(ans, nextEl++, estimate);
410         SET_VECTOR_ELT(ans, nextEl++, gradient);
411         SET_VECTOR_ELT(ans, nextEl++, hessian);
412         SET_VECTOR_ELT(ans, nextEl++, status);
413         SET_VECTOR_ELT(ans, nextEl++, iterations);
414         SET_VECTOR_ELT(ans, nextEl++, evaluations);
415         SET_VECTOR_ELT(ans, nextEl++, matrices);
416         SET_VECTOR_ELT(ans, nextEl++, algebras);
417         SET_VECTOR_ELT(ans, nextEl++, expectations);
418         SET_VECTOR_ELT(ans, nextEl++, intervals);
419         SET_VECTOR_ELT(ans, nextEl++, intervalCodes);
420         if(numHessians == 0) {
421                 SET_VECTOR_ELT(ans, nextEl++, NAmat);
422         } else {
423                 SET_VECTOR_ELT(ans, nextEl++, calculatedHessian);
424         }
425         if(!calculateStdErrors) {
426                 SET_VECTOR_ELT(ans, nextEl++, NAmat);
427         } else {
428                 SET_VECTOR_ELT(ans, nextEl++, stdErrors);
429         }
430         namesgets(ans, names);
431
432         if(OMX_VERBOSE) {
433                 Rprintf("Inform Value: %d\n", globalState->optimumStatus);
434                 Rprintf("--------------------------\n");
435         }
436
437         /* Free data memory */
438         omxFreeState(globalState);
439
440         if(OMX_DEBUG) {Rprintf("All vectors freed.\n");}
441
442         return(ans);
443
444 }
445
446 SEXP omxBackend(SEXP fitfunction, SEXP startVals, SEXP constraints,
447         SEXP matList, SEXP varList, SEXP algList, SEXP expectList,
448         SEXP data, SEXP intervalList, SEXP checkpointList, SEXP options, SEXP state)
449 {
450         try {
451                 return omxBackend2(fitfunction, startVals, constraints,
452                                    matList, varList, algList, expectList,
453                                    data, intervalList, checkpointList, options, state);
454         } catch( std::exception& __ex__ ) {
455                 exception_to_try_error( __ex__ );
456         } catch(...) {
457                 string_to_try_error( "c++ exception (unknown reason)" );
458         }
459 }
460