Revert "Split fit function initialization similar to expectations"
[openmx:openmx.git] / src / fitMultigroup.cpp
1 #include "omxExpectation.h"
2 #include "omxOptimizer.h"
3 #include "fitMultigroup.h"
4 #include <vector>
5
6 // http://openmx.psyc.virginia.edu/issue/2013/01/multigroup-fit-function
7
8 struct FitMultigroup {
9         std::vector< int > fits;  // store pointers or index numbers? TODO
10         bool checkedRepopulate;
11         FitMultigroup() : checkedRepopulate(0) {}
12 };
13
14 static void mgDestroy(omxFitFunction* oo)
15 {
16         FitMultigroup *mg = (FitMultigroup*) oo->argStruct;
17         delete mg;
18 }
19
20 static void checkRepopulate(omxFitFunction* oo)
21 {
22         FitMultigroup *mg = (FitMultigroup*) oo->argStruct;
23         omxState *os = oo->matrix->currentState;
24         for (size_t ex=0; ex < mg->fits.size(); ex++) {
25                 omxMatrix* f1 = os->algebraList[mg->fits[ex]];
26                 omxFitFunction *ff = f1->fitFunction;
27                 if (!ff || ff->repopulateFun == handleFreeVarList) continue;
28                 error("Cannot add %s to multigroup fit", f1->name);
29         }
30 }
31
32 static void mgCompute(omxFitFunction* oo, int ffcompute, double* grad)
33 {
34         omxMatrix *fitMatrix  = oo->matrix;
35         omxState *os = fitMatrix->currentState;
36         fitMatrix->data[0] = 0;
37
38         FitMultigroup *mg = (FitMultigroup*) oo->argStruct;
39
40         if (!mg->checkedRepopulate) {
41                 checkRepopulate(oo);
42                 mg->checkedRepopulate = TRUE;
43         }
44
45         for (size_t ex=0; ex < mg->fits.size(); ex++) {
46                 omxMatrix* f1 = os->algebraList[mg->fits[ex]];
47                 if (f1->fitFunction) {
48                         // possibly invalidate gradients TODO
49                         omxFitFunctionCompute(f1->fitFunction, ffcompute, grad);
50                 } else {
51                         // invalidate gradients TODO
52                         omxRecompute(f1);
53
54                         // This should really be checked elsewhere. TODO
55                         if(f1->rows != 1 || f1->cols != 1) {
56                                 error("%s algebra %d does not evaluate to a 1x1 matrix", oo->fitType, ex);
57                         }
58                 }
59                 fitMatrix->data[0] += f1->data[0];
60         }
61         if(OMX_DEBUG) { Rprintf("Fit Function sum of %d groups is %f.\n", mg->fits.size(), fitMatrix->data[0]); }
62 }
63
64 void initFitMultigroup(omxFitFunction *oo, SEXP ign)
65 {
66         oo->expectation = NULL;  // don't care about this
67         oo->computeFun = mgCompute;
68         oo->destructFun = mgDestroy;
69         oo->repopulateFun = handleFreeVarList;
70
71         FitMultigroup *mg = new FitMultigroup;
72         oo->argStruct = mg;
73
74         SEXP rObj = oo->rObj;
75         if (!rObj) return;
76
77         int myIndex = oo->matrix->matrixNumber;
78
79         SEXP slotValue;
80         PROTECT(slotValue = GET_SLOT(rObj, install("groups")));
81         int *fits = INTEGER(slotValue);
82         for(int gx = 0; gx < length(slotValue); gx++) {
83                 if (fits[gx] == myIndex) error("Cannot add multigroup to itself");
84                 mg->fits.push_back(fits[gx]);
85         }
86 }
87
88 void omxMultigroupAdd(omxFitFunction *oo, omxFitFunction *grp)
89 {
90         if (oo->initFun != initFitMultigroup) error("%s is not the multigroup fit", oo->fitType);
91         if (!oo->initialized) error("Fit %p not initialized", oo);
92
93         int myIndex = oo->matrix->matrixNumber;
94         int grpIndex = grp->matrix->matrixNumber;
95         if (grpIndex == myIndex) error("Cannot add multigroup to itself");
96
97         omxState *os = oo->matrix->currentState;
98         if (os->algebraList.at(grpIndex) != grp->matrix) {
99                 error("Attempt to add group %d missing from algebraList", grpIndex);
100         }
101
102         FitMultigroup *mg = (FitMultigroup*) oo->argStruct;
103         mg->fits.push_back(grpIndex);
104         //addFreeVarDependency(oo->matrix->currentState, oo->matrix, grp->matrix);
105 }