13.9. Skip-Gram#

Note

このノートは Skip-Gram実装課題 のヒントになるように書かれています.

CBoWと遂になる単語埋め込みベクトル作成手法である Skip-Gram を実装します.ここでは,Negative Samplingのような技術を使わず,出力層でsoftmax関数を利用することで実装を簡単にしています.そのため計算コストが膨大になる傾向があり,大規模なコーパスに適用することはお勧めしません.

Skip-Gramの計算コストの大きさには出力層のSoftmax活性化関数が大きな影響を与えます.そのため,高速化を行うためにはSoftmaxを Negative Sampling と呼ばれるアルゴリズムで代用することになります.これについてはこのブログが実装の助けになります.また,直接Skip-Gramを紹介しているわけではないのですが,CBoWの説明の中でこれを説明しているゼロから作るDeep Learning ❷ ―自然言語処理編も非常に参考になるでしょう.

# packageのimport
import re
import math 
from typing import Any
from tqdm.std import trange,tqdm
import numpy as np 
import matplotlib.pyplot as plt 
import seaborn as sns
from scipy.sparse import lil_matrix

# pytorch関連のimport
import torch
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim 
import skorch
from skorch import NeuralNetClassifier, NeuralNetRegressor
from skorch.callbacks import Callback, EpochScoring
from torch.utils.data import Dataset

from janome.tokenizer import Tokenizer

13.9.1. データの読み込み#

コーパスにはja.text8のサブセットを利用します.発展課題に取り組む場合も,指定されているハイパーパラメータやコーパスのサイズが実行困難である場合は適宜修正してください.ただし,その場合はskip_gram.pyの先頭行に,docstringを用意してその旨を書いてください.あるいはCLIのオプションにしてもいいかもしれません.

with open("./data/ja.text8") as f:
    text8 = f.read()
print(text8[:200])

#LIMIT = math.floor(len(text8)*0.1)
LIMIT = 100_0000
print(f"{LIMIT}/ {len(text8)}")
text8 = text8[:LIMIT]
ちょん 掛け ( ちょん がけ 、 丁 斧 掛け ・ 手斧 掛け と も 表記 ) と は 、 相撲 の 決まり 手 の ひとつ で ある 。 自分 の 右 ( 左 ) 足 の 踵 を 相手 の 右 ( 左 ) 足 の 踵 に 掛け 、 後方 に 捻っ て 倒す 技 。 手斧 ( ちょう な ) を かける 仕草 に 似 て いる こと から 、 ちょう な が 訛っ て ちょん 掛け と なっ 
1000000/ 46507793

13.9.2. 形態素解析#

コーパス内の単語(トークン)全てを利用すると語彙が多くなりすぎるので,ここでは名詞(それも一般名詞と固有名詞)のみを利用します.そのために形態素解析を行う必要があるので,python製の形態素解析器であるjanomeを利用しています.形態素解析器はこれ以外にもMecabなどが有名です.

13.9.2.1. 形態素解析器 janome#

ja.text8の一部に品詞分解を行なった結果を以下に示します.

t = Tokenizer()
sample_text = "".join(text8[:50].split())
for token in t.tokenize(sample_text):
    print(token.surface, "\t", token.part_of_speech.split(","))
ちょん 	 ['名詞', '一般', '*', '*']
掛け 	 ['名詞', '接尾', '一般', '*']
( 	 ['記号', '括弧開', '*', '*']
ちょん 	 ['名詞', '一般', '*', '*']
がけ 	 ['名詞', '接尾', '一般', '*']
、 	 ['記号', '読点', '*', '*']
丁 	 ['名詞', '固有名詞', '人名', '姓']
斧 	 ['名詞', '一般', '*', '*']
掛け 	 ['名詞', '接尾', '一般', '*']
・ 	 ['記号', '一般', '*', '*']
手斧 	 ['名詞', '一般', '*', '*']
掛け 	 ['名詞', '接尾', '一般', '*']
と 	 ['助詞', '格助詞', '引用', '*']
も 	 ['助詞', '係助詞', '*', '*']
表記 	 ['名詞', 'サ変接続', '*', '*']
) 	 ['記号', '括弧閉', '*', '*']
と 	 ['助詞', '格助詞', '引用', '*']
は 	 ['助詞', '係助詞', '*', '*']
、 	 ['記号', '読点', '*', '*']
相撲 	 ['名詞', '一般', '*', '*']

