ifa: Gradients
[openmx:openmx.git] / src / omxFitFunctionBA81.c
1 /*
2  * This is throw-away proof-of-concept code that will likely be
3  * replaced by something else.
4  *
5  * JNP 2012Dec10
6  */
7
8 #include "omxFitFunction.h"
9 #include "omxExpectationBA81.h"
10
11 static const char *NAME = "FitFunctionBA81";
12
13 typedef struct {
14
15         omxData *data;
16
17 } omxBA81State;
18
19
20 static void ba81Destroy(omxFitFunction *oo) {
21         if(OMX_DEBUG) {
22                 Rprintf("Freeing %s function.\n", NAME);
23         }
24         //omxBA81State *mml = (omxBA81State *) oo->argStruct;
25         // nothing to do yet
26 }
27
28 static omxRListElement* ba81SetFinalReturns(omxFitFunction *off, int *numReturns) {
29
30         omxRListElement* retVal;
31
32         *numReturns = 1;
33
34         retVal = (omxRListElement*) R_alloc(1, sizeof(omxRListElement));
35
36         retVal[0].numValues = 1;
37         retVal[0].values = (double*) R_alloc(1, sizeof(double));
38         strcpy(retVal[0].label, "Minus2LogLikelihood");
39         retVal[0].values[0] = omxMatrixElement(off->matrix, 0, 0);
40
41         return retVal;
42 }
43
44 // TODO: Don't trample the Expectation/FitFunction separation.
45
46 static void ba81GradientHook(omxFitFunction* oo, double *out)
47 {
48         ba81Gradient(oo->expectation, out);
49 }
50
51 static void ba81Compute(omxFitFunction *oo) {
52         if(OMX_DEBUG_MML) {Rprintf("Beginning %s Computation.\n", NAME);}
53
54         omxExpectation* expectation = oo->expectation;
55   
56         oo->matrix->data[0] = ba81ComputeFit(expectation);
57 }
58
59 void omxInitFitFunctionBA81(omxFitFunction* oo, SEXP rObj) {
60         //omxExpectation* expectation = oo->expectation;
61
62         //omxState* currentState = oo->matrix->currentState;
63         
64         if(OMX_DEBUG) {
65           Rprintf("Initializing %s.\n", NAME);
66         }
67         
68         //omxBA81State *newObj = (omxBA81State*) R_alloc(1, sizeof(omxBA81State));
69         
70         //newObj->data = oo->expectation->data;
71
72         omxExpectationCompute(oo->expectation);
73
74         oo->computeFun = ba81Compute;
75         oo->setFinalReturns = ba81SetFinalReturns;
76         oo->destructFun = ba81Destroy;
77
78         if (ba81ExpectationHasGradients(oo->expectation)) {
79                 oo->gradientFun = ba81GradientHook;
80         }
81 }