-
Notifications
You must be signed in to change notification settings - Fork 398
/
Copy pathcommon_utils.py
148 lines (132 loc) · 6.2 KB
/
common_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from typing import List, Any, Dict, Callable
import torch
import torch.nn as nn
import numpy as np
import treetensor.torch as ttorch
from ding.utils.data import default_collate
from ding.torch_utils import to_tensor, to_ndarray, unsqueeze, squeeze
from ding.torch_utils import NoiseLinearLayer
def set_noise_mode(module: nn.Module, noise_enabled: bool):
"""
Overview:
Recursively set the 'force_noise' flag on all NoiseLinearLayer modules within the given module.
"""
for m in module.modules():
if isinstance(m, NoiseLinearLayer):
m.force_noise = noise_enabled
def default_preprocess_learn(
data: List[Any],
use_priority_IS_weight: bool = False,
use_priority: bool = False,
use_nstep: bool = False,
ignore_done: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Overview:
Default data pre-processing in policy's ``_forward_learn`` method, including stacking batch data, preprocess \
ignore done, nstep and priority IS weight.
Arguments:
- data (:obj:`List[Any]`): The list of a training batch samples, each sample is a dict of PyTorch Tensor.
- use_priority_IS_weight (:obj:`bool`): Whether to use priority IS weight correction, if True, this function \
will set the weight of each sample to the priority IS weight.
- use_priority (:obj:`bool`): Whether to use priority, if True, this function will set the priority IS weight.
- use_nstep (:obj:`bool`): Whether to use nstep TD error, if True, this function will reshape the reward.
- ignore_done (:obj:`bool`): Whether to ignore done, if True, this function will set the done to 0.
Returns:
- data (:obj:`Dict[str, torch.Tensor]`): The preprocessed dict data whose values can be directly used for \
the following model forward and loss computation.
"""
# data preprocess
elem = data[0]
if isinstance(elem['action'], (np.ndarray, torch.Tensor)) and elem['action'].dtype in [np.int64, torch.int64]:
data = default_collate(data, cat_1dim=True) # for discrete action
else:
data = default_collate(data, cat_1dim=False) # for continuous action
if 'value' in data and data['value'].dim() == 2 and data['value'].shape[1] == 1:
data['value'] = data['value'].squeeze(-1)
if 'adv' in data and data['adv'].dim() == 2 and data['adv'].shape[1] == 1:
data['adv'] = data['adv'].squeeze(-1)
if ignore_done:
data['done'] = torch.zeros_like(data['done']).float()
else:
data['done'] = data['done'].float()
if data['done'].dim() == 2 and data['done'].shape[1] == 1:
data['done'] = data['done'].squeeze(-1)
if use_priority_IS_weight:
assert use_priority, "Use IS Weight correction, but Priority is not used."
if use_priority and use_priority_IS_weight:
if 'priority_IS' in data:
data['weight'] = data['priority_IS']
else: # for compability
data['weight'] = data['IS']
else:
data['weight'] = data.get('weight', None)
if use_nstep:
# reward reshaping for n-step
reward = data['reward']
if len(reward.shape) == 1:
reward = reward.unsqueeze(1)
# single agent reward: (batch_size, nstep) -> (nstep, batch_size)
# multi-agent reward: (batch_size, agent_dim, nstep) -> (nstep, batch_size, agent_dim)
# Assuming 'reward' is a PyTorch tensor with shape (batch_size, nstep) or (batch_size, agent_dim, nstep)
if reward.ndim == 2:
# For a 2D tensor, simply transpose it to get (nstep, batch_size)
data['reward'] = reward.transpose(0, 1).contiguous()
elif reward.ndim == 3:
# For a 3D tensor, move the last dimension to the front to get (nstep, batch_size, agent_dim)
data['reward'] = reward.permute(2, 0, 1).contiguous()
else:
raise ValueError("The 'reward' tensor must be either 2D or 3D. Got shape: {}".format(reward.shape))
else:
if data['reward'].dim() == 2 and data['reward'].shape[1] == 1:
data['reward'] = data['reward'].squeeze(-1)
return data
def single_env_forward_wrapper(forward_fn: Callable) -> Callable:
"""
Overview:
Wrap policy to support gym-style interaction between policy and single environment.
Arguments:
- forward_fn (:obj:`Callable`): The original forward function of policy.
Returns:
- wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy.
Examples:
>>> env = gym.make('CartPole-v0')
>>> policy = DQNPolicy(...)
>>> forward_fn = single_env_forward_wrapper(policy.eval_mode.forward)
>>> obs = env.reset()
>>> action = forward_fn(obs)
>>> next_obs, rew, done, info = env.step(action)
"""
def _forward(obs):
obs = {0: unsqueeze(to_tensor(obs))}
action = forward_fn(obs)[0]['action']
action = to_ndarray(squeeze(action))
return action
return _forward
def single_env_forward_wrapper_ttorch(forward_fn: Callable, cuda: bool = True) -> Callable:
"""
Overview:
Wrap policy to support gym-style interaction between policy and single environment for treetensor (ttorch) data.
Arguments:
- forward_fn (:obj:`Callable`): The original forward function of policy.
- cuda (:obj:`bool`): Whether to use cuda in policy, if True, this function will move the input data to cuda.
Returns:
- wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy.
Examples:
>>> env = gym.make('CartPole-v0')
>>> policy = PPOFPolicy(...)
>>> forward_fn = single_env_forward_wrapper_ttorch(policy.eval)
>>> obs = env.reset()
>>> action = forward_fn(obs)
>>> next_obs, rew, done, info = env.step(action)
"""
def _forward(obs):
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
obs = ttorch.as_tensor(obs).unsqueeze(0)
if cuda and torch.cuda.is_available():
obs = obs.cuda()
action = forward_fn(obs).action
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
action = action.squeeze(0).cpu().numpy()
return action
return _forward