Refrain from copying markMatrices to omxState children
[openmx:openmx.git] / src / npsolWrap.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include <stdio.h>
18 #include <sys/types.h>
19 #include <errno.h>
20
21 #include <R.h>
22 #include <Rinternals.h>
23 #include <Rdefines.h>
24 #include <R_ext/Rdynload.h>
25 #include <R_ext/BLAS.h>
26 #include <R_ext/Lapack.h>
27
28 #include "omxDefines.h"
29 #include "types.h"
30 #include "npsolWrap.h"
31 #include "omxOpenmpWrap.h"
32 #include "omxState.h"
33 #include "omxGlobalState.h"
34 #include "omxMatrix.h"
35 #include "omxAlgebra.h"
36 #include "omxFitFunction.h"
37 #include "omxExpectation.h"
38 #include "omxNPSOLSpecific.h"
39 #include "omxImportFrontendState.h"
40 #include "omxExportBackendState.h"
41 #include "omxOptimizer.h"
42 #include "omxHessianCalculation.h"
43 #include "Compute.h"
44
45 omp_lock_t GlobalRLock;
46
47 static R_CallMethodDef callMethods[] = {
48         {"omxBackend", (DL_FUNC) omxBackend, 12},
49         {"omxCallAlgebra", (DL_FUNC) omxCallAlgebra, 3},
50         {"findIdenticalRowsData", (DL_FUNC) findIdenticalRowsData, 5},
51         {NULL, NULL, 0}
52 };
53
54 #ifdef  __cplusplus
55 extern "C" {
56 #endif
57
58 void R_init_OpenMx(DllInfo *info) {
59         R_registerRoutines(info, NULL, callMethods, NULL, NULL);
60
61         omx_omp_init_lock(&GlobalRLock);
62
63         // There is no code that will change behavior whether openmp
64         // is set for nested or not. I'm just keeping this in case it
65         // makes a difference with older versions of openmp. 2012-12-24 JNP
66 #if defined(_OPENMP) && _OPENMP <= 200505
67         omp_set_nested(0);
68 #endif
69 }
70
71 void R_unload_OpenMx(DllInfo *info) {
72         omx_omp_destroy_lock(&GlobalRLock);
73 }
74
75 #ifdef  __cplusplus
76 }
77 #endif
78
79 void string_to_try_error( const std::string& str )
80 {
81         error("%s", str.c_str());
82 }
83
84 void exception_to_try_error( const std::exception& ex )
85 {
86         string_to_try_error(ex.what());
87 }
88
89 SEXP asR(MxRList *out)
90 {
91        SEXP names, ans;
92        int len = out->size();
93        PROTECT(names = allocVector(STRSXP, len));
94        PROTECT(ans = allocVector(VECSXP, len));
95        for (int lx=0; lx < len; ++lx) {
96                SET_STRING_ELT(names, lx, (*out)[lx].first);
97                SET_VECTOR_ELT(ans,   lx, (*out)[lx].second);
98        }
99        namesgets(ans, names);
100        return ans;
101 }
102
103 /* Main functions */
104 SEXP omxCallAlgebra2(SEXP matList, SEXP algNum, SEXP options) {
105
106         omxManageProtectInsanity protectManager;
107
108         if(OMX_DEBUG) { Rprintf("-----------------------------------------------------------------------\n");}
109         if(OMX_DEBUG) { Rprintf("Explicit call to algebra %d.\n", INTEGER(algNum));}
110
111         int j,k,l;
112         omxMatrix* algebra;
113         int algebraNum = INTEGER(algNum)[0];
114         SEXP ans, nextMat;
115         char output[MAX_STRING_LEN];
116
117         /* Create new omxState for current state storage and initialize it. */
118         
119         globalState = new omxState;
120         omxInitState(globalState, NULL);
121         if(OMX_DEBUG) { Rprintf("Created state object at 0x%x.\n", globalState);}
122
123         /* Retrieve All Matrices From the MatList */
124
125         if(OMX_DEBUG) { Rprintf("Processing %d matrix(ces).\n", length(matList));}
126
127         omxMatrix *args[length(matList)];
128         for(k = 0; k < length(matList); k++) {
129                 PROTECT(nextMat = VECTOR_ELT(matList, k));      // This is the matrix + populations
130                 args[k] = omxNewMatrixFromRPrimitive(nextMat, globalState, 1, - k - 1);
131                 globalState->matrixList.push_back(args[k]);
132                 if(OMX_DEBUG) {
133                         Rprintf("Matrix initialized at 0x%0xd = (%d x %d).\n",
134                                 globalState->matrixList[k], globalState->matrixList[k]->rows, globalState->matrixList[k]->cols);
135                 }
136         }
137
138         algebra = omxNewAlgebraFromOperatorAndArgs(algebraNum, args, length(matList), globalState);
139
140         if(algebra==NULL) {
141                 error(globalState->statusMsg);
142         }
143
144         if(OMX_DEBUG) {Rprintf("Completed Algebras and Matrices.  Beginning Initial Compute.\n");}
145         omxStateNextEvaluation(globalState);
146
147         omxRecompute(algebra);
148
149         PROTECT(ans = allocMatrix(REALSXP, algebra->rows, algebra->cols));
150         for(l = 0; l < algebra->rows; l++)
151                 for(j = 0; j < algebra->cols; j++)
152                         REAL(ans)[j * algebra->rows + l] =
153                                 omxMatrixElement(algebra, l, j);
154
155         if(OMX_DEBUG) { Rprintf("All Algebras complete.\n"); }
156
157         output[0] = 0;
158         if (isErrorRaised(globalState)) {
159                 strncpy(output, globalState->statusMsg, MAX_STRING_LEN);
160         }
161
162         omxFreeAllMatrixData(algebra);
163         omxFreeState(globalState);
164
165         if(output[0]) error(output);
166
167         return ans;
168 }
169
170 SEXP omxCallAlgebra(SEXP matList, SEXP algNum, SEXP options)
171 {
172         try {
173                 return omxCallAlgebra2(matList, algNum, options);
174         } catch( std::exception& __ex__ ) {
175                 exception_to_try_error( __ex__ );
176         } catch(...) {
177                 string_to_try_error( "c++ exception (unknown reason)" );
178         }
179 }
180
181 SEXP omxBackend2(SEXP fitfunction, SEXP startVals, SEXP constraints,
182         SEXP matList, SEXP varList, SEXP algList, SEXP expectList,
183         SEXP data, SEXP intervalList, SEXP checkpointList, SEXP options, SEXP state) {
184
185         /* Helpful variables */
186
187         SEXP nextLoc;
188
189         int disableOptimizer = 0;
190         int analyticGradients = 0;
191
192         /* Sanity Check and Parse Inputs */
193         /* TODO: Need to find a way to account for nullness in these.  For now, all checking is done on the front-end. */
194 //      if(!isVector(startVals)) error ("startVals must be a vector");
195 //      if(!isVector(matList)) error ("matList must be a list");
196 //      if(!isVector(algList)) error ("algList must be a list");
197
198         omxManageProtectInsanity protectManager;
199
200         /* Create new omxState for current state storage and initialize it. */
201         globalState = new omxState;
202         omxInitState(globalState, NULL);
203
204         /*      Set NPSOL options */
205         omxSetNPSOLOpts(options, &globalState->numHessians, &globalState->calculateStdErrors, 
206                 &globalState->ciMaxIterations, &disableOptimizer, &globalState->numThreads, 
207                 &analyticGradients, length(startVals));
208
209         globalState->numFreeParams = length(startVals);
210         globalState->analyticGradients = analyticGradients;
211         if(OMX_DEBUG) { Rprintf("Created state object at 0x%x.\n", globalState);}
212
213         /* Retrieve Data Objects */
214         omxProcessMxDataEntities(data);
215         if (globalState->statusMsg[0]) error(globalState->statusMsg);
216     
217         /* Retrieve All Matrices From the MatList */
218         omxProcessMxMatrixEntities(matList);
219         if (globalState->statusMsg[0]) error(globalState->statusMsg);
220
221         if (length(startVals) != length(varList)) error("varList and startVals must be the same length");
222
223         /* Process Free Var List */
224         omxProcessFreeVarList(varList);
225         if (globalState->statusMsg[0]) error(globalState->statusMsg);
226
227         omxProcessMxExpectationEntities(expectList);
228         if (globalState->statusMsg[0]) error(globalState->statusMsg);
229
230         omxProcessMxAlgebraEntities(algList);
231         if (globalState->statusMsg[0]) error(globalState->statusMsg);
232
233         omxCompleteMxExpectationEntities();
234         if (globalState->statusMsg[0]) error(globalState->statusMsg);
235
236         omxProcessMxFitFunction(algList);
237         if (globalState->statusMsg[0]) error(globalState->statusMsg);
238
239         // This is the chance to check for matrix
240         // conformability, etc.  Any errors encountered should
241         // be reported using R's error() function, not
242         // omxRaiseErrorf.
243
244         omxInitialMatrixAlgebraCompute();
245         omxResetStatus(globalState);
246
247         // maybe require a Compute object? TODO
248         omxComputeGD *topCompute = NULL;
249         omxMatrix *fitMatrix = NULL;
250         if(!isNull(fitfunction)) {
251                 if(OMX_DEBUG) { Rprintf("Processing fit function.\n"); }
252                 fitMatrix = omxMatrixLookupFromState1(fitfunction, globalState);
253                 topCompute = new omxComputeGD(fitMatrix);
254         }
255         if (globalState->statusMsg[0]) error(globalState->statusMsg);
256         
257         /* Process Matrix and Algebra Population Function */
258         /*
259           Each matrix is a list containing a matrix and the other matrices/algebras that are
260           populated into it at each iteration.  The first element is already processed, above.
261           The rest of the list will be processed here.
262         */
263         for(int j = 0; j < length(matList); j++) {
264                 PROTECT(nextLoc = VECTOR_ELT(matList, j));              // This is the matrix + populations
265                 omxProcessMatrixPopulationList(globalState->matrixList[j], nextLoc);
266         }
267
268         /* Processing Constraints */
269         omxProcessConstraints(constraints);
270
271         /* Process Confidence Interval List */
272         omxProcessConfidenceIntervals(intervalList);
273
274         /* Process Checkpoint List */
275         omxProcessCheckpointOptions(checkpointList);
276
277         cacheFreeVarDependencies();
278
279         int n = globalState->numFreeParams;
280
281         if (topCompute) topCompute->setStartValues(startVals);
282         
283         if (topCompute) topCompute->compute(disableOptimizer);
284
285         SEXP evaluations, algebras, matrices, expectations;
286
287         PROTECT(evaluations = NEW_NUMERIC(2));
288         PROTECT(matrices = NEW_LIST(globalState->matrixList.size()));
289         PROTECT(algebras = NEW_LIST(globalState->algebraList.size()));
290         PROTECT(expectations = NEW_LIST(globalState->expectationList.size()));
291
292         REAL(evaluations)[0] = globalState->computeCount;
293
294         MxRList result;
295
296         if (!isErrorRaised(globalState) && globalState->numHessians && fitMatrix != NULL &&
297             globalState->numConstraints == 0) {
298                 omxComputeEstimateHessian *eh =
299                         new omxComputeEstimateHessian(fitMatrix, topCompute->getEstimate());
300                 eh->compute(FALSE);
301                 eh->reportResults(&result);
302                 delete eh;
303         }
304
305         // What if fitfunction has its own repopulateFun? TODO
306         if (topCompute) handleFreeVarListHelper(globalState, topCompute->getEstimate(), n);
307
308         omxFinalAlgebraCalculation(globalState, matrices, algebras, expectations); 
309
310         REAL(evaluations)[1] = globalState->computeCount;
311
312         double optStatus = NA_REAL;
313         if (topCompute) {
314                 topCompute->reportResults(&result);
315                 optStatus = topCompute->getOptimizerStatus();
316                 delete topCompute;
317         }
318
319         MxRList backwardCompatStatus;
320         backwardCompatStatus.push_back(std::make_pair(mkChar("code"), ScalarReal(optStatus)));
321         backwardCompatStatus.push_back(std::make_pair(mkChar("status"),
322                                                       ScalarInteger(-isErrorRaised(globalState))));
323
324         if (isErrorRaised(globalState)) {
325                 SEXP msg;
326                 PROTECT(msg = allocVector(STRSXP, 1));
327                 SET_STRING_ELT(msg, 0, mkChar(globalState->statusMsg));
328                 result.push_back(std::make_pair(mkChar("error"), msg));
329                 backwardCompatStatus.push_back(std::make_pair(mkChar("statusMsg"), msg));
330         }
331
332         result.push_back(std::make_pair(mkChar("status"), asR(&backwardCompatStatus)));
333         result.push_back(std::make_pair(mkChar("evaluations"), evaluations));
334         result.push_back(std::make_pair(mkChar("matrices"), matrices));
335         result.push_back(std::make_pair(mkChar("algebras"), algebras));
336         result.push_back(std::make_pair(mkChar("expectations"), expectations));
337
338         /* Free data memory */
339         omxFreeState(globalState);
340
341         return asR(&result);
342
343 }
344
345 SEXP omxBackend(SEXP fitfunction, SEXP startVals, SEXP constraints,
346         SEXP matList, SEXP varList, SEXP algList, SEXP expectList,
347         SEXP data, SEXP intervalList, SEXP checkpointList, SEXP options, SEXP state)
348 {
349         try {
350                 return omxBackend2(fitfunction, startVals, constraints,
351                                    matList, varList, algList, expectList,
352                                    data, intervalList, checkpointList, options, state);
353         } catch( std::exception& __ex__ ) {
354                 exception_to_try_error( __ex__ );
355         } catch(...) {
356                 string_to_try_error( "c++ exception (unknown reason)" );
357         }
358 }
359