変分推論法
目次
変分推論法の概要
ここでは,変分推論法の基本となる考え方を紹介する.
導入
変分推論法とは,主にBayes推定において,複雑な事後分布を,より単純な分布で近似する最適化ベースの学習方法 を指す.具体的にどのような分布で近似するかは後に回して,ここでは,どのように近似するかを述べ,最適化問題として定式化しておこう.変分推論法の包括的な説明としては,Bealの博士論文1やBleiのレビュー論文2が参考になりそう.この記事の実験で使ったJuliaプログラムはこちら.
問題設定
いま,手元には$N$次元データからなる有限集合$\mathcal{D}$があるとしよう.統計家は,データ点$x_1, \dots, x_{\#\mathcal{D}}\in\mathcal{D}$を,未知の真の分布$\mathcal{Q}$に従う独立な確率変数列$X_1, \dots, X_{\#\mathcal{D}}$の一組の実現値であると想定し,真の分布を推測するための確率モデル$\mathcal{P}_\theta$を作る.そして,予測分布$\mathcal{P}^* $を計算し,真の分布$\mathcal{Q}$はおおよそ${\mathcal{P}}^* $であろうと推測する.ただし,$\theta\in\Theta$は未知のパラメータであり,データと確率モデルから推定する必要がある.Bayes推定を用いる場合には,確率モデル$p(\cdot\mid\theta)$と,既知のパラメータ$\lambda\in\Lambda$を持つ事前分布$\pi_{\lambda}$を作り,次式で定義される事後分布を計算する: \begin{equation} \pi^*_{\lambda} (\theta\mid\mathcal{D}) = \frac{1}{\mathcal{M}_\lambda}\left[\prod_{x\in\mathcal{D}}p(x\mid\theta)\right]\pi_{\lambda}(\theta). \end{equation} ただし,$\mathcal{M}_{\lambda}$は周辺尤度と呼ばれ, \begin{equation} \mathcal{M}_{\lambda} = \int \left[\prod_{x\in\mathcal{D}}p(x\mid\theta)\right]\pi_{\lambda}(\theta)\mathrm{d}{\theta} \end{equation} で定義される.この事後分布に対して,(Bayes事後)予測分布を \begin{equation} {p_{\lambda}}^* (x) = \int p(x\mid\theta)\pi_{\lambda}^*(\theta\mid\mathcal{D})\mathrm{d}\theta \end{equation} と定義する.この記事では,事後分布の解析的な計算が難しい場合を考えよう.
基本的な考え方
経験Bayes法との組み合わせ
平均場近似による方法
ここでは,近似事後分布の族の決め方として,平均場近似を用いた方法を説明する.
方針
平均場近似による方法とは,近似事後分布を独立な分布から構成する方法を指す.いま,未知パラメータ$\theta$は既知の定数$K\geq 2$に対して, \begin{equation} \theta = (\theta_{(1)}, \dots, \theta_{(K)}) \end{equation} のように$K$個のブロックに分かれているとする.ここで,近似事後分布の族として,次の集合を考えよう: \begin{equation} \mathcal{R}_{\mathrm{MF}} = \left\{ r\mid r(\theta\mid\mathcal{D}) = \prod_{k=1}^K r^{(k)}(\theta_{(k)}\mid\mathcal{D}), \quad \text{各$r^{(k)}$は確率(密度)関数}\right\}. \end{equation} ここで,独立性のみを課している ことに注意しよう.各確率(密度)関数が何であるかは問わず,とにかく独立性のみを課すのである.実は,この分布族の中で最適な分布は,より具体的に表すことができる.まずはELBOを計算してみよう.すると,$l=1, \dots, K$に対して,次のような形になる: \begin{equation} \mathcal{L}_{\lambda}[r] = -D_{\mathrm{KL}}\left[r^{(l)}\Vert \frac{1}{Z_{\lambda}}\exp\left(\mathbb{E}_{k\neq l}\left[\log\left(\pi_{\lambda}(\theta)\prod_{x\in\mathcal{D}}p(x\mid\theta)\right)\right]\right)\right] + (\text{$r^{(l)}$に依存しない項}). \end{equation} ただし,$Z_{\lambda}$は正規化定数であり,期待値は \begin{equation} \mathbb{E}_{k\neq l}\left[f(\theta)\right] = \int f(\theta)\left(\prod_{k\neq l} r^{(k)}(\theta_{(k)}\mid\mathcal{D})\right) \mathrm{d}{\theta_{(1)}\cdots\theta_{(l-1)}\theta_{(l+1)}\cdots\theta_{(K)}} \end{equation} のようにとる.KLダイバージェンスの最小値を考えれば,$k\neq l$に対する各因子$r^{(k)}$が既知の下で,最適な近似事後分布は \begin{equation} \hat{r}_{\lambda}^{(l)} \propto \exp\left(\mathbb{E}_{k\neq l}\left[\log\left(\pi_{\lambda}(\theta)\prod_{x\in\mathcal{D}}p(x\mid\theta)\right)\right]\right) \end{equation} として得られる.しかし,ここで次のような問題が生じる.
- 各因子は未知であるから最適な分布は求まらないのでは?
- 最適な分布の確率(密度)関数が分かってもその分布の正体は分からないのでは?
例1:単純な例
正規分布のパラメータに関して,Bayes推定してみよう.次のようなモデルを考える: \begin{equation} x\mid \mu, \nu \sim \mathrm{N}(\mu, \nu^{-1}), \quad \mu\sim\mathrm{N}(0,1), \quad \nu\sim\mathrm{Gamma}(1,\beta). \end{equation} ただし,$\beta$は既知の定数である.真の事後分布を計算すると, \begin{equation} \pi_\beta^* (\mu, \nu\mid \mathcal{D}) \propto \nu^{\frac{\#\mathcal{D}}{2}}\exp\left[-\frac{1}{2}\mu^2-\frac{\nu}{2}\sum_{x\in\mathcal{D}}(x-\mu)^2-\beta\nu \right] \end{equation} となる.これに対し近似事後分布の族を,平均場近似を用いて次のように採ろう: \begin{equation} \mathcal{R}_{\mathrm{MF}} = \left\{r\mid r(\mu, \nu\mid\mathcal{D}) = r^{(1)}(\mu\mid\mathcal{D})r^{(2)}(\nu\mid\mathcal{D}), \text{各$r^{(k)}$は確率密度関数}\right\}. \end{equation} 最適な近似事後分布を(頑張って)計算すれば,次のようになるはず: \begin{align} \hat{r}^{(1)}_{\beta}(\mu\mid\mathcal{D}) &\propto \exp\left[-\frac{1}{2}\mu^2 - \frac{1}{2}\mathbb{E}_\nu[\nu]\sum_{x\in\mathcal{D}}(x-\mu)^2\right], \\ \hat{r}^{(2)}_{\beta}(\nu\mid\mathcal{D}) &\propto \exp\left[\frac{\#\mathcal{D}}{2}\log\nu - \beta\nu - \frac{\nu}{2}\sum_{x\in\mathcal{D}}\mathbb{E}_\mu[(x-\mu)^2] \right]. \end{align} ただし,期待値$\mathbb{E}_\mu$は近似事後分布$\hat{r}^{(1)}_\beta$に関して,期待値$\mathbb{E}_{\nu}$は近似事後分布$\hat{r}^{(2)}_\beta$に関してとるものとする.上の式から,近似事後分布の正体がそれぞれ正規分布とガンマ分布であることが分かる.そこで,$\hat{r}^{(1)}_\beta$と$\hat{r}^{(2)}_\beta$はそれぞれ, \begin{equation} \mathrm{N}(\hat{\mu}, \hat{\nu}^{-1}), \quad \mathrm{Gamma}(\hat{\alpha}, \hat{\beta}) \end{equation} の確率密度関数であるとしよう.すると,先程の最適な分布の式から,次式が得られる: \begin{align} \hat{\nu} &= 1 + \frac{\hat{\alpha}\#\mathcal{D}}{\hat{\beta}}\\ \hat{\mu} &= \frac{\hat{\alpha}}{\hat{\beta}\hat{\nu}}\left(\sum_{x\in\mathcal{D}}x\right)\\ \hat{\alpha} &= 1 + \frac{\#\mathcal{D}}{2}\\ \hat{\beta} &= \beta + \frac{1}{2}\left(\sum_{x\in\mathcal{D}}x^2\right)-\hat{\mu}\left(\sum_{x\in\mathcal{D}}x\right) + \frac{\#\mathcal{D}}{2}\hat{\mu}^2 + \frac{\#\mathcal{D}}{2\hat{\nu}}. \end{align} 適当な初期値から始めて順に更新していけば良いだろう.ELBOは次式のようになる: \begin{align} \mathcal{L}_{\beta}[\hat{r}_{\beta}] &= -\frac{1}{2}\log\hat{\nu} + \frac{1}{2} - \hat{\alpha}\log\hat{\beta} + \log\Gamma (\hat{\alpha}) - (\hat{\alpha}-1)(\psi(\hat{\alpha}) - \log\hat{\beta}) + \hat{\alpha}+ \frac{\#\mathcal{D}}{2}\left[\psi(\hat{\alpha}) - \log(2\pi\hat{\beta})\right]\\ &-\frac{\hat{\alpha}}{2\hat{\beta}}\left[\left(\sum_{x\in\mathcal{D}}x^2\right)-2\hat{\mu}\left(\sum_{x\in\mathcal{D}}x\right)+ \#\mathcal{D}\hat{\mu}^2 + \frac{\#\mathcal{D}}{\hat{\nu}}\right]-\frac{1}{2}\left(\hat{\mu}^2 + \frac{1}{\hat{\nu}}\right) + \log{\beta} - \frac{\hat{\alpha}\beta}{\hat{\beta}}. \end{align} ただし,$\Gamma$はガンマ関数,$\psi$はディガンマ関数である.
以下に数値実験結果を示す.人工的なデータを発生させた真の分布は混合正規分布 \begin{equation} \frac{1}{2}\mathrm{N}(-1.5, 1.5^2) + \frac{1}{2}\mathrm{N}(1.5, 1.5^2) \end{equation} とし,データサイズは$\#\mathcal{D}=30$とした.左下図の赤実線が真の分布,ヒストグラムが生成したデータである.真の分布から少しだけ左にずれている.右下図にはモデルの尤度関数を示した.
Gibbs samplerと変分推論の$2$つの手法で推定した.左下図の緑の点が真の事後分布からのサンプル,ヒートマップが近似事後分布の確率密度関数である.両者は重なっており,ほぼ識別できない.右下図の緑破線がGibbs samplerによる予測分布で,黄一点鎖線が変分推論法による予測分布である.こちらの図でも,推測方法による差はほとんど観察できない.また,データの偏りに影響されて,真の分布からは平均が少しずれている.
勾配計算による方法
前節の平均場近似では,更新式を常に導出できるとは限らない.ここでは,近似事後分布の形を予め指定し,勾配計算により最適な分布を探索する方法を説明する.
方針
この方法では,正規分布やガンマ分布のように,近似分布族の分布形を具体的に指定し,パラメータを勾配計算により調整する.例えば, \begin{equation} \mathcal{R}_{\mathrm{hoge}} = \left\{r_{\eta} \mid \text{$r_{\eta}$はhoge分布の確率(密度)関数}, \eta\in H\right\} \end{equation} と決め,近似事後分布のパラメータ勾配方向へと反復的に更新していく.ELBOが$H$上の関数$\mathcal{L}(\eta)$となることに注意しよう.ELBOを最大化したいのだから,勾配方向へ更新 するのが良いだろう: \begin{equation} \eta^{(n+1)} = \eta^{(n)} + \alpha_n \nabla_\eta \mathcal{L}(\eta^{(n)}). \end{equation} データサイズが大きい場合には,ミニバッチを利用するのも一案である.なお,上の更新式が解析的に求められない場合には,近似事後分布からのサンプルを用いて,積分項をモンテカルロ近似する必要がある.しかし,ここで問題が生じる.(微分と積分の順序交換を認めたとして,)積分項の勾配は \begin{equation} \nabla_\eta \int \left[\log p(x\mid\theta)\right]r_{\eta}(\theta\mid\mathcal{D})\mathrm{d}\theta = \int \left[ \log p(x\mid\theta)\right]\nabla_\eta r_{\eta}(\theta\mid\mathcal{D})\mathrm{d}\theta \end{equation} であるが,右辺から確率密度関数が消えてしまうため,このままではモンテカルロ近似ができない.こうした場合の対処法として,reparametrization trickが知られる.いま,近似事後分布に従う確率変数$\theta$に対して,近似事後分布のパラメータ$\eta$に依存しない分布$F$と,$\eta$に依存する(微分可能な)変換$g_\eta$が存在して, \begin{equation} \theta = g_\eta(\phi), \quad \phi\sim F \end{equation} と変換できるとしよう.このとき,分布$F$の密度関数を$f$として, \begin{equation} \int \left[\log p(x\mid\theta)\right] r_{\eta}(\theta\mid\mathcal{D}) \mathrm{d}\theta = \int \left[\log p(x\mid g_\eta(\phi))\right] f(\phi) \mathrm{d}\phi \end{equation} であるから,積分項の勾配の,$F$からのサンプル$\mathcal{S}$を利用したモンテカルロ近似 \begin{equation} \nabla_\eta\int \left[\log p(x\mid\theta)\right] r_{\eta}(\theta\mid\mathcal{D}) \mathrm{d}\theta \simeq \frac{1}{\#\mathcal{S}}\sum_{\phi\in\mathcal{S}}\nabla_\eta \log p(x\mid g_\eta(\phi)) \end{equation} が得られる.まとめると, \begin{equation} \hat{\mathcal{L}}_\lambda(\eta) = \sum_{x\in\mathcal{D}}\left[\frac{1}{\#\mathcal{S}}\sum_{\phi\in\mathcal{S}}\nabla_\eta \log p(x\mid g_\eta(\phi)) \right] - D_{\mathrm{KL}}[r_{\eta} \Vert \pi_{\lambda}] \end{equation} をELBOの勾配の近似値として採用する.また,制約なし最適化にするために変数変換などの工夫が必要なこともある.まあ,細々とした注意点は次の例で確認しよう.
例2:ロジスティック回帰
次のようなロジスティックモデルを考えよう: \begin{equation} y\mid x, w \sim \mathrm{Bernoulli}(\sigma (w_1 + w_2 x)), \quad w = (w_1, w_2). \end{equation} ただし,重み$w$の事前分布を標準正規分布とする.手元の入出力の組からなるデータ$\mathcal{D}$から事後分布と予測分布を計算しよう.近似事後分布の族としては,次のような正規分布族が適当だろう: \begin{equation} \mathcal{R}_{\mathrm{N}} = \left\{ r_{\eta} \mid \text{$r_\eta$は$\mathrm{N}(m_\eta, \mathrm{diag}(s_\eta\odot s_\eta))$の確率密度関数}, \eta = [m, \log s ]^{\top}\in\mathbb{R}^4 \right\}. \end{equation} ただし,ベクトルに対する対数は成分ごとにとるものと約束する.標準偏差$s$に対して対数を作用させることで,制約なし最適化にしていることに注意しよう.以上の設定の下で,ELBOは次式で与えられる: \begin{equation} \mathcal{L}(\eta) = \sum_{(x,y)\in\mathcal{D}}\int \left[\log p(y\mid x, w)\right]r_\eta(w\mid\mathcal{D})\mathrm{d}w - \left[\frac{1}{2}\left(\| m\|^2 + \|s \|^2 \right) - \sum_{j=1}^{2}\log s |_{j} - \frac{1}{2}\right]. \end{equation} 上式の積分項の勾配の近似を考えよう.近似事後分布に従う確率変数$w$は \begin{equation} w = g_\eta(\phi) = m + s \odot \phi, \quad \phi\sim\mathrm{N}(0,\mathit{I}) \end{equation} と書ける.従って,標準正規分布からのサンプル$\mathcal{S}$を利用して, \begin{equation} \nabla_\eta \int\left[\log p(y\mid x, w)\right]\mathrm{d}w \simeq \frac{1}{\#\mathcal{S}}\sum_{\phi\in\mathcal{S}} \nabla_\eta\log p(y\mid x, g_\eta(\phi)) \end{equation} と近似すれば良い.
以下に数値実験結果を示す.真の分布は真値を$w_1=-4, w_2=4$とするロジスティックモデルとした.入力$x$は乱数で作り,データサイズは$\#\mathcal{D}=10$とした.左下図の赤実線が真の成功確率,青点が生成したデータである.右下図にはモデルの尤度関数を示した.
HMCと変分推論の$2$つの手法で推定した.左下図の緑の点が真の事後分布からのサンプル,ヒートマップが近似事後分布の確率密度関数である.両者はおおよそ重なっている.右下図の緑破線がHMCによる予測成功確率で,黄一点鎖線が変分推論法による予測成功確率である.こちらの図でも,推測方法による差はほとんど観察できない.