Approximate methods: Variational Inference

Stat 221, Lecture 21

@snesterko

• Sam Wong on protein folding and PT/EE sampler
• Justify the need for shortcuts instead of MCMC
• Point estimation as an approximate method
• Connection to computing
• Variational Inference - general framework
• Introduce approximate methods
• vEM
• mcEM

Example: document authorship

• Document authorship \$$\pi_d \sim \rr{Dirichlet}_K(\vec{\alpha},1) \$$.
• Data \$$y_{dw} \given \pi_d \sim \rr{Poisson}(\sum_k\mu_{kw}\pi_{dk})\$$.
• Author \$$k\$$, document \$$d \$$, word \$$w \$$.
• \$$\pi_{dk} \$$ represents the amount of contribution of author \$$k \$$ do document \$$d\$$.

Likelihood

For document \$$d \$$:

\begin{align}L(\alpha, \mu, \pi \given y) \propto & \prod_{w=1}^W\left[\left(\sum_k\mu_{kw}\pi_{dk}\right)^{y_{dw}} \right. \\ & \left.\cdot e^{-\left(\sum_k\mu_{kw}\pi_{dk}\right)}\right] \cdot \prod_{k=1}^K \pi_{dk}^{\alpha_k - 1}\end{align}

How do we perform inference?

A simple example

Approximate a correlated bivariate Normal*:

$$Z \sim \rr{Normal} \left( \left( \begin{array} 0\mu_1 \\ \mu_2\end{array}\right), \left( \begin{array} 0\Lambda_{11} & \Lambda_{12} \\ \Lambda_{21} & \Lambda_{22}\end{array}\right) \right)$$

*A setting where we may not care about the off-diagonal terms**

**Why not simply set them to zero in our model?

Options

• Full MCMC:
• Metropolis or MH
• HMC on transformed space
• add Parallel Tempering (+ MPI)
• add Equi-energy sampler (+ MPI)
• "Exact" methods:
• Direct MLE
• Point estimation with EM*
• Approximate methods (Variational Inference).

*for EM estimation, need to pretend \$$\pi \$$ is a latent variable.

Approximate methods

• Variational inference
• Approximate the distribution of missing data.
• Monte Carlo EM
• Use MCMC to estimate quantities in the E step.
• Simulated Annealing - not focusing
• An iterative maximization algorithm.

+ curvature at the mode

Why approximate methods are dangerous

• Correlation is often key.
• Ease of implementation versus statistical performance tradeoff
• Warning
• It is easy to "get an answer".
• Rigorous checks are necessary to see how much performance we have sacrificed for ease of implementation.

Variational inference

Reference: Bishop Chapter 10, pp. 461-517.

Main idea:

• Assume we have a posterior distribution \$$p( Z \given X) \$$
• The distribution is complex.
• We want to approximate it with another distribution that would be close to it, but wouldn't be too hard to work with.

Variational inference: start

Decompose log marginal probability

$$\rr{log} p(X) = \mathcal{L}(q) + \rr{KL}(q || p)$$

with

\begin{align}\mathcal{L}(q) &= \int q(Z) \rr{log} \left\{ {p(X, Z) \over q(Z)}\right\} dZ \\ \rr{KL}(q || p) & = - \int q(Z) \rr{log} \left\{ {p(Z \given X) \over q(Z) }\right\}dZ\end{align}

Why decompose like this?

• \$$\mathcal{L}(q) \$$ is called the lower bound.
• Maximizing \$$\mathcal{L}(q) \$$ with respect to \$$q(Z) \$$ yields \$$q(z) = p(Z \given X) \$$, and the KL divergence vanishes.
• We can also restrict \$$q \$$ so that the restricted class is easier to optimize over - this is the main idea in variational inference.

One way to restrict \$$q\$$

• Use a parametric distribution \$$q(Z \given \omega) \$$.
• Lower bound \$$\mathcal{L}(q) \$$ then becomes a function of \$$\omega \$$.
• Can then use standard optimization techniques to find optimal values for the parameters \$$\omega \$$.

Factorized restriction on \$$q \$$

• It is very common to assume parameter independence.
• Partition the elements of \$$Z \$$ into \$$M \$$ disjoint groups that we denote with \$$Z_i \$$, \$$i = 1, \ldots, M \$$.
• Then impose parameter independence restriction on \$$q \$$: $$q(Z) = \prod_{i=1}^M q_i(Z)$$

Optimal factorized \$$q \$$

Denote \$$q_j(Z_j)\$$ with \$$q_j\$$.

\begin{align} \mathcal{L}(q) = & \int \prod_i q_i \left\{ \rr{log} p(X, Z) - \sum_i \rr{log} q_i \right\} dZ \\ = & \int q_j \left\{ \rr{log} p(X,Z) \prod_{i \neq j} q_i dZ_i\right\} dZ_j \\ & - \int q_j \rr{log} q_j dZ_j + \rr{const} = \ldots\end{align}

\begin{align} \ldots = & \int q_j \rr{log} \tilde{p} (X, Z_j) dZ_j \\ & - \int q_j \rr{log} q_j d Z_j + \rr{const}\end{align}

