Fix/remove improper printf style formats
[openmx:openmx.git] / src / fitMultigroup.cpp
1 #include "omxExpectation.h"
2 #include "fitMultigroup.h"
3 #include "omxExportBackendState.h"
4 #include <vector>
5
6 // http://openmx.psyc.virginia.edu/issue/2013/01/multigroup-fit-function
7
8 struct FitMultigroup {
9         std::vector< FreeVarGroup* > varGroups;
10         std::vector< omxMatrix* > fits;
11 };
12
13 static void mgDestroy(omxFitFunction* oo)
14 {
15         FitMultigroup *mg = (FitMultigroup*) oo->argStruct;
16         delete mg;
17 }
18
19 static void mgCompute(omxFitFunction* oo, int ffcompute, FitContext *fc)
20 {
21         omxMatrix *fitMatrix  = oo->matrix;
22         fitMatrix->data[0] = 0;
23
24         FitMultigroup *mg = (FitMultigroup*) oo->argStruct;
25
26         for (size_t ex=0; ex < mg->fits.size(); ex++) {
27                 omxMatrix* f1 = mg->fits[ex];
28                 if (f1->fitFunction) {
29                         omxFitFunctionCompute(f1->fitFunction, ffcompute, fc);
30                         if (OMX_DEBUG) { mxLog("mg fit %s %d", f1->name, ffcompute); }
31                 } else {
32                         omxRecompute(f1);
33
34                         // This should really be checked elsewhere. TODO
35                         if(f1->rows != 1 || f1->cols != 1) {
36                                 error("%s algebra %d does not evaluate to a 1x1 matrix", oo->fitType, ex);
37                         }
38                 }
39                 fitMatrix->data[0] += f1->data[0];
40         }
41         if(OMX_DEBUG) { mxLog("Fit Function sum of %lu groups is %f.", mg->fits.size(), fitMatrix->data[0]); }
42 }
43
44 void mgSetFreeVarGroup(omxFitFunction *oo, FreeVarGroup *fvg)
45 {
46         if (!oo->argStruct) initFitMultigroup(oo); // ugh TODO
47
48         FitMultigroup *mg = (FitMultigroup*) oo->argStruct;
49
50         if (!mg->fits.size()) {
51                 mg->varGroups.push_back(fvg);
52         } else {
53                 for (size_t ex=0; ex < mg->fits.size(); ex++) {
54                         omxMatrix *f1 = mg->fits[ex];
55                         if (!f1->fitFunction) {  // simple algebra
56                                 oo->freeVarGroup = fvg;
57                                 continue;
58                         }
59                         setFreeVarGroup(f1->fitFunction, fvg);
60                         oo->freeVarGroup = f1->fitFunction->freeVarGroup;
61                 }
62         }
63 }
64
65 void mgAddOutput(omxFitFunction* oo, MxRList *out)
66 {
67         FitMultigroup *mg = (FitMultigroup*) oo->argStruct;
68
69         for (size_t ex=0; ex < mg->fits.size(); ex++) {
70                 omxMatrix* f1 = mg->fits[ex];
71                 if (!f1->fitFunction) continue;
72                 omxPopulateFitFunction(f1, out);
73         }
74 }
75
76 void initFitMultigroup(omxFitFunction *oo)
77 {
78         oo->expectation = NULL;  // don't care about this
79         oo->computeFun = mgCompute;
80         oo->destructFun = mgDestroy;
81         oo->setVarGroup = mgSetFreeVarGroup;
82         oo->addOutput = mgAddOutput;
83
84         if (!oo->argStruct) oo->argStruct = new FitMultigroup;
85         FitMultigroup *mg = (FitMultigroup *) oo->argStruct;
86
87         SEXP rObj = oo->rObj;
88         if (!rObj) return;
89
90         if (mg->fits.size()) return; // hack to prevent double initialization, remove TOOD
91
92         oo->gradientAvailable = TRUE;
93         oo->hessianAvailable = TRUE;
94
95         omxState *os = oo->matrix->currentState;
96
97         SEXP slotValue;
98         PROTECT(slotValue = GET_SLOT(rObj, install("groups")));
99         int *fits = INTEGER(slotValue);
100         for(int gx = 0; gx < length(slotValue); gx++) {
101                 omxMatrix *mat;
102                 if (fits[gx] >= 0) {
103                         mat = os->algebraList[fits[gx]];
104                 } else {
105                         error("Can only add algebra and fitfunction");
106                 }
107                 if (mat == oo->matrix) error("Cannot add multigroup to itself");
108                 mg->fits.push_back(mat);
109                 if (mat->fitFunction) {
110                         for (size_t vg=0; vg < mg->varGroups.size(); ++vg) {
111                                 setFreeVarGroup(mat->fitFunction, mg->varGroups[vg]);
112                                 oo->freeVarGroup = mat->fitFunction->freeVarGroup;
113                         }
114                         omxCompleteFitFunction(mat);
115                         oo->gradientAvailable = (oo->gradientAvailable && mat->fitFunction->gradientAvailable);
116                         oo->hessianAvailable = (oo->hessianAvailable && mat->fitFunction->hessianAvailable);
117                 } else {
118                         // TODO derivs for algebra
119                         oo->gradientAvailable = FALSE;
120                         oo->hessianAvailable = FALSE;
121                 }
122         }
123 }
124
125 /* TODO
126 void omxMultigroupAdd(omxFitFunction *oo, omxFitFunction *fit)
127 {
128         if (oo->initFun != initFitMultigroup) error("%s is not the multigroup fit", oo->fitType);
129         if (!oo->initialized) error("Fit %p not initialized", oo);
130
131         FitMultigroup *mg = (FitMultigroup*) oo->argStruct;
132         mg->fits.push_back(fit->matrix);
133         //addFreeVarDependency(oo->matrix->currentState, oo->matrix, fit->matrix);
134 }
135 */