Variational EM, Monte Carlo EM

Stat 221, Lecture 22

@snesterko

Roadmap

  • vEM, mcEM by example
  • Performance/usability tradeoffs of these methods

Factorized restriction on \\( q \\)

  • In variational inference, it is common to assume independence in \\(q \\).
  • Partition the elements of \\( Z \\) into \\( M \\) disjoint groups that we denote with \\( Z_i \\), \\( i = 1, \ldots, M \\).
  • Then impose parameter independence restriction on \\( q \\): $$q(Z) = \prod_{i=1}^M q_i(Z)$$

Optimal factorized \\( q \\)

Denote \\( q_j(Z_j)\\) with \\( q_j\\).

$$\begin{align} \mathcal{L}(q) = & \int \prod_i q_i \left\{ \rr{log} p(X, Z) - \sum_i \rr{log} q_i \right\} dZ \\ = & \int q_j \left\{ \rr{log} p(X,Z) \prod_{i \neq j} q_i dZ_i\right\} dZ_j \\ & - \int q_j \rr{log} q_j dZ_j + \rr{const} = \ldots\end{align}$$

$$\begin{align} \ldots = & \int q_j \rr{log} \tilde{p} (X, Z_j) dZ_j \\ & - \int q_j \rr{log} q_j d Z_j + \rr{const}\end{align}$$

Define \\( \tilde{p} \\) by the relation

$$\rr{log} \tilde{p}(X, Z_j) = E_{i\neq j} \left[ \rr{log} p(X, Z) \right]$$

where \\( E_{i\neq j} \left[ \ldots \right]\\) is expectation over all \\( q \\) with \\( i \neq j \\): $$E_{i\neq j} \left[ \rr{log} p(X, Z) \right] = \rr{log} \int p(X,Z) \prod_{i \neq j} q_i dZ_i$$

Optimal \\( q \\)

  • Keeping \\( \{ q_{i \neq j} \} \\) fixed, \\( \mathcal{L}(q) \\) is the negative KL divergence between \\( q_j(Z_j) \\) and \\( \tilde{p}(X, Z_j) \\).
  • The minimum of the KL divergence is at \\( q_j^*(Z_j) = \tilde{p}(X, Z_j) \\): $$\rr{log} q_j^*(Z_j) = E_{i\neq j} \left[ \rr{log} p(X, Z)\right] + \rr{const}$$ $$q_j^*(Z_j) = { \rr{exp} \left(E_{i\neq j} \left[ \rr{log} p(X, Z)\right]\right) \over \int \rr{exp}\left(E_{i\neq j} \left[ \rr{log} p(X, Z)\right]\right) dZ_j}$$

Example

Approximate a correlated bivariate Normal*:

$$ Z \sim \rr{Normal} \left( \left( \begin{array} 0\mu_1 \\ \mu_2\end{array}\right), \left( \begin{array} 0\Lambda_{11} & \Lambda_{12} \\ \Lambda_{21} & \Lambda_{22}\end{array}\right) \right) $$

We can use the general identity to obtain optimal \\( q \\).

*for simplicity, there is no data to condition on.

Example: optimal \\( q\\)

$$\begin{align} \rr{log} & q_1^*(z_1) = E_{z_2} \left[ \rr{log} p(z)\right] + \rr{const}\\ = & E_{z_2} \left[ - {1 \over 2} (z_1 - \mu_1)^2 \Lambda_{11}\right. \\ &\left.- (z_1 - \mu_1) \Lambda_{12}(z_2 - \mu_2)\right] + \rr{const} \\ = & - {1 \over 2} z_1^2 \Lambda_{11} + z_1 \mu_1 \Lambda_{11} - z_1 \Lambda_{12} \left( E[z_2] - \mu_2)\right) \\ & + \rr{const}\end{align}$$

This means \\( q_1^* \\) is Normal (quadratic log density).

Explicit form for \\( q^* \\)

$$q_1^* (z_1) = \rr{Normal}(z_1 \given m_1, \Lambda_{11}^{-1})$$

