Ease output data assembly
[openmx:openmx.git] / src / npsolWrap.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include <stdio.h>
18 #include <sys/types.h>
19 #include <errno.h>
20
21 #include <R.h>
22 #include <Rinternals.h>
23 #include <Rdefines.h>
24 #include <R_ext/Rdynload.h>
25 #include <R_ext/BLAS.h>
26 #include <R_ext/Lapack.h>
27
28 #include "omxDefines.h"
29 #include "types.h"
30 #include "npsolWrap.h"
31 #include "omxOpenmpWrap.h"
32 #include "omxState.h"
33 #include "omxGlobalState.h"
34 #include "omxMatrix.h"
35 #include "omxAlgebra.h"
36 #include "omxFitFunction.h"
37 #include "omxExpectation.h"
38 #include "omxNPSOLSpecific.h"
39 #include "omxImportFrontendState.h"
40 #include "omxExportBackendState.h"
41 #include "omxHessianCalculation.h"
42 #include "omxOptimizer.h"
43
44 omp_lock_t GlobalRLock;
45
46 static R_CallMethodDef callMethods[] = {
47         {"omxBackend", (DL_FUNC) omxBackend, 12},
48         {"omxCallAlgebra", (DL_FUNC) omxCallAlgebra, 3},
49         {"findIdenticalRowsData", (DL_FUNC) findIdenticalRowsData, 5},
50         {NULL, NULL, 0}
51 };
52
53 #ifdef  __cplusplus
54 extern "C" {
55 #endif
56
57 void R_init_OpenMx(DllInfo *info) {
58         R_registerRoutines(info, NULL, callMethods, NULL, NULL);
59
60         omx_omp_init_lock(&GlobalRLock);
61
62         // There is no code that will change behavior whether openmp
63         // is set for nested or not. I'm just keeping this in case it
64         // makes a difference with older versions of openmp. 2012-12-24 JNP
65 #if defined(_OPENMP) && _OPENMP <= 200505
66         omp_set_nested(0);
67 #endif
68 }
69
70 void R_unload_OpenMx(DllInfo *info) {
71         omx_omp_destroy_lock(&GlobalRLock);
72 }
73
74 #ifdef  __cplusplus
75 }
76 #endif
77
78 void string_to_try_error( const std::string& str )
79 {
80         error("%s", str.c_str());
81 }
82
83 void exception_to_try_error( const std::exception& ex )
84 {
85         string_to_try_error(ex.what());
86 }
87
88 SEXP asR(MxRList *out)
89 {
90        SEXP names, ans;
91        int len = out->size();
92        PROTECT(names = allocVector(STRSXP, len));
93        PROTECT(ans = allocVector(VECSXP, len));
94        for (int lx=0; lx < len; ++lx) {
95                SET_STRING_ELT(names, lx, (*out)[lx].first);
96                SET_VECTOR_ELT(ans,   lx, (*out)[lx].second);
97        }
98        namesgets(ans, names);
99        return ans;
100 }
101
102 /* Main functions */
103 SEXP omxCallAlgebra2(SEXP matList, SEXP algNum, SEXP options) {
104
105         omxManageProtectInsanity protectManager;
106
107         if(OMX_DEBUG) { Rprintf("-----------------------------------------------------------------------\n");}
108         if(OMX_DEBUG) { Rprintf("Explicit call to algebra %d.\n", INTEGER(algNum));}
109
110         int j,k,l;
111         omxMatrix* algebra;
112         int algebraNum = INTEGER(algNum)[0];
113         SEXP ans, nextMat;
114         char output[250];
115         int errOut = 0;
116
117         /* Create new omxState for current state storage and initialize it. */
118         
119         globalState = new omxState;
120         omxInitState(globalState, NULL);
121         if(OMX_DEBUG) { Rprintf("Created state object at 0x%x.\n", globalState);}
122
123         /* Retrieve All Matrices From the MatList */
124
125         if(OMX_DEBUG) { Rprintf("Processing %d matrix(ces).\n", length(matList));}
126
127         omxMatrix *args[length(matList)];
128         for(k = 0; k < length(matList); k++) {
129                 PROTECT(nextMat = VECTOR_ELT(matList, k));      // This is the matrix + populations
130                 args[k] = omxNewMatrixFromRPrimitive(nextMat, globalState, 1, - k - 1);
131                 globalState->matrixList.push_back(args[k]);
132                 if(OMX_DEBUG) {
133                         Rprintf("Matrix initialized at 0x%0xd = (%d x %d).\n",
134                                 globalState->matrixList[k], globalState->matrixList[k]->rows, globalState->matrixList[k]->cols);
135                 }
136         }
137
138         algebra = omxNewAlgebraFromOperatorAndArgs(algebraNum, args, length(matList), globalState);
139
140         if(algebra==NULL) {
141                 error(globalState->statusMsg);
142         }
143
144         if(OMX_DEBUG) {Rprintf("Completed Algebras and Matrices.  Beginning Initial Compute.\n");}
145         omxStateNextEvaluation(globalState);
146
147         omxRecompute(algebra);
148
149         PROTECT(ans = allocMatrix(REALSXP, algebra->rows, algebra->cols));
150         for(l = 0; l < algebra->rows; l++)
151                 for(j = 0; j < algebra->cols; j++)
152                         REAL(ans)[j * algebra->rows + l] =
153                                 omxMatrixElement(algebra, l, j);
154
155         if(OMX_DEBUG) { Rprintf("All Algebras complete.\n"); }
156
157         if(globalState->statusCode != 0) {
158                 errOut = globalState->statusCode;
159                 strncpy(output, globalState->statusMsg, 250);
160         }
161
162         omxFreeAllMatrixData(algebra);
163         omxFreeState(globalState);
164
165         if(errOut != 0) {
166                 error(output);
167         }
168
169         return ans;
170 }
171
172 SEXP omxCallAlgebra(SEXP matList, SEXP algNum, SEXP options)
173 {
174         try {
175                 return omxCallAlgebra2(matList, algNum, options);
176         } catch( std::exception& __ex__ ) {
177                 exception_to_try_error( __ex__ );
178         } catch(...) {
179                 string_to_try_error( "c++ exception (unknown reason)" );
180         }
181 }
182
183 SEXP omxBackend2(SEXP fitfunction, SEXP startVals, SEXP constraints,
184         SEXP matList, SEXP varList, SEXP algList, SEXP expectList,
185         SEXP data, SEXP intervalList, SEXP checkpointList, SEXP options, SEXP state) {
186
187         /* Helpful variables */
188
189         SEXP nextLoc;
190
191         int calculateStdErrors = FALSE;
192         int numHessians = 0;
193         int ciMaxIterations = 5;
194         int disableOptimizer = 0;
195         int numThreads = 1;
196         int analyticGradients = 0;
197
198         /* Sanity Check and Parse Inputs */
199         /* TODO: Need to find a way to account for nullness in these.  For now, all checking is done on the front-end. */
200 //      if(!isVector(startVals)) error ("startVals must be a vector");
201 //      if(!isVector(matList)) error ("matList must be a list");
202 //      if(!isVector(algList)) error ("algList must be a list");
203
204         omxManageProtectInsanity protectManager;
205
206         /*      Set NPSOL options */
207         omxSetNPSOLOpts(options, &numHessians, &calculateStdErrors, 
208                 &ciMaxIterations, &disableOptimizer, &numThreads, 
209                 &analyticGradients, length(startVals));
210
211         /* Create new omxState for current state storage and initialize it. */
212         globalState = new omxState;
213         omxInitState(globalState, NULL);
214         globalState->numThreads = numThreads;
215         globalState->numFreeParams = length(startVals);
216         globalState->analyticGradients = analyticGradients;
217         if(OMX_DEBUG) { Rprintf("Created state object at 0x%x.\n", globalState);}
218
219         /* Retrieve Data Objects */
220         omxProcessMxDataEntities(data);
221         if (globalState->statusMsg[0]) error(globalState->statusMsg);
222     
223         /* Retrieve All Matrices From the MatList */
224         omxProcessMxMatrixEntities(matList);
225         if (globalState->statusMsg[0]) error(globalState->statusMsg);
226
227         globalState->numAlgs = length(algList);
228         
229         if (length(startVals) != length(varList)) error("varList and startVals must be the same length");
230
231         /* Process Free Var List */
232         omxProcessFreeVarList(varList);
233         if (globalState->statusMsg[0]) error(globalState->statusMsg);
234
235         omxProcessMxExpectationEntities(expectList);
236         if (globalState->statusMsg[0]) error(globalState->statusMsg);
237
238         omxProcessMxAlgebraEntities(algList);
239         if (globalState->statusMsg[0]) error(globalState->statusMsg);
240
241         omxCompleteMxExpectationEntities();
242         if (globalState->statusMsg[0]) error(globalState->statusMsg);
243
244         omxProcessMxFitFunction(algList);
245         if (globalState->statusMsg[0]) error(globalState->statusMsg);
246
247         // This is the chance to check for matrix
248         // conformability, etc.  Any errors encountered should
249         // be reported using R's error() function, not
250         // omxRaiseErrorf.
251
252         omxInitialMatrixAlgebraCompute();
253         omxResetStatus(globalState);
254
255         if(!isNull(fitfunction)) {
256                 if(OMX_DEBUG) { Rprintf("Processing fit function.\n"); }
257                 globalState->fitMatrix = omxMatrixLookupFromState1(fitfunction, globalState);
258         }
259         if (globalState->statusMsg[0]) error(globalState->statusMsg);
260         
261         // TODO: Make calculateHessians an option instead.
262
263         /* Process Matrix and Algebra Population Function */
264         /*
265           Each matrix is a list containing a matrix and the other matrices/algebras that are
266           populated into it at each iteration.  The first element is already processed, above.
267           The rest of the list will be processed here.
268         */
269         for(int j = 0; j < length(matList); j++) {
270                 PROTECT(nextLoc = VECTOR_ELT(matList, j));              // This is the matrix + populations
271                 omxProcessMatrixPopulationList(globalState->matrixList[j], nextLoc);
272         }
273
274         /* Processing Constraints */
275         omxProcessConstraints(constraints);
276
277         /* Process Confidence Interval List */
278         omxProcessConfidenceIntervals(intervalList);
279
280         /* Process Checkpoint List */
281         omxProcessCheckpointOptions(checkpointList);
282
283         // Probably, this is always the same for all children and
284         // doesn't need to be copied to child states.
285         cacheFreeVarDependencies(globalState);
286
287         omxFitFunctionCreateChildren(globalState, numThreads);
288
289         int n = globalState->numFreeParams;
290
291         SEXP minimum, estimate, gradient, hessian;
292         PROTECT(minimum = NEW_NUMERIC(1));
293         PROTECT(estimate = allocVector(REALSXP, n));
294         PROTECT(gradient = allocVector(REALSXP, n));
295         PROTECT(hessian = allocMatrix(REALSXP, n, n));
296
297         if (n>0) { memcpy(REAL(estimate), REAL(startVals), sizeof(double)*n); }
298         
299         omxInvokeNPSOL(globalState->fitMatrix, REAL(minimum), REAL(estimate),
300                        REAL(gradient), REAL(hessian), disableOptimizer);
301
302         SEXP code, status, statusMsg, iterations;
303         SEXP evaluations, algebras, matrices, expectations;
304         SEXP intervals, intervalCodes, calculatedHessian, stdErrors;
305
306         PROTECT(code = NEW_NUMERIC(1));
307         PROTECT(status = allocVector(VECSXP, 3));
308         PROTECT(iterations = NEW_NUMERIC(1));
309         PROTECT(evaluations = NEW_NUMERIC(2));
310         PROTECT(matrices = NEW_LIST(globalState->matrixList.size()));
311         PROTECT(algebras = NEW_LIST(globalState->numAlgs));
312         PROTECT(expectations = NEW_LIST(globalState->numExpects));
313
314         PROTECT(calculatedHessian = allocMatrix(REALSXP, n, n));
315         PROTECT(stdErrors = allocMatrix(REALSXP, n, 1)); // for optimizer
316         PROTECT(intervals = allocMatrix(REALSXP, globalState->numIntervals, 2)); // for optimizer
317         PROTECT(intervalCodes = allocMatrix(INTSXP, globalState->numIntervals, 2)); // for optimizer
318
319         omxSaveState(globalState, REAL(estimate), REAL(minimum)[0]);
320
321         REAL(code)[0] = globalState->inform;
322         REAL(iterations)[0] = globalState->iter;
323         REAL(evaluations)[0] = globalState->computeCount;
324
325         /* Fill Status code. */
326         SET_VECTOR_ELT(status, 0, code);
327         PROTECT(code = NEW_NUMERIC(1));
328         REAL(code)[0] = globalState->statusCode;
329         SET_VECTOR_ELT(status, 1, code);
330         PROTECT(statusMsg = allocVector(STRSXP, 1));
331         SET_STRING_ELT(statusMsg, 0, mkChar(globalState->statusMsg));
332         SET_VECTOR_ELT(status, 2, statusMsg);
333
334         if(numHessians && globalState->fitMatrix != NULL && globalState->optimumStatus >= 0) {          // No hessians or standard errors if the optimum is invalid
335                 if(globalState->numConstraints == 0) {
336                         if(OMX_DEBUG) { Rprintf("Calculating Hessian for Fit Function.\n");}
337                         int gotHessians = omxEstimateHessian(numHessians, .0001, 4, globalState);
338                         if(gotHessians) {
339                                 if(calculateStdErrors) {
340                                         for(int j = 0; j < numHessians; j++) {          //TODO: Fix Hessian calculation to allow more if requested
341                                                 if(OMX_DEBUG) { Rprintf("Calculating Standard Errors for Fit Function.\n");}
342                                                 omxFitFunction* oo = globalState->fitMatrix->fitFunction;
343                                                 omxCalculateStdErrorFromHessian(2.0, oo);
344                                         }
345                                 }
346                         } else {
347                                 numHessians = 0;
348                         }
349                 } else {
350                         numHessians = 0;
351                 }
352         } else {
353                 numHessians = 0;
354         }
355
356         /* Likelihood-based Confidence Interval Calculation */
357         if(globalState->numIntervals) {
358                 omxNPSOLConfidenceIntervals(globalState->fitMatrix, REAL(minimum), REAL(estimate),
359                                             REAL(gradient), REAL(hessian), ciMaxIterations);
360         }  
361
362         // What if fitfunction has its own repopulateFun? TODO
363         handleFreeVarListHelper(globalState, globalState->optimalValues, n);
364
365         omxFinalAlgebraCalculation(globalState, matrices, algebras, expectations); 
366
367         MxRList result;
368
369         omxPopulateFitFunction(globalState, &result);
370
371         if(numHessians) {
372                 omxPopulateHessians(numHessians, globalState->fitMatrix, 
373                         calculatedHessian, stdErrors, calculateStdErrors, n);
374         }
375
376         if(globalState->numIntervals) { // Populate CIs
377                 omxPopulateConfidenceIntervals(globalState, intervals, intervalCodes);
378         }
379         
380         REAL(evaluations)[1] = globalState->computeCount;
381
382         result.push_back(std::make_pair(mkChar("minimum"), minimum));
383         result.push_back(std::make_pair(mkChar("estimate"), estimate));
384         result.push_back(std::make_pair(mkChar("gradient"), gradient));
385         result.push_back(std::make_pair(mkChar("hessianCholesky"), hessian));
386         result.push_back(std::make_pair(mkChar("status"), status));
387         result.push_back(std::make_pair(mkChar("iterations"), iterations));
388         result.push_back(std::make_pair(mkChar("evaluations"), evaluations));
389         result.push_back(std::make_pair(mkChar("matrices"), matrices));
390         result.push_back(std::make_pair(mkChar("algebras"), algebras));
391         result.push_back(std::make_pair(mkChar("expectations"), expectations));
392         result.push_back(std::make_pair(mkChar("confidenceIntervals"), intervals));
393         result.push_back(std::make_pair(mkChar("confidenceIntervalCodes"), intervalCodes));
394
395         if (numHessians != 0) {
396                 result.push_back(std::make_pair(mkChar("calculatedHessian"), calculatedHessian));
397         }
398         if (calculateStdErrors) {
399                 result.push_back(std::make_pair(mkChar("standardErrors"), stdErrors));
400         }
401
402         /* Free data memory */
403         omxFreeState(globalState);
404
405         return asR(&result);
406
407 }
408
409 SEXP omxBackend(SEXP fitfunction, SEXP startVals, SEXP constraints,
410         SEXP matList, SEXP varList, SEXP algList, SEXP expectList,
411         SEXP data, SEXP intervalList, SEXP checkpointList, SEXP options, SEXP state)
412 {
413         try {
414                 return omxBackend2(fitfunction, startVals, constraints,
415                                    matList, varList, algList, expectList,
416                                    data, intervalList, checkpointList, options, state);
417         } catch( std::exception& __ex__ ) {
418                 exception_to_try_error( __ex__ );
419         } catch(...) {
420                 string_to_try_error( "c++ exception (unknown reason)" );
421         }
422 }
423