Add ComputeAssign
[openmx:openmx.git] / src / npsolWrap.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include <stdio.h>
18 #include <sys/types.h>
19 #include <errno.h>
20
21 #include <R.h>
22 #include <Rinternals.h>
23 #include <Rdefines.h>
24 #include <R_ext/Rdynload.h>
25 #include <R_ext/BLAS.h>
26 #include <R_ext/Lapack.h>
27
28 #include "omxDefines.h"
29 #include "types.h"
30 #include "npsolWrap.h"
31 #include "omxOpenmpWrap.h"
32 #include "omxState.h"
33 #include "omxMatrix.h"
34 #include "omxAlgebra.h"
35 #include "omxFitFunction.h"
36 #include "omxExpectation.h"
37 #include "omxNPSOLSpecific.h"
38 #include "omxImportFrontendState.h"
39 #include "omxExportBackendState.h"
40 #include "Compute.h"
41 #include "dmvnorm.h"
42
43 static R_CallMethodDef callMethods[] = {
44         {"omxBackend", (DL_FUNC) omxBackend, 12},
45         {"omxCallAlgebra", (DL_FUNC) omxCallAlgebra, 3},
46         {"findIdenticalRowsData", (DL_FUNC) findIdenticalRowsData, 5},
47         {"imxDmvnorm_wrapper", (DL_FUNC) dmvnorm_wrapper, 3},
48         {NULL, NULL, 0}
49 };
50
51 #ifdef  __cplusplus
52 extern "C" {
53 #endif
54
55 void R_init_OpenMx(DllInfo *info) {
56         R_registerRoutines(info, NULL, callMethods, NULL, NULL);
57
58         // There is no code that will change behavior whether openmp
59         // is set for nested or not. I'm just keeping this in case it
60         // makes a difference with older versions of openmp. 2012-12-24 JNP
61 #if defined(_OPENMP) && _OPENMP <= 200505
62         omp_set_nested(0);
63 #endif
64 }
65
66 void R_unload_OpenMx(DllInfo *info) {
67         // keep this stub in case we need it
68 }
69
70 #ifdef  __cplusplus
71 }
72 #endif
73
74 void string_to_try_error( const std::string& str )
75 {
76         error("%s", str.c_str());
77 }
78
79 void exception_to_try_error( const std::exception& ex )
80 {
81         string_to_try_error(ex.what());
82 }
83
84 SEXP asR(MxRList *out)
85 {
86         // change to a set to avoid duplicate keys TODO
87         SEXP names, ans;
88         int len = out->size();
89         PROTECT(names = allocVector(STRSXP, len));
90         PROTECT(ans = allocVector(VECSXP, len));
91         for (int lx=0; lx < len; ++lx) {
92                 SET_STRING_ELT(names, lx, (*out)[lx].first);
93                 SET_VECTOR_ELT(ans,   lx, (*out)[lx].second);
94         }
95         namesgets(ans, names);
96         return ans;
97 }
98
99 /* Main functions */
100 SEXP omxCallAlgebra2(SEXP matList, SEXP algNum, SEXP options) {
101
102         omxManageProtectInsanity protectManager;
103
104         if(OMX_DEBUG) { mxLog("-----------------------------------------------------------------------");}
105         if(OMX_DEBUG) { mxLog("Explicit call to algebra %d.", INTEGER(algNum));}
106
107         int j,k,l;
108         omxMatrix* algebra;
109         int algebraNum = INTEGER(algNum)[0];
110         SEXP ans, nextMat;
111         char output[MAX_STRING_LEN];
112
113         FitContext::setRFitFunction(NULL);
114         Global = new omxGlobal;
115
116         globalState = new omxState;
117         omxInitState(globalState);
118         if(OMX_DEBUG) { mxLog("Created state object at 0x%x.", globalState);}
119
120         /* Retrieve All Matrices From the MatList */
121
122         if(OMX_DEBUG) { mxLog("Processing %d matrix(ces).", length(matList));}
123
124         omxMatrix *args[length(matList)];
125         for(k = 0; k < length(matList); k++) {
126                 PROTECT(nextMat = VECTOR_ELT(matList, k));      // This is the matrix + populations
127                 args[k] = omxNewMatrixFromRPrimitive(nextMat, globalState, 1, - k - 1);
128                 globalState->matrixList.push_back(args[k]);
129                 if(OMX_DEBUG) {
130                         mxLog("Matrix initialized at 0x%0xd = (%d x %d).",
131                                 globalState->matrixList[k], globalState->matrixList[k]->rows, globalState->matrixList[k]->cols);
132                 }
133         }
134
135         algebra = omxNewAlgebraFromOperatorAndArgs(algebraNum, args, length(matList), globalState);
136
137         if(algebra==NULL) {
138                 error(globalState->statusMsg);
139         }
140
141         if(OMX_DEBUG) {mxLog("Completed Algebras and Matrices.  Beginning Initial Compute.");}
142         omxStateNextEvaluation(globalState);
143
144         omxRecompute(algebra);
145
146         PROTECT(ans = allocMatrix(REALSXP, algebra->rows, algebra->cols));
147         for(l = 0; l < algebra->rows; l++)
148                 for(j = 0; j < algebra->cols; j++)
149                         REAL(ans)[j * algebra->rows + l] =
150                                 omxMatrixElement(algebra, l, j);
151
152         if(OMX_DEBUG) { mxLog("All Algebras complete."); }
153
154         output[0] = 0;
155         if (isErrorRaised(globalState)) {
156                 strncpy(output, globalState->statusMsg, MAX_STRING_LEN);
157         }
158
159         omxFreeAllMatrixData(algebra);
160         omxFreeState(globalState);
161         delete Global;
162
163         if(output[0]) error(output);
164
165         return ans;
166 }
167
168 SEXP omxCallAlgebra(SEXP matList, SEXP algNum, SEXP options)
169 {
170         try {
171                 return omxCallAlgebra2(matList, algNum, options);
172         } catch( std::exception& __ex__ ) {
173                 exception_to_try_error( __ex__ );
174         } catch(...) {
175                 string_to_try_error( "c++ exception (unknown reason)" );
176         }
177 }
178
179 SEXP omxBackend2(SEXP computeIndex, SEXP constraints, SEXP matList, SEXP fgNames,
180                  SEXP varList, SEXP algList, SEXP expectList, SEXP computeList,
181                  SEXP data, SEXP intervalList, SEXP checkpointList, SEXP options)
182 {
183         SEXP nextLoc;
184
185         /* Sanity Check and Parse Inputs */
186         /* TODO: Need to find a way to account for nullness in these.  For now, all checking is done on the front-end. */
187 //      if(!isVector(matList)) error ("matList must be a list");
188 //      if(!isVector(algList)) error ("algList must be a list");
189
190         omxManageProtectInsanity protectManager;
191
192         FitContext::setRFitFunction(NULL);
193         Global = new omxGlobal;
194
195         /* Create new omxState for current state storage and initialize it. */
196         globalState = new omxState;
197         omxInitState(globalState);
198         if(OMX_DEBUG) { mxLog("Created state object at 0x%x.", globalState);}
199
200         Global->ciMaxIterations = 5;
201         Global->numThreads = 1;
202         Global->analyticGradients = 0;
203         Global->numChildren = 0;
204         omxSetNPSOLOpts(options, &Global->ciMaxIterations, &Global->numThreads, 
205                         &Global->analyticGradients);
206
207         omxProcessMxDataEntities(data);
208         if (isErrorRaised(globalState)) error(globalState->statusMsg);
209     
210         omxProcessMxMatrixEntities(matList);
211         if (isErrorRaised(globalState)) error(globalState->statusMsg);
212
213         omxProcessFreeVarList(fgNames, varList);
214         if (isErrorRaised(globalState)) error(globalState->statusMsg);
215
216         omxProcessMxExpectationEntities(expectList);
217         if (isErrorRaised(globalState)) error(globalState->statusMsg);
218
219         omxProcessMxAlgebraEntities(algList);
220         if (isErrorRaised(globalState)) error(globalState->statusMsg);
221
222         omxProcessMxFitFunction(algList);
223         if (isErrorRaised(globalState)) error(globalState->statusMsg);
224
225         omxProcessMxComputeEntities(computeList);
226         if (isErrorRaised(globalState)) error(globalState->statusMsg);
227
228         omxCompleteMxExpectationEntities();
229         if (isErrorRaised(globalState)) error(globalState->statusMsg);
230
231         omxCompleteMxFitFunction(algList);
232         if (isErrorRaised(globalState)) error(globalState->statusMsg);
233
234         // This is the chance to check for matrix
235         // conformability, etc.  Any errors encountered should
236         // be reported using R's error() function, not
237         // omxRaiseErrorf.
238
239         omxInitialMatrixAlgebraCompute();
240         omxResetStatus(globalState);
241
242         // maybe require a Compute object? TODO
243         omxCompute *topCompute = NULL;
244         if (!isNull(computeIndex)) {
245                 int ox = INTEGER(computeIndex)[0];
246                 topCompute = Global->computeList[ox];
247         }
248
249         /* Process Matrix and Algebra Population Function */
250         /*
251           Each matrix is a list containing a matrix and the other matrices/algebras that are
252           populated into it at each iteration.  The first element is already processed, above.
253           The rest of the list will be processed here.
254         */
255         for(int j = 0; j < length(matList); j++) {
256                 PROTECT(nextLoc = VECTOR_ELT(matList, j));              // This is the matrix + populations
257                 omxProcessMatrixPopulationList(globalState->matrixList[j], nextLoc);
258         }
259
260         omxProcessConstraints(constraints);
261
262         omxProcessConfidenceIntervals(intervalList);
263
264         omxProcessCheckpointOptions(checkpointList);
265
266         FitContext::cacheFreeVarDependencies();
267
268         FitContext fc;
269
270         if (topCompute && !isErrorRaised(globalState)) {
271                 // switch varGroup, if necessary TODO
272                 topCompute->compute(&fc);
273         }
274
275         SEXP evaluations;
276         PROTECT(evaluations = NEW_NUMERIC(2));
277
278         REAL(evaluations)[0] = globalState->computeCount;
279
280         MxRList result;
281
282         if (topCompute && !isErrorRaised(globalState)) {
283                 fc.copyParamToModel(globalState); // probably unnecessary to do this again? TODO
284         }
285
286         omxExportResults(globalState, &result); 
287
288         REAL(evaluations)[1] = globalState->computeCount;
289
290         double optStatus = NA_REAL;
291         if (topCompute && !isErrorRaised(globalState)) {
292                 topCompute->reportResults(&fc, &result);
293                 optStatus = topCompute->getOptimizerStatus();
294         }
295
296         MxRList backwardCompatStatus;
297         backwardCompatStatus.push_back(std::make_pair(mkChar("code"), ScalarReal(optStatus)));
298         backwardCompatStatus.push_back(std::make_pair(mkChar("status"),
299                                                       ScalarInteger(-isErrorRaised(globalState))));
300
301         if (isErrorRaised(globalState)) {
302                 SEXP msg;
303                 PROTECT(msg = allocVector(STRSXP, 1));
304                 SET_STRING_ELT(msg, 0, mkChar(globalState->statusMsg));
305                 result.push_back(std::make_pair(mkChar("error"), msg));
306                 backwardCompatStatus.push_back(std::make_pair(mkChar("statusMsg"), msg));
307         }
308
309         result.push_back(std::make_pair(mkChar("status"), asR(&backwardCompatStatus)));
310         result.push_back(std::make_pair(mkChar("evaluations"), evaluations));
311
312         omxFreeState(globalState);
313         delete Global;
314
315         return asR(&result);
316 }
317
318 SEXP omxBackend(SEXP computeIndex, SEXP constraints, SEXP matList, SEXP fgNames,
319                 SEXP varList, SEXP algList, SEXP expectList, SEXP computeList,
320                 SEXP data, SEXP intervalList, SEXP checkpointList, SEXP options)
321 {
322         try {
323                 return omxBackend2(computeIndex, constraints, matList, fgNames,
324                                    varList, algList, expectList, computeList,
325                                    data, intervalList, checkpointList, options);
326         } catch( std::exception& __ex__ ) {
327                 exception_to_try_error( __ex__ );
328         } catch(...) {
329                 string_to_try_error( "c++ exception (unknown reason)" );
330         }
331 }
332