Enable R_NO_REMAP for a cleaner namespace
[openmx:openmx.git] / src / fitMultigroup.cpp
1 #include "omxExpectation.h"
2 #include "fitMultigroup.h"
3 #include "omxExportBackendState.h"
4 #include <vector>
5
6 // http://openmx.psyc.virginia.edu/issue/2013/01/multigroup-fit-function
7
8 struct FitMultigroup {
9         std::vector< FreeVarGroup* > varGroups;
10         std::vector< omxMatrix* > fits;
11 };
12
13 static void mgDestroy(omxFitFunction* oo)
14 {
15         FitMultigroup *mg = (FitMultigroup*) oo->argStruct;
16         delete mg;
17 }
18
19 static void mgCompute(omxFitFunction* oo, int ffcompute, FitContext *fc)
20 {
21         omxMatrix *fitMatrix  = oo->matrix;
22         fitMatrix->data[0] = 0;
23         double mac = 0;
24
25         FitMultigroup *mg = (FitMultigroup*) oo->argStruct;
26
27         for (size_t ex=0; ex < mg->fits.size(); ex++) {
28                 omxMatrix* f1 = mg->fits[ex];
29                 if (f1->fitFunction) {
30                         omxFitFunctionCompute(f1->fitFunction, ffcompute, fc);
31                         if (ffcompute & FF_COMPUTE_MAXABSCHANGE) {
32                                 mac = std::max(fc->mac, mac);
33                         }
34                         if (OMX_DEBUG) { mxLog("mg fit %s %d", f1->name, ffcompute); }
35                 } else {
36                         omxRecompute(f1);
37
38                         // This should really be checked elsewhere. TODO
39                         if(f1->rows != 1 || f1->cols != 1) {
40                                 Rf_error("%s algebra %d does not evaluate to a 1x1 matrix", oo->fitType, ex);
41                         }
42                 }
43                 fitMatrix->data[0] += f1->data[0];
44         }
45         if (fc) fc->mac = mac;
46         if(OMX_DEBUG) { mxLog("Fit Function sum of %lu groups is %f.", mg->fits.size(), fitMatrix->data[0]); }
47 }
48
49 void mgSetFreeVarGroup(omxFitFunction *oo, FreeVarGroup *fvg)
50 {
51         if (!oo->argStruct) initFitMultigroup(oo); // ugh TODO
52
53         FitMultigroup *mg = (FitMultigroup*) oo->argStruct;
54
55         if (!mg->fits.size()) {
56                 mg->varGroups.push_back(fvg);
57         } else {
58                 for (size_t ex=0; ex < mg->fits.size(); ex++) {
59                         omxMatrix *f1 = mg->fits[ex];
60                         if (!f1->fitFunction) {  // simple algebra
61                                 oo->freeVarGroup = fvg;
62                                 continue;
63                         }
64                         setFreeVarGroup(f1->fitFunction, fvg);
65                         oo->freeVarGroup = f1->fitFunction->freeVarGroup;
66                 }
67         }
68 }
69
70 void mgAddOutput(omxFitFunction* oo, MxRList *out)
71 {
72         FitMultigroup *mg = (FitMultigroup*) oo->argStruct;
73
74         for (size_t ex=0; ex < mg->fits.size(); ex++) {
75                 omxMatrix* f1 = mg->fits[ex];
76                 if (!f1->fitFunction) continue;
77                 omxPopulateFitFunction(f1, out);
78         }
79 }
80
81 void initFitMultigroup(omxFitFunction *oo)
82 {
83         oo->expectation = NULL;  // don't care about this
84         oo->computeFun = mgCompute;
85         oo->destructFun = mgDestroy;
86         oo->setVarGroup = mgSetFreeVarGroup;
87         oo->addOutput = mgAddOutput;
88
89         if (!oo->argStruct) oo->argStruct = new FitMultigroup;
90         FitMultigroup *mg = (FitMultigroup *) oo->argStruct;
91
92         SEXP rObj = oo->rObj;
93         if (!rObj) return;
94
95         if (mg->fits.size()) return; // hack to prevent double initialization, remove TOOD
96
97         oo->gradientAvailable = TRUE;
98         oo->hessianAvailable = TRUE;
99         oo->parametersHaveFlavor = TRUE;
100
101         omxState *os = oo->matrix->currentState;
102
103         SEXP slotValue;
104         Rf_protect(slotValue = R_do_slot(rObj, Rf_install("groups")));
105         int *fits = INTEGER(slotValue);
106         for(int gx = 0; gx < Rf_length(slotValue); gx++) {
107                 omxMatrix *mat;
108                 if (fits[gx] >= 0) {
109                         mat = os->algebraList[fits[gx]];
110                 } else {
111                         Rf_error("Can only add algebra and fitfunction");
112                 }
113                 if (mat == oo->matrix) Rf_error("Cannot add multigroup to itself");
114                 mg->fits.push_back(mat);
115                 if (mat->fitFunction) {
116                         for (size_t vg=0; vg < mg->varGroups.size(); ++vg) {
117                                 setFreeVarGroup(mat->fitFunction, mg->varGroups[vg]);
118                                 oo->freeVarGroup = mat->fitFunction->freeVarGroup;
119                         }
120                         omxCompleteFitFunction(mat);
121                         oo->gradientAvailable = (oo->gradientAvailable && mat->fitFunction->gradientAvailable);
122                         oo->hessianAvailable = (oo->hessianAvailable && mat->fitFunction->hessianAvailable);
123                         oo->parametersHaveFlavor = (oo->parametersHaveFlavor && mat->fitFunction->parametersHaveFlavor);
124                 } else {
125                         // TODO derivs for algebra
126                         oo->gradientAvailable = FALSE;
127                         oo->hessianAvailable = FALSE;
128                 }
129         }
130 }
131
132 /* TODO
133 void omxMultigroupAdd(omxFitFunction *oo, omxFitFunction *fit)
134 {
135         if (oo->initFun != initFitMultigroup) Rf_error("%s is not the multigroup fit", oo->fitType);
136         if (!oo->initialized) Rf_error("Fit not initialized", oo);
137
138         FitMultigroup *mg = (FitMultigroup*) oo->argStruct;
139         mg->fits.push_back(fit->matrix);
140         //addFreeVarDependency(oo->matrix->currentState, oo->matrix, fit->matrix);
141 }
142 */