with \\( m_1 = \mu_1 - \Lambda_{11}^{-1} \Lambda_{12}( E [z_2] - \mu_2) \\).

By symmetry,

$$q_2^* (z_2) = \rr{Normal}(z_2 \given m_2, \Lambda_{22}^{-1})$$

with \\( m_2 = \mu_2 - \Lambda_{2}^{-1} \Lambda_{21}( E [z_1] - \mu_1) \\).

Solutions are coupled. In general, need iteration.

In this case, \\( E[z_1] = \mu_1\\) and \\( E[z_2] = \mu_2\\) work exactly.

Example: document authorship

  • Document authorship \\( \pi_d \sim \rr{Dirichlet}_K(\vec{\alpha},1) \\).
  • Data \\( y_{dw} \given \pi_d \sim \rr{Poisson}(\sum_k\mu_{kw}\pi_{dk})\\).
  • Author \\( k\\), document \\( d \\), word \\( w \\).
  • \\( \pi_{dk} \\) represents the amount of contribution of author \\( k \\) do document \\(d\\).
  • \\( \mu_{kw} \\) represents the style of author \\( k \\) of using word \\(w\\).

Likelihood

For document \\( d \\):

$$\begin{align}L(\alpha, \mu, \pi \given y) \propto & \prod_{w=1}^W\left[\left(\sum_k\mu_{kw}\pi_{dk}\right)^{y_{dw}} \right. \\ & \left.\cdot e^{-\left(\sum_k\mu_{kw}\pi_{dk}\right)}\right] \cdot \prod_{k=1}^K \pi_{dk}^{\alpha_k - 1}\end{align}$$

How do we perform inference?

Options

  • Full MCMC:
    • Metropolis or MH
    • HMC on transformed space
    • add Parallel Tempering (+ MPI)
    • add Equi-energy sampler (+ MPI)
  • "Exact" methods:
    • Direct MLE - problematic
    • Point estimation with EM - problematic*
  • Approximate methods (Variational Inference).

*for EM, need to pretend \\( \pi \\) is a latent variable.

Challenges with MCMC

  • The posterior distributions of \\( \mu \\) and \\( \pi \\) are not available for direct simulations.
  • \\( \pi \\) lives on a simplex.
  • There could be (and probably is) multimodality.
  • The amount of data (word counts) and the amount of documents could be large, which means slow everything (likelihood/posterior, gradient).

Why is this a hard exercise for MCMC?

  • The answer is entangled with the available computing resources and our ability to implement good MCMC sampling algorithms.
  • But let's assume that the resources are scarse, and our ability to figure out an elegant MCMC scheme, or apply PT, or HMC, or the generic EE sampler is not taking us far enough.
  • Is this assumption realistic?

Try exact EM

Treat \\( \pi \\) as missing data. For document \\( d \\):

$$\begin{align}L(\alpha, \mu \given y, \pi) \propto & \prod_{w=1}^W\left[\left(\sum_k\mu_{kw}\pi_{dk}\right)^{y_{dw}} \right. \\ & \left.\cdot e^{-\left(\sum_k\mu_{kw}\pi_{dk}\right)}\right] \cdot \prod_{k=1}^K \pi_{dk}^{\alpha_k - 1}\end{align}$$

What is the distribution of author document contributions given the data, \\( p(\pi \given y) \\)? Hard stop.

If we don't do MCMC, we must

Do one of the following:

  1. Simplify/change the model.
  2. Use approximate methods:
    • Local optimization algorithms for MLE treating \\( \pi \\) as a parameter (risk of Neyman-Scott problem).
    • Use variational inference.

What are we approximating here?

Depending on what we want to do, we could approximate different things:

  • The distribution of document contributions given the data, \\( p (\pi \given y) \\).
  • The distribution of author styles given the data, \\( p ( \mu \given y) \\).
  • Both, \\( p ( \pi, \mu \given y) \\).

A variational solution

  • There are both parameters and missing data.
  • An approximate solution - an MLE, but with respect to an approximate distribution minimizing the KL divergence from the true \\( p(Y_{\rr{mis}} \given \rr{data}) \\).

