Approximate methods: Variational Inference

Stat 221, Lecture 21

@snesterko

Roadmap

  • Sam Wong on protein folding and PT/EE sampler
  • Justify the need for shortcuts instead of MCMC
    • Point estimation as an approximate method
    • Connection to computing
  • Variational Inference - general framework
  • Introduce approximate methods
    • vEM
    • mcEM

Example: document authorship

  • Document authorship \\( \pi_d \sim \rr{Dirichlet}_K(\vec{\alpha},1) \\).
  • Data \\( y_{dw} \given \pi_d \sim \rr{Poisson}(\sum_k\mu_{kw}\pi_{dk})\\).
  • Author \\( k\\), document \\( d \\), word \\( w \\).
  • \\( \pi_{dk} \\) represents the amount of contribution of author \\( k \\) do document \\(d\\).

Likelihood

For document \\( d \\):

$$\begin{align}L(\alpha, \mu, \pi \given y) \propto & \prod_{w=1}^W\left[\left(\sum_k\mu_{kw}\pi_{dk}\right)^{y_{dw}} \right. \\ & \left.\cdot e^{-\left(\sum_k\mu_{kw}\pi_{dk}\right)}\right] \cdot \prod_{k=1}^K \pi_{dk}^{\alpha_k - 1}\end{align}$$

How do we perform inference?

A simple example

Approximate a correlated bivariate Normal*:

$$ Z \sim \rr{Normal} \left( \left( \begin{array} 0\mu_1 \\ \mu_2\end{array}\right), \left( \begin{array} 0\Lambda_{11} & \Lambda_{12} \\ \Lambda_{21} & \Lambda_{22}\end{array}\right) \right) $$

*A setting where we may not care about the off-diagonal terms**

**Why not simply set them to zero in our model?

Options

  • Full MCMC:
    • Metropolis or MH
    • HMC on transformed space
    • add Parallel Tempering (+ MPI)
    • add Equi-energy sampler (+ MPI)
  • "Exact" methods:
    • Direct MLE
    • Point estimation with EM*
  • Approximate methods (Variational Inference).

*for EM estimation, need to pretend \\( \pi \\) is a latent variable.

Approximate methods

  • Variational inference
    • Approximate the distribution of missing data.
  • Monte Carlo EM
    • Use MCMC to estimate quantities in the E step.
  • Simulated Annealing - not focusing
    • An iterative maximization algorithm.

+ curvature at the mode

Why approximate methods are dangerous

  • Correlation is often key.
  • Ease of implementation versus statistical performance tradeoff
  • Warning
    • It is easy to "get an answer".
    • Rigorous checks are necessary to see how much performance we have sacrificed for ease of implementation.

Variational inference

Reference: Bishop Chapter 10, pp. 461-517.

Main idea:

  • Assume we have a posterior distribution \\( p( Z \given X) \\)
  • The distribution is complex.
  • We want to approximate it with another distribution that would be close to it, but wouldn't be too hard to work with.

Variational inference: start

Decompose log marginal probability

$$\rr{log} p(X) = \mathcal{L}(q) + \rr{KL}(q || p)$$

with

$$\begin{align}\mathcal{L}(q) &= \int q(Z) \rr{log} \left\{ {p(X, Z) \over q(Z)}\right\} dZ \\ \rr{KL}(q || p) & = - \int q(Z) \rr{log} \left\{ {p(Z \given X) \over q(Z) }\right\}dZ\end{align}$$

Why decompose like this?

  • \\( \mathcal{L}(q) \\) is called the lower bound.
  • Maximizing \\( \mathcal{L}(q) \\) with respect to \\( q(Z) \\) yields \\( q(z) = p(Z \given X) \\), and the KL divergence vanishes.
  • We can also restrict \\( q \\) so that the restricted class is easier to optimize over - this is the main idea in variational inference.

One way to restrict \\( q\\)

  • Use a parametric distribution \\( q(Z \given \omega) \\).
  • Lower bound \\( \mathcal{L}(q) \\) then becomes a function of \\( \omega \\).
  • Can then use standard optimization techniques to find optimal values for the parameters \\( \omega \\).

Factorized restriction on \\( q \\)

  • It is very common to assume parameter independence.
  • Partition the elements of \\( Z \\) into \\( M \\) disjoint groups that we denote with \\( Z_i \\), \\( i = 1, \ldots, M \\).
  • Then impose parameter independence restriction on \\( q \\): $$q(Z) = \prod_{i=1}^M q_i(Z)$$

Optimal factorized \\( q \\)

Denote \\( q_j(Z_j)\\) with \\( q_j\\).

$$\begin{align} \mathcal{L}(q) = & \int \prod_i q_i \left\{ \rr{log} p(X, Z) - \sum_i \rr{log} q_i \right\} dZ \\ = & \int q_j \left\{ \rr{log} p(X,Z) \prod_{i \neq j} q_i dZ_i\right\} dZ_j \\ & - \int q_j \rr{log} q_j dZ_j + \rr{const} = \ldots\end{align}$$

$$\begin{align} \ldots = & \int q_j \rr{log} \tilde{p} (X, Z_j) dZ_j \\ & - \int q_j \rr{log} q_j d Z_j + \rr{const}\end{align}$$

