13.8. [発展課題]skip_gram.py#
Skip-Gramを実装し,max_epochs=100, minibatch_size=512として訓練し,「サッカー」,「日本」,「女王」,「機械学習」について類似単語を類似度の高い順に上位5個表示するプログラムを作成してください.
cbow.pyを参考にしてください.
学習にはja.text8を利用してください.
雛形:
class SkipGram(nn.Module):
def __init__(self):
super().__init__()
...
def forward(self,x):
...
13.8.1. 実装#
13.8.2. 実行結果#
13.8.2.1. Usage#
argparserのdescriptionやhelpに説明を書き込んで,--help
オプションで使い方が表示できるようにしてください.
(datasci) mriki@RikinoMac _prml % python skipgram.py -h
usage: skipgram.py [-h] [--learning_rate LEARNING_RATE] [--batch_size BATCH_SIZE] [--embedding_dim EMBEDDING_DIM] [--seed SEED] [--max_epochs MAX_EPOCHS] [--char_limit CHAR_LIMIT] [--device DEVICE] [--data_path DATA_PATH] [--save_path SAVE_PATH] [--window_size WINDOW_SIZE] [--query QUERY] [--topn TOPN]
Skip-Gramの訓練をja.text8で行う
options:
-h, --help show this help message and exit
--learning_rate LEARNING_RATE
--batch_size BATCH_SIZE
--embedding_dim EMBEDDING_DIM
--seed SEED
--max_epochs MAX_EPOCHS
--char_limit CHAR_LIMIT
ja.text8の先頭から何文字を利用するか.Noneの場合は全てを使う. ex. 1_000_000
--device DEVICE
--data_path DATA_PATH
訓練用コーパスの保存場所
--save_path SAVE_PATH
学習済みモデルのファイル名.すでに存在していた場合はそれを読み込んで利用する
--window_size WINDOW_SIZE
--query QUERY 文字列を渡すと類似する単語をtopn個検索する
--topn TOPN 検索単語数
13.8.2.2. 実行#
初回学習時:
(datasci) mriki@RikinoMac _prml % python skipgram.py --char_limit 1000000 --seed 7012 --save_path ./skipgram.pkl --max_epochs 2
全文書の文字数が46507793あり,その内1000000だけを利用します.
前処理...
363003it [00:16, 21819.80it/s]
100%|███████████████████████| 80264/80264 [00:00<00:00, 900528.08it/s]
contextsのshape: (80264, 17871)
訓練開始...
epoch train_loss train_ppl valid_loss valid_ppl dur
------- ------------ ----------- ------------ ----------- -------
1 9.8043 18111.2683 10.1251 24961.8040 14.5564
2 8.4743 4789.8352 10.6640 42786.7327 14.5319
学習済みの場合:
(datasci) mriki@RikinoMac _prml % python skipgram.py --char_limit 1000000 --seed 7012 --save_path ./skipgram.pkl --max_epochs 2
./skipgram.pklから学習済みモデルを読み込みます...
学習済みでクエリを検索する場合:
(datasci) mriki@RikinoMac _prml % python skipgram.py --save_path ./skipgram.pkl --query 日本
./skipgram.pklから学習済みモデルを読み込みます...
>>> 日本
1:古代 0.9472917318344116
2:文明 0.9328379034996033
3:社会 0.931919515132904
4:文化 0.9224883317947388
5:天皇 0.9139895439147949
(datasci) mriki@RikinoMac _prml % python skipgram.py --save_path ./skipgram.pkl --query ロボット
./skipgram.pklから学習済みモデルを読み込みます...
>>> ロボット
1:ロボティックス 0.8361095190048218
2:ステーション 0.8090811967849731
3:ぼう 0.8085721135139465
4:ロケット 0.7877843976020813
5:地球 0.7545166611671448