Fix mxComputeOnce per Tim's suggestion
[openmx:openmx.git] / R / MxCompute.R
1 #
2 #   Copyright 2013 The OpenMx Project
3 #
4 #   Licensed under the Apache License, Version 2.0 (the "License");
5 #   you may not use this file except in compliance with the License.
6 #   You may obtain a copy of the License at
7
8 #        http://www.apache.org/licenses/LICENSE-2.0
9
10 #   Unless required by applicable law or agreed to in writing, software
11 #   distributed under the License is distributed on an "AS IS" BASIS,
12 #   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 #   See the License for the specific language governing permissions and
14 #   limitations under the License.
15
16 setClass(Class = "MxBaseCompute", 
17          representation = representation(
18            "VIRTUAL"),
19          contains = "MxBaseNamed")
20
21 setClassUnion("MxCompute", c("NULL", "MxBaseCompute"))
22
23 setGeneric("convertForBackend",
24         function(.Object, flatModel, model) {
25                 return(standardGeneric("convertForBackend"))
26         })
27
28 #----------------------------------------------------
29
30 setClass(Class = "MxComputeOperation",
31          contains = "MxBaseCompute",
32          representation = representation(
33            free.group = "MxCharOrNumber"))
34
35 setMethod("qualifyNames", signature("MxComputeOperation"),
36         function(.Object, modelname, namespace) {
37                 .Object@name <- imxIdentifier(modelname, .Object@name)
38                 .Object
39         })
40
41 setMethod("convertForBackend", signature("MxComputeOperation"),
42         function(.Object, flatModel, model) {
43                 name <- .Object@name
44                 fg <- match(.Object@free.group, flatModel@freeGroupNames)
45                 if (is.na(fg)) {
46                         stop(paste("Cannot find free group", .Object@free.group,
47                                    "in list of free groups:",
48                                    omxQuotes(flatModel@freeGroupNames)))
49                 } else {
50                         .Object@free.group <- fg - 1L
51                 }
52                 .Object
53         })
54
55 #----------------------------------------------------
56
57 setClass(Class = "MxComputeAssign",  # good name? or ComputeCopy?
58          contains = "MxComputeOperation",
59          representation = representation(
60            from = "MxCharOrNumber",
61            to = "MxCharOrNumber"))
62
63 setMethod("initialize", "MxComputeAssign",
64           function(.Object, from, to, free.group) {
65                   .Object@name <- 'compute'
66                   .Object@from <- from
67                   .Object@to <- to
68                   .Object@free.group <- free.group
69                   .Object
70           })
71
72 setMethod("qualifyNames", signature("MxComputeAssign"),
73         function(.Object, modelname, namespace) {
74                 .Object <- callNextMethod();
75                 .Object@from <- imxIdentifier(modelname, .Object@from)
76                 .Object@to <- imxIdentifier(modelname, .Object@to)
77                 .Object
78         })
79
80 setMethod("convertForBackend", signature("MxComputeAssign"),
81         function(.Object, flatModel, model) {
82                 .Object <- callNextMethod();
83                 name <- .Object@name
84                 for (sl in c('from', 'to')) {
85                         mat <- match(slot(.Object, sl), names(flatModel@matrices))
86                         if (any(is.na(mat))) {
87                                 stop(paste("MxComputeAssign: cannot find",
88                                            omxQuotes(slot(.Object, sl)[is.na(mat)]),
89                                            "mentioned in slot '", sl, "'"))
90                         }
91                         slot(.Object, sl) <- -mat
92                 }
93                 .Object
94         })
95
96 mxComputeAssign <- function(from, to, free.group="default") {
97         if (length(from) != length(to)) {
98                 stop("Arguments 'from' and 'to' must be the same length")
99         }
100         new("MxComputeAssign", from=from, to=to, free.group=free.group)
101 }
102
103 #----------------------------------------------------
104
105 setClass(Class = "MxComputeOnce",
106          contains = "MxComputeOperation",
107          representation = representation(
108            what = "MxOptionalCharOrNumber",
109            context = "character",
110            gradient = "logical",
111            hessian = "logical"))
112
113 setMethod("qualifyNames", signature("MxComputeOnce"),
114         function(.Object, modelname, namespace) {
115                 .Object@name <- imxIdentifier(modelname, .Object@name)
116                 .Object@what <- imxConvertIdentifier(.Object@what, modelname, namespace)
117                 .Object
118         })
119
120 setMethod("convertForBackend", signature("MxComputeOnce"),
121         function(.Object, flatModel, model) {
122                 .Object <- callNextMethod();
123                 name <- .Object@name
124                 if (length(.Object@what) != 1) stop("Can only apply MxComputeOnce to one object")
125                 if (!is.integer(.Object@what)) {
126                         expNum <- match(.Object@what, names(flatModel@expectations))
127                         algNum <- match(.Object@what, append(names(flatModel@algebras),
128                                                              names(flatModel@fitfunctions)))
129                         if (is.na(expNum) && is.na(algNum)) {
130                                 stop("Can only apply MxComputeOnce to MxAlgebra or MxExpectation")
131                         }
132                         if (!is.na(expNum)) {
133                                 .Object@what <- - expNum  # usually negative numbers indicate matrices
134                         } else {
135                                 .Object@what <- algNum - 1L
136                         }
137                 }
138                 .Object
139         })
140
141 setMethod("initialize", "MxComputeOnce",
142           function(.Object, what, free.group, context, gradient, hessian) {
143                   .Object@name <- 'compute'
144                   .Object@what <- what
145                   .Object@free.group <- free.group
146                   .Object@context <- context
147                   .Object@gradient <- gradient
148                   .Object@hessian <- hessian
149                   .Object
150           })
151
152 mxComputeOnce <- function(what, free.group='default', context=character(0), gradient=FALSE, hessian=FALSE) {
153         new("MxComputeOnce", what, free.group, context, gradient, hessian)
154 }
155
156 #----------------------------------------------------
157
158 setClass(Class = "MxComputeGradientDescent",
159          contains = "MxComputeOperation",
160          representation = representation(
161            fitfunction = "MxCharOrNumber",
162            engine = "character"))
163
164 setMethod("qualifyNames", signature("MxComputeGradientDescent"),
165         function(.Object, modelname, namespace) {
166                 .Object@name <- imxIdentifier(modelname, .Object@name)
167                 .Object@fitfunction <- imxConvertIdentifier(.Object@fitfunction, modelname, namespace)
168                 .Object
169         })
170
171 setMethod("convertForBackend", signature("MxComputeGradientDescent"),
172         function(.Object, flatModel, model) {
173                 .Object <- callNextMethod();
174                 name <- .Object@name
175                 if (is.character(.Object@fitfunction)) {
176                         .Object@fitfunction <- imxLocateIndex(flatModel, .Object@fitfunction, name)
177                 }
178                 .Object
179         })
180
181 setMethod("initialize", "MxComputeGradientDescent",
182           function(.Object, free.group, engine, fit) {
183                   .Object@name <- 'compute'
184                   .Object@free.group <- free.group
185                   .Object@fitfunction <- fit
186                   .Object@engine <- engine
187                   .Object
188           })
189
190 mxComputeGradientDescent <- function(type, free.group='default',
191                                      engine=NULL, fitfunction='fitfunction') {
192 # What to do with 'type'?
193 #       if (length(type) != 1) stop("Specific 1 compute type")
194
195         if (is.null(engine)) engine <- as.character(NA)
196
197         new("MxComputeGradientDescent", free.group, engine, fitfunction)
198 }
199
200 #----------------------------------------------------
201
202 setClass(Class = "MxComputeIterate",
203          contains = "MxBaseCompute",
204          representation = representation(
205            steps = "list",
206            maxIter = "integer",
207            tolerance = "numeric",
208            verbose = "logical"))
209
210 setMethod("initialize", "MxComputeIterate",
211           function(.Object, steps, maxIter, tolerance, verbose) {
212                   .Object@name <- 'compute'
213                   .Object@steps <- steps
214                   .Object@maxIter <- maxIter
215                   .Object@tolerance <- tolerance
216                   .Object@verbose <- verbose
217                   .Object
218           })
219
220 setMethod("qualifyNames", signature("MxComputeIterate"),
221         function(.Object, modelname, namespace) {
222                 .Object@name <- imxIdentifier(modelname, .Object@name)
223                 .Object@steps <- lapply(.Object@steps, function (c) qualifyNames(c, modelname, namespace))
224                 .Object
225         })
226
227 setMethod("convertForBackend", signature("MxComputeIterate"),
228         function(.Object, flatModel, model) {
229                 .Object@steps <- lapply(.Object@steps, function (c) convertForBackend(c, flatModel, model))
230                 .Object
231         })
232
233 mxComputeIterate <- function(steps, maxIter=500L, tolerance=1e-4, verbose=FALSE) {
234         new("MxComputeIterate", steps=steps, maxIter=maxIter, tolerance=tolerance, verbose)
235 }
236
237 displayMxComputeIterate <- function(opt) {
238         cat(class(opt), omxQuotes(opt@name), '\n')
239         cat("@tolerance :", omxQuotes(opt@tolerance), '\n')
240         cat("@maxIter :", omxQuotes(opt@maxIter), '\n')
241         for (step in 1:length(opt@steps)) {
242                 cat("[[", step, "]] :", class(opt@steps[[step]]), '\n')
243         }
244         invisible(opt)
245 }
246
247 setMethod("print", "MxComputeIterate", function(x, ...) displayMxComputeIterate(x))
248 setMethod("show",  "MxComputeIterate", function(object) displayMxComputeIterate(object))
249
250 #----------------------------------------------------
251
252 setClass(Class = "MxComputeEstimatedHessian",
253          contains = "MxComputeOperation",
254          representation = representation(
255            fitfunction = "MxCharOrNumber",
256            se = "logical"))
257
258 setMethod("qualifyNames", signature("MxComputeEstimatedHessian"),
259         function(.Object, modelname, namespace) {
260                 .Object@name <- imxIdentifier(modelname, .Object@name)
261                 .Object@fitfunction <- imxConvertIdentifier(.Object@fitfunction, modelname, namespace)
262                 .Object
263         })
264
265 setMethod("convertForBackend", signature("MxComputeEstimatedHessian"),
266         function(.Object, flatModel, model) {
267                 .Object <- callNextMethod();
268                 name <- .Object@name
269                 if (is.character(.Object@fitfunction)) {
270                         .Object@fitfunction <- imxLocateIndex(flatModel, .Object@fitfunction, name)
271                 }
272                 .Object
273         })
274
275 setMethod("initialize", "MxComputeEstimatedHessian",
276           function(.Object, free.group, fit, want.se) {
277                   .Object@name <- 'compute'
278                   .Object@free.group <- free.group
279                   .Object@fitfunction <- fit
280                   .Object@se <- want.se
281                   .Object
282           })
283
284 mxComputeEstimatedHessian <- function(free.group='default', fitfunction='fitfunction', want.se=TRUE) {
285         new("MxComputeEstimatedHessian", free.group, fitfunction, want.se)
286 }
287
288 #----------------------------------------------------
289
290 setClass(Class = "MxComputeSequence",
291          contains = "MxBaseCompute",
292          representation = representation(
293            steps = "list"))
294
295 setMethod("initialize", "MxComputeSequence",
296           function(.Object, steps) {
297                   .Object@name <- 'compute'
298                   .Object@steps <- steps
299                   .Object
300           })
301
302 setMethod("qualifyNames", signature("MxComputeSequence"),
303         function(.Object, modelname, namespace) {
304                 .Object@name <- imxIdentifier(modelname, .Object@name)
305                 .Object@steps <- lapply(.Object@steps, function (c) qualifyNames(c, modelname, namespace))
306                 .Object
307         })
308
309 setMethod("convertForBackend", signature("MxComputeSequence"),
310         function(.Object, flatModel, model) {
311                 .Object@steps <- lapply(.Object@steps, function (c) convertForBackend(c, flatModel, model))
312                 .Object
313         })
314
315 mxComputeSequence <- function(steps) {
316         new("MxComputeSequence", steps=steps)
317 }
318
319 displayMxComputeSequence <- function(opt) {
320         cat(class(opt), omxQuotes(opt@name), '\n')
321         for (step in 1:length(opt@steps)) {
322                 cat("[[", step, "]] :", class(opt@steps[[step]]), '\n')
323         }
324         invisible(opt)
325 }
326
327 setMethod("print", "MxComputeSequence", function(x, ...) displayMxComputeSequence(x))
328 setMethod("show",  "MxComputeSequence", function(object) displayMxComputeSequence(object))
329
330 #----------------------------------------------------
331
332 displayMxComputeOperation <- function(opt) {
333         cat(class(opt), omxQuotes(opt@name), '\n')
334         cat("@free.group :", omxQuotes(opt@free.group), '\n')
335         invisible(opt)
336 }
337
338 setMethod("print", "MxComputeOperation", function(x, ...) displayMxComputeOperation(x))
339 setMethod("show",  "MxComputeOperation", function(object) displayMxComputeOperation(object))
340
341 displayMxComputeGradientDescent <- function(opt) {
342         cat("@type :", omxQuotes(opt@type), '\n')
343         cat("@engine :", omxQuotes(opt@engine), '\n')
344         cat("@fitfunction :", omxQuotes(opt@fitfunction), '\n')
345         invisible(opt)
346 }
347
348 setMethod("print", "MxComputeGradientDescent",
349           function(x, ...) { callNextMethod(); displayMxComputeGradientDescent(x) })
350 setMethod("show",  "MxComputeGradientDescent",
351           function(object) { callNextMethod(); displayMxComputeGradientDescent(object) })
352
353 convertComputes <- function(flatModel, model) {
354         retval <- lapply(flatModel@computes, function(opt) {
355                 convertForBackend(opt, flatModel, model)
356         })
357         retval
358 }