テストデータの精度が低いのはなぜ?ディープラーニングの落とし穴と解決策
初めに
ディープラーニングのお悩みに回答するね
この記事では、ディープラーニングで訓練データの正解率(Actually)は高いのに、テストデータの正解率(Actually)が極めて低いという悩みに焦点を当て、その原因と対策を解説します。筆者が実際に経験した事例も交えながら、わかりやすく説明していきます。
この記事を読めばわかること
- 訓練データとテストデータの正解率の違いが理解できる
- テストデータの正解率が低い原因を3つに分類できる
- 各原因に対する具体的な対策方法がわかる
本記事の対象者
- ディープラーニングの基本的な知識を持っている方
- 自分でモデルを組んで学習させている方
- 訓練データとテストデータの正解率に差があり悩んでいる方
訓練データとテストデータの正解率とは?
ディープラーニングでは、モデルの性能を評価するために、訓練データとテストデータの2種類のデータセットを使用します。
- 訓練データ: モデルの学習に使用されるデータセットです。
- テストデータ: モデルの学習には使用されず、学習済みのモデルの性能を評価するために使用されるデータセットです。
訓練データの正解率は、モデルが学習に使用したデータに対する正解率です。一方、テストデータの正解率は、モデルが学習に使用していないデータに対する正解率です。
例えるなら、訓練データは模擬試験、テストデータは本試験です。
- 訓練データ (模擬試験): モデルは、訓練データを使って何度も学習し、模擬試験で良い点を取れるように訓練されます。
- テストデータ (本試験): 学習済みのモデルは、テストデータを使って、実際にどの程度の知識を習得しているのか、つまり本試験でどの程度の点数が取れるのかが評価されます。
なぜテストデータの正解率が低いのは問題なのか?
ディープラーニングモデルの目的は、未知のデータに対しても高い精度で予測することです。訓練データの正解率が高くても、テストデータの正解率が低いということは、モデルが訓練データに特化してしまい、汎化性能が低いことを意味します。
例えるなら、訓練データは模擬試験、テストデータは本試験です。いくら模擬試験で良い点を取れても、本試験で良い点が取れなければ意味がありません。テストデータの正解率が低いディープラーニングのモデルは、あまり実用性の無いモデルになります。
テストデータの正解率が低い原因を3つに分類
訓練データの正解率が高いのにテストデータの正解率が低い場合、その原因は大きく3つに分類できます。
1. データに隔たりがある
訓練データとテストデータの間にデータの偏りがあると、モデルは訓練データの偏った特徴に特化して学習してしまい、テストデータに対してはうまく予測できなくなってしまいます。
例えば犬と猫の分類において、訓練データの犬と猫の比率が9対1であった場合、常に犬と答えれば正解率は90%になりますので、ディープラーニングのモデルが常に犬と答えるように学習することがあります。この場合、テストデータの犬と猫の比率が1対1であれば、正解率は50%にしかなりません。
2. 汎化出来ていない
モデルが訓練データに特化しすぎてしまい、訓練データ以外のデータに対してはうまく対応できない状態を「汎化できていない」といいます。これは、過学習と呼ばれる現象によって起こります。
例えば模擬試験をひたすら勉強し模擬試験は満点を取れたけど、応用が出来なくて本試験の点数が良くなかった。みたいな状況です。
既出のデータに対する結果は丸暗記したけど、応用力(汎化)が出来てなくて、新しいデータに正しく答えられないのが過学習です。
この場合もテストデータに対する正解率は低くなります。
3. 入力と出力に因果関係が無い
入力データと出力データ間に本来の因果関係がない場合、訓練データの正解率は高くなる可能性がありますが、テストデータでは意味のない結果になる可能性があります。
例えばじゃんけんの過去の結果(グー、チョキ、パー)を学習し、次の手を予測するモデルを考えます。訓練データに「グーが3回続いたら次はパーが出る」というパターンが含まれていると、モデルはこのパターンを学習し、訓練データの正解率は高くなります。しかし、実際にはじゃんけんの結果に過去の結果との因果関係はなく、次の手はランダムに決まります。そのため、テストデータでは正解率は低くなるでしょう。
各原因に対する対策
それぞれの原因に対して、具体的な対策方法を以下に示します。
1. データに隔たりがある場合の対策
データの隔たりを解消するために、以下の方法を試してみましょう。
- 少ないデータのサンプルを増やす: 最も効果的な方法です。データ収集やデータ作成を行い、少ないデータのサンプルを増やしましょう。
- 少ないデータをコピーし増幅する: データ収集が難しい場合は、既存のデータのコピーと改変(回転、ぼかし、色調変更など)によってデータを増幅することも有効です。
- 多いデータを減らす: 多いデータのサンプルを減らすことで、データの偏りを解消することもできます。ただし、データ量が減ると学習不足になる可能性もあるため、注意が必要です。
- データの重みを調整する: 学習時に、少ないデータのサンプルに対して重みを大きく設定することで、少ないデータの影響を大きくすることができます。
2. 汎化できていない場合の対策
過学習を防ぎ、汎化性能を高めるために、以下の対策を検討しましょう。
- データに適したモデルを選択する: 例えば、画像分類では畳み込みニューラルネットワーク(CNN)が有効です。CNNは画像の空間的な特徴を捉えることに長けており、汎化性能が高いモデルです。
- 正則化手法を用いる: 正則化手法は、モデルのパラメータの大きさを制限することで、過学習を防ぐ効果があります。代表的な手法として、L1正則化、L2正則化、ドロップアウトなどがあります。
- 早期停止(アーリーストップ): 学習の途中で、テストデータに対する性能が向上しなくなったら学習を停止する手法です。過学習を防ぎ、汎化性能を向上させる効果が期待できます。
3. 入力と出力に因果関係が無い場合の対策
因果関係がないデータで学習しても意味がないため、データ自体を見直す必要があります。例えば、じゃんけんの例では、過去の結果だけでなく、対戦相手の情報(性別、年齢、過去の対戦成績など)も入力データに加えることで、因果関係を構築できる可能性があります。
各対策の具体例(Kerasコードサンプル付き)
1. データに隔たりがある場合の対策
1-1. 少ないデータのサンプルを増やす
- データ収集:
- WebスクレイピングやAPIを利用して、足りないデータを取得する。
- 実験やアンケートを実施して、新しいデータを作成する。
- データ作成:
- 画像データであれば、画像編集ソフトやライブラリを用いて、既存の画像を加工する。
- テキストデータであれば、既存の文章を改変したり、新しい文章を生成する。
1-2. 少ないデータをコピーし増幅する
- 画像データの場合:
- 画像の回転、反転、スケール変更、色調変更などを施す。
- KerasのImageDataGeneratorを使用すると、これらの変換を簡単に適用できる。
from tensorflow.keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
# トレーニングデータの増幅
train_generator = datagen.flow_from_directory(
'train_data',
target_size=(224, 224),
batch_size=32,
class_mode='categorical'
)
# モデルの学習
model.fit(train_generator, epochs=10)
1-3. 多いデータを減らす
- サンプリング: データセットからランダムにサンプルを抽出する。
- データ選択: データの偏りを解消するために、特定の種類のデータを削除する。
# ランダムサンプリングの例
import random
# トレーニングデータから50%をランダムに選択
train_data_subset = random.sample(train_data, int(len(train_data) * 0.5))
# モデルの学習
model.fit(train_data_subset, epochs=10)
1-4. データの重みを調整する
- クラスウェイト: 各クラスのサンプル数に反比例する重みを設定する。
- Kerasのclass_weightパラメータを使用する。
from sklearn.utils import class_weight
# クラスウェイトを計算
class_weights = class_weight.compute_class_weight('balanced',
np.unique(train_labels),
train_labels)
# モデルの学習
model.fit(train_data, train_labels, epochs=10, class_weight=class_weights)
2. 汎化できていない場合の対策
2-1. データに適したモデルを選択する
- CNN: 画像データの場合、畳み込みニューラルネットワーク(CNN)は空間的な特徴を捉えることに長けており、汎化性能が高い。
- RNN: 時系列データの場合、再帰型ニューラルネットワーク(RNN)が有効。
- Transformer: 自然言語処理の場合、Transformerは文脈を理解する能力が高く、汎化性能が高い。
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
# CNNモデルの例
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(10, activation='softmax'))
# モデルのコンパイル
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
# モデルの学習
model.fit(train_data, train_labels, epochs=10)
2-2. 正則化手法を用いる
- L1正則化: モデルのパラメータの絶対値の合計を最小化する。
- L2正則化: モデルのパラメータの2乗の合計を最小化する。
- ドロップアウト: 学習時に、ランダムにニューロンを無効化することで、過学習を防ぐ。
from tensorflow.keras.layers import Dropout
# L2正則化とドロップアウトの例
model = Sequential()
model.add(Dense(128, activation='relu', kernel_regularizer='l2', input_shape=(100,)))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
# モデルのコンパイル
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
# モデルの学習
model.fit(train_data, train_labels, epochs=10)
2-3. 早期停止(アーリーストップ)
- EarlyStoppingコールバック: KerasのEarlyStoppingコールバックを使用することで、テストデータに対する性能が向上しなくなったら学習を停止できる。
from tensorflow.keras.callbacks import EarlyStopping
# 早期停止の例
early_stopping = EarlyStopping(monitor='val_loss', patience=3)
# モデルの学習
model.fit(train_data, train_labels, epochs=10, callbacks=[early_stopping])
3. 入力と出力に因果関係が無い場合の対策
- データ自体を見直す: データに因果関係が含まれていない場合は、データの取得方法や内容を見直す必要がある。
- 因果関係を構築できる新たなデータを追加する: 例えば、じゃんけんの例では、対戦相手の情報(性別、年齢、過去の対戦成績など)を追加する。
困ったときは
困った時は、Chat GPTやGoogle Gemini等を使うのもひとつの方法です。機械学習の知識がある程度あれば、Chat GPTやGoogle Gemini等を使い対策前進出来ると思います。機械学習の難しいところはモデルよりも実はデータ集めで、そこは自分の力で何とかする必要がありますが、標準的なモデルの構築であればChat GPTやGoogle Geminiを活用すると何とかなることが多いです。
最後に
本記事が少しでも皆さまの役にたてたら幸いです。
一緒に楽しくAIについて学びましょう。