Define \$$\tilde{p} \$$ by the relation

$$\rr{log} \tilde{p}(X, Z_j) = E_{i\neq j} \left[ \rr{log} p(X, Z) \right]$$

where \$$E_{i\neq j} \left[ \ldots \right]\$$ is expectation over all \$$q \$$ with \$$i \neq j \$$: $$E_{i\neq j} \left[ \rr{log} p(X, Z) \right] = \rr{log} \int p(X,Z) \prod_{i \neq j} q_i dZ_i$$

Optimal \$$q \$$

• Keeping \$$\{ q_{i \neq j} \} \$$ fixed, \$$\mathcal{L}(q) \$$ is the negative KL divergence between \$$q_j(Z_j) \$$ and \$$\tilde{p}(X, Z_j) \$$.
• The minimum of the KL divergence is at \$$q_j^*(Z_j) = \tilde{p}(X, Z_j) \$$: $$\rr{log} q_j^*(Z_j) = E_{i\neq j} \left[ \rr{log} p(X, Z)\right] + \rr{const}$$ $$q_j^*(Z_j) = { \rr{exp} \left(E_{i\neq j} \left[ \rr{log} p(X, Z)\right]\right) \over \int \rr{exp}\left(E_{i\neq j} \left[ \rr{log} p(X, Z)\right]\right) dZ_j}$$

Example

Approximate a correlated bivariate Normal*:

$$Z \sim \rr{Normal} \left( \left( \begin{array} 0\mu_1 \\ \mu_2\end{array}\right), \left( \begin{array} 0\Lambda_{11} & \Lambda_{12} \\ \Lambda_{21} & \Lambda_{22}\end{array}\right) \right)$$

We can use the general identity to obtain optimal \$$q \$$.

*for simplicity, there is no data to condition on.

Example: optimal \$$q\$$

\begin{align} \rr{log} & q_1^*(z_1) = E_{z_2} \left[ \rr{log} p(z)\right] + \rr{const}\\ = & E_{z_2} \left[ - {1 \over 2} (z_1 - \mu_1)^2 \Lambda_{11}\right. \\ &\left.- (z_1 - \mu_1) \Lambda_{12}(z_2 - \mu_2)\right] + \rr{const} \\ = & - {1 \over 2} z_1^2 \Lambda_{11} + z_1 \mu_1 \Lambda_{11} - z_1 \Lambda_{12} \left( E[z_2] - \mu_2)\right) \\ & + \rr{const}\end{align}

This means \$$q_1^* \$$ is Normal (quadratic log density).

Explicit form for \$$q^* \$$

$$q_1^* (z_1) = \rr{Normal}(z_1 \given m_1, \Lambda_{11}^{-1})$$

with \$$m_1 = \mu_1 - \Lambda_{11}^{-1} \Lambda_{12}( E [z_2] - \mu_2) \$$.

By symmetry,

$$q_2^* (z_2) = \rr{Normal}(z_2 \given m_2, \Lambda_{22}^{-1})$$

with \$$m_2 = \mu_2 - \Lambda_{2}^{-1} \Lambda_{21}( E [z_1] - \mu_1) \$$.

Solutions are coupled. In general, need iteration.

In this case, \$$E[z_1] = \mu_1\$$ and \$$E[z_2] = \mu_2\$$ work exactly.

Important questions

• Variational inference seems computing-friendly, but what's lost in performance?
• Is it equivalent to making simpler models?

Document authorship example

Treat \$$\pi \$$ as missing data. For document \$$d \$$:

\begin{align}L(\alpha, \mu \given y, \pi) \propto & \prod_{w=1}^W\left[\left(\sum_k\mu_{kw}\pi_{dk}\right)^{y_{dw}} \right. \\ & \left.\cdot e^{-\left(\sum_k\mu_{kw}\pi_{dk}\right)}\right] \cdot \prod_{k=1}^K \pi_{dk}^{\alpha_k - 1}\end{align}

Different methods for the problem

• MCMC
• Metropolis-Hastings - could be very slow
• HMC - need to transform the space for \$$\pi \$$ and \$$\alpha \$$
• Parallel Tempering/EE Sampler - high implementation cost, many likelihood evaluations needed
• Point estimation
• EM - \$$p( \pi \given y) \$$ is analytically intractable
• mcEM - could use HMC for the MC step
• vEM - Variational inference adapration for EM

What are we approximating here?

• The variational methods are applied to \$$p (\pi \given y) \$$
• Would it be easy to reformulate the model to "build in" this approximation?

A variational solution

• We need more involved methods: there are both parameters, and missing data.
• We can end up having an interesting approximate solution - an MLE, but with respect to a different distribution.

Introducing vEM

• Approximate the E step:
• Instead of calculating \$$p( Z \given X) \$$ in the E step, maximize the lower bound with respect to a family of approximating distributions \$$q(Z) \$$.

Announcements

• T-shirt comp winner!
• Balancing work for pset5 and final project.

Final slide

• Next lecture: Variational EM and mcEM.