Introducing vEM

  • Approximate the E step:
    • Instead of calculating \\( p( Z \given X) \\) in the E step, maximize the lower bound with respect to a family of approximating distributions \\( q(Z) \\).

Main idea

  • Replace the \\( Q \\) function from EM with a lower bound $$\mathcal{L} = \rr{E}_qp(y_{\rr{mis}}, y_{\rr{obs}} \given \theta) - \rr{E}_qq$$
  • It is the negative of KL divergence between \\(p(y_{\rr{mis}}, y_{\rr{obs}} \given \theta)\\) and \\( q \\).
  • Maximizing the lower bound will bring us closer to observed-data likelihood.
  • In a conditional maximization, optimize \\( \mathcal{L} \\) with respect to variational parameters of \\( q \\) and the remaining likelihood parameters.

Back to the example

  • Let's approximate \\( p (\pi \given y) \\).
  • We are lucky as other parameters of the model don't have priors, so we can just maximize over them.
  • Let the variational distribution on \\( \pi \\) also be Dirichlet: $$\vec{\pi}_d \mathop{\sim}^{q} \rr{Dirichlet}(c_{d1},\dots, c_{dK})$$

Lower bound: $$\mathcal{L} = \rr{E}_ql = \rr{E}_q\left[\rr{log}p(y \given \rr{params}) - \rr{log}q(\pi)\right]$$

$$ l = \ldots$$

$$\begin{align} & \sum_{d=1}^D \sum_{w=1}^W \left[y_{dw}\rr{log}\left(\sum_{k=1}^K\pi_{dk} \mu_{kw}\right) - \sum_{k=1}^K\pi_{dk} \mu_{kw} \right] \\ & + \sum_{d=1}^D \left[ \sum_{k=1}^K (\alpha_k-1)\rr{log}\pi_{dk} + \rr{log}\Gamma(\sum_{k=1}^K \alpha_k) - \right.\\ & \left.\sum_{k=1}^K \rr{log}\Gamma(\alpha_k) \right] - \sum_{d=1}^D \left[\sum_{k=1}^K (c_{dk} - 1) \rr{log}\pi_{dk} + \right.\\ & \left.\rr{log} \Gamma(\sum_{k=1}^K c_{dk}) - \sum_{k=1}^K \rr{log}\Gamma(c_{dk})\right] \end{align}$$

Now need to take \\( \rr{E}_ql \\)

$$\begin{align}E_q \rr{log}(\pi_{dk}) & = \psi\left(c_{dk}\right) - \psi\left(\sum_{k=1}^K {c_{dk}}\right) \\ E_q(\pi_{dk}) & = \frac{c_{dk}}{\sum_{k=1}^K{c_{dk}}} \end{align}$$

\\( \psi \\) is the digamma function.

But we need to deal with

$$\rr{E}_q\left[\rr{log}\left(\sum_{k=1}^K\pi_{dk} \mu_{kw}\right)\right]$$

  • Can't really use Jensen's inequality to take the log outside the expectation.

By Jensen's inequality

$$\begin{align}E_q \rr{log} \left(\sum_{k=1}^K \pi_{dk} \mu_{kw}\right) & \leq \rr{log} \sum_{k=1}^K \rr{E}_q \pi_{dk} \mu_{kw} \\ & = \rr{log} \sum_{k=1}^K \frac{c_{dk}}{\sum_{k=1}^K c_{dk}} \mu_{kw} \end{align}$$

  • Since \\( \mathcal{L} \\) is a lower bound, we cannot use this expression as a lower-lower bound.
  • Need to compute the needed quantity.

Putting it together

Let's make a few definitions for convenience:

  • \\( A_{dk} = \rr{E}_q \rr{log}(\pi_{dk}) = \psi\left(c_{dk}\right) - \psi\left(\sum_{k=1}^K {c_{dk}}\right) \\)
  • \\( B_{dk} = \rr{E}_q(\pi_{dk}) = \frac{c_{dk}}{\sum_{k=1}^K{c_{dk}}} \\)
  • \\( C_{dw} = E_q \rr{log} \left(\sum_{k=1}^K \pi_{dk} \mu_{kw}\right) \\)

