リアルタイム学習手法:オンライン機械学習の実装例のご紹介

この記事は約15分で読めます。

機械学習は一般的には大量の学習データを用意して、一括で読込んでモデルの学習を行います。
この学習方法はバッチ学習といいます。
一度に全てのデータを処理するため、新しいデータの追加や大量データに対する反応性において柔軟に対応できない場合があります。

これに対しオンライン学習とは、流れてくるデータにリアルタイムで学習する手法です。
これは大量かつ頻繁に更新されるデータをすぐにモデルに取り入れて学習させたい場合に有効な手段の1つです。
オンライン機械学習とは、このような学習方法を用いて行われる機械学習全体を指します。

今回はオンライン機械学習モデルをPythonで実装する例を紹介しようと思います。
その中でも特に今回は、Confidence Weighted Learning: CWとSoft Confidence Weighted Learning: SCWの2つのオンライン機械学習の手法に焦点を当てます。

オンライン機械学習について

オンライン機械学習は、流れてくるデータにリアルタイムで学習し、予測を行うという特性を持っていて、以下のようなメリットを持っています。

  1. 最新データのリアルタイム入力に対応
  2. 大量のデータに対応
  3. データの安全性
  4. ハードウェアリソースの節約

1については、オンライン機械学習はデータがリアルタイムに到着するという設定において最適です。
新しいデータが到着する度にモデルを更新するため、最新の情報を反映した予測が可能となります。
これは、常に最新のデータに基づいて精度を更新し続けることが求められる状況に対応します。
例えば、金融市場の予測やリアルタイムでの商品推薦などに非常に有用です。

2については、一部のデータだけを使用して学習を行えるため、データが膨大な量で一度に全てをメモリに保持できない状況においても、全体のデータを通じた学習結果を得ることができる点で有効です。

3については、一度に一部のデータしか必要としないため、個人情報などのセンシティブな情報を含むデータの取り扱いにも有用です。
データは一度学習に用いられた後、すぐに破棄することができ、データの保管リスクを軽減します。

4については、全てのデータを一度にメモリに保管する必要がないため、性能の低い端末でも処理が可能です。
これは、リソース制約のある環境での適用に有利です。

しかし、オンライン機械学習にはいくつかのデメリットも存在します。
例えば、データがランダムに到着しない場合や、データの順序が結果に影響を与える場合、オンライン機械学習の性能は下がる可能性があります。
また、ノイズの多いデータや外れ値が含まれる場合、その影響を即座に学習してしまうため、予測精度が低下することもあります。

これらの理由から、オンライン機械学習は、高度に動的で大規模なデータ、または敏感なデータを扱う多くの現実的なアプリケーションに対して利点を提供しますが、適用する際はデータの性質と目的に応じて慎重に考えることが重要です。

より詳しい内容については、下記の書籍が参考になります。
機械学習プロフェッショナルシリーズの中で、オンライン機械学習をテーマにした一冊です。

今回実装する手法は以下の論文で紹介されているCW・SCWを実装してみます。

Confidence Weighted Learning: CW

まずはCWの実装になります。

CWのアルゴリズムのざっくりとしたイメージは、線形分離をする際の各次元の重みが多次元正規分布に従うとして考え、その各重みに対応する分散を重みの信頼度として考えたモデルになります。
分散が大きければ大きいほど、その重みについては自信がないと解釈します。

この場合、共分散を考えることもできますが、おそらく共分散は保存する必要はなさそうな気がします。
使わなそうだけど、何かに使えないかなーとか考えてしまいますね。

早速、Pythonによる実装が下記になります。
以下は試しにアヤメの分類データを学習させてみたものになります。
実装自体は実はそんなにコード量は多くなく作れてしまいます。

import numpy as np
import pandas as pd
from sklearn import datasets
from scipy import stats

# irisデータセット
iris = datasets.load_iris()
data = pd.DataFrame(data= np.c_[iris["data"], iris["target"]], columns= iris["feature_names"] + ["target"])
selected_species = [0, 1] # setosa, versicolor
df = data[data["target"].isin(selected_species)] # setosa, versicolorのデータ
df = df.reindex(np.random.permutation(df.index)) # シャッフル
data_x = df[iris["feature_names"]].as_matrix() # sepal length (cm), sepal width (cm), petal length (cm), petal width (cm)
data_t = [1 if i == selected_species[0] else -1 for i in df["target"]] # ラベル{1, -1}
train_x, test_x = data_x[:80], data_x[80:]
train_t, test_t = data_t[:80], data_t[80:]
print("学習データ件数: ", len(train_x))
print("テストデータ件数: ", len(test_x))

