学習済みモデルの保存と再利用
なぜモデルを保存するの?
機械学習のモデルを一度学習させたら、その「賢くなった状態」を保存しておきたくなります。なぜなら、
- 再利用のため: アプリケーションを再起動するたびに、何時間もかけてモデルを再学習させるのは非効率です。
- 他の人との共有: チームメンバーに学習済みのモデルを渡して、同じ予測結果を再現してもらえます。
- 本番環境への展開: 学習させたモデルをWebサービスなどの実際のアプリケーションに組み込むために使います。
このプロセスをモデルの永続化 (Persistence) と呼びます。
モデルの保存方法: joblib を使おう
Pythonでオブジェクトを保存する方法はいくつかありますが、scikit-learnでは joblib というライブラリを使うことが推奨されています。joblib は、特にNumPyの大きな配列を含むようなPythonオブジェクトを効率的に扱えるように設計されています。
joblib のインストール
joblibはscikit-learnと一緒にインストールされることが多いですが、もし単体でインストールする場合は、Colabのセルで以下のコマンドを実行します。
!pip install joblib
ステップ1:モデルを保存する (joblib.dump)
学習済みのモデルをファイルとして保存するには、joblib.dump() 関数を使います。
使い方:
joblib.dump(保存したいオブジェクト, 'ファイル名')
例:
演習で作成した model を cafe_model.joblib という名前で保存してみましょう。
# joblibライブラリをインポート
import joblib
# --- ここにモデルの学習までのコードがあるとする ---
# from sklearn.tree import DecisionTreeClassifier
# model = DecisionTreeClassifier()
# model.fit(X_train, y_train)
# -----------------------------------------
# 学習済みモデル(model)を'cafe_model.joblib'というファイル名で保存
joblib.dump(model, 'cafe_model.joblib')
joblib.dump(X_test, 'X_test.joblib')
print("モデルが 'cafe_model.joblib' として保存されました。")
これを実行すると、Colabのファイル一覧に cafe_model.joblib というファイルが表示されるはずです。
ステップ2:保存したモデルを読み込む (joblib.load)
保存したモデルは、joblib.load() 関数を使っていつでもプログラムに読み込むことができます。
使い方:
読み込み先の変数 = joblib.load('ファイル名')
例:
先ほど保存した cafe_model.joblib を loaded_model という新しい変数に読み込んでみましょう。
# joblibライブラリをインポート
import joblib
# 'cafe_model.joblib' ファイルからモデルを読み込む
loaded_model = joblib.load('cafe_model.joblib')
print("モデルが正常に読み込まれました。")
ステップ3:読み込んだモデルで予測する
読み込んだモデルは、元の model と全く同じように使うことができます。例えば、テストデータを使って予測をしてみましょう。
# 必要なライブラリをインポート
import joblib
import pandas as pd
from sklearn.metrics import accuracy_score
# 保存したモデルを読み込む
loaded_model = joblib.load('cafe_model.joblib')
# テストデータを読み込む
X_test = joblib.load('X_test.joblib')
# 読み込んだモデルを使って予測を実行
# X_testは演習で作成したテストデータとします
predictions = loaded_model.predict(X_test)
# 答えのファイルを読み込む
answers_df = pd.read_csv('cafe_test_answers.csv')
y_test_answers = answers_df['success']
# 結果の確認
accuracy = accuracy_score(y_test_answers, predictions)
print(f"読み込んだモデルの正解率 (Accuracy): {accuracy}")
元のモデルで計算した正解率と全く同じ結果が出れば、モデルの保存と読み込みが成功している証拠です。