9. NMF (Non-Negative Matrix Factorization; 非負値行列因子分解)#

9.1. NMFとは#

NMFは,非負値(\(>=0\))の元行列 \(\mathbf{V}\) を他の2つの非負値な行列 \(\mathbf{W}, \mathbf{H}\)の積で近似するアルゴリズムです.例えばユーザーごとの購買履歴を保存した行列\(\mathbf{V}\)が与えられた時に,これをユーザー数\(D\)✖️所与の埋め込み次元\(K\)\(K\)は元の特徴数よりもかなり小さい値)の行列\(\mathbf{W}\)\(K\)✖️特徴数\(F\)の行列\(\mathbf{H}\)に分解するようなタスクです.これらの二つの行列に何かしらの演算(ここでは積を取ります)をして,元の行列に近い行列を再構築できるようにすることで,より小さい二つの行列で元の行列を圧縮することができていると言えます.このようなタスクを行列分解と呼びます.

学習にはさまざまな方法があります.

  1. 乗法更新式

    • 損失関数として定義するユークリッド距離やIダイバージェンスを,パラメタ更新の度に小さくするような更新式を利用します.これは数学的に損失関数が単調減少することが証明されています.[Lee and Seung, 1999, Lee and Seung, 2000]

  2. 勾配法

    • Neural Netの訓練でも利用されるアルゴリズムです.そのまま使うと\(W\)\(H\)には0未満の値が含まれてしまうので,0未満になる場合は0で置き換えるような処理を追加して利用します.収束するとは限りませんが,実装は簡単です.[Lin, 2007]

9.2. scikit-learnを使った実験#

NMFはscikit-learnに実装されているので,これを利用してみましょう.

import numpy as np
import pandas as pd 
import plotly.express as px 
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.decomposition import NMF 
from sklearn.exceptions import NotFittedError
from tqdm.auto import trange
import plotly.express as px 
import matplotlib.pyplot as plt
/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

今回は20news groupsというテキストデータを利用します.BoWに変換しているので,行列の要素は全て0以上になります.

news_train = fetch_20newsgroups(subset="train")
news_test = fetch_20newsgroups(subset="test")
vectorizer = CountVectorizer(lowercase=True, max_features=1000, stop_words="english", min_df=2, max_df=0.5)
X_train = vectorizer.fit_transform(news_train.data)
X_test = vectorizer.transform(news_test.data)
id2word = {id:key for id,key in enumerate(vectorizer.get_feature_names())}
word2id = {key:id for id,key in id2word.items()}
/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function get_feature_names is deprecated; get_feature_names is deprecated in 1.0 and will be removed in 1.2. Please use get_feature_names_out instead.
  warnings.warn(msg, category=FutureWarning)
nmf = NMF(n_components=20)
nmf.fit(X_train)
W_doc_k = nmf.transform(X_test)
print(W_doc_k.shape)
H_k_vocab = nmf.components_
print(H_k_vocab.shape)
(7532, 20)
(20, 1000)
nmf.reconstruction_err_ #/ X_train.shape[0]
1488.5032852593972

9.3. NumPyを使って実装する#

ユークリッド距離を損失関数にして,乗法更新式による訓練を行います.

できるだけ行列計算を行う行にはshapeをコメントしてあります.プログラムを読む時に参考にしてください.

def update_Vt_by_euclid(X, U, Vt):
    _X = U @ Vt # (D,F)=(D,K)@(K,F)
    _bias = (U.T @ X) / (U.T @ _X) # (D,K).T@(D,F) / (D,K).T@(D,F)
    _bias[np.isnan(_bias)] = 0.0
    Vt *= _bias # (K,F)=(K,F)*(K,F)
    return Vt 

