Elastic Weight Consolidation

Real world data is often not available all at once, but rather sequentially. Therefore it is usually desired to have algorithms to do sequential learning, or continual learning.

Formally, we have a collection of datasets $\mathcal{D}_1,\mathcal{D}_2,\cdots,\mathcal{D}_n$ and we want to learn how to reasonably achieve a low validation error on all of these datasets assuming that we see these datasets in the order $\mathcal{D}_1,\mathcal{D}_2,\cdots,\mathcal{D}_n$.

The simplest alternative to sequential learning is joint training, whereby after seeing datasets $\mathcal{D}_{1:i} := \mathcal{D}_1,\mathcal{D}_2,\cdots,\mathcal{D}_i$ we learn a neural network configuration by minibatch gradient descent algorithm, where the minibatch is composed of data from $\mathcal{D} _{1:i}$. However, it is easy to see that as the number of datasets seen increases, the variance of the minibatch sample also increases and therefore the convergence takes a longer time. So, it makes sense to do sequential learning.

Catastrophic forgetting: Sequential learning solutions typically want to avoid catastrophic forgetting, that is, forgetting previous tasks as new tasks are learned. There are three broad approaches in continual learning which address catastrophic forgetting in different ways.

Bayesian formulation: Elastic weight consolidation, a regularization approach formulates the continual learning problem from a Bayesian perspective. Let $\boldsymbol{\theta} _{1:i}^*$ refer to the weight configuration achieved at the end of the task $i$, which is expected to solve all the datasets $\mathcal{D} _{1:i}$. Let $p(\boldsymbol{\theta})$ refer to a weight distribution, the MLE of which is typically found through gradient descent. From a Bayesian perspective, learning the second task is akin to maximizing the posterior $p(\boldsymbol{\theta}|\mathcal{D}_1, \mathcal{D}_2)$ from a prior $p(\boldsymbol{\theta} | \mathcal{D}_1)$.

\[ \begin{align*} \max_{\boldsymbol{\theta}} p(\boldsymbol{\theta}|\mathcal{D}_1, \mathcal{D}_2) &= \max_{\boldsymbol{\theta}} \log p(\boldsymbol{\theta}|\mathcal{D}_1, \mathcal{D}_2) \\ &= \max_{\boldsymbol{\theta}} [ \log p(\mathcal{D}_1, \mathcal{D}_2|\boldsymbol{\theta}) + \log p(\boldsymbol{\theta}) ] \end{align*} \]

More generally,

\[ \max_{\boldsymbol{\theta}} \log p(\boldsymbol{\theta}|\mathcal{D}_{1:i}) = \max_{\boldsymbol{\theta}} [\log p(\mathcal{D}_{1:i}|\boldsymbol{\theta}) + \log p(\boldsymbol{\boldsymbol{\theta}})] \]

The likelihood may be decomposed as:

\[ \log p(\mathcal{D}_{1:i}|\boldsymbol{\theta}) = \sum_{j=1}^i \log p(\mathcal{D}_{j}|\boldsymbol{\theta}) \]

And with this decomposition, the general posterior would be:

\[ \begin{align*} \max_{\boldsymbol{\theta}} \log p(\boldsymbol{\theta}|\mathcal{D}_{1:i}) &= \max_{\boldsymbol{\theta}} [\log p(\mathcal{D}_{1:i}|\boldsymbol{\theta}) + \log p(\boldsymbol{\boldsymbol{\theta}})]\\ &= \max_{\boldsymbol{\theta}} [\log p(\mathcal{D}_{i}|\boldsymbol{\theta}) + \log p(\boldsymbol{\boldsymbol{\theta}}|\mathcal{D}_{1:i-1})] \end{align*} \]

The posterior maximization therefore simultaneously depends on the likelihood maximization for the new dataset, and the posterior maximization on the previous dataset, the solution of which is $\boldsymbol{\theta}_ {1:i-1}^ * $. Such an objective can be minimized by adding a regularization loss, which prevents $\boldsymbol{\theta}_ {1:i}^*$ from veering too far away from $\boldsymbol{\theta}_ {1:i-1}^ * $. Since this regularization loss should preserve closeness to the previous solution, we can use the KL-divergence between $p(\boldsymbol{\theta}|\mathcal{D}_ {1:i})$ and $p(\boldsymbol{\theta}|\mathcal{D}_ {1:i-1})$ as the regularization loss.

In practice, EWC proposes using the second order approximation of this KL-divergence:

\[ KL(p(\boldsymbol{\theta}|\mathcal{D}_ {1:i}) || p(\boldsymbol{\theta}|\mathcal{D}_ {1:i-1})) \approx \frac{1}{2} \sum_j F_{jj} (\theta_j - \theta_{1:i-1,j}^*)^2 \]

Here, $F$ refers to the empirical Fisher matrix, only the diagonal of which is used in the approximation.

References

  1. Kirkpatrick, James, et al. "Overcoming catastrophic forgetting in neural networks." Proceedings of the national academy of sciences 114.13 (2017): 3521-3526.