テストデータの精度が低いのはなぜ?ディープラーニングの落とし穴と解決策

過学習
イラスト出展:OpenAI dall-e-3

初めに

ディープラーニングのお悩みに回答するね

この記事では、ディープラーニングで訓練データの正解率(Actually)は高いのに、テストデータの正解率(Actually)が極めて低いという悩みに焦点を当て、その原因と対策を解説します。筆者が実際に経験した事例も交えながら、わかりやすく説明していきます。

過学習

この記事を読めばわかること

  • 訓練データとテストデータの正解率の違いが理解できる
  • テストデータの正解率が低い原因を3つに分類できる
  • 各原因に対する具体的な対策方法がわかる

本記事の対象者

  • ディープラーニングの基本的な知識を持っている方
  • 自分でモデルを組んで学習させている方
  • 訓練データとテストデータの正解率に差があり悩んでいる方

訓練データとテストデータの正解率とは?

ディープラーニングでは、モデルの性能を評価するために、訓練データとテストデータの2種類のデータセットを使用します。

  • 訓練データ: モデルの学習に使用されるデータセットです。
  • テストデータ: モデルの学習には使用されず、学習済みのモデルの性能を評価するために使用されるデータセットです。

訓練データの正解率は、モデルが学習に使用したデータに対する正解率です。一方、テストデータの正解率は、モデルが学習に使用していないデータに対する正解率です。

例えるなら、訓練データは模擬試験、テストデータは本試験です。

  • 訓練データ (模擬試験): モデルは、訓練データを使って何度も学習し、模擬試験で良い点を取れるように訓練されます。
  • テストデータ (本試験): 学習済みのモデルは、テストデータを使って、実際にどの程度の知識を習得しているのか、つまり本試験でどの程度の点数が取れるのかが評価されます。

なぜテストデータの正解率が低いのは問題なのか?

ディープラーニングモデルの目的は、未知のデータに対しても高い精度で予測することです。訓練データの正解率が高くても、テストデータの正解率が低いということは、モデルが訓練データに特化してしまい、汎化性能が低いことを意味します。

例えるなら、訓練データは模擬試験、テストデータは本試験です。いくら模擬試験で良い点を取れても、本試験で良い点が取れなければ意味がありません。テストデータの正解率が低いディープラーニングのモデルは、あまり実用性の無いモデルになります。

テストデータの正解率が低い原因を3つに分類

訓練データの正解率が高いのにテストデータの正解率が低い場合、その原因は大きく3つに分類できます。

1. データに隔たりがある

訓練データとテストデータの間にデータの偏りがあると、モデルは訓練データの偏った特徴に特化して学習してしまい、テストデータに対してはうまく予測できなくなってしまいます。

例えば犬と猫の分類において、訓練データの犬と猫の比率が9対1であった場合、常に犬と答えれば正解率は90%になりますので、ディープラーニングのモデルが常に犬と答えるように学習することがあります。この場合、テストデータの犬と猫の比率が1対1であれば、正解率は50%にしかなりません。

データの隔たり
出展:OpenAI dall-e-3

2. 汎化出来ていない

モデルが訓練データに特化しすぎてしまい、訓練データ以外のデータに対してはうまく対応できない状態を「汎化できていない」といいます。これは、過学習と呼ばれる現象によって起こります。

例えば模擬試験をひたすら勉強し模擬試験は満点を取れたけど、応用が出来なくて本試験の点数が良くなかった。みたいな状況です。
既出のデータに対する結果は丸暗記したけど、応用力(汎化)が出来てなくて、新しいデータに正しく答えられないのが過学習です。
この場合もテストデータに対する正解率は低くなります。

過学習
出展:OpenAI dall-e-3

3. 入力と出力に因果関係が無い

入力データと出力データ間に本来の因果関係がない場合、訓練データの正解率は高くなる可能性がありますが、テストデータでは意味のない結果になる可能性があります。

例えばじゃんけんの過去の結果(グー、チョキ、パー)を学習し、次の手を予測するモデルを考えます。訓練データに「グーが3回続いたら次はパーが出る」というパターンが含まれていると、モデルはこのパターンを学習し、訓練データの正解率は高くなります。しかし、実際にはじゃんけんの結果に過去の結果との因果関係はなく、次の手はランダムに決まります。そのため、テストデータでは正解率は低くなるでしょう。

データの因果関係
出展:OpenAI dall-e-3

各原因に対する対策

それぞれの原因に対して、具体的な対策方法を以下に示します。

1. データに隔たりがある場合の対策

データの隔たりを解消するために、以下の方法を試してみましょう。

  • 少ないデータのサンプルを増やす: 最も効果的な方法です。データ収集やデータ作成を行い、少ないデータのサンプルを増やしましょう。
  • 少ないデータをコピーし増幅する: データ収集が難しい場合は、既存のデータのコピーと改変(回転、ぼかし、色調変更など)によってデータを増幅することも有効です。
  • 多いデータを減らす: 多いデータのサンプルを減らすことで、データの偏りを解消することもできます。ただし、データ量が減ると学習不足になる可能性もあるため、注意が必要です。
  • データの重みを調整する: 学習時に、少ないデータのサンプルに対して重みを大きく設定することで、少ないデータの影響を大きくすることができます。

2. 汎化できていない場合の対策

