-
Notifications
You must be signed in to change notification settings - Fork 398
/
Copy pathtest_serial_entry_reward_model.py
90 lines (81 loc) · 3.28 KB
/
test_serial_entry_reward_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import pytest
import os
from ditk import logging
from easydict import EasyDict
from copy import deepcopy
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_rnd_onppo_config import cartpole_ppo_rnd_config, cartpole_ppo_rnd_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_config, cartpole_ppo_icm_create_config # noqa
from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy, \
serial_pipeline_reward_model_onpolicy
cfg = [
{
'type': 'pdeil',
"alpha": 0.5,
"discrete_action": False
},
{
'type': 'gail',
'input_size': 5,
'hidden_size': 64,
'batch_size': 64,
},
{
'type': 'pwil',
's_size': 4,
'a_size': 2,
'sample_size': 500,
},
{
'type': 'red',
'sample_size': 5000,
'obs_shape': 4,
'action_shape': 1,
'hidden_size_list': [64, 1],
'update_per_collect': 200,
'batch_size': 128,
},
]
@pytest.mark.unittest
@pytest.mark.parametrize('reward_model_config', cfg)
def test_irl(reward_model_config):
reward_model_config = EasyDict(reward_model_config)
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
expert_policy = serial_pipeline(config, seed=0, max_train_iter=2)
# collect expert demo data
collect_count = 10000
expert_data_path = 'expert_data.pkl'
state_dict = expert_policy.collect_mode.state_dict()
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
collect_demo_data(
config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
)
# irl + rl training
cp_cartpole_dqn_config = deepcopy(cartpole_dqn_config)
cp_cartpole_dqn_create_config = deepcopy(cartpole_dqn_create_config)
cp_cartpole_dqn_create_config.reward_model = dict(type=reward_model_config.type)
if reward_model_config.type == 'gail':
reward_model_config['data_path'] = '.'
else:
reward_model_config['expert_data_path'] = expert_data_path
cp_cartpole_dqn_config.reward_model = reward_model_config
cp_cartpole_dqn_config.policy.collect.n_sample = 128
serial_pipeline_reward_model_offpolicy(
(cp_cartpole_dqn_config, cp_cartpole_dqn_create_config), seed=0, max_train_iter=2
)
os.popen("rm -rf ckpt_* log expert_data.pkl")
@pytest.mark.unittest
def test_rnd():
config = [deepcopy(cartpole_ppo_rnd_config), deepcopy(cartpole_ppo_rnd_create_config)]
try:
serial_pipeline_reward_model_onpolicy(config, seed=0, max_train_iter=2)
except Exception:
assert False, "pipeline fail"
@pytest.mark.unittest
def test_icm():
config = [deepcopy(cartpole_ppo_icm_config), deepcopy(cartpole_ppo_icm_create_config)]
try:
serial_pipeline_reward_model_offpolicy(config, seed=0, max_train_iter=2)
except Exception:
assert False, "pipeline fail"