# confidence weighted learning クラス
class CW():
    def __init__(self, in_size):
        self.mu = np.zeros(in_size)
        self.sigma = np.eye(in_size)
        self.eta = 0.95
        self.phi = stats.norm.ppf(self.eta)
        self.psi = 1+self.phi**2/2
        self.xi = 1+self.phi**2

    def train(self, x, t):
        #  学習
        m_t = t*self.mu.dot(x)
        v_t = x.dot(self.sigma).dot(x)
        alpha_t = max(0, (-m_t*self.psi+np.sqrt((m_t**2)*(self.phi**4)/4+v_t*(self.phi**2)*self.xi))/(v_t*self.xi))
        u_t = ((-alpha_t*v_t*self.phi+np.sqrt((alpha_t**2)*(v_t**2)*(self.phi**2)+4*v_t))**2)/4
        beta_t = (alpha_t*self.phi)/(np.sqrt(u_t)+v_t*alpha_t*self.phi)
        self.mu = self.mu+alpha_t*t*self.sigma.dot(x)
        self.sigma = self.sigma-beta_t*self.sigma.dot(x)[:,np.newaxis].dot(x.dot(self.sigma)[np.newaxis,:])

    def predict(self, x):
        # 予測
        if x.dot(self.mu) > 0:
            return 1
        else:
            return -1

# 正解率を計算する関数
def get_accuracy(model, dataset_x, dataset_t):
    result = []
    for x, t in zip(dataset_x, dataset_t):
        if model.predict(x)*t > 0: # 正解すれば1, 間違えれば-1
            result.append(1)
        else:
            result.append(0)
        accuracy = sum(result)/len(result)
        return accuracy

# 定数
EPOCH_NUM = 1

# CWクラス
cw = CW(in_size=len(iris["feature_names"]))

# 学習
for epoch in range(EPOCH_NUM):
    for x, t in zip(train_x, train_t):
        cw.train(x, t)
    accuracy1 = get_accuracy(cw, train_x, train_t) # 学習データの正解率
    accuracy2 = get_accuracy(cw, test_x, test_t) # バリデーションデータの正解率
    print("train/accuracy: {}, test/accuracy: {}".format(accuracy1, accuracy2)) # ログ
学習データ件数:  80
テストデータ件数:  20
train/accuracy: 1.0, test/accuracy: 1.0

アヤメのデータは簡単な問題なので、精度はあまり気にしないで良いと思います。

線形分離可能な問題ではかなり有効なアルゴリズムのように思います。
いろいろと試してみると分かりますが、かなり少ないサンプル数(最初の1,2件の入力)でもすぐに収束する性質があるのが分かります。
故に、バイアスがかかったデータや外れ値なデータが入力されてしまった場合に、その一瞬だけ精度が大きく悪化する可能性がありそうです。

Soft Confidence Weighted Learning: SCW

次にSCWを実装してみます。

CWでは線形分離不可の問題に対してはうまく学習できず、境界が大きくぶれてしまうという弱点があります。
それを克服するため、幾分か耐性をもたせたモデルがSCWです。
より具体的には、ソフトマージンを導入することで誤分類されたインスタンスに対するペナルティを緩和し、より頑健な学習を提供します。
予測の不確実性をモデルに組み込むことで、データストリームの中で新しいパターンが出現した場合でも迅速に対応できるようにすることを目指しています。

前述の論文では2つのやり方(Prop.1, Prop.2)が提案されていましたので、両方とも実装を紹介して終わりにしようと思います。

Prop.1

