In this short document, I want to demonstrate the James-Stein estimator for a particular hierarchical model, and how it reduces global error compared to applying maximum likelihood estimation repeatedly. The model I will show is from the following article:
Stein’s estimation rule and its competitors - an empirical Bayes approach by Bradley Efron and Carl Morris
The model is, suppose \(k \ge 3\) parameters \(\theta_1, ..., \theta_k\). For each \(\theta_i\), we have a single observation \(X_i\) drawn from the following distribution:
\[ X_i \sim N(\theta_i, 1) \]
As stated in the article above, you can think of \(X_i\) as the sample mean for n observations \(Y_{ij} \sim N(\theta_i, \sigma^2)\), where then the above model scales such that \(\sigma^2 / n = 1\). Furthermore we will assume that the \(\theta\) are normally distributed with mean 0 and some unknown variance. The following chunk of code computes the James-Stein estimator given a value of k and the variance of the normal distribution for \(\theta_i\). You can see that the James-Stein estimators move from the observed \(x_i\) toward the middle, and that, over all the estimates for B, the James-Stein estimator tends to be close to the minimum.
However, seeing how this behaves over a range of values for sigma.theta
is much more useful than just a fixed value. For this, the final code chunk uses the manipulate package, which works if you run it within RStudio. If you think of 1, the variance of \(x_i\), as “noise”, and sigma.theta
as how large the signal is, you can see that when noise is small relative to sigma.theta
there is little shrinkage, while, when sigma.theta
is small and the noise overwhelms the signal, the shrinkage toward 0 is large.
k <- 100
sigma.theta <- 1
plotJS <- function(k, sigma.theta) {
theta <- rnorm(k, 0, sigma.theta)
# oracle estimator: what if we knew sigma_theta^2,
# which Efron and Morris denote as 'A'
B.oracle <- 1/(1 + sigma.theta^2)
x <- rnorm(k, theta, 1)
B <- (k - 2)/sum(x^2)
eb <- (1 - B) * x
oracle <- (1 - B.oracle) * x
par(mfrow=c(2,1), mar=c(5,5,3,1))
plot(theta, x, xlim=c(-2*sigma.theta,2*sigma.theta), ylim=c(-6,6))
points(theta, eb, col="red")
abline(0,1)
legend("topleft","James-Stein estimators",col="red",pch=1)
s <- seq(from=0,to=1,length=100)
par(mar=c(5,5,1,1))
sse.f <- function(b) sum((theta - (1 - b)*x)^2)
sse <- sapply(s, sse.f)
plot(s, sse, ylim=c(0, max(sse)), type="l",
xlab="possible values for B", ylab="sum squared error")
points(B, sum((theta - eb)^2), col="red", pch=16)
points(B.oracle, sum((theta - oracle)^2), col="blue", pch=16)
points(0, sse[1], col="black", pch=16)
legend("top",c("James-Stein B","oracle","MLE"), col=c("red","blue","black"), pch=16)
}
set.seed(1)
plotJS(k, sigma.theta)
library(manipulate)
manipulate(plotJS(k, sigma.theta), sigma.theta = slider(.1,2))