Then the lower bound we will be optimizing over is ...

$$\begin{align} & \sum_{d=1}^D \sum_{w=1}^W \left[y_{dw} C_{dw} - \sum_{k=1}^KB_{dk} \mu_{kw} \right] \\ & + \sum_{d=1}^D \left[ \sum_{k=1}^K (\alpha_k-1) A_{dk} + \rr{log}\Gamma(\sum_{k=1}^K \alpha_k) - \right.\\ & \left.\sum_{k=1}^K \rr{log}\Gamma(\alpha_k) \right] - \sum_{d=1}^D \left[\sum_{k=1}^K (c_{dk} - 1) A_{dk} + \right.\\ & \left.\rr{log} \Gamma(\sum_{k=1}^K c_{dk}) - \sum_{k=1}^K \rr{log}\Gamma(c_{dk})\right] \end{align}$$

The vEM algoritm

  1. Initialize some values for \\( \{\alpha, c, \mu \}\\).
  2. Compute \\( A_{dk}, B_{dk}, C_{dw} \\) given current values of \\( \{\alpha, \mu, c \}\\).
  3. Maximize \\( \mathcal{L} \\) with respect to \\( \{\alpha, \mu, c \} \\).
  4. Repeat 2 and 3 until convergence.

An alternative

  • Monte Carlo EM:
    • Simulate \\( \pi \\) using MCMC (back to the MCMC paradigm?)

Example: log-likelihood

$$\begin{align} l = & \sum_{d=1}^D\sum_{w=1}^W \left[ y_{dw}\left(\rr{log}\sum_{k=1}^K\pi_{dk}\mu_{kw}\right) \right.\\ & - \left.\sum_{k=1}^K\pi_{dk}\mu_{kw}\right] + D\left[\Gamma(K\alpha) - K\Gamma(\alpha)\right] \\ & + \sum_{d=1}^D(\alpha - 1)\sum_{k=1}^K \rr{log}\pi_{dk} \end{align}$$

\\( Q \\) function: \\( \rr{E}_{p( \pi \given y)}l \\)

$$\begin{align} l = & \sum_{d=1}^D\sum_{w=1}^W \left[ y_{dw} \rr{E}_{p( \pi \given y)}\left(\rr{log}\sum_{k=1}^K\pi_{dk}\mu_{kw}\right) \right.\\ & - \left.\sum_{k=1}^K\rr{E}_{p( \pi \given y)}(\pi_{dk})\mu_{kw}\right] + D\left[\Gamma(K\alpha) - \right.\\ & \left. K\Gamma(\alpha)\right] + \sum_{d=1}^D(\alpha - 1)\sum_{k=1}^K \rr{log}\pi_{dk} \end{align}$$

mcEM algorithm

  1. Pick starting values for \\( \{\alpha, \mu \}\\).
  2. E step. Simulate \\(n\\) draws of missing data \\( \pi \\) from the likelihood treating it as an unnormalized posterior density, and approximate the expectations in the \\(Q \\) function with the averages of the corresponding quantities given the draws of \\( \pi \\) and current values of parameters.
  3. Maximize \\( Q \\) with respect to \\( \{\alpha, \mu \}\\).
  4. Iterate between 2 and 3 until convergence.

Computational issues

  • This is still a local maximization algorithm:
    • All considerations of conventional EM apply
    • Only in this case, this is an approximation
  • There are extra parameters to optimize over (the parameters for the variational distribution \\( q \\))

Things to be careful about

  • By using variational methods, we are giving up the second moment
  • Validation is needed to see the extent of performance loss

Parallelization

  • Naive parallelization with random starting points (in order to get more local modes)
  • Within-likelihood computation parallelization (not recommended unless necessary)

Announcements

  • Final T-shirt competition is on!
  • Final projects, pset5

Resources

Final slide

  • Next lecture: Data Augmentation.