def update_U_by_euclid(X,U,Vt):        
    _X = U @ Vt # (D,F)=(D,K)@(K,F)
    _bias = (X @ Vt.T) / (_X @ Vt.T) # (D,F)@(K,F).T / (D,F)@(K,F).T
    _bias[np.isnan(_bias)] = 0
    U *= _bias # (D,K)=(D,K)(D,K)
    return U

def cost_fn_by_euclid(X,_X):
    return np.linalg.norm(X - _X, axis=1).mean()

class MyNMF():
    def __init__(self, n_components:int=2, max_iter:int=100, rng:bool=None, divergence="euclid"):
        self.n_components = n_components
        self.max_iter = max_iter
        self.rng_ = rng if rng is not None else np.random.default_rng(2**1000) 
        self.divergence = divergence
        self.is_fitted = False 
        self.cost_ = []
        if self.divergence == "euclid":
            self.update_Vt = update_Vt_by_euclid
            self.update_U = update_U_by_euclid
            self.cost_fn = cost_fn_by_euclid
        else:
            NotImplementedError('divergenceは["eculid",]から選択')

    def fit_transform(self, X:np.ndarray,y=None):
        X = X.astype(np.float64)
        self._n_features = X.shape[1]
        
        # Initialize two small matrices from a uniform distribution
        _U = self.rng_.uniform(0,1, 
                              size=[X.shape[0],self.n_components],
                              ).astype(X.dtype) # (D,K)
        _Vt = self.rng_.uniform(0,1,
                                size=[self.n_components, self._n_features],
                                ).astype(X.dtype) # (K,F)
        
        # update parameters
        for i in trange(self.max_iter):
            _Vt = self.update_Vt(X,_U,_Vt)
            _U = self.update_U(X,_U,_Vt)
            _X = _U@_Vt
            self.cost_.append(self.cost_fn(X, _X))
        
        # output
        self.components_ = _Vt # Store _Vt in instance variable to be accessed from outside
        self.is_fitted = True # Raise the flag
        return _U
    
    def fit(self,X,y=None):
        self.fit_transform(X)
        return self
    
    def transform(self, X):
        if not self.is_fitted:
            raise NotFittedError(f"{self.__class__.__name__}.transformはfit後にのみ利用できる")
        if self.components_.shape[1] != X.shape[1]:
            raise ValueError("Xと訓練データの特徴数が異なっている")
        X = X.astype(np.float64)
        
        # Initialize U from a uniform distribution
        U = self.rng_.uniform(0,1, 
                              size=[X.shape[0],self.n_components],
                              ).astype(X.dtype) # (K,F)
        
        for i in trange(self.max_iter):
            U = self.update_U(X,U,self.components_)
        return U

9.3.1. 訓練の実行#