Define \\( \tilde{p} \\) by the relation

$$\rr{log} \tilde{p}(X, Z_j) = E_{i\neq j} \left[ \rr{log} p(X, Z) \right]$$

where \\( E_{i\neq j} \left[ \ldots \right]\\) is expectation over all \\( q \\) with \\( i \neq j \\): $$E_{i\neq j} \left[ \rr{log} p(X, Z) \right] = \rr{log} \int p(X,Z) \prod_{i \neq j} q_i dZ_i$$

Optimal \\( q \\)

  • Keeping \\( \{ q_{i \neq j} \} \\) fixed, \\( \mathcal{L}(q) \\) is the negative KL divergence between \\( q_j(Z_j) \\) and \\( \tilde{p}(X, Z_j) \\).
  • The minimum of the KL divergence is at \\( q_j^*(Z_j) = \tilde{p}(X, Z_j) \\): $$\rr{log} q_j^*(Z_j) = E_{i\neq j} \left[ \rr{log} p(X, Z)\right] + \rr{const}$$ $$q_j^*(Z_j) = { \rr{exp} \left(E_{i\neq j} \left[ \rr{log} p(X, Z)\right]\right) \over \int \rr{exp}\left(E_{i\neq j} \left[ \rr{log} p(X, Z)\right]\right) dZ_j}$$

Example

Approximate a correlated bivariate Normal*:

$$ Z \sim \rr{Normal} \left( \left( \begin{array} 0\mu_1 \\ \mu_2\end{array}\right), \left( \begin{array} 0\Lambda_{11} & \Lambda_{12} \\ \Lambda_{21} & \Lambda_{22}\end{array}\right) \right) $$

We can use the general identity to obtain optimal \\( q \\).

*for simplicity, there is no data to condition on.

Example: optimal \\( q\\)

$$\begin{align} \rr{log} & q_1^*(z_1) = E_{z_2} \left[ \rr{log} p(z)\right] + \rr{const}\\ = & E_{z_2} \left[ - {1 \over 2} (z_1 - \mu_1)^2 \Lambda_{11}\right. \\ &\left.- (z_1 - \mu_1) \Lambda_{12}(z_2 - \mu_2)\right] + \rr{const} \\ = & - {1 \over 2} z_1^2 \Lambda_{11} + z_1 \mu_1 \Lambda_{11} - z_1 \Lambda_{12} \left( E[z_2] - \mu_2)\right) \\ & + \rr{const}\end{align}$$

This means \\( q_1^* \\) is Normal (quadratic log density).

Explicit form for \\( q^* \\)

$$q_1^* (z_1) = \rr{Normal}(z_1 \given m_1, \Lambda_{11}^{-1})$$

with \\( m_1 = \mu_1 - \Lambda_{11}^{-1} \Lambda_{12}( E [z_2] - \mu_2) \\).

By symmetry,

$$q_2^* (z_2) = \rr{Normal}(z_2 \given m_2, \Lambda_{22}^{-1})$$

with \\( m_2 = \mu_2 - \Lambda_{2}^{-1} \Lambda_{21}( E [z_1] - \mu_1) \\).

Solutions are coupled. In general, need iteration.

In this case, \\( E[z_1] = \mu_1\\) and \\( E[z_2] = \mu_2\\) work exactly.

Important questions

  • Variational inference seems computing-friendly, but what's lost in performance?
  • Is it equivalent to making simpler models?

Document authorship example

Treat \\( \pi \\) as missing data. For document \\( d \\):

$$\begin{align}L(\alpha, \mu \given y, \pi) \propto & \prod_{w=1}^W\left[\left(\sum_k\mu_{kw}\pi_{dk}\right)^{y_{dw}} \right. \\ & \left.\cdot e^{-\left(\sum_k\mu_{kw}\pi_{dk}\right)}\right] \cdot \prod_{k=1}^K \pi_{dk}^{\alpha_k - 1}\end{align}$$

Different methods for the problem

  • MCMC
    • Metropolis-Hastings - could be very slow
    • HMC - need to transform the space for \\( \pi \\) and \\( \alpha \\)
    • Parallel Tempering/EE Sampler - high implementation cost, many likelihood evaluations needed
  • Point estimation
    • EM - \\( p( \pi \given y) \\) is analytically intractable
    • mcEM - could use HMC for the MC step
    • vEM - Variational inference adapration for EM

What are we approximating here?

  • The variational methods are applied to \\( p (\pi \given y) \\)
  • Would it be easy to reformulate the model to "build in" this approximation?

A variational solution

  • We need more involved methods: there are both parameters, and missing data.
  • We can end up having an interesting approximate solution - an MLE, but with respect to a different distribution.

Introducing vEM

  • Approximate the E step:
    • Instead of calculating \\( p( Z \given X) \\) in the E step, maximize the lower bound with respect to a family of approximating distributions \\( q(Z) \\).

Announcements

  • T-shirt comp winner!
  • Balancing work for pset5 and final project.

Resources

Final slide

  • Next lecture: Variational EM and mcEM.