Skip to main content

学習済みモデルの保存と再利用

なぜモデルを保存するの?

機械学習のモデルを一度学習させたら、その「賢くなった状態」を保存しておきたくなります。なぜなら、

  • 再利用のため: アプリケーションを再起動するたびに、何時間もかけてモデルを再学習させるのは非効率です。
  • 他の人との共有: チームメンバーに学習済みのモデルを渡して、同じ予測結果を再現してもらえます。
  • 本番環境への展開: 学習させたモデルをWebサービスなどの実際のアプリケーションに組み込むために使います。

このプロセスをモデルの永続化 (Persistence) と呼びます。

モデルの保存方法: joblib を使おう

Pythonでオブジェクトを保存する方法はいくつかありますが、scikit-learnでは joblib というライブラリを使うことが推奨されています。joblib は、特にNumPyの大きな配列を含むようなPythonオブジェクトを効率的に扱えるように設計されています。

joblib のインストール

joblibはscikit-learnと一緒にインストールされることが多いですが、もし単体でインストールする場合は、Colabのセルで以下のコマンドを実行します。

!pip install joblib

ステップ1:モデルを保存する (joblib.dump)

学習済みのモデルをファイルとして保存するには、joblib.dump() 関数を使います。

使い方: joblib.dump(保存したいオブジェクト, 'ファイル名')

例: 演習で作成した modelcafe_model.joblib という名前で保存してみましょう。

# joblibライブラリをインポート
import joblib

# --- ここにモデルの学習までのコードがあるとする ---
# from sklearn.tree import DecisionTreeClassifier
# model = DecisionTreeClassifier()
# model.fit(X_train, y_train)
# -----------------------------------------

# 学習済みモデル(model)を'cafe_model.joblib'というファイル名で保存
joblib.dump(model, 'cafe_model.joblib')
joblib.dump(X_test, 'X_test.joblib')

print("モデルが 'cafe_model.joblib' として保存されました。")

これを実行すると、Colabのファイル一覧に cafe_model.joblib というファイルが表示されるはずです。

ステップ2:保存したモデルを読み込む (joblib.load)

保存したモデルは、joblib.load() 関数を使っていつでもプログラムに読み込むことができます。

使い方: 読み込み先の変数 = joblib.load('ファイル名')

例: 先ほど保存した cafe_model.joblibloaded_model という新しい変数に読み込んでみましょう。

# joblibライブラリをインポート
import joblib

# 'cafe_model.joblib' ファイルからモデルを読み込む
loaded_model = joblib.load('cafe_model.joblib')

print("モデルが正常に読み込まれました。")

ステップ3:読み込んだモデルで予測する

読み込んだモデルは、元の model と全く同じように使うことができます。例えば、テストデータを使って予測をしてみましょう。

# 必要なライブラリをインポート
import joblib
import pandas as pd
from sklearn.metrics import accuracy_score

# 保存したモデルを読み込む
loaded_model = joblib.load('cafe_model.joblib')
# テストデータを読み込む
X_test = joblib.load('X_test.joblib')

# 読み込んだモデルを使って予測を実行
# X_testは演習で作成したテストデータとします
predictions = loaded_model.predict(X_test)

# 答えのファイルを読み込む
answers_df = pd.read_csv('cafe_test_answers.csv')
y_test_answers = answers_df['success']

# 結果の確認
accuracy = accuracy_score(y_test_answers, predictions)

print(f"読み込んだモデルの正解率 (Accuracy): {accuracy}")

元のモデルで計算した正解率と全く同じ結果が出れば、モデルの保存と読み込みが成功している証拠です。