14.4. LSTM#
14.4.1. LSTMの仕組み#
Show code cell source
# packageのimport
from typing import Any, Union, Callable, Type, TypeVar
from tqdm.std import trange,tqdm
import numpy as np
import numpy.typing as npt
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
import seaborn as sns
plt.style.use("bmh")
# pytorch関連のimport
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
14.4.1.1. outputゲート(o)#
tanh(ct) の各要素に対して「それらが次時刻の隠れ状態としてどれだけ重要か」ということを調整する。
14.4.1.2. forgetゲート(f)#
記憶セルに対して「何を忘れるか」を明示的に支持する。
14.4.1.3. 新しい記憶セル(g)#
新しく覚えるべき情報を記憶セルに追加する。
14.4.1.4. inputゲート(i)#
新しい記憶セル(g)の各要素が、新たに追加する情報としてどれだけ価値があるかを判断する。このinputゲートによって、何も考えずに新しい情報を追加するのではなく、追加する情報の取捨選択を行う。
「+」ノードの逆伝播は上流から伝わる勾配をそのまま流すだけなので、勾配の変化(劣化)は起きない。残る「×」ノードに関して、これは「行列の積」ではなく「要素ごとの積(アダマール積)」であり、毎時刻、異なるゲート値によって要素毎の積の計算が実施される。ここに勾配消失を起こさない理由がある。
「×」ノードの計算はforgetゲートによってコントロールされている。ここで、forgetノードが「忘れるべき」と判断した記憶セルの要素に対しては、その勾配の要素は小さくなる。一方で、forgetゲートが「忘れてはいけない」と判断した要素に対しては、その勾配の要素は劣化することなく過去方向へ伝わる。そのため、記憶セルの勾配は、(長期にわたって覚えておくべき情報に対しては)勾配消失を起こさずに伝播することが期待できる。
14.4.2. Dropout#
十分以上にパラメータ数の多いニューラルネットワークは,表現能力が豊かすぎるせいで 過学習 が起きることがあります.
過学習:訓練データにフィットしすぎてテストデータでの評価が悪くなること.未知のデータに対応できるようなシンプルな構造を学習したいのに,訓練データを全て表現しきるような細かすぎる構造を獲得してしまった状態.
ニューラルネットワークの訓練において,これを解決するために Dropout法 が使われます.
ドロップアウトはミニバッチごとに一定の確率で選んだニューロンを使用しないようにします。
この確率を ドロップアウト確率 と呼びます。
ドロップアウトではミニバッチごとに一定割合のニューロンを使わないようにすることによって、特定のニューロン(パラメータ)への依存度を減らすことでオーバーフィッティングを減らします。
もう一つの背景として、複数のモデルを構築し各モデルの予測を平均すると、多くの場合で予測精度が上がるというものです。
Dropoutクラスの引数
p (float) – probability of an element to be zeroed. Default: 0.5
inplace (bool) – If set to True, will do this operation in-place. Default: False
>>> m = nn.Dropout(p=0.2)
>>> input = torch.randn(20, 16)
>>> output = m(input)
これをnn.Linearの後に挟むことで,過学習を抑制することができます.ただし,pの値はハイパーパラメータです.
14.4.3. Batch Normalization#
こちらも過学習を抑制する効果のあるトリックです.nn.Linearの後ろにつけます.
Deep LearningにおけるBatch Normalizationの理解メモと、実際にその効果を見てみる
ネットワークのある層毎に、\(\boldsymbol{H}\)という行列を定義します。これは、各行がminibatchの1つのデータ、各列 がそれぞれのactivationとなるような值をとる行列です。 batch sizeが128, hidden unitの数を 256 とすると、\(\boldsymbol{H}\) は128×256の行列になります。 これを正規化するために、次のように \(\boldsymbol{H}^{\prime}\) と置き直すのがbatch normalizationです。
ここで、 \(\boldsymbol{\mu}\) とのは、その層での各ユニットの平均、及び標準偏差のベクトルを表しています。 上記の式は行列とべクトルの演算になってしまっていますが, \(H\) 各行にそれぞれのべクトルかかるよう にbroadcastingしています。 即ち、 \(\mathrm{H}_{\mathrm{i}, \mathrm{j}}\) の正規化は、 \(f, g\) を使って計算されます。 トレーニング時の \(\mu , \sigma\) は次で与えられます。
\(\delta\)は、標準偏差が0になってしまうのを防ぐ、108 \(10^8\) ような小さな值です。 推定時は、minibatchなどはないため、 \(\boldsymbol{\mu} , \boldsymbol{\sigma}\) は訓練中に計算したものの移動平均を使います。 ネットワークに入れて計算するとなると少しややこしいですが、考え方自体は非常に単純ですね!