過学習を防ぎ、汎化性能を高めるために、以下の対策を検討しましょう。

  • データに適したモデルを選択する: 例えば、画像分類では畳み込みニューラルネットワーク(CNN)が有効です。CNNは画像の空間的な特徴を捉えることに長けており、汎化性能が高いモデルです。
  • 正則化手法を用いる: 正則化手法は、モデルのパラメータの大きさを制限することで、過学習を防ぐ効果があります。代表的な手法として、L1正則化、L2正則化、ドロップアウトなどがあります。
  • 早期停止(アーリーストップ): 学習の途中で、テストデータに対する性能が向上しなくなったら学習を停止する手法です。過学習を防ぎ、汎化性能を向上させる効果が期待できます。

3. 入力と出力に因果関係が無い場合の対策

因果関係がないデータで学習しても意味がないため、データ自体を見直す必要があります。例えば、じゃんけんの例では、過去の結果だけでなく、対戦相手の情報(性別、年齢、過去の対戦成績など)も入力データに加えることで、因果関係を構築できる可能性があります。

各対策の具体例(Kerasコードサンプル付き)

1. データに隔たりがある場合の対策

1-1. 少ないデータのサンプルを増やす

  • データ収集:
    • WebスクレイピングやAPIを利用して、足りないデータを取得する。
    • 実験やアンケートを実施して、新しいデータを作成する。
  • データ作成:
    • 画像データであれば、画像編集ソフトやライブラリを用いて、既存の画像を加工する。
    • テキストデータであれば、既存の文章を改変したり、新しい文章を生成する。

1-2. 少ないデータをコピーし増幅する

  • 画像データの場合:
    • 画像の回転、反転、スケール変更、色調変更などを施す。
    • KerasのImageDataGeneratorを使用すると、これらの変換を簡単に適用できる。
from tensorflow.keras.preprocessing.image import ImageDataGenerator

datagen = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# トレーニングデータの増幅
train_generator = datagen.flow_from_directory(
    'train_data',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'
)

# モデルの学習
model.fit(train_generator, epochs=10)

1-3. 多いデータを減らす

  • サンプリング: データセットからランダムにサンプルを抽出する。
  • データ選択: データの偏りを解消するために、特定の種類のデータを削除する。
# ランダムサンプリングの例
import random

# トレーニングデータから50%をランダムに選択
train_data_subset = random.sample(train_data, int(len(train_data) * 0.5))

# モデルの学習
model.fit(train_data_subset, epochs=10) 

1-4. データの重みを調整する

  • クラスウェイト: 各クラスのサンプル数に反比例する重みを設定する。
  • Kerasのclass_weightパラメータを使用する。
from sklearn.utils import class_weight

# クラスウェイトを計算
class_weights = class_weight.compute_class_weight('balanced',
                                                 np.unique(train_labels),
                                                 train_labels)

# モデルの学習
model.fit(train_data, train_labels, epochs=10, class_weight=class_weights)

2. 汎化できていない場合の対策

2-1. データに適したモデルを選択する

  • CNN: 画像データの場合、畳み込みニューラルネットワーク(CNN)は空間的な特徴を捉えることに長けており、汎化性能が高い。
  • RNN: 時系列データの場合、再帰型ニューラルネットワーク(RNN)が有効。
  • Transformer: 自然言語処理の場合、Transformerは文脈を理解する能力が高く、汎化性能が高い。
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

# CNNモデルの例
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(10, activation='softmax'))

# モデルのコンパイル
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# モデルの学習
model.fit(train_data, train_labels, epochs=10)

2-2. 正則化手法を用いる

  • L1正則化: モデルのパラメータの絶対値の合計を最小化する。
  • L2正則化: モデルのパラメータの2乗の合計を最小化する。
  • ドロップアウト: 学習時に、ランダムにニューロンを無効化することで、過学習を防ぐ。
from tensorflow.keras.layers import Dropout

# L2正則化とドロップアウトの例
model = Sequential()
model.add(Dense(128, activation='relu', kernel_regularizer='l2', input_shape=(100,)))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))

# モデルのコンパイル
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# モデルの学習
model.fit(train_data, train_labels, epochs=10)

2-3. 早期停止(アーリーストップ)

  • EarlyStoppingコールバック: KerasのEarlyStoppingコールバックを使用することで、テストデータに対する性能が向上しなくなったら学習を停止できる。
from tensorflow.keras.callbacks import EarlyStopping

# 早期停止の例
early_stopping = EarlyStopping(monitor='val_loss', patience=3)

# モデルの学習
model.fit(train_data, train_labels, epochs=10, callbacks=[early_stopping])

3. 入力と出力に因果関係が無い場合の対策

  • データ自体を見直す: データに因果関係が含まれていない場合は、データの取得方法や内容を見直す必要がある。
  • 因果関係を構築できる新たなデータを追加する: 例えば、じゃんけんの例では、対戦相手の情報(性別、年齢、過去の対戦成績など)を追加する。

困ったときは

困った時は、Chat GPTやGoogle Gemini等を使うのもひとつの方法です。機械学習の知識がある程度あれば、Chat GPTやGoogle Gemini等を使い対策前進出来ると思います。機械学習の難しいところはモデルよりも実はデータ集めで、そこは自分の力で何とかする必要がありますが、標準的なモデルの構築であればChat GPTやGoogle Geminiを活用すると何とかなることが多いです。

最後に

本記事が少しでも皆さまの役にたてたら幸いです。

一緒に楽しくAIについて学びましょう。