Switch over to new structured Compute system
[openmx:openmx.git] / src / npsolWrap.cpp
1 /*
2  *  Copyright 2007-2013 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16
17 #include <stdio.h>
18 #include <sys/types.h>
19 #include <errno.h>
20
21 #include <R.h>
22 #include <Rinternals.h>
23 #include <Rdefines.h>
24 #include <R_ext/Rdynload.h>
25 #include <R_ext/BLAS.h>
26 #include <R_ext/Lapack.h>
27
28 #include "omxDefines.h"
29 #include "types.h"
30 #include "npsolWrap.h"
31 #include "omxOpenmpWrap.h"
32 #include "omxState.h"
33 #include "omxMatrix.h"
34 #include "omxAlgebra.h"
35 #include "omxFitFunction.h"
36 #include "omxExpectation.h"
37 #include "omxNPSOLSpecific.h"
38 #include "omxImportFrontendState.h"
39 #include "omxExportBackendState.h"
40 #include "omxOptimizer.h"
41 #include "Compute.h"
42
43 omp_lock_t GlobalRLock;
44
45 static R_CallMethodDef callMethods[] = {
46         {"omxBackend", (DL_FUNC) omxBackend, 12},
47         {"omxCallAlgebra", (DL_FUNC) omxCallAlgebra, 3},
48         {"findIdenticalRowsData", (DL_FUNC) findIdenticalRowsData, 5},
49         {NULL, NULL, 0}
50 };
51
52 #ifdef  __cplusplus
53 extern "C" {
54 #endif
55
56 void R_init_OpenMx(DllInfo *info) {
57         R_registerRoutines(info, NULL, callMethods, NULL, NULL);
58
59         omx_omp_init_lock(&GlobalRLock);
60
61         // There is no code that will change behavior whether openmp
62         // is set for nested or not. I'm just keeping this in case it
63         // makes a difference with older versions of openmp. 2012-12-24 JNP
64 #if defined(_OPENMP) && _OPENMP <= 200505
65         omp_set_nested(0);
66 #endif
67 }
68
69 void R_unload_OpenMx(DllInfo *info) {
70         omx_omp_destroy_lock(&GlobalRLock);
71 }
72
73 #ifdef  __cplusplus
74 }
75 #endif
76
77 void string_to_try_error( const std::string& str )
78 {
79         error("%s", str.c_str());
80 }
81
82 void exception_to_try_error( const std::exception& ex )
83 {
84         string_to_try_error(ex.what());
85 }
86
87 SEXP asR(MxRList *out)
88 {
89         // change to a set to avoid duplicate keys TODO
90         SEXP names, ans;
91         int len = out->size();
92         PROTECT(names = allocVector(STRSXP, len));
93         PROTECT(ans = allocVector(VECSXP, len));
94         for (int lx=0; lx < len; ++lx) {
95                 SET_STRING_ELT(names, lx, (*out)[lx].first);
96                 SET_VECTOR_ELT(ans,   lx, (*out)[lx].second);
97         }
98         namesgets(ans, names);
99         return ans;
100 }
101
102 /* Main functions */
103 SEXP omxCallAlgebra2(SEXP matList, SEXP algNum, SEXP options) {
104
105         omxManageProtectInsanity protectManager;
106
107         if(OMX_DEBUG) { Rprintf("-----------------------------------------------------------------------\n");}
108         if(OMX_DEBUG) { Rprintf("Explicit call to algebra %d.\n", INTEGER(algNum));}
109
110         int j,k,l;
111         omxMatrix* algebra;
112         int algebraNum = INTEGER(algNum)[0];
113         SEXP ans, nextMat;
114         char output[MAX_STRING_LEN];
115
116         /* Create new omxState for current state storage and initialize it. */
117         
118         globalState = new omxState;
119         omxInitState(globalState, NULL);
120         if(OMX_DEBUG) { Rprintf("Created state object at 0x%x.\n", globalState);}
121
122         /* Retrieve All Matrices From the MatList */
123
124         if(OMX_DEBUG) { Rprintf("Processing %d matrix(ces).\n", length(matList));}
125
126         omxMatrix *args[length(matList)];
127         for(k = 0; k < length(matList); k++) {
128                 PROTECT(nextMat = VECTOR_ELT(matList, k));      // This is the matrix + populations
129                 args[k] = omxNewMatrixFromRPrimitive(nextMat, globalState, 1, - k - 1);
130                 globalState->matrixList.push_back(args[k]);
131                 if(OMX_DEBUG) {
132                         Rprintf("Matrix initialized at 0x%0xd = (%d x %d).\n",
133                                 globalState->matrixList[k], globalState->matrixList[k]->rows, globalState->matrixList[k]->cols);
134                 }
135         }
136
137         algebra = omxNewAlgebraFromOperatorAndArgs(algebraNum, args, length(matList), globalState);
138
139         if(algebra==NULL) {
140                 error(globalState->statusMsg);
141         }
142
143         if(OMX_DEBUG) {Rprintf("Completed Algebras and Matrices.  Beginning Initial Compute.\n");}
144         omxStateNextEvaluation(globalState);
145
146         omxRecompute(algebra);
147
148         PROTECT(ans = allocMatrix(REALSXP, algebra->rows, algebra->cols));
149         for(l = 0; l < algebra->rows; l++)
150                 for(j = 0; j < algebra->cols; j++)
151                         REAL(ans)[j * algebra->rows + l] =
152                                 omxMatrixElement(algebra, l, j);
153
154         if(OMX_DEBUG) { Rprintf("All Algebras complete.\n"); }
155
156         output[0] = 0;
157         if (isErrorRaised(globalState)) {
158                 strncpy(output, globalState->statusMsg, MAX_STRING_LEN);
159         }
160
161         omxFreeAllMatrixData(algebra);
162         omxFreeState(globalState);
163
164         if(output[0]) error(output);
165
166         return ans;
167 }
168
169 SEXP omxCallAlgebra(SEXP matList, SEXP algNum, SEXP options)
170 {
171         try {
172                 return omxCallAlgebra2(matList, algNum, options);
173         } catch( std::exception& __ex__ ) {
174                 exception_to_try_error( __ex__ );
175         } catch(...) {
176                 string_to_try_error( "c++ exception (unknown reason)" );
177         }
178 }
179
180 SEXP omxBackend2(SEXP computeIndex, SEXP startVals, SEXP constraints,
181                  SEXP matList, SEXP varList, SEXP algList, SEXP expectList, SEXP computeList,
182         SEXP data, SEXP intervalList, SEXP checkpointList, SEXP options)
183 {
184         SEXP nextLoc;
185
186         int analyticGradients = 0;
187
188         /* Sanity Check and Parse Inputs */
189         /* TODO: Need to find a way to account for nullness in these.  For now, all checking is done on the front-end. */
190 //      if(!isVector(startVals)) error ("startVals must be a vector");
191 //      if(!isVector(matList)) error ("matList must be a list");
192 //      if(!isVector(algList)) error ("algList must be a list");
193
194         omxManageProtectInsanity protectManager;
195
196         /* Create new omxState for current state storage and initialize it. */
197         globalState = new omxState;
198         omxInitState(globalState, NULL);
199
200         /*      Set NPSOL options */
201         omxSetNPSOLOpts(options, &globalState->ciMaxIterations, &globalState->numThreads, 
202                         &analyticGradients, length(startVals));
203
204         globalState->numFreeParams = length(startVals);
205         globalState->analyticGradients = analyticGradients;
206         if(OMX_DEBUG) { Rprintf("Created state object at 0x%x.\n", globalState);}
207
208         /* Retrieve Data Objects */
209         omxProcessMxDataEntities(data);
210         if (isErrorRaised(globalState)) error(globalState->statusMsg);
211     
212         /* Retrieve All Matrices From the MatList */
213         omxProcessMxMatrixEntities(matList);
214         if (isErrorRaised(globalState)) error(globalState->statusMsg);
215
216         if (length(startVals) != length(varList)) error("varList and startVals must be the same length");
217
218         /* Process Free Var List */
219         omxProcessFreeVarList(varList);
220         if (isErrorRaised(globalState)) error(globalState->statusMsg);
221
222         omxProcessMxExpectationEntities(expectList);
223         if (isErrorRaised(globalState)) error(globalState->statusMsg);
224
225         omxProcessMxAlgebraEntities(algList);
226         if (isErrorRaised(globalState)) error(globalState->statusMsg);
227
228         omxCompleteMxExpectationEntities();
229         if (isErrorRaised(globalState)) error(globalState->statusMsg);
230
231         omxProcessMxFitFunction(algList);
232         if (isErrorRaised(globalState)) error(globalState->statusMsg);
233
234         // This is the chance to check for matrix
235         // conformability, etc.  Any errors encountered should
236         // be reported using R's error() function, not
237         // omxRaiseErrorf.
238
239         omxInitialMatrixAlgebraCompute();
240         omxResetStatus(globalState);
241
242         omxProcessMxComputeEntities(computeList);
243
244         // maybe require a Compute object? TODO
245         omxCompute *topCompute = NULL;
246         if (!isNull(computeIndex)) {
247                 int ox = INTEGER(computeIndex)[0];
248                 topCompute = globalState->computeList[ox];
249         }
250
251         /* Process Matrix and Algebra Population Function */
252         /*
253           Each matrix is a list containing a matrix and the other matrices/algebras that are
254           populated into it at each iteration.  The first element is already processed, above.
255           The rest of the list will be processed here.
256         */
257         for(int j = 0; j < length(matList); j++) {
258                 PROTECT(nextLoc = VECTOR_ELT(matList, j));              // This is the matrix + populations
259                 omxProcessMatrixPopulationList(globalState->matrixList[j], nextLoc);
260         }
261
262         omxProcessConstraints(constraints);
263
264         /* Process Confidence Interval List */
265         omxProcessConfidenceIntervals(intervalList);
266
267         /* Process Checkpoint List */
268         omxProcessCheckpointOptions(checkpointList);
269
270         cacheFreeVarDependencies();
271
272         int n = globalState->numFreeParams;
273
274         if (topCompute && !isErrorRaised(globalState)) {
275                 double *sv = NULL;
276                 if (n) sv = REAL(startVals);
277                 topCompute->compute(sv);
278         }
279
280         SEXP evaluations;
281         PROTECT(evaluations = NEW_NUMERIC(2));
282
283         REAL(evaluations)[0] = globalState->computeCount;
284
285         MxRList result;
286
287         // What if fitfunction has its own repopulateFun? TODO
288         if (topCompute && !isErrorRaised(globalState) && n > 0) {
289                 handleFreeVarListHelper(globalState, topCompute->getEstimate(), n);
290         }
291
292         omxExportResults(globalState, &result); 
293
294         REAL(evaluations)[1] = globalState->computeCount;
295
296         double optStatus = NA_REAL;
297         if (topCompute && !isErrorRaised(globalState)) {
298                 topCompute->reportResults(&result);
299                 optStatus = topCompute->getOptimizerStatus();
300         }
301
302         MxRList backwardCompatStatus;
303         backwardCompatStatus.push_back(std::make_pair(mkChar("code"), ScalarReal(optStatus)));
304         backwardCompatStatus.push_back(std::make_pair(mkChar("status"),
305                                                       ScalarInteger(-isErrorRaised(globalState))));
306
307         if (isErrorRaised(globalState)) {
308                 SEXP msg;
309                 PROTECT(msg = allocVector(STRSXP, 1));
310                 SET_STRING_ELT(msg, 0, mkChar(globalState->statusMsg));
311                 result.push_back(std::make_pair(mkChar("error"), msg));
312                 backwardCompatStatus.push_back(std::make_pair(mkChar("statusMsg"), msg));
313         }
314
315         result.push_back(std::make_pair(mkChar("status"), asR(&backwardCompatStatus)));
316         result.push_back(std::make_pair(mkChar("evaluations"), evaluations));
317
318         omxFreeState(globalState);
319
320         return asR(&result);
321
322 }
323
324 SEXP omxBackend(SEXP computeIndex, SEXP startVals, SEXP constraints,
325                 SEXP matList, SEXP varList, SEXP algList, SEXP expectList, SEXP computeList,
326                 SEXP data, SEXP intervalList, SEXP checkpointList, SEXP options)
327 {
328         try {
329                 return omxBackend2(computeIndex, startVals, constraints,
330                                    matList, varList, algList, expectList, computeList,
331                                    data, intervalList, checkpointList, options);
332         } catch( std::exception& __ex__ ) {
333                 exception_to_try_error( __ex__ );
334         } catch(...) {
335                 string_to_try_error( "c++ exception (unknown reason)" );
336         }
337 }
338