mynmf = MyNMF(20,max_iter=100)
U = mynmf.fit_transform(X_train)
Hide code cell output
  0%|                                                                                                                                                                                                                                                                             | 0/100 [00:00<?, ?it/s]
  1%|██▌                                                                                                                                                                                                                                                                  | 1/100 [00:00<00:13,  7.34it/s]
  3%|███████▊                                                                                                                                                                                                                                                             | 3/100 [00:00<00:10,  9.37it/s]
  5%|█████████████                                                                                                                                                                                                                                                        | 5/100 [00:00<00:09,  9.85it/s]
  6%|███████████████▋                                                                                                                                                                                                                                                     | 6/100 [00:00<00:09,  9.74it/s]
  7%|██████████████████▎                                                                                                                                                                                                                                                  | 7/100 [00:00<00:10,  8.84it/s]
  8%|████████████████████▉                                                                                                                                                                                                                                                | 8/100 [00:00<00:10,  8.50it/s]
  9%|███████████████████████▍                                                                                                                                                                                                                                             | 9/100 [00:01<00:10,  8.80it/s]
 11%|████████████████████████████▌                                                                                                                                                                                                                                       | 11/100 [00:01<00:09,  9.49it/s]
 12%|███████████████████████████████▏                                                                                                                                                                                                                                    | 12/100 [00:01<00:09,  9.24it/s]
 13%|█████████████████████████████████▊                                                                                                                                                                                                                                  | 13/100 [00:01<00:11,  7.85it/s]
 14%|████████████████████████████████████▍                                                                                                                                                                                                                               | 14/100 [00:01<00:10,  8.32it/s]
 15%|███████████████████████████████████████                                                                                                                                                                                                                             | 15/100 [00:01<00:09,  8.72it/s]
 16%|█████████████████████████████████████████▌                                                                                                                                                                                                                          | 16/100 [00:01<00:09,  8.49it/s]
 17%|████████████████████████████████████████████▏                                                                                                                                                                                                                       | 17/100 [00:01<00:10,  8.26it/s]
 18%|██████████████████████████████████████████████▊                                                                                                                                                                                                                     | 18/100 [00:02<00:10,  8.04it/s]
 19%|█████████████████████████████████████████████████▍                                                                                                                                                                                                                  | 19/100 [00:02<00:10,  7.87it/s]
 20%|████████████████████████████████████████████████████                                                                                                                                                                                                                | 20/100 [00:02<00:10,  7.80it/s]
 21%|██████████████████████████████████████████████████████▌                                                                                                                                                                                                             | 21/100 [00:02<00:10,  7.82it/s]
 22%|█████████████████████████████████████████████████████████▏                                                                                                                                                                                                          | 22/100 [00:02<00:10,  7.52it/s]
 23%|███████████████████████████████████████████████████████████▊                                                                                                                                                                                                        | 23/100 [00:02<00:10,  7.33it/s]
 24%|██████████████████████████████████████████████████████████████▍                                                                                                                                                                                                     | 24/100 [00:02<00:09,  7.93it/s]
 26%|███████████████████████████████████████████████████████████████████▌                                                                                                                                                                                                | 26/100 [00:03<00:08,  8.41it/s]
 27%|██████████████████████████████████████████████████████████████████████▏                                                                                                                                                                                             | 27/100 [00:03<00:09,  7.98it/s]
 28%|████████████████████████████████████████████████████████████████████████▊                                                                                                                                                                                           | 28/100 [00:03<00:08,  8.19it/s]
 29%|███████████████████████████████████████████████████████████████████████████▍                                                                                                                                                                                        | 29/100 [00:03<00:08,  8.45it/s]
 30%|██████████████████████████████████████████████████████████████████████████████                                                                                                                                                                                      | 30/100 [00:03<00:08,  7.95it/s]
 31%|████████████████████████████████████████████████████████████████████████████████▌                                                                                                                                                                                   | 31/100 [00:03<00:08,  8.42it/s]
 32%|███████████████████████████████████████████████████████████████████████████████████▏                                                                                                                                                                                | 32/100 [00:03<00:09,  6.81it/s]
 33%|█████████████████████████████████████████████████████████████████████████████████████▊                                                                                                                                                                              | 33/100 [00:04<00:09,  7.01it/s]
 34%|████████████████████████████████████████████████████████████████████████████████████████▍                                                                                                                                                                           | 34/100 [00:04<00:08,  7.54it/s]
 35%|███████████████████████████████████████████████████████████████████████████████████████████                                                                                                                                                                         | 35/100 [00:04<00:08,  8.05it/s]
 36%|█████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                                                                                      | 36/100 [00:04<00:08,  7.70it/s]
 37%|████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                                                                                                                   | 37/100 [00:04<00:08,  7.02it/s]
 38%|██████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                                                                                                                 | 38/100 [00:04<00:08,  7.35it/s]
 39%|█████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                                                                                                              | 39/100 [00:04<00:07,  7.94it/s]
 40%|████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                                                                            | 40/100 [00:04<00:07,  8.27it/s]
 41%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                                                                         | 41/100 [00:05<00:07,  7.77it/s]
 42%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                                                                                                      | 42/100 [00:05<00:07,  8.02it/s]
 44%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                                                                                                 | 44/100 [00:05<00:06,  8.95it/s]
 46%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                                                            | 46/100 [00:05<00:06,  8.97it/s]
 47%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                                                                                         | 47/100 [00:05<00:06,  8.81it/s]
 48%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                                                                                       | 48/100 [00:05<00:05,  8.88it/s]
 49%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                                                                                    | 49/100 [00:05<00:05,  9.13it/s]
 50%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                                                  | 50/100 [00:06<00:05,  9.03it/s]
 51%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                                               | 51/100 [00:06<00:05,  9.10it/s]
 52%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                                                                            | 52/100 [00:06<00:05,  9.13it/s]
 53%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                                                                          | 53/100 [00:06<00:05,  9.08it/s]
 54%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                                                                       | 54/100 [00:06<00:04,  9.25it/s]
 55%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                                     | 55/100 [00:06<00:04,  9.17it/s]
 56%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                                  | 56/100 [00:06<00:04,  8.91it/s]
 58%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                                                             | 58/100 [00:06<00:04,  9.43it/s]
 60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                        | 60/100 [00:07<00:04,  9.82it/s]
 62%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                                                  | 62/100 [00:07<00:03,  9.97it/s]
 64%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                                             | 64/100 [00:07<00:03, 10.06it/s]
 66%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                        | 66/100 [00:07<00:03, 10.17it/s]
 68%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                                   | 68/100 [00:07<00:03, 10.28it/s]
 70%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                              | 70/100 [00:08<00:02, 10.28it/s]
 72%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                        | 72/100 [00:08<00:02, 10.32it/s]
 74%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                   | 74/100 [00:08<00:02, 10.32it/s]
 76%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                              | 76/100 [00:08<00:02, 10.37it/s]
 78%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                         | 78/100 [00:08<00:02, 10.41it/s]
 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                    | 80/100 [00:09<00:01, 10.35it/s]
 82%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                              | 82/100 [00:09<00:01, 10.33it/s]
 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                         | 84/100 [00:09<00:01, 10.00it/s]
 86%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                    | 86/100 [00:09<00:01, 10.12it/s]
 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                               | 88/100 [00:09<00:01, 10.14it/s]
 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                          | 90/100 [00:09<00:00, 10.19it/s]
 92%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                    | 92/100 [00:10<00:00, 10.03it/s]
 94%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍               | 94/100 [00:10<00:00, 10.09it/s]
 96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌          | 96/100 [00:10<00:00, 10.13it/s]
 98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊     | 98/100 [00:10<00:00, 10.18it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:11<00:00,  9.91it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:11<00:00,  9.09it/s]

fig = plt.figure()
ax = fig.add_subplot()
_cost = np.array(mynmf.cost_) #/X_train.shape[0]
ax.plot(_cost)
ax.set_xlabel("Number of update")
ax.set_ylabel("Reconstruction error")
ax.set_title("Reconstruction Error")
print("最後の更新時の損失関数の値:",_cost[-1])
最後の更新時の損失関数の値: 10.900271385543064
_images/a415014f6627befcbbb7a1032b6429764420320c662a4e62ba0c3264a1c9c27b.png

9.4. 参考文献#

9.4.1. 論文等#

LS99

D D Lee and H S Seung. Learning the parts of objects by non-negative matrix factorization. Nature, 401(6755):788–791, October 1999. doi:10.1038/44565.

LS00

Daniel Lee and H Sebastian Seung. Algorithms for non-negative matrix factorization. In T Leen, T Dietterich, and V Tresp, editors, Advances in Neural Information Processing Systems, volume 13. MIT Press, 2000.

Lin07

Chih-Jen Lin. Projected gradient methods for nonnegative matrix factorization. Neural Comput., 19(10):2756–2779, October 2007. doi:10.1162/neco.2007.19.10.2756.