13.9.2.2. janomeを使った語彙辞書作成#

活用する語彙をまとめた辞書(word2id, id2word)を作成します.この実装はダーティなので,実際に自然言語処理を行う場合は参考にしないでください.

def my_analyzer(text):
    #text = code_regex.sub('', text)
    #tokens = text.split()
    #tokens = filter(lambda token: re.search(r'[ぁ-ん]+|[ァ-ヴー]+|[一-龠]+', token), tokens)
    tokens = []
    for token in tqdm(t.tokenize(text)):
        pos = token.part_of_speech.split(",")
        if "名詞" == pos[0]:
            if "一般" == pos[1] or "固有名詞" == pos[1]:
                tokens.append(token.surface)
    tokens = filter(lambda token: re.search(r'[ぁ-ん]+|[ァ-ヴー]+|[一-龠]+', token), tokens)
    return tokens 

def build_contexts_and_target(corpus, window_size:int=5)->tuple[np.ndarray,np.ndarray]:
    contexts = []
    target = []
    vocab = set()
    _window_size = window_size//2
    # 文ごとに分割
    preprocessed_corpus = corpus.replace(" ","")
    # posを見て単語ごとに分割
    tokens = list(my_analyzer(preprocessed_corpus))

    # 新しい語彙を追加
    vocab = vocab | set(tokens)

    # スライディングウィンドウ
    for i in trange(_window_size, len(tokens)-_window_size):
        # ウィンドウの真ん中をtargetにする
        target.append(tokens[i])
        # 真ん中以外の単語をcontextsへ
        tmp = tokens[i-_window_size:i]
        tmp += tokens[i+1:i+1+_window_size]
        contexts.append(tmp)

    # 辞書作成
    id2word = list(vocab)
    word2id = {word:id for id,word in enumerate(id2word)}
    vocab_size = len(word2id)


    # contextsとtargetを単語id配列へ置き換え
    contexts_id_list = [[word2id[word] for word in doc] for doc in contexts]
    target_id_list = [word2id[word] for word in target]


    contexts = lil_matrix((len(contexts_id_list), vocab_size),dtype=np.float32)
    for index, _contexts_id_list in enumerate(contexts_id_list):
        #tmp = np.eye(vocab_size)[np.array(_contexts_id_list)]
        for word_id in _contexts_id_list:
            contexts[index, word_id] +=1.

    target = np.array(target_id_list)
    return contexts.tocsr().astype(np.float32), target, word2id, id2word

