9. NMF (Non-Negative Matrix Factorization; 非負値行列因子分解)#
9.1. NMFとは#
NMFは,非負値(\(>=0\))の元行列 \(\mathbf{V}\) を他の2つの非負値な行列 \(\mathbf{W}, \mathbf{H}\)の積で近似するアルゴリズムです.例えばユーザーごとの購買履歴を保存した行列\(\mathbf{V}\)が与えられた時に,これをユーザー数\(D\)✖️所与の埋め込み次元\(K\)(\(K\)は元の特徴数よりもかなり小さい値)の行列\(\mathbf{W}\)と\(K\)✖️特徴数\(F\)の行列\(\mathbf{H}\)に分解するようなタスクです.これらの二つの行列に何かしらの演算(ここでは積を取ります)をして,元の行列に近い行列を再構築できるようにすることで,より小さい二つの行列で元の行列を圧縮することができていると言えます.このようなタスクを行列分解と呼びます.
学習にはさまざまな方法があります.
乗法更新式
損失関数として定義するユークリッド距離やIダイバージェンスを,パラメタ更新の度に小さくするような更新式を利用します.これは数学的に損失関数が単調減少することが証明されています.[Lee and Seung, 1999, Lee and Seung, 2000]
勾配法
Neural Netの訓練でも利用されるアルゴリズムです.そのまま使うと\(W\)や\(H\)には0未満の値が含まれてしまうので,0未満になる場合は0で置き換えるような処理を追加して利用します.収束するとは限りませんが,実装は簡単です.[Lin, 2007]
9.2. scikit-learnを使った実験#
NMFはscikit-learnに実装されているので,これを利用してみましょう.
import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.decomposition import NMF
from sklearn.exceptions import NotFittedError
from tqdm.auto import trange
import plotly.express as px
import matplotlib.pyplot as plt
/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
今回は20news groupsというテキストデータを利用します.BoWに変換しているので,行列の要素は全て0以上になります.
news_train = fetch_20newsgroups(subset="train")
news_test = fetch_20newsgroups(subset="test")
vectorizer = CountVectorizer(lowercase=True, max_features=1000, stop_words="english", min_df=2, max_df=0.5)
X_train = vectorizer.fit_transform(news_train.data)
X_test = vectorizer.transform(news_test.data)
id2word = {id:key for id,key in enumerate(vectorizer.get_feature_names())}
word2id = {key:id for id,key in id2word.items()}
/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function get_feature_names is deprecated; get_feature_names is deprecated in 1.0 and will be removed in 1.2. Please use get_feature_names_out instead.
warnings.warn(msg, category=FutureWarning)
nmf = NMF(n_components=20)
nmf.fit(X_train)
W_doc_k = nmf.transform(X_test)
print(W_doc_k.shape)
H_k_vocab = nmf.components_
print(H_k_vocab.shape)
(7532, 20)
(20, 1000)
nmf.reconstruction_err_ #/ X_train.shape[0]
1488.5032852593972
9.3. NumPyを使って実装する#
ユークリッド距離を損失関数にして,乗法更新式による訓練を行います.
できるだけ行列計算を行う行にはshapeをコメントしてあります.プログラムを読む時に参考にしてください.
def update_Vt_by_euclid(X, U, Vt):
_X = U @ Vt # (D,F)=(D,K)@(K,F)
_bias = (U.T @ X) / (U.T @ _X) # (D,K).T@(D,F) / (D,K).T@(D,F)
_bias[np.isnan(_bias)] = 0.0
Vt *= _bias # (K,F)=(K,F)*(K,F)
return Vt
def update_U_by_euclid(X,U,Vt):
_X = U @ Vt # (D,F)=(D,K)@(K,F)
_bias = (X @ Vt.T) / (_X @ Vt.T) # (D,F)@(K,F).T / (D,F)@(K,F).T
_bias[np.isnan(_bias)] = 0
U *= _bias # (D,K)=(D,K)(D,K)
return U
def cost_fn_by_euclid(X,_X):
return np.linalg.norm(X - _X, axis=1).mean()
class MyNMF():
def __init__(self, n_components:int=2, max_iter:int=100, rng:bool=None, divergence="euclid"):
self.n_components = n_components
self.max_iter = max_iter
self.rng_ = rng if rng is not None else np.random.default_rng(2**1000)
self.divergence = divergence
self.is_fitted = False
self.cost_ = []
if self.divergence == "euclid":
self.update_Vt = update_Vt_by_euclid
self.update_U = update_U_by_euclid
self.cost_fn = cost_fn_by_euclid
else:
NotImplementedError('divergenceは["eculid",]から選択')
def fit_transform(self, X:np.ndarray,y=None):
X = X.astype(np.float64)
self._n_features = X.shape[1]
# Initialize two small matrices from a uniform distribution
_U = self.rng_.uniform(0,1,
size=[X.shape[0],self.n_components],
).astype(X.dtype) # (D,K)
_Vt = self.rng_.uniform(0,1,
size=[self.n_components, self._n_features],
).astype(X.dtype) # (K,F)
# update parameters
for i in trange(self.max_iter):
_Vt = self.update_Vt(X,_U,_Vt)
_U = self.update_U(X,_U,_Vt)
_X = _U@_Vt
self.cost_.append(self.cost_fn(X, _X))
# output
self.components_ = _Vt # Store _Vt in instance variable to be accessed from outside
self.is_fitted = True # Raise the flag
return _U
def fit(self,X,y=None):
self.fit_transform(X)
return self
def transform(self, X):
if not self.is_fitted:
raise NotFittedError(f"{self.__class__.__name__}.transformはfit後にのみ利用できる")
if self.components_.shape[1] != X.shape[1]:
raise ValueError("Xと訓練データの特徴数が異なっている")
X = X.astype(np.float64)
# Initialize U from a uniform distribution
U = self.rng_.uniform(0,1,
size=[X.shape[0],self.n_components],
).astype(X.dtype) # (K,F)
for i in trange(self.max_iter):
U = self.update_U(X,U,self.components_)
return U
9.3.1. 訓練の実行#
mynmf = MyNMF(20,max_iter=100)
U = mynmf.fit_transform(X_train)
Show code cell output
0%| | 0/100 [00:00<?, ?it/s]
1%|██▌ | 1/100 [00:00<00:13, 7.34it/s]
3%|███████▊ | 3/100 [00:00<00:10, 9.37it/s]
5%|█████████████ | 5/100 [00:00<00:09, 9.85it/s]
6%|███████████████▋ | 6/100 [00:00<00:09, 9.74it/s]
7%|██████████████████▎ | 7/100 [00:00<00:10, 8.84it/s]
8%|████████████████████▉ | 8/100 [00:00<00:10, 8.50it/s]
9%|███████████████████████▍ | 9/100 [00:01<00:10, 8.80it/s]
11%|████████████████████████████▌ | 11/100 [00:01<00:09, 9.49it/s]
12%|███████████████████████████████▏ | 12/100 [00:01<00:09, 9.24it/s]
13%|█████████████████████████████████▊ | 13/100 [00:01<00:11, 7.85it/s]
14%|████████████████████████████████████▍ | 14/100 [00:01<00:10, 8.32it/s]
15%|███████████████████████████████████████ | 15/100 [00:01<00:09, 8.72it/s]
16%|█████████████████████████████████████████▌ | 16/100 [00:01<00:09, 8.49it/s]
17%|████████████████████████████████████████████▏ | 17/100 [00:01<00:10, 8.26it/s]
18%|██████████████████████████████████████████████▊ | 18/100 [00:02<00:10, 8.04it/s]
19%|█████████████████████████████████████████████████▍ | 19/100 [00:02<00:10, 7.87it/s]
20%|████████████████████████████████████████████████████ | 20/100 [00:02<00:10, 7.80it/s]
21%|██████████████████████████████████████████████████████▌ | 21/100 [00:02<00:10, 7.82it/s]
22%|█████████████████████████████████████████████████████████▏ | 22/100 [00:02<00:10, 7.52it/s]
23%|███████████████████████████████████████████████████████████▊ | 23/100 [00:02<00:10, 7.33it/s]
24%|██████████████████████████████████████████████████████████████▍ | 24/100 [00:02<00:09, 7.93it/s]
26%|███████████████████████████████████████████████████████████████████▌ | 26/100 [00:03<00:08, 8.41it/s]
27%|██████████████████████████████████████████████████████████████████████▏ | 27/100 [00:03<00:09, 7.98it/s]
28%|████████████████████████████████████████████████████████████████████████▊ | 28/100 [00:03<00:08, 8.19it/s]
29%|███████████████████████████████████████████████████████████████████████████▍ | 29/100 [00:03<00:08, 8.45it/s]
30%|██████████████████████████████████████████████████████████████████████████████ | 30/100 [00:03<00:08, 7.95it/s]
31%|████████████████████████████████████████████████████████████████████████████████▌ | 31/100 [00:03<00:08, 8.42it/s]
32%|███████████████████████████████████████████████████████████████████████████████████▏ | 32/100 [00:03<00:09, 6.81it/s]
33%|█████████████████████████████████████████████████████████████████████████████████████▊ | 33/100 [00:04<00:09, 7.01it/s]
34%|████████████████████████████████████████████████████████████████████████████████████████▍ | 34/100 [00:04<00:08, 7.54it/s]
35%|███████████████████████████████████████████████████████████████████████████████████████████ | 35/100 [00:04<00:08, 8.05it/s]
36%|█████████████████████████████████████████████████████████████████████████████████████████████▌ | 36/100 [00:04<00:08, 7.70it/s]
37%|████████████████████████████████████████████████████████████████████████████████████████████████▏ | 37/100 [00:04<00:08, 7.02it/s]
38%|██████████████████████████████████████████████████████████████████████████████████████████████████▊ | 38/100 [00:04<00:08, 7.35it/s]
39%|█████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 39/100 [00:04<00:07, 7.94it/s]
40%|████████████████████████████████████████████████████████████████████████████████████████████████████████ | 40/100 [00:04<00:07, 8.27it/s]
41%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 41/100 [00:05<00:07, 7.77it/s]
42%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 42/100 [00:05<00:07, 8.02it/s]
44%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 44/100 [00:05<00:06, 8.95it/s]
46%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 46/100 [00:05<00:06, 8.97it/s]
47%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 47/100 [00:05<00:06, 8.81it/s]
48%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 48/100 [00:05<00:05, 8.88it/s]
49%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 49/100 [00:05<00:05, 9.13it/s]
50%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 50/100 [00:06<00:05, 9.03it/s]
51%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 51/100 [00:06<00:05, 9.10it/s]
52%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 52/100 [00:06<00:05, 9.13it/s]
53%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 53/100 [00:06<00:05, 9.08it/s]
54%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 54/100 [00:06<00:04, 9.25it/s]
55%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 55/100 [00:06<00:04, 9.17it/s]
56%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 56/100 [00:06<00:04, 8.91it/s]
58%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 58/100 [00:06<00:04, 9.43it/s]
60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 60/100 [00:07<00:04, 9.82it/s]
62%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 62/100 [00:07<00:03, 9.97it/s]
64%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 64/100 [00:07<00:03, 10.06it/s]
66%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 66/100 [00:07<00:03, 10.17it/s]
68%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 68/100 [00:07<00:03, 10.28it/s]
70%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 70/100 [00:08<00:02, 10.28it/s]
72%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 72/100 [00:08<00:02, 10.32it/s]
74%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 74/100 [00:08<00:02, 10.32it/s]
76%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 76/100 [00:08<00:02, 10.37it/s]
78%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 78/100 [00:08<00:02, 10.41it/s]
80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 80/100 [00:09<00:01, 10.35it/s]
82%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 82/100 [00:09<00:01, 10.33it/s]
84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 84/100 [00:09<00:01, 10.00it/s]
86%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 86/100 [00:09<00:01, 10.12it/s]
88%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 88/100 [00:09<00:01, 10.14it/s]
90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 90/100 [00:09<00:00, 10.19it/s]
92%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 92/100 [00:10<00:00, 10.03it/s]
94%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 94/100 [00:10<00:00, 10.09it/s]
96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 96/100 [00:10<00:00, 10.13it/s]
98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 98/100 [00:10<00:00, 10.18it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:11<00:00, 9.91it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:11<00:00, 9.09it/s]
fig = plt.figure()
ax = fig.add_subplot()
_cost = np.array(mynmf.cost_) #/X_train.shape[0]
ax.plot(_cost)
ax.set_xlabel("Number of update")
ax.set_ylabel("Reconstruction error")
ax.set_title("Reconstruction Error")
print("最後の更新時の損失関数の値:",_cost[-1])
最後の更新時の損失関数の値: 10.900271385543064
9.4. 参考文献#
9.4.1. 論文等#
- LS99
D D Lee and H S Seung. Learning the parts of objects by non-negative matrix factorization. Nature, 401(6755):788–791, October 1999. doi:10.1038/44565.
- LS00
Daniel Lee and H Sebastian Seung. Algorithms for non-negative matrix factorization. In T Leen, T Dietterich, and V Tresp, editors, Advances in Neural Information Processing Systems, volume 13. MIT Press, 2000.
- Lin07
Chih-Jen Lin. Projected gradient methods for nonnegative matrix factorization. Neural Comput., 19(10):2756–2779, October 2007. doi:10.1162/neco.2007.19.10.2756.