OpenAI

強化学習「Stable Baselines」モデルの保存と読み込みをしてみる

このままでは、学習させて確認したモデルは、実行させるたびに学習が必要になってしまい時間がかかります。一度学習したモデルの保存と、読み込みが簡単にできるようになっているので実際に動かしてみます。

概要

PPO2は、ActorCriticRLModelを継承しており、ActorCriticRLModel はBaseRLModelを継承しています。モデルを定義しているbase_class.pyには、saveとloadが宣言されていることが確認できます。

モデルの保存

model.save("model")

モデルの読み込み

model = PPO2.load("model") 

※先日のPPO2モデルを使うことを前提にした場合

サンプルコード

サンプルコード(リンク)を実行します。ppo2_cartpole.zipという名前でモデルが保存されていることが確認できます。

import gym

from stable_baselines.common.policies import MlpPolicy
from stable_baselines import PPO2

env = gym.make('CartPole-v1')

model = PPO2(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
model.save("ppo2_cartpole")

del model # remove to demonstrate saving and loading

model = PPO2.load("ppo2_cartpole")

# Enjoy trained agent
obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()