13.9. Skip-Gram#
Note
このノートは Skip-Gram実装課題 のヒントになるように書かれています.
CBoWと遂になる単語埋め込みベクトル作成手法である Skip-Gram を実装します.ここでは,Negative Samplingのような技術を使わず,出力層でsoftmax関数を利用することで実装を簡単にしています.そのため計算コストが膨大になる傾向があり,大規模なコーパスに適用することはお勧めしません.
Skip-Gramの計算コストの大きさには出力層のSoftmax活性化関数が大きな影響を与えます.そのため,高速化を行うためにはSoftmaxを Negative Sampling と呼ばれるアルゴリズムで代用することになります.これについてはこのブログが実装の助けになります.また,直接Skip-Gramを紹介しているわけではないのですが,CBoWの説明の中でこれを説明しているゼロから作るDeep Learning ❷ ―自然言語処理編も非常に参考になるでしょう.
# packageのimport
import re
import math
from typing import Any
from tqdm.std import trange,tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.sparse import lil_matrix
# pytorch関連のimport
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import skorch
from skorch import NeuralNetClassifier, NeuralNetRegressor
from skorch.callbacks import Callback, EpochScoring
from torch.utils.data import Dataset
from janome.tokenizer import Tokenizer
13.9.1. データの読み込み#
コーパスにはja.text8のサブセットを利用します.発展課題に取り組む場合も,指定されているハイパーパラメータやコーパスのサイズが実行困難である場合は適宜修正してください.ただし,その場合はskip_gram.pyの先頭行に,docstringを用意してその旨を書いてください.あるいはCLIのオプションにしてもいいかもしれません.
with open("./data/ja.text8") as f:
text8 = f.read()
print(text8[:200])
#LIMIT = math.floor(len(text8)*0.1)
LIMIT = 100_0000
print(f"{LIMIT}/ {len(text8)}")
text8 = text8[:LIMIT]
ちょん 掛け ( ちょん がけ 、 丁 斧 掛け ・ 手斧 掛け と も 表記 ) と は 、 相撲 の 決まり 手 の ひとつ で ある 。 自分 の 右 ( 左 ) 足 の 踵 を 相手 の 右 ( 左 ) 足 の 踵 に 掛け 、 後方 に 捻っ て 倒す 技 。 手斧 ( ちょう な ) を かける 仕草 に 似 て いる こと から 、 ちょう な が 訛っ て ちょん 掛け と なっ
1000000/ 46507793
13.9.2. 形態素解析#
コーパス内の単語(トークン)全てを利用すると語彙が多くなりすぎるので,ここでは名詞(それも一般名詞と固有名詞)のみを利用します.そのために形態素解析を行う必要があるので,python製の形態素解析器であるjanomeを利用しています.形態素解析器はこれ以外にもMecabなどが有名です.
13.9.2.1. 形態素解析器 janome#
ja.text8の一部に品詞分解を行なった結果を以下に示します.
t = Tokenizer()
sample_text = "".join(text8[:50].split())
for token in t.tokenize(sample_text):
print(token.surface, "\t", token.part_of_speech.split(","))
ちょん ['名詞', '一般', '*', '*']
掛け ['名詞', '接尾', '一般', '*']
( ['記号', '括弧開', '*', '*']
ちょん ['名詞', '一般', '*', '*']
がけ ['名詞', '接尾', '一般', '*']
、 ['記号', '読点', '*', '*']
丁 ['名詞', '固有名詞', '人名', '姓']
斧 ['名詞', '一般', '*', '*']
掛け ['名詞', '接尾', '一般', '*']
・ ['記号', '一般', '*', '*']
手斧 ['名詞', '一般', '*', '*']
掛け ['名詞', '接尾', '一般', '*']
と ['助詞', '格助詞', '引用', '*']
も ['助詞', '係助詞', '*', '*']
表記 ['名詞', 'サ変接続', '*', '*']
) ['記号', '括弧閉', '*', '*']
と ['助詞', '格助詞', '引用', '*']
は ['助詞', '係助詞', '*', '*']
、 ['記号', '読点', '*', '*']
相撲 ['名詞', '一般', '*', '*']
13.9.2.2. janomeを使った語彙辞書作成#
活用する語彙をまとめた辞書(word2id, id2word)を作成します.この実装はダーティなので,実際に自然言語処理を行う場合は参考にしないでください.
def my_analyzer(text):
#text = code_regex.sub('', text)
#tokens = text.split()
#tokens = filter(lambda token: re.search(r'[ぁ-ん]+|[ァ-ヴー]+|[一-龠]+', token), tokens)
tokens = []
for token in tqdm(t.tokenize(text)):
pos = token.part_of_speech.split(",")
if "名詞" == pos[0]:
if "一般" == pos[1] or "固有名詞" == pos[1]:
tokens.append(token.surface)
tokens = filter(lambda token: re.search(r'[ぁ-ん]+|[ァ-ヴー]+|[一-龠]+', token), tokens)
return tokens
def build_contexts_and_target(corpus, window_size:int=5)->tuple[np.ndarray,np.ndarray]:
contexts = []
target = []
vocab = set()
_window_size = window_size//2
# 文ごとに分割
preprocessed_corpus = corpus.replace(" ","")
# posを見て単語ごとに分割
tokens = list(my_analyzer(preprocessed_corpus))
# 新しい語彙を追加
vocab = vocab | set(tokens)
# スライディングウィンドウ
for i in trange(_window_size, len(tokens)-_window_size):
# ウィンドウの真ん中をtargetにする
target.append(tokens[i])
# 真ん中以外の単語をcontextsへ
tmp = tokens[i-_window_size:i]
tmp += tokens[i+1:i+1+_window_size]
contexts.append(tmp)
# 辞書作成
id2word = list(vocab)
word2id = {word:id for id,word in enumerate(id2word)}
vocab_size = len(word2id)
# contextsとtargetを単語id配列へ置き換え
contexts_id_list = [[word2id[word] for word in doc] for doc in contexts]
target_id_list = [word2id[word] for word in target]
contexts = lil_matrix((len(contexts_id_list), vocab_size),dtype=np.float32)
for index, _contexts_id_list in enumerate(contexts_id_list):
#tmp = np.eye(vocab_size)[np.array(_contexts_id_list)]
for word_id in _contexts_id_list:
contexts[index, word_id] +=1.
target = np.array(target_id_list)
return contexts.tocsr().astype(np.float32), target, word2id, id2word
WINDOW_SIZE = 11
contexts, target, word2id, id2word = build_contexts_and_target(text8, window_size=WINDOW_SIZE)
print(f"contextsのshape: {contexts.shape}")
Show code cell output
0it [00:00, ?it/s]
1031it [00:00, 9207.91it/s]
3384it [00:00, 15945.77it/s]
4973it [00:00, 14136.67it/s]
7290it [00:00, 16801.08it/s]
9911it [00:00, 19503.88it/s]
11894it [00:00, 17662.89it/s]
14688it [00:00, 20015.87it/s]
17443it [00:00, 21492.97it/s]
19626it [00:01, 19385.60it/s]
21942it [00:01, 20228.29it/s]
24600it [00:01, 19012.62it/s]
27245it [00:01, 20455.99it/s]
29947it [00:01, 21519.38it/s]
32146it [00:01, 19971.38it/s]
34186it [00:01, 20058.81it/s]
36732it [00:01, 21187.94it/s]
38881it [00:02, 19209.84it/s]
41367it [00:02, 20392.16it/s]
44323it [00:02, 22351.98it/s]
46600it [00:02, 20562.91it/s]
49044it [00:02, 21314.90it/s]
51545it [00:02, 19706.76it/s]
54201it [00:02, 20885.00it/s]
57259it [00:02, 22870.02it/s]
59596it [00:02, 20649.21it/s]
62042it [00:03, 21511.60it/s]
64868it [00:03, 22962.30it/s]
67216it [00:03, 20470.28it/s]
69684it [00:03, 21235.41it/s]
72375it [00:03, 19860.99it/s]
74930it [00:03, 20951.33it/s]
77557it [00:03, 22169.03it/s]
79835it [00:03, 20464.69it/s]
82490it [00:04, 21592.60it/s]
85158it [00:04, 22843.15it/s]
87492it [00:04, 20145.55it/s]
90217it [00:04, 21777.19it/s]
93355it [00:04, 20684.09it/s]
96171it [00:04, 22188.52it/s]
98915it [00:04, 22964.88it/s]
101276it [00:04, 20695.44it/s]
103797it [00:05, 21508.38it/s]
106294it [00:05, 22354.19it/s]
108585it [00:05, 20573.30it/s]
110986it [00:05, 21142.76it/s]
113351it [00:05, 19119.68it/s]
115930it [00:05, 20666.51it/s]
118738it [00:05, 22417.33it/s]
121050it [00:05, 20500.66it/s]
123478it [00:05, 21138.16it/s]
126155it [00:06, 22051.38it/s]
128404it [00:06, 20331.06it/s]
131015it [00:06, 21627.94it/s]
133783it [00:06, 19737.50it/s]
136575it [00:06, 21104.03it/s]
139184it [00:06, 22008.09it/s]
141443it [00:06, 20336.34it/s]
144166it [00:06, 21509.16it/s]
146774it [00:07, 22653.08it/s]
149088it [00:07, 20462.27it/s]
151535it [00:07, 21069.40it/s]
153692it [00:07, 19337.81it/s]
156152it [00:07, 20612.80it/s]
158966it [00:07, 22302.98it/s]
161248it [00:07, 20797.37it/s]
163827it [00:07, 21824.98it/s]
166330it [00:07, 22336.57it/s]
168597it [00:08, 20933.34it/s]
171093it [00:08, 21761.61it/s]
173663it [00:08, 19606.54it/s]
176402it [00:08, 21169.04it/s]
178955it [00:08, 22142.95it/s]
181228it [00:08, 20072.82it/s]
184132it [00:08, 22041.71it/s]
187285it [00:08, 24487.93it/s]
189811it [00:09, 21763.63it/s]
192223it [00:09, 22250.37it/s]
194628it [00:09, 20218.48it/s]
197490it [00:09, 21886.23it/s]
200093it [00:09, 22864.94it/s]
202446it [00:09, 20549.72it/s]
205254it [00:09, 22066.05it/s]
208023it [00:09, 23121.44it/s]
210394it [00:10, 20775.12it/s]
213253it [00:10, 22675.59it/s]
215602it [00:10, 20756.17it/s]
218624it [00:10, 22845.53it/s]
221348it [00:10, 23806.36it/s]
223794it [00:10, 21371.23it/s]
226502it [00:10, 22538.14it/s]
229410it [00:10, 20618.47it/s]
231829it [00:11, 21059.41it/s]
234744it [00:11, 22564.11it/s]
237064it [00:11, 20716.34it/s]
239851it [00:11, 22198.74it/s]
242636it [00:11, 20444.58it/s]
245386it [00:11, 21886.29it/s]
248174it [00:11, 23221.07it/s]
250569it [00:11, 20837.15it/s]
253681it [00:12, 22960.38it/s]
256660it [00:12, 24316.69it/s]
259163it [00:12, 21965.20it/s]
261707it [00:12, 22343.22it/s]
264102it [00:12, 19962.96it/s]
266947it [00:12, 21613.89it/s]
269677it [00:12, 22882.45it/s]
272037it [00:12, 20781.29it/s]
274815it [00:12, 22043.62it/s]
277695it [00:13, 20336.98it/s]
280492it [00:13, 21831.75it/s]
283322it [00:13, 23121.58it/s]
285708it [00:13, 20859.52it/s]
288313it [00:13, 22135.23it/s]
290985it [00:13, 23067.38it/s]
293356it [00:13, 21029.72it/s]
296111it [00:13, 22521.41it/s]
298431it [00:14, 20682.75it/s]
301136it [00:14, 21924.03it/s]
303773it [00:14, 23026.40it/s]
306130it [00:14, 20672.88it/s]
308772it [00:14, 21915.16it/s]
311264it [00:14, 19904.46it/s]
314143it [00:14, 21457.10it/s]
317011it [00:14, 22812.50it/s]
319354it [00:15, 21175.80it/s]
322056it [00:15, 22014.84it/s]
324449it [00:15, 19909.35it/s]
327130it [00:15, 21500.45it/s]
329983it [00:15, 23328.20it/s]
332392it [00:15, 21184.59it/s]
334783it [00:15, 21789.33it/s]
337871it [00:15, 20501.95it/s]
340898it [00:16, 22644.90it/s]
343623it [00:16, 23466.44it/s]
346043it [00:16, 21297.18it/s]
348436it [00:16, 21590.39it/s]
351042it [00:16, 22718.74it/s]
353369it [00:16, 20369.71it/s]
356445it [00:16, 22845.44it/s]
358815it [00:16, 20713.17it/s]
361440it [00:16, 21746.02it/s]
363003it [00:17, 21313.04it/s]
0%| | 0/80264 [00:00<?, ?it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80264/80264 [00:00<00:00, 849904.99it/s]
contextsのshape: (80264, 17871)
13.9.3. クラスの作成#
Skip-gramをnn.Moduleのサブクラスとして実装します.
クラスの実装には上のskip-gramアーキテクチャ図を参考にしてください.高速化のテクニックなどは不要です.(もちろん実装できる人は実装してもOK)
class SkipGram(nn.Module):
def __init__(self, vocab_size:int, embedding_dim:int)->None:
super().__init__()
...
def forward(self, input:torch.Tensor)->torch.Tensor:
...
13.9.4. 損失関数の作成#
Skip-Gramはクラス分類の体裁をとっているので,損失関数にはCross Entropyを用います.ただしPyTorchで用意されているnn.CrossEntropyを用いることは(おそらく)できないので,自作しましょう.
Hint
条件:
batch_size=128, vocab_size=11342のとき,以下が損失関数に入力されると仮定して実装してください.
SkipGramがforwardメソッドから出力するtensor.shapeは「torch.Size([128, 11342])」,
正解データとして利用するtensor.shapeは「torch.Size([128, 11342])」
callbackにおいて,ここで実装したcross entropyを使ってperplexityを計算します.
class BowCrossEntropy(nn.Module):
def forward(self, input, target):
"""
inputはSkip-gramの出力です.
targetは予測したいcontextsです.
"""
...
13.9.5. trainerの準備と訓練#
ここまでの実装が終わったら,あとは訓練用のプログラムを書くだけです.この解説ではskorchを利用して楽をします.Skip-Gramはクラス分類の体裁を取っていると言いましたが,出力はcategoricalではなくmultinomialです.つまり 一つのデータに対して正解ラベルが複数あります .これはskorchのNeuralNetClassifier
では上手く扱えないので,NeuralNetRegressor
を使っています.
Note
NeuralNetClassifierは主に1データ1ラベルの場合に利用します.今回の例でも使えないわけではないのですが,標準で設定された「正答率を表示するコールバック」が動作してしまうので利用を見送りました.
EpochScoring(lambda net,X=None,y=None: np.exp(net.history_[-1, "valid_loss"]), name="valid_ppl"),
はエポックの終わりに呼び出されるコールバック関数の雛形であるEpochScoring
を利用して,Perplexityを計算します.targetもcontextsもnp.ndarrayのままでfitに渡します.
trainerが中でdatasetやdataloaderを用意してくれます.
contextsはscipy.sparse.lil_matrix or scipy.sparse.csr_matrixになっているので,
toarray
メソッドでnp.ndarrayに戻しています.
trainer = NeuralNetRegressor(
SkipGram(len(word2id), 50),
optimizer=optim.Adam,
criterion=BowCrossEntropy,
max_epochs=20,
batch_size=128,
lr=0.01,
callbacks=[
EpochScoring(lambda net,X=None,y=None: np.exp(net.history_[-1, "valid_loss"]), name="valid_ppl"),
EpochScoring(lambda net,X=None,y=None: np.exp(net.history_[-1, "train_loss"]), name="train_ppl", on_train=True,)
],
device="cpu", # 適宜変更
)
trainer.fit(target, contexts.toarray())
Show code cell output
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[7], line 15
1 trainer = NeuralNetRegressor(
2 SkipGram(len(word2id), 50),
3 optimizer=optim.Adam,
(...)
12 device="cpu", # 適宜変更
13 )
---> 15 trainer.fit(target, contexts.toarray())
File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/skorch/regressor.py:82, in NeuralNetRegressor.fit(self, X, y, **fit_params)
71 """See ``NeuralNet.fit``.
72
73 In contrast to ``NeuralNet.fit``, ``y`` is non-optional to
(...)
77
78 """
79 # pylint: disable=useless-super-delegation
80 # this is actually a pylint bug:
81 # https://github.com/PyCQA/pylint/issues/1085
---> 82 return super(NeuralNetRegressor, self).fit(X, y, **fit_params)
File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/skorch/net.py:1300, in NeuralNet.fit(self, X, y, **fit_params)
1268 """Initialize and fit the module.
1269
1270 If the module was already initialized, by calling fit, the
(...)
1297
1298 """
1299 if not self.warm_start or not self.initialized_:
-> 1300 self.initialize()
1302 self.partial_fit(X, y, **fit_params)
1303 return self
File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/skorch/net.py:888, in NeuralNet.initialize(self)
886 self._initialize_module()
887 self._initialize_criterion()
--> 888 self._initialize_optimizer()
889 self._initialize_history()
891 self._validate_params()
File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/skorch/net.py:859, in NeuralNet._initialize_optimizer(self, reason)
856 msg = self._format_reinit_msg("optimizer", triggered_directly=False)
857 print(msg)
--> 859 self.initialize_optimizer()
861 # register the virtual params for all optimizers
862 for name in self._optimizers:
File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/skorch/net.py:625, in NeuralNet.initialize_optimizer(self, triggered_directly)
621 args, kwargs = self.get_params_for_optimizer(
622 'optimizer', named_parameters)
624 # pylint: disable=attribute-defined-outside-init
--> 625 self.optimizer_ = self.optimizer(*args, **kwargs)
626 return self
File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/optim/adam.py:33, in Adam.__init__(self, params, lr, betas, eps, weight_decay, amsgrad, foreach, maximize, capturable, differentiable, fused)
27 raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
29 defaults = dict(lr=lr, betas=betas, eps=eps,
30 weight_decay=weight_decay, amsgrad=amsgrad,
31 maximize=maximize, foreach=foreach, capturable=capturable,
32 differentiable=differentiable, fused=fused)
---> 33 super().__init__(params, defaults)
35 if fused:
36 if differentiable:
File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/optim/optimizer.py:187, in Optimizer.__init__(self, params, defaults)
185 param_groups = list(params)
186 if len(param_groups) == 0:
--> 187 raise ValueError("optimizer got an empty parameter list")
188 if not isinstance(param_groups[0], dict):
189 param_groups = [{'params': param_groups}]
ValueError: optimizer got an empty parameter list
13.9.6. 類似単語検索#
cbowと同様に単語埋め込みベクトルを使って,類似単語の検索を行います.
def get_similar_words(query, word_embeddings, topn=5, word2id=word2id, ):
"""単語埋め込みベクトルを使って似た単語を検索する
Args:
query (str): 類似単語を検索したい単語
topn (int, optional): 検索結果の表示個数. Defaults to 5.
word2id (dict[str,int], optional): 単語→単語idの辞書. Defaults to word2id.
word_embeddings (np.ndarray, optional): 単語埋め込み行列.必ず(語彙数x埋め込み次元数)の行列であること. Defaults to word_embeddings.
"""
id=word2id[query]
E = (word_embeddings.T / np.linalg.norm(word_embeddings,ord=2, axis=1)).T # {(V,L).T / (V)}.T = (V,L)
target_vector = E[id]
cossim = E @ target_vector # (V,L)@(L)=(V)
sorted_index = np.argsort(cossim)[::-1][1:topn+1] # 最も似たベクトルは自分自身なので先頭を除外
print(f">>> {query}")
_id2word = list(word2id.keys())
for rank, i in enumerate(sorted_index):
print(f"{rank+1}:{_id2word[i]} \t{cossim[i]}")
word_embeddings = trainer.module_.embedding.weight.detach().cpu().numpy()
get_similar_words("ロボット", word_embeddings, )
>>> ロボット
1:ユニバーサル 0.898608386516571
2:ポルト 0.7893995642662048
3:ロボティックス 0.763614296913147
4:テラ 0.742680013179779
5:関節 0.7259170413017273
get_similar_words("サッカー", word_embeddings, )
get_similar_words("日本", word_embeddings, )
get_similar_words("女王", word_embeddings, )
get_similar_words("機械学習", word_embeddings, )
>>> サッカー
1:リーグ 0.734089195728302
2:専業 0.7245967388153076
3:ヴァンフォーレ 0.6850863695144653
4:選手 0.6845436692237854
5:アルビレックス 0.6741206645965576
>>> 日本
1:ほん 0.6705817580223083
2:米国 0.6255179047584534
3:王者 0.6063108444213867
4:社団 0.5765134692192078
5:蓄音機 0.5684884786605835
>>> 女王
1:ヴィクトリアシリーズ 0.6750556826591492
2:後塵 0.649889349937439
3:ティアラカップ 0.641579806804657
4:ボウラー 0.6231715083122253
5:シェクター 0.6060587763786316
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[11], line 4
2 get_similar_words("日本", word_embeddings, )
3 get_similar_words("女王", word_embeddings, )
----> 4 get_similar_words("機械学習", word_embeddings, )
Cell In[9], line 10, in get_similar_words(query, word_embeddings, topn, word2id)
1 def get_similar_words(query, word_embeddings, topn=5, word2id=word2id, ):
2 """単語埋め込みベクトルを使って似た単語を検索する
3
4 Args:
(...)
8 word_embeddings (np.ndarray, optional): 単語埋め込み行列.必ず(語彙数x埋め込み次元数)の行列であること. Defaults to word_embeddings.
9 """
---> 10 id=word2id[query]
11 E = (word_embeddings.T / np.linalg.norm(word_embeddings,ord=2, axis=1)).T # {(V,L).T / (V)}.T = (V,L)
12 target_vector = E[id]
KeyError: '機械学習'
今回の解説ではja.text8のサブセットを利用しているせいで,この単語埋め込みがカバーしている語彙に「機械学習」は含まれていないようです.