Approximate methods: Variational Inference
Stat 221, Lecture 21
@snesterko
Roadmap
- Sam Wong on protein folding and PT/EE sampler
- Justify the need for shortcuts instead of MCMC
- Point estimation as an approximate method
- Connection to computing
- Variational Inference - general framework
- Introduce approximate methods
Example: document authorship
- Document authorship \\( \pi_d \sim \rr{Dirichlet}_K(\vec{\alpha},1) \\).
- Data \\( y_{dw} \given \pi_d \sim \rr{Poisson}(\sum_k\mu_{kw}\pi_{dk})\\).
- Author \\( k\\), document \\( d \\), word \\( w \\).
- \\( \pi_{dk} \\) represents the amount of contribution of author \\( k \\) do document \\(d\\).
Likelihood
For document \\( d \\):
$$\begin{align}L(\alpha, \mu, \pi \given y) \propto & \prod_{w=1}^W\left[\left(\sum_k\mu_{kw}\pi_{dk}\right)^{y_{dw}} \right. \\
& \left.\cdot e^{-\left(\sum_k\mu_{kw}\pi_{dk}\right)}\right] \cdot \prod_{k=1}^K \pi_{dk}^{\alpha_k - 1}\end{align}$$
How do we perform inference?
A simple example
Approximate a correlated bivariate Normal*:
$$ Z \sim \rr{Normal} \left( \left( \begin{array} 0\mu_1 \\ \mu_2\end{array}\right),
\left( \begin{array} 0\Lambda_{11} & \Lambda_{12} \\ \Lambda_{21} & \Lambda_{22}\end{array}\right) \right) $$
*A setting where we may not care about the off-diagonal terms**
**Why not simply set them to zero in our model?
Options
- Full MCMC:
- Metropolis or MH
- HMC on transformed space
- add Parallel Tempering (+ MPI)
- add Equi-energy sampler (+ MPI)
- "Exact" methods:
- Direct MLE
- Point estimation with EM*
- Approximate methods (Variational Inference).
*for EM estimation, need to pretend \\( \pi \\) is a latent
variable.
Approximate methods
- Variational inference
- Approximate the distribution of missing data.
- Monte Carlo EM
- Use MCMC to estimate quantities in the E step.
- Simulated Annealing - not focusing
- An iterative maximization algorithm.
+ curvature at the mode
Why approximate methods are dangerous
- Correlation is often key.
- Ease of implementation versus statistical performance tradeoff
- Warning
- It is easy to "get an answer".
- Rigorous checks are necessary to see how much performance we
have sacrificed for ease of implementation.
Variational inference
Reference: Bishop Chapter 10, pp. 461-517.
Main idea:
- Assume we have a posterior distribution
\\( p( Z \given X) \\)
- The distribution is complex.
- We want to approximate it with another distribution
that would be close to it, but wouldn't be too
hard to work with.
Variational inference: start
Decompose log marginal probability
$$\rr{log} p(X) = \mathcal{L}(q) + \rr{KL}(q || p)$$
with
$$\begin{align}\mathcal{L}(q) &= \int q(Z) \rr{log} \left\{ {p(X, Z) \over q(Z)}\right\} dZ \\
\rr{KL}(q || p) & = - \int q(Z) \rr{log} \left\{ {p(Z \given X) \over q(Z) }\right\}dZ\end{align}$$
Why decompose like this?
- \\( \mathcal{L}(q) \\) is called the lower bound.
- Maximizing \\( \mathcal{L}(q) \\) with respect to \\( q(Z) \\) yields
\\( q(z) = p(Z \given X) \\), and the KL divergence vanishes.
- We can also restrict \\( q \\) so that the restricted class is
easier to optimize over - this is the main idea in variational
inference.
One way to restrict \\( q\\)
- Use a parametric distribution \\( q(Z \given \omega) \\).
- Lower bound \\( \mathcal{L}(q) \\) then becomes a function
of \\( \omega \\).
- Can then use standard optimization techniques to
find optimal values for the parameters \\( \omega \\).
Factorized restriction on \\( q \\)
- It is very common to assume parameter independence.
- Partition the elements of \\( Z \\) into \\( M \\) disjoint groups
that we denote with \\( Z_i \\), \\( i = 1, \ldots, M \\).
- Then impose parameter independence restriction on \\( q \\):
$$q(Z) = \prod_{i=1}^M q_i(Z)$$
Optimal factorized \\( q \\)
Denote \\( q_j(Z_j)\\) with \\( q_j\\).
$$\begin{align} \mathcal{L}(q) = & \int \prod_i q_i \left\{ \rr{log} p(X, Z) - \sum_i \rr{log} q_i \right\} dZ \\
= & \int q_j \left\{ \rr{log} p(X,Z) \prod_{i \neq j} q_i dZ_i\right\} dZ_j \\
& - \int q_j \rr{log} q_j dZ_j + \rr{const} = \ldots\end{align}$$
$$\begin{align} \ldots = & \int q_j \rr{log} \tilde{p} (X, Z_j) dZ_j \\
& - \int q_j \rr{log} q_j d Z_j + \rr{const}\end{align}$$
Define \\( \tilde{p} \\) by the relation
$$\rr{log} \tilde{p}(X, Z_j) = E_{i\neq j} \left[ \rr{log} p(X, Z) \right]$$
where \\( E_{i\neq j} \left[ \ldots \right]\\) is expectation over all
\\( q \\) with \\( i \neq j \\):
$$E_{i\neq j} \left[ \rr{log} p(X, Z) \right] = \rr{log} \int p(X,Z) \prod_{i \neq j} q_i dZ_i$$
Optimal \\( q \\)
- Keeping \\( \{ q_{i \neq j} \} \\) fixed, \\( \mathcal{L}(q) \\)
is the negative KL divergence between \\( q_j(Z_j) \\) and
\\( \tilde{p}(X, Z_j) \\).
- The minimum of the KL divergence is at \\( q_j^*(Z_j) = \tilde{p}(X, Z_j) \\):
$$\rr{log} q_j^*(Z_j) = E_{i\neq j} \left[ \rr{log} p(X, Z)\right] + \rr{const}$$
$$q_j^*(Z_j) = { \rr{exp} \left(E_{i\neq j} \left[ \rr{log} p(X, Z)\right]\right)
\over \int \rr{exp}\left(E_{i\neq j} \left[ \rr{log} p(X, Z)\right]\right) dZ_j}$$
Example
Approximate a correlated bivariate Normal*:
$$ Z \sim \rr{Normal} \left( \left( \begin{array} 0\mu_1 \\ \mu_2\end{array}\right),
\left( \begin{array} 0\Lambda_{11} & \Lambda_{12} \\ \Lambda_{21} & \Lambda_{22}\end{array}\right) \right) $$
We can use the general identity to
obtain optimal \\( q \\).
*for simplicity, there is no data to condition on.
Example: optimal \\( q\\)
$$\begin{align} \rr{log} & q_1^*(z_1) = E_{z_2} \left[ \rr{log} p(z)\right] + \rr{const}\\
= & E_{z_2} \left[ - {1 \over 2} (z_1 - \mu_1)^2 \Lambda_{11}\right. \\
&\left.- (z_1 - \mu_1) \Lambda_{12}(z_2 - \mu_2)\right] + \rr{const} \\
= & - {1 \over 2} z_1^2 \Lambda_{11} + z_1 \mu_1 \Lambda_{11} - z_1 \Lambda_{12} \left( E[z_2] - \mu_2)\right) \\
& + \rr{const}\end{align}$$
This means \\( q_1^* \\) is Normal (quadratic log density).
Explicit form for \\( q^* \\)
$$q_1^* (z_1) = \rr{Normal}(z_1 \given m_1, \Lambda_{11}^{-1})$$
with \\( m_1 = \mu_1 - \Lambda_{11}^{-1} \Lambda_{12}( E [z_2] - \mu_2) \\).
By symmetry,
$$q_2^* (z_2) = \rr{Normal}(z_2 \given m_2, \Lambda_{22}^{-1})$$
with \\( m_2 = \mu_2 - \Lambda_{2}^{-1} \Lambda_{21}( E [z_1] - \mu_1) \\).
Solutions are coupled. In general, need iteration.
In this case, \\( E[z_1] = \mu_1\\) and \\( E[z_2] = \mu_2\\) work exactly.
Important questions
- Variational inference seems computing-friendly, but what's
lost in performance?
- Is it equivalent to making simpler models?
Document authorship example
Treat \\( \pi \\) as missing data. For document \\( d \\):
$$\begin{align}L(\alpha, \mu \given y, \pi) \propto & \prod_{w=1}^W\left[\left(\sum_k\mu_{kw}\pi_{dk}\right)^{y_{dw}} \right. \\
& \left.\cdot e^{-\left(\sum_k\mu_{kw}\pi_{dk}\right)}\right] \cdot \prod_{k=1}^K \pi_{dk}^{\alpha_k - 1}\end{align}$$
Different methods for the problem
- MCMC
- Metropolis-Hastings - could be very slow
- HMC - need to transform the space for \\( \pi \\) and \\( \alpha \\)
- Parallel Tempering/EE Sampler - high implementation cost, many likelihood
evaluations needed
- Point estimation
- EM - \\( p( \pi \given y) \\) is analytically intractable
- mcEM - could use HMC for the MC step
- vEM - Variational inference adapration for EM
What are we approximating here?
- The variational methods are applied to \\( p (\pi \given y) \\)
- Would it be easy to reformulate the model to "build in" this
approximation?
A variational solution
- We need more involved methods: there are both parameters, and
missing data.
- We can end up having an interesting approximate solution -
an MLE, but with respect to a different distribution.
Introducing vEM
- Approximate the E step:
- Instead of calculating \\( p( Z \given X) \\) in the E step,
maximize the lower bound with respect to a family of
approximating distributions \\( q(Z) \\).
Announcements
- T-shirt comp winner!
- Balancing work for pset5 and final project.
Final slide
- Next lecture: Variational EM and mcEM.