Expectations don't have setFinalReturns
[openmx:openmx.git] / src / omxStateSpaceExpectation.c
1 /*
2  *  Copyright 2007-2012 The OpenMx Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *       http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *   Unless required by applicable law or agreed to in writing, software
11  *   distributed under the License is distributed on an "AS IS" BASIS,
12  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  *
16  */
17
18
19 /***********************************************************
20 *
21 *  omxStateSpaceExpectation.c
22 *
23 *  Created: Michael D. Hunter   Date: 2012-10-28 20:07:36
24 *
25 *  Contains code to calculate the objective function for a
26 *   state space model.  Currently, this is done with a 
27 *   Kalman filter in separate Predict and Update steps.
28 *   Later this could be done with one of several Kalman 
29 *   filter-smoothers (a forward-backward algorithm).
30 *
31 **********************************************************/
32
33 #include "omxExpectation.h"
34 #include "omxBLAS.h"
35 #include "omxFIMLFitFunction.h"
36 #include "omxStateSpaceExpectation.h"
37
38
39 void omxCallStateSpaceExpectation(omxExpectation* ox) {
40     if(OMX_DEBUG) { Rprintf("State Space Expectation Called.\n"); }
41         omxStateSpaceExpectation* ose = (omxStateSpaceExpectation*)(ox->argStruct);
42         
43         omxRecompute(ose->A);
44         omxRecompute(ose->B);
45         omxRecompute(ose->C);
46         omxRecompute(ose->D);
47         omxRecompute(ose->Q);
48         omxRecompute(ose->R);
49         
50         // Probably should loop through all the data here!!!
51         omxKalmanPredict(ose);
52         omxKalmanUpdate(ose);
53 }
54
55
56
57 void omxDestroyStateSpaceExpectation(omxExpectation* ox) {
58         
59         if(OMX_DEBUG) { Rprintf("Destroying State Space Expectation.\n"); }
60         
61         omxStateSpaceExpectation* argStruct = (omxStateSpaceExpectation*)(ox->argStruct);
62         
63         /* We allocated 'em, so we destroy 'em. */
64         omxFreeMatrixData(argStruct->r);
65         omxFreeMatrixData(argStruct->s);
66         omxFreeMatrixData(argStruct->z);
67         //omxFreeMatrixData(argStruct->u); // This is data, destroy it?
68         //omxFreeMatrixData(argStruct->x); // This is latent data, destroy it?
69         //omxFreeMatrixData(argStruct->y); // This is data, destroy it?
70         omxFreeMatrixData(argStruct->K); // This is the Kalman gain, destroy it?
71         //omxFreeMatrixData(argStruct->P); // This is latent cov, destroy it?
72         omxFreeMatrixData(argStruct->S); // This is data error cov, destroy it?
73         omxFreeMatrixData(argStruct->Y);
74         omxFreeMatrixData(argStruct->Z);
75 }
76
77
78 void omxPopulateSSMAttributes(omxExpectation *ox, SEXP algebra) {
79     if(OMX_DEBUG) { Rprintf("Populating State Space Attributes.  Currently this does very little!\n"); }
80         
81 }
82
83
84
85
86 void omxKalmanPredict(omxStateSpaceExpectation* ose) {
87     if(OMX_DEBUG) { Rprintf("Kalman Predict Called.\n"); }
88         /* Creat local copies of State Space Matrices */
89         omxMatrix* A = ose->A;
90         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(A, "....State Space: A"); }
91         omxMatrix* B = ose->B;
92         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(B, "....State Space: B"); }
93         //omxMatrix* C = ose->C;
94         //omxMatrix* D = ose->D;
95         omxMatrix* Q = ose->Q;
96         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(Q, "....State Space: Q"); }
97         //omxMatrix* R = ose->R;
98         //omxMatrix* r = ose->r;
99         //omxMatrix* s = ose->s;
100         omxMatrix* u = ose->u;
101         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(u, "....State Space: u"); }
102         omxMatrix* x = ose->x;
103         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(x, "....State Space: x"); }
104         //omxMatrix* y = ose->y;
105         omxMatrix* z = ose->z;
106         //omxMatrix* K = ose->K;
107         omxMatrix* P = ose->P;
108         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(P, "....State Space: P"); }
109         //omxMatrix* S = ose->S;
110         //omxMatrix* Y = ose->Y;
111         omxMatrix* Z = ose->Z;
112
113         /* x = A x + B u */
114         omxDGEMV(FALSE, 1.0, A, x, 0.0, z); // x = A x
115         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(z, "....State Space: z = A x"); }
116         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(A, "....State Space: A"); }
117         omxDGEMV(FALSE, 1.0, B, u, 1.0, z); // x = B u + x THAT IS x = A x + B u
118         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(z, "....State Space: z = A x + B u"); }
119         omxCopyMatrix(x, z); // x = z THAT IS x = A x + B u
120         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(x, "....State Space: x = A x + B u"); }
121         
122         /* P = A P A^T + Q */
123         omxDSYMM(FALSE, 1.0, P, A, 0.0, Z); // Z = A P
124         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(Z, "....State Space: Z = A P"); }
125         omxCopyMatrix(P, Q); // P = Q
126         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(P, "....State Space: P = Q"); }
127         omxDGEMM(FALSE, TRUE, 1.0, Z, A, 1.0, P); // P = Z A^T + P THAT IS P = A P A^T + Q
128         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(P, "....State Space: P = A P A^T + Q"); }
129 }
130
131
132 void omxKalmanUpdate(omxStateSpaceExpectation* ose) {
133     if(OMX_DEBUG) { Rprintf("Kalman Update Called.\n"); }
134         /* Creat local copies of State Space Matrices */
135         //omxMatrix* A = ose->A;
136         //omxMatrix* B = ose->B;
137         omxMatrix* C = ose->C;
138         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(C, "....State Space: C"); }
139         omxMatrix* D = ose->D;
140         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(D, "....State Space: D"); }
141         //omxMatrix* Q = ose->Q;
142         omxMatrix* R = ose->R;
143         omxMatrix* r = ose->r;
144         omxMatrix* s = ose->s;
145         omxMatrix* u = ose->u;
146         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(u, "....State Space: u"); }
147         omxMatrix* x = ose->x;
148         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(x, "....State Space: x"); }
149         omxMatrix* y = ose->y;
150         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(y, "....State Space: y"); }
151         omxMatrix* K = ose->K;
152         omxMatrix* P = ose->P;
153         omxMatrix* S = ose->S;
154         omxMatrix* Y = ose->Y;
155         //omxMatrix* Z = ose->Z;
156         omxMatrix* Cov = ose->cov;
157         omxMatrix* Means = ose->means;
158         
159         int info = 0; // Used for computing inverse for Kalman gain
160         
161         /* r = y - C x - D u */
162         omxCopyMatrix(r, y); // r = y
163         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(r, "....State Space: r = y"); }
164         omxDGEMV(FALSE, -1.0, C, x, 1.0, r); // r = -C x + r THAT IS r = -C x + y
165         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(r, "....State Space: r = -C x + y"); }
166         omxDGEMV(FALSE, -1.0, D, u, 1.0, r); // r = -D u + r THAT IS r = y - (C x + D u)
167         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(r, "....State Space: r = y - (C x + D u)"); }
168         
169         /* Alternatively, create just the expected value for the data row, x. */
170         omxDGEMV(FALSE, 1.0, C, x, 0.0, s);
171         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(s, "....State Space: s = C x"); }
172         omxDGEMV(FALSE, 1.0, D, u, 1.0, s);
173         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(s, "....State Space: s = C x + D u"); }
174         omxCopyMatrix(Means, s);
175         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(Means, "....State Space: Means"); }
176         omxTransposeMatrix(Means);
177         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(Means, "....State Space: Means"); }
178         
179         /* S = C P C^T + R */
180         omxDSYMM(FALSE, 1.0, P, C, 0.0, Y); // Y = C P
181         omxCopyMatrix(S, R); // S = R
182         omxDGEMM(FALSE, TRUE, 1.0, Y, C, 1.0, S); // S = Y C^T + S THAT IS C P C^T + R
183         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(S, "....State Space: S = C P C^T + R"); }
184         
185         omxCopyMatrix(Cov, S); //Note: I know this is inefficient memory use, but for now it is more clear.-MDH
186         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(Cov, "....State Space: Cov"); }
187         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(Means, "....State Space: Means"); }
188         
189         /* Now compute the Kalman Gain and update the error covariance matrix */
190         /* S = S^-1 */
191         omxDPOTRF(S, &info); // S replaced by the lower triangular matrix of the Cholesky factorization
192         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(S, "....State Space: Cholesky of S"); }
193         //for(int i = 0; i < S->cols; i++) {
194         //      det += log(fabs(S->data[i+S->rows*i]));
195         // alternatively log(fabs(omxMatrixElement(S, i, i)));
196         //}
197         //det *= 2.0; //sum( log( abs( diag( chol(S) ) ) ) )*2
198         omxDPOTRI(S, &info); // S = S^-1 via Cholesky factorization
199         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(S, "....State Space: Inverse of S"); }
200         
201         /* K = P C^T S^-1 */
202         /* Computed as K^T = S^-1 C P */
203         omxDSYMM(TRUE, 1.0, S, Y, 0.0, K); // K = Y^T S THAT IS K = P C^T S^-1
204         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(K, "....State Space: K^T = S^-1 C P"); }
205         
206         /* x = x + K r */
207         omxDGEMV(TRUE, 1.0, K, r, 1.0, x); // x = K r + x
208         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(x, "....State Space: x = K r + x"); }
209         
210         /* P = (I - K C) P */
211         /* P = P - K C P */
212         omxDGEMM(TRUE, FALSE, -1.0, K, Y, 1.0, P); // P = -K Y + P THAT IS P = P - K C P
213         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(P, "....State Space: P = P - K C P"); }
214         
215         /*m2ll = r^T S r */
216         //omxDSYMV(1.0, S, r, 0.0, s); // s = S r
217         //m2ll = omxDDOT(r, s); // m2ll = r s THAT IS r^T S r
218         //m2ll += det; // m2ll = m2ll + det THAT IS m2ll = log(det(S)) + r^T S r
219         // Note: this leaves off the S->cols * log(2*pi) THAT IS k*log(2*pi)
220 }
221
222
223
224 void omxInitStateSpaceExpectation(omxExpectation* ox, SEXP rObj) {
225         
226         if(OMX_DEBUG) { Rprintf("Initializing State Space Expectation.\n"); }
227                 
228         int nx, ny, nu;
229         
230         //SEXP slotValue;   //Used by PPML
231         
232         /* Create and fill expectation */
233         ox->expType = "omxStateSpaceExpectation";
234         omxStateSpaceExpectation *SSMexp = (omxStateSpaceExpectation*) R_alloc(1, sizeof(omxStateSpaceExpectation));
235         omxState* currentState = ox->currentState;
236         
237         /* Set Expectation Calls and Structures */
238         ox->computeFun = omxCallStateSpaceExpectation;
239         ox->destructFun = omxDestroyStateSpaceExpectation;
240         ox->componentFun = omxGetStateSpaceExpectationComponent;
241         ox->mutateFun = omxSetStateSpaceExpectationComponent;
242         ox->populateAttrFun = omxPopulateSSMAttributes;
243         ox->argStruct = (void*) SSMexp;
244         
245         /* Set up expectation structures */
246         if(OMX_DEBUG) { Rprintf("Initializing State Space Meta Data for expectation.\n"); }
247         
248         if(OMX_DEBUG) { Rprintf("Processing A.\n"); }
249         SSMexp->A = omxNewMatrixFromIndexSlot(rObj, currentState, "A");
250         
251         if(OMX_DEBUG) { Rprintf("Processing B.\n"); }
252         SSMexp->B = omxNewMatrixFromIndexSlot(rObj, currentState, "B");
253         
254         if(OMX_DEBUG) { Rprintf("Processing C.\n"); }
255         SSMexp->C = omxNewMatrixFromIndexSlot(rObj, currentState, "C");
256         
257         if(OMX_DEBUG) { Rprintf("Processing D.\n"); }
258         SSMexp->D = omxNewMatrixFromIndexSlot(rObj, currentState, "D");
259         
260         if(OMX_DEBUG) { Rprintf("Processing Q.\n"); }
261         SSMexp->Q = omxNewMatrixFromIndexSlot(rObj, currentState, "Q");
262         
263         if(OMX_DEBUG) { Rprintf("Processing R.\n"); }
264         SSMexp->R = omxNewMatrixFromIndexSlot(rObj, currentState, "R");
265         
266         if(OMX_DEBUG) { Rprintf("Processing initial x.\n"); }
267         SSMexp->x = omxNewMatrixFromIndexSlot(rObj, currentState, "x");
268         
269         if(OMX_DEBUG) { Rprintf("Processing initial P.\n"); }
270         SSMexp->P = omxNewMatrixFromIndexSlot(rObj, currentState, "P");
271         
272         
273         /* Initialize the place holder matrices used in calculations */
274         nx = SSMexp->C->cols;
275         ny = SSMexp->C->rows;
276         nu = SSMexp->D->cols;
277         
278         if(OMX_DEBUG) { Rprintf("Processing first data row for y.\n"); }
279         SSMexp->y = omxInitMatrix(NULL, ny, 1, TRUE, currentState);
280         for(int i = 0; i < ny; i++) {
281                 omxSetMatrixElement(SSMexp->y, i, 0, omxMatrixElement(ox->data->dataMat, 0, i));
282         }
283         if(OMX_DEBUG_ALGEBRA) {omxPrintMatrix(SSMexp->y, "....State Space: y"); }
284         
285         // TODO Make x0 and P0 static (if possible) to save memory
286         // TODO Look into omxMatrix.c/h for a possible new matrix from omxMatrix function
287         if(OMX_DEBUG) { Rprintf("Generating static internals for resetting initial values.\n"); }
288         SSMexp->x0 =    omxInitMatrix(NULL, nx, 1, TRUE, currentState);
289         SSMexp->P0 =    omxInitMatrix(NULL, nx, nx, TRUE, currentState);
290         omxCopyMatrix(SSMexp->x0, SSMexp->x);
291         omxCopyMatrix(SSMexp->P0, SSMexp->P);
292         
293         if(OMX_DEBUG) { Rprintf("Generating internals for computation.\n"); }
294         
295         SSMexp->r =     omxInitMatrix(NULL, ny, 1, TRUE, currentState);
296         SSMexp->s =     omxInitMatrix(NULL, ny, 1, TRUE, currentState);
297         SSMexp->u =     omxInitMatrix(NULL, nu, 1, TRUE, currentState);
298         SSMexp->z =     omxInitMatrix(NULL, nx, 1, TRUE, currentState);
299         SSMexp->K =     omxInitMatrix(NULL, ny, nx, TRUE, currentState); // Actually the tranpose of the Kalman gain
300         SSMexp->S =     omxInitMatrix(NULL, ny, ny, TRUE, currentState);
301         SSMexp->Y =     omxInitMatrix(NULL, ny, nx, TRUE, currentState);
302         SSMexp->Z =     omxInitMatrix(NULL, nx, nx, TRUE, currentState);
303         
304         SSMexp->cov =           omxInitMatrix(NULL, ny, ny, TRUE, currentState);
305         SSMexp->means =         omxInitMatrix(NULL, 1, nx, TRUE, currentState);
306 }
307
308
309 omxMatrix* omxGetStateSpaceExpectationComponent(omxExpectation* ox, omxFitFunction* off, const char* component) {
310         omxStateSpaceExpectation* ose = (omxStateSpaceExpectation*)(ox->argStruct);
311         omxMatrix* retval = NULL;
312
313         if(!strncmp("cov", component, 3)) {
314                 retval = ose->cov;
315         } else if(!strncmp("mean", component, 4)) {
316                 retval = ose->means;
317         } else if(!strncmp("pvec", component, 4)) {
318                 // Once implemented, change compute function and return pvec
319         }
320         
321         if(OMX_DEBUG) { Rprintf("Returning 0x%x.\n", retval); }
322
323         return retval;
324 }
325
326 void omxSetStateSpaceExpectationComponent(omxExpectation* ox, omxFitFunction* off, const char* component, omxMatrix* om) {
327         omxStateSpaceExpectation* ose = (omxStateSpaceExpectation*)(ox->argStruct);
328         
329         if(!strcmp("y", component)) {
330                 ose->y = om;
331         }
332         if(!strcmp("Reset", component)) {
333                 omxCopyMatrix(ose->x, ose->x0);
334                 omxCopyMatrix(ose->P, ose->P0);
335         }
336 }
337
338
339