Commit 9e8501b1c856f1f95ceed5f64a9983d38c6a4b2a

added C Dirichlet log-density and test case

Commit diff

RGS/src/rgs_dirichlet.c

 
1/*
2 * Copyright (C) 2008 Antonio, Fabio Di Narzo <antonio.fabio _at_ gmail.com>
3 *
4 * This program is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation; either version 3 of the License, or (at
7 * your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful, but
10 * WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
12 * General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, write to the Free Software
16 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
17 */
18
19#include <Rmath.h>
20#include "rgs_distrib.h"
21#include "rgs_lapack.h"
22
23/*code adapted from 'gsl_ran_dirichlet_lnpdf' in the GSL library */
24SEXP rgs_ddirichlet_work(SEXP x, SEXP in_alpha, SEXP out_ans) {
25 int K = length(x);
26 double *alpha = REAL(in_alpha);
27 double *theta = REAL(x);
28
29 size_t i;
30 double log_p = 0.0;
31 double sum_alpha = 0.0;
32
33 for (i = 0; i < K; i++)
34 log_p += (alpha[i] - 1.0) * log (theta[i]);
35
36 for (i = 0; i < K; i++)
37 sum_alpha += alpha[i];
38
39 log_p += lgammafn(sum_alpha);
40
41 for (i = 0; i < K; i++)
42 log_p -= lgammafn (alpha[i]);
43
44 REAL(out_ans)[0] = log_p;
45 return R_NilValue;
46}
47
48double rgs_ddirichlet (SEXP node) {
49 SEXP ans;
50 PROTECT(ans = allocVector(REALSXP, 1));
51 rgs_ddirichlet_work(RGS_NVALUE(node), RGS_NVALUE(RGS_CPARAMETERS(node, 1)), ans);
52 UNPROTECT(1);
53 return REAL(ans)[0];
54}
toggle raw diff

RGS/tests/test_dirichlet.R

 
1source("Test.R")
2
3cwrapper <- function(x, alpha) {
4 ans <- 0.0
5 .Call("rgs_ddirichlet_work", as.real(x), as.real(alpha), ans)
6 ans
7}
8
9##Adapted from MCMCpack
10ddirichlet <- function (x, alpha) {
11 dirichlet1 <- function(x, alpha) {
12 logD <- sum(lgamma(alpha)) - lgamma(sum(alpha))
13 s <- sum((alpha - 1) * log(x))
14 exp(sum(s) - logD)
15 }
16 x <- t(x)
17 alpha <- matrix(alpha, ncol = length(alpha), nrow = nrow(x), byrow = TRUE)
18 pd <- vector(length = nrow(x))
19 for (i in 1:nrow(x)) pd[i] <- dirichlet1(x[i, ], alpha[i, ])
20 pd[apply(x, 1, function(z) any(z < 0 | z > 1))] <- 0
21 pd[apply(x, 1, function(z) all.equal(sum(z), 1) != TRUE)] <- 0
22 return(pd)
23}
24
25val.true <- log(ddirichlet(c(0.2, 0.8), c(0.8, 0.2)))
26val.check <- cwrapper(c(0.2, 0.8), c(0.8, 0.2))
27stopifnot(signif(val.true,6) == signif(val.check, 6))
toggle raw diff