JKになりたい

何か書きたいことを書きます。主にWeb方面の技術系記事が多いかも。

stable baselines3+imitationで模倣学習(BC&AIRL)

stable baselines3では、模倣学習のfeatureがimitationというライブラリに移譲されることになりました。

stable-baselines3.readthedocs.io

github.com

これにより、(過渡期である事も要因であるとは思いますが)以前は非常に簡単にできていた模倣学習に一手間必要になりました。

そこで今回は、stable baselines3とimitationを使った模倣学習の実行について備忘録を残しておきたいと思います(これを書いている時点では、ドキュメントがほぼなかったので・・)


エキスパートデータの収集

言わずもかな逆強化学習ではまずエキスパートデータを用意する必要があります。 今回はatariのインベーダー環境に対して、人手で操作してエキスパートデータを生成してみたいと思います

stable baselinesの2系ではstable_baselines.gail.generate_expert_trajを使うことで簡単に軌跡の記録ができましたが、3系ではimitationのTrajectoryデータ型として記録していきます

imitation/types.py at bb363a6c1039f29d4670647734fd21de9aa2e85d · HumanCompatibleAI/imitation · GitHub

1つのエピソードに対して1つのTrajectoryオブジェクトとなり、複数エピソードではTrajectoryのArrayとして保存することになります。

以下は、実装例です

from imitation.data.types import Trajectory
from stable_baselines3.common.atari_wrappers import *
import gym
import pyglet
from pyglet.window import key
import time
import pickle

def get_key_state(win, key_handler):
    key_state = set()
    win.dispatch_events()
    for key_code, pressed in key_handler.items():
        if pressed:
            key_state.add(key_code)
    return key_state

def human_expert(_state, win, key_handler):
    key_state = get_key_state(win, key_handler)
    action = 0
    if key.SPACE in key_state:
        action = 1
    elif key.LEFT in key_state:
        action = 3
    elif key.RIGHT in key_state:
        action = 4
    time.sleep(1.0 / 30.0)
    return action

def main():
    record_episodes = 1
    ENV_ID = 'SpaceInvaders-v0'
    env = gym.make(ENV_ID)
    env.render()

    win = pyglet.window.Window(width=300, height=100, vsync=False)
    key_handler = pyglet.window.key.KeyStateHandler()
    win.push_handlers(key_handler)
    pyglet.app.platform_event_loop.start()
    while len(get_key_state(win, key_handler)) == 0:
        time.sleep(1.0 / 30.0)
    
    trajectorys = []
    for i in range(0, record_episodes):
        state = env.reset()
        actions = []
        infos = []
        observations = [state]
        while True:
            env.render()
            action = human_expert(state, win, key_handler)
            state, reward, done, info = env.step(action)
            actions.append(action)
            observations.append(state)
            infos.append(info)
            if done:
                ts = Trajectory(obs=np.array(observations), acts=np.array(actions), infos=np.array(infos))
                trajectorys.append(ts)
                break
    with open("invader_expert.pickle", mode="wb") as f:
        pickle.dump(trajectorys, f)
if __name__ == '__main__':
    main()

キー入力を受け取る部分の実装はnpakaさんの以下の記事を参考にさせて頂きました

note.com

インベーダーが下手すぎてエキスパートとは程遠いデータを収集することができました。

BCによる学習

こちらはimitationのquickstartと同様です。

https://github.com/HumanCompatibleAI/imitation/blob/master/examples/quickstart.py

with open("invader_expert.pickle", "rb") as f:
    trajectories = pickle.load(f)
transitions = rollout.flatten_trajectories(trajectories)

ENV_ID = 'SpaceInvaders-v0'
venv = util.make_vec_env(ENV_ID, n_envs=2)
logger.configure(".BC/")
bc_trainer = bc.BC(venv.observation_space, venv.action_space, expert_data=transitions)
bc_trainer.train(n_epochs=100)
bc_trainer.save_policy('space_invader_policy_v0')

保存したpolicyは以下のようにロードができます

bc_trainer = bc.reconstruct_policy("space_invader_policy_v0")

BC→順強化学習とモデルの実行

まず、BCの学習により得たポリシーを使ってインベーダーをプレイさせてみます。 こちらはポリシーのpredictを使って簡単に実行が可能です。

def main():
    ENV_ID = 'SpaceInvaders-v0'
    env = gym.make(ENV_ID)
    bc_trainer = bc.reconstruct_policy("space_invader_policy_v0")
    state = env.reset()
    while True:
        env.render()
        action = bc_trainer.predict(state)
        state, reward, done, info = env.step(action)
        if done:
            break
if __name__ == '__main__':
    main()

次に、このポリシーをベースに順強化学習をさせることも可能です。

エキスパートのデータを渡して、初期の探索を手助けしてあげた後、順強化学習による最適化を進めたい・・というのが直感的なニーズですが、良いのか悪いのかは正直よくわかりません。

こちらは、単純に強化学習モデルクラスの第一引数であるPolicyに先ほどのものを入れて初期化→再学習すればよさそうなんですが、素直に入れると動きません。

そこで、ハック的なテクニックですが以下のようにする必要があります。

class CopyPolicy(ActorCriticPolicy):
    def __new__(cls, *args, **kwargs):
        return bc_trainer.policy

model = sb3.PPO(CopyPolicy, venv, verbose=0)
model.learn(total_timesteps=128000, callback=callback)

このアプローチは以下のissueで言及されています

github.com

AIRLによる学習

こちらも、imitationのquickstartに記載されている通りで大丈夫です

https://github.com/HumanCompatibleAI/imitation/blob/master/examples/quickstart.py

最後に.gen_algo.saveでPPOモデルを保存します。 ゲームの実行にはこのモデルが必要になります。

venv = util.make_vec_env(ENV_ID, n_envs=2)
logger.configure(".AIRL/")
airl_trainer = adversarial.AIRL(
    venv,
    expert_data=transitions,
    expert_batch_size=32,
    gen_algo=sb3.PPO("MlpPolicy", venv, verbose=1, n_steps=1024),
)
airl_trainer.train(total_timesteps=2048)
airl_trainer.gen_algo.save("airl_trainer_gen_algo")

モデルの実行

PPO.loadで先ほど保存したモデルをロードしてきます。 後は特に違いはありません。

def main():
    ENV_ID = 'SpaceInvaders-v0'
    env = gym.make(ENV_ID)
    model = PPO.load("airl_trainer_gen_algo")

    state = env.reset()
    while True:
        env.render()
        action = model.predict(state)
        state, reward, done, info = env.step(action)
        if done:
            break
if __name__ == '__main__':
    main()