# soft confidence weighted learning prop1クラス
class SCW1():
    def __init__(self, in_size):
        self.mu = np.zeros(in_size)
        self.sigma = np.eye(in_size)
        self.eta = 0.95
        self.C = 1
        self.phi = stats.norm.ppf(self.eta)
        self.psi = 1+self.phi**2/2
        self.xi = 1+self.phi**2

    def train(self, x, t):
        #  学習
        m_t = t*self.mu.dot(x)
        v_t = x.dot(self.sigma).dot(x)
        alpha_t = min(self.C, max(0, (-m_t*self.psi+np.sqrt((m_t**2)*(self.phi**4)/4+v_t*(self.phi**2)*self.xi))/(v_t*self.xi)))
        u_t = ((-alpha_t*v_t*self.phi+np.sqrt((alpha_t**2)*(v_t**2)*(self.phi**2)+4*v_t))**2)/4
        beta_t = (alpha_t*self.phi)/(np.sqrt(u_t)+v_t*alpha_t*self.phi)
        self.mu = self.mu+alpha_t*t*self.sigma.dot(x)
        self.sigma = self.sigma-beta_t*self.sigma.dot(x)[:,np.newaxis].dot(x.dot(self.sigma)[np.newaxis,:])

    def predict(self, x):
        # 予測
        if x.dot(self.mu) > 0:
            return 1
        else:
            return -1

# SCW1クラス
scw1 = SCW1(in_size=len(iris["feature_names"]))

# 学習
for epoch in range(EPOCH_NUM):
    for x, t in zip(train_x, train_t):
        scw1.train(x, t)
    accuracy1 = get_accuracy(scw1, train_x, train_t) # 学習データの正解率
    accuracy2 = get_accuracy(scw1, test_x, test_t) # バリデーションデータの正解率
    print("train/accuracy: {}, test/accuracy: {}".format(accuracy1, accuracy2)) # ログ
train/accuracy: 1.0, test/accuracy: 1.0

Prop.2

# soft confidence weighted learning prop2クラス
class SCW2():
    def __init__(self, in_size):
        self.mu = np.zeros(in_size)
        self.sigma = np.eye(in_size)
        self.eta = 0.95
        self.C = 1
        self.phi = stats.norm.ppf(self.eta)
        self.psi = 1+self.phi**2/2
        self.xi = 1+self.phi**2

    def train(self, x, t):
        #  学習
        m_t = t*self.mu.dot(x)
        v_t = x.dot(self.sigma).dot(x)
        n_t = v_t+1/2*self.C
        gamma_t = self.phi*np.sqrt((self.phi**2)*(m_t**2)*(v_t**2)+4*n_t*v_t*(n_t+v_t*(self.phi**2)))
        alpha_t = max(0, (-(2*m_t*n_t+(self.phi**2)*m_t*v_t)+gamma_t)/(2*(n_t**2+n_t*v_t*(self.phi**2))))
        u_t = ((-alpha_t*v_t*self.phi+np.sqrt((alpha_t**2)*(v_t**2)*(self.phi**2)+4*v_t))**2)/4
        beta_t = (alpha_t*self.phi)/(np.sqrt(u_t)+v_t*alpha_t*self.phi)
        self.mu = self.mu+alpha_t*t*self.sigma.dot(x)
        self.sigma = self.sigma-beta_t*self.sigma.dot(x)[:,np.newaxis].dot(x.dot(self.sigma)[np.newaxis,:])

    def predict(self, x):
        # 予測
        if x.dot(self.mu) > 0:
            return 1
        else:
            return -1

# SCW2クラス
scw2 = SCW2(in_size=len(iris["feature_names"]))

# 学習
for epoch in range(EPOCH_NUM):
    for x, t in zip(train_x, train_t):
        scw2.train(x, t)
    accuracy1 = get_accuracy(scw2, train_x, train_t) # 学習データの正解率
    accuracy2 = get_accuracy(scw2, test_x, test_t) # バリデーションデータの正解率
    print("train/accuracy: {}, test/accuracy: {}".format(accuracy1, accuracy2)) # ログ
train/accuracy: 1.0, test/accuracy: 1.0

まとめ

今回はオンライン機械学習について解説し、CW、SCWのPythonによる実装について紹介しました。

重みを正規分布で表現するという考え方は面白いですね。
そのためモデル自体はかなり軽量ですので、とてもクイックに学習が可能です。

これらの手法はリアルタイムのデータストリームを扱うための強力なツールであり、各データ点の到着に応じてモデルをリアルタイムに更新します。
しかし、モデリング手法を選ぶ際には、常にその手法が問題の性質やデータの特徴に適しているかどうかを考慮することが重要です。
特に実務においては、モデルはリッチであればリッチであるほど良いという単純なものでもなく、精度や計算リソースのコスト、処理時間などのバランスをケースバイケースで考慮しながらアプローチしなければならない時が多くあります。
オンライン機械学習もあくまで一つの選択肢であり、その利点と制約を理解した上で、適切なツールを選択するよう心掛けることが重要です。