WINDOW_SIZE = 11
contexts, target, word2id, id2word = build_contexts_and_target(text8, window_size=WINDOW_SIZE)
print(f"contextsのshape: {contexts.shape}")
Hide code cell output
0it [00:00, ?it/s]
1031it [00:00, 9207.91it/s]
3384it [00:00, 15945.77it/s]
4973it [00:00, 14136.67it/s]
7290it [00:00, 16801.08it/s]
9911it [00:00, 19503.88it/s]
11894it [00:00, 17662.89it/s]
14688it [00:00, 20015.87it/s]
17443it [00:00, 21492.97it/s]
19626it [00:01, 19385.60it/s]
21942it [00:01, 20228.29it/s]
24600it [00:01, 19012.62it/s]
27245it [00:01, 20455.99it/s]
29947it [00:01, 21519.38it/s]
32146it [00:01, 19971.38it/s]
34186it [00:01, 20058.81it/s]
36732it [00:01, 21187.94it/s]
38881it [00:02, 19209.84it/s]
41367it [00:02, 20392.16it/s]
44323it [00:02, 22351.98it/s]
46600it [00:02, 20562.91it/s]
49044it [00:02, 21314.90it/s]
51545it [00:02, 19706.76it/s]
54201it [00:02, 20885.00it/s]
57259it [00:02, 22870.02it/s]
59596it [00:02, 20649.21it/s]
62042it [00:03, 21511.60it/s]
64868it [00:03, 22962.30it/s]
67216it [00:03, 20470.28it/s]
69684it [00:03, 21235.41it/s]
72375it [00:03, 19860.99it/s]
74930it [00:03, 20951.33it/s]
77557it [00:03, 22169.03it/s]
79835it [00:03, 20464.69it/s]
82490it [00:04, 21592.60it/s]
85158it [00:04, 22843.15it/s]
87492it [00:04, 20145.55it/s]
90217it [00:04, 21777.19it/s]
93355it [00:04, 20684.09it/s]
96171it [00:04, 22188.52it/s]
98915it [00:04, 22964.88it/s]
101276it [00:04, 20695.44it/s]
103797it [00:05, 21508.38it/s]
106294it [00:05, 22354.19it/s]
108585it [00:05, 20573.30it/s]
110986it [00:05, 21142.76it/s]
113351it [00:05, 19119.68it/s]
115930it [00:05, 20666.51it/s]
118738it [00:05, 22417.33it/s]
121050it [00:05, 20500.66it/s]
123478it [00:05, 21138.16it/s]
126155it [00:06, 22051.38it/s]
128404it [00:06, 20331.06it/s]
131015it [00:06, 21627.94it/s]
133783it [00:06, 19737.50it/s]
136575it [00:06, 21104.03it/s]
139184it [00:06, 22008.09it/s]
141443it [00:06, 20336.34it/s]
144166it [00:06, 21509.16it/s]
146774it [00:07, 22653.08it/s]
149088it [00:07, 20462.27it/s]
151535it [00:07, 21069.40it/s]
153692it [00:07, 19337.81it/s]
156152it [00:07, 20612.80it/s]
158966it [00:07, 22302.98it/s]
161248it [00:07, 20797.37it/s]
163827it [00:07, 21824.98it/s]
166330it [00:07, 22336.57it/s]
168597it [00:08, 20933.34it/s]
171093it [00:08, 21761.61it/s]
173663it [00:08, 19606.54it/s]
176402it [00:08, 21169.04it/s]
178955it [00:08, 22142.95it/s]
181228it [00:08, 20072.82it/s]
184132it [00:08, 22041.71it/s]
187285it [00:08, 24487.93it/s]
189811it [00:09, 21763.63it/s]
192223it [00:09, 22250.37it/s]
194628it [00:09, 20218.48it/s]
197490it [00:09, 21886.23it/s]
200093it [00:09, 22864.94it/s]
202446it [00:09, 20549.72it/s]
205254it [00:09, 22066.05it/s]
208023it [00:09, 23121.44it/s]
210394it [00:10, 20775.12it/s]
213253it [00:10, 22675.59it/s]
215602it [00:10, 20756.17it/s]
218624it [00:10, 22845.53it/s]
221348it [00:10, 23806.36it/s]
223794it [00:10, 21371.23it/s]
226502it [00:10, 22538.14it/s]
229410it [00:10, 20618.47it/s]
231829it [00:11, 21059.41it/s]
234744it [00:11, 22564.11it/s]
237064it [00:11, 20716.34it/s]
239851it [00:11, 22198.74it/s]
242636it [00:11, 20444.58it/s]
245386it [00:11, 21886.29it/s]
248174it [00:11, 23221.07it/s]
250569it [00:11, 20837.15it/s]
253681it [00:12, 22960.38it/s]
256660it [00:12, 24316.69it/s]
259163it [00:12, 21965.20it/s]
261707it [00:12, 22343.22it/s]
264102it [00:12, 19962.96it/s]
266947it [00:12, 21613.89it/s]
269677it [00:12, 22882.45it/s]
272037it [00:12, 20781.29it/s]
274815it [00:12, 22043.62it/s]
277695it [00:13, 20336.98it/s]
280492it [00:13, 21831.75it/s]
283322it [00:13, 23121.58it/s]
285708it [00:13, 20859.52it/s]
288313it [00:13, 22135.23it/s]
290985it [00:13, 23067.38it/s]
293356it [00:13, 21029.72it/s]
296111it [00:13, 22521.41it/s]
298431it [00:14, 20682.75it/s]
301136it [00:14, 21924.03it/s]
303773it [00:14, 23026.40it/s]
306130it [00:14, 20672.88it/s]
308772it [00:14, 21915.16it/s]
311264it [00:14, 19904.46it/s]
314143it [00:14, 21457.10it/s]
317011it [00:14, 22812.50it/s]
319354it [00:15, 21175.80it/s]
322056it [00:15, 22014.84it/s]
324449it [00:15, 19909.35it/s]
327130it [00:15, 21500.45it/s]
329983it [00:15, 23328.20it/s]
332392it [00:15, 21184.59it/s]
334783it [00:15, 21789.33it/s]
337871it [00:15, 20501.95it/s]
340898it [00:16, 22644.90it/s]
343623it [00:16, 23466.44it/s]
346043it [00:16, 21297.18it/s]
348436it [00:16, 21590.39it/s]
351042it [00:16, 22718.74it/s]
353369it [00:16, 20369.71it/s]
356445it [00:16, 22845.44it/s]
358815it [00:16, 20713.17it/s]
361440it [00:16, 21746.02it/s]
363003it [00:17, 21313.04it/s]

  0%|                                                                                                                                                                                                                                                                           | 0/80264 [00:00<?, ?it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80264/80264 [00:00<00:00, 849904.99it/s]

contextsのshape: (80264, 17871)

13.9.3. クラスの作成#

Skip-gramをnn.Moduleのサブクラスとして実装します.

クラスの実装には上のskip-gramアーキテクチャ図を参考にしてください.高速化のテクニックなどは不要です.(もちろん実装できる人は実装してもOK)

class SkipGram(nn.Module):
    def __init__(self, vocab_size:int, embedding_dim:int)->None:
        super().__init__()
        ...

    def forward(self, input:torch.Tensor)->torch.Tensor:
        ...

13.9.4. 損失関数の作成#

Skip-Gramはクラス分類の体裁をとっているので,損失関数にはCross Entropyを用います.ただしPyTorchで用意されているnn.CrossEntropyを用いることは(おそらく)できないので,自作しましょう.

Hint

条件:

  • batch_size=128, vocab_size=11342のとき,以下が損失関数に入力されると仮定して実装してください.

    • SkipGramがforwardメソッドから出力するtensor.shapeは「torch.Size([128, 11342])」,

    • 正解データとして利用するtensor.shapeは「torch.Size([128, 11342])」

  • callbackにおいて,ここで実装したcross entropyを使ってperplexityを計算します.

class BowCrossEntropy(nn.Module):
    def forward(self, input, target):
        """
        inputはSkip-gramの出力です.
        targetは予測したいcontextsです.
        """
        ...

13.9.5. trainerの準備と訓練#

ここまでの実装が終わったら,あとは訓練用のプログラムを書くだけです.この解説ではskorchを利用して楽をします.Skip-Gramはクラス分類の体裁を取っていると言いましたが,出力はcategoricalではなくmultinomialです.つまり 一つのデータに対して正解ラベルが複数あります .これはskorchのNeuralNetClassifierでは上手く扱えないので,NeuralNetRegressor を使っています.

Note

  • NeuralNetClassifierは主に1データ1ラベルの場合に利用します.今回の例でも使えないわけではないのですが,標準で設定された「正答率を表示するコールバック」が動作してしまうので利用を見送りました.

  • EpochScoring(lambda net,X=None,y=None: np.exp(net.history_[-1, "valid_loss"]), name="valid_ppl"), はエポックの終わりに呼び出されるコールバック関数の雛形であるEpochScoringを利用して,Perplexityを計算します.

  • targetもcontextsもnp.ndarrayのままでfitに渡します.

    • trainerが中でdatasetやdataloaderを用意してくれます.

    • contextsはscipy.sparse.lil_matrix or scipy.sparse.csr_matrixになっているので,toarrayメソッドでnp.ndarrayに戻しています.

trainer = NeuralNetRegressor(
    SkipGram(len(word2id), 50),
    optimizer=optim.Adam,
    criterion=BowCrossEntropy,
    max_epochs=20,
    batch_size=128,
    lr=0.01,
    callbacks=[
        EpochScoring(lambda net,X=None,y=None: np.exp(net.history_[-1, "valid_loss"]), name="valid_ppl"), 
        EpochScoring(lambda net,X=None,y=None: np.exp(net.history_[-1, "train_loss"]), name="train_ppl", on_train=True,)
    ],
    device="cpu", # 適宜変更
)

trainer.fit(target, contexts.toarray())
Hide code cell output
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[7], line 15
      1 trainer = NeuralNetRegressor(
      2     SkipGram(len(word2id), 50),
      3     optimizer=optim.Adam,
   (...)
     12     device="cpu", # 適宜変更
     13 )
---> 15 trainer.fit(target, contexts.toarray())

File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/skorch/regressor.py:82, in NeuralNetRegressor.fit(self, X, y, **fit_params)
     71 """See ``NeuralNet.fit``.
     72 
     73 In contrast to ``NeuralNet.fit``, ``y`` is non-optional to
   (...)
     77 
     78 """
     79 # pylint: disable=useless-super-delegation
     80 # this is actually a pylint bug:
     81 # https://github.com/PyCQA/pylint/issues/1085
---> 82 return super(NeuralNetRegressor, self).fit(X, y, **fit_params)

File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/skorch/net.py:1300, in NeuralNet.fit(self, X, y, **fit_params)
   1268 """Initialize and fit the module.
   1269 
   1270 If the module was already initialized, by calling fit, the
   (...)
   1297 
   1298 """
   1299 if not self.warm_start or not self.initialized_:
-> 1300     self.initialize()
   1302 self.partial_fit(X, y, **fit_params)
   1303 return self

File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/skorch/net.py:888, in NeuralNet.initialize(self)
    886 self._initialize_module()
    887 self._initialize_criterion()
--> 888 self._initialize_optimizer()
    889 self._initialize_history()
    891 self._validate_params()

File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/skorch/net.py:859, in NeuralNet._initialize_optimizer(self, reason)
    856         msg = self._format_reinit_msg("optimizer", triggered_directly=False)
    857     print(msg)
--> 859 self.initialize_optimizer()
    861 # register the virtual params for all optimizers
    862 for name in self._optimizers:

File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/skorch/net.py:625, in NeuralNet.initialize_optimizer(self, triggered_directly)
    621 args, kwargs = self.get_params_for_optimizer(
    622     'optimizer', named_parameters)
    624 # pylint: disable=attribute-defined-outside-init
--> 625 self.optimizer_ = self.optimizer(*args, **kwargs)
    626 return self

File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/optim/adam.py:33, in Adam.__init__(self, params, lr, betas, eps, weight_decay, amsgrad, foreach, maximize, capturable, differentiable, fused)
     27     raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
     29 defaults = dict(lr=lr, betas=betas, eps=eps,
     30                 weight_decay=weight_decay, amsgrad=amsgrad,
     31                 maximize=maximize, foreach=foreach, capturable=capturable,
     32                 differentiable=differentiable, fused=fused)
---> 33 super().__init__(params, defaults)
     35 if fused:
     36     if differentiable:

File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/optim/optimizer.py:187, in Optimizer.__init__(self, params, defaults)
    185 param_groups = list(params)
    186 if len(param_groups) == 0:
--> 187     raise ValueError("optimizer got an empty parameter list")
    188 if not isinstance(param_groups[0], dict):
    189     param_groups = [{'params': param_groups}]

ValueError: optimizer got an empty parameter list

13.9.6. 類似単語検索#

cbowと同様に単語埋め込みベクトルを使って,類似単語の検索を行います.

def get_similar_words(query, word_embeddings, topn=5, word2id=word2id, ):
    """単語埋め込みベクトルを使って似た単語を検索する

    Args:
        query (str): 類似単語を検索したい単語
        topn (int, optional): 検索結果の表示個数. Defaults to 5.
        word2id (dict[str,int], optional): 単語→単語idの辞書. Defaults to word2id.
        word_embeddings (np.ndarray, optional): 単語埋め込み行列.必ず(語彙数x埋め込み次元数)の行列であること. Defaults to word_embeddings.
    """
    id=word2id[query]
    E = (word_embeddings.T / np.linalg.norm(word_embeddings,ord=2, axis=1)).T # {(V,L).T / (V)}.T = (V,L)
    target_vector = E[id]
    cossim = E @ target_vector # (V,L)@(L)=(V)
    sorted_index = np.argsort(cossim)[::-1][1:topn+1] # 最も似たベクトルは自分自身なので先頭を除外

    print(f">>> {query}")
    _id2word = list(word2id.keys())
    for rank, i in enumerate(sorted_index):
        print(f"{rank+1}:{_id2word[i]} \t{cossim[i]}")

word_embeddings = trainer.module_.embedding.weight.detach().cpu().numpy()

get_similar_words("ロボット", word_embeddings, )
>>> ロボット
1:ユニバーサル 	0.898608386516571
2:ポルト 	0.7893995642662048
3:ロボティックス 	0.763614296913147
4:テラ 	0.742680013179779
5:関節 	0.7259170413017273
get_similar_words("サッカー", word_embeddings, )
get_similar_words("日本", word_embeddings, )
get_similar_words("女王", word_embeddings, )
get_similar_words("機械学習", word_embeddings, )
>>> サッカー
1:リーグ 	0.734089195728302
2:専業 	0.7245967388153076
3:ヴァンフォーレ 	0.6850863695144653
4:選手 	0.6845436692237854
5:アルビレックス 	0.6741206645965576
>>> 日本
1:ほん 	0.6705817580223083
2:米国 	0.6255179047584534
3:王者 	0.6063108444213867
4:社団 	0.5765134692192078
5:蓄音機 	0.5684884786605835
>>> 女王
1:ヴィクトリアシリーズ 	0.6750556826591492
2:後塵 	0.649889349937439
3:ティアラカップ 	0.641579806804657
4:ボウラー 	0.6231715083122253
5:シェクター 	0.6060587763786316
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[11], line 4
      2 get_similar_words("日本", word_embeddings, )
      3 get_similar_words("女王", word_embeddings, )
----> 4 get_similar_words("機械学習", word_embeddings, )

Cell In[9], line 10, in get_similar_words(query, word_embeddings, topn, word2id)
      1 def get_similar_words(query, word_embeddings, topn=5, word2id=word2id, ):
      2     """単語埋め込みベクトルを使って似た単語を検索する
      3 
      4     Args:
   (...)
      8         word_embeddings (np.ndarray, optional): 単語埋め込み行列.必ず(語彙数x埋め込み次元数)の行列であること. Defaults to word_embeddings.
      9     """
---> 10     id=word2id[query]
     11     E = (word_embeddings.T / np.linalg.norm(word_embeddings,ord=2, axis=1)).T # {(V,L).T / (V)}.T = (V,L)
     12     target_vector = E[id]

KeyError: '機械学習'

今回の解説ではja.text8のサブセットを利用しているせいで,この単語埋め込みがカバーしている語彙に「機械学習」は含まれていないようです.