-
Notifications
You must be signed in to change notification settings - Fork 398
fix(pu): fix noise layer's usage #866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,7 @@ def __init__( | |
norm_type: Optional[str] = None, | ||
dropout: Optional[float] = None, | ||
init_bias: Optional[float] = None, | ||
noise: bool = False, | ||
) -> None: | ||
""" | ||
Overview: | ||
|
@@ -57,6 +58,8 @@ def __init__( | |
- dropout (:obj:`Optional[float]`): The dropout rate of the dropout layer. \ | ||
if ``None`` then default disable dropout layer. | ||
- init_bias (:obj:`Optional[float]`): The initial value of the last layer bias in the head network. \ | ||
- noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \ | ||
Default ``False``. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Default to |
||
""" | ||
super(DQN, self).__init__() | ||
# Squeeze data from tuple, list or dict to single object. For example, from (4, ) to 4 | ||
|
@@ -90,7 +93,8 @@ def __init__( | |
layer_num=head_layer_num, | ||
activation=activation, | ||
norm_type=norm_type, | ||
dropout=dropout | ||
dropout=dropout, | ||
noise=noise, | ||
) | ||
else: | ||
self.head = head_cls( | ||
|
@@ -99,7 +103,8 @@ def __init__( | |
head_layer_num, | ||
activation=activation, | ||
norm_type=norm_type, | ||
dropout=dropout | ||
dropout=dropout, | ||
noise=noise, | ||
) | ||
if init_bias is not None and head_cls == DuelingHead: | ||
# Zero the last layer bias of advantage head | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
from ding.utils.data import default_collate, default_decollate | ||
|
||
from .base_policy import Policy | ||
from .common_utils import default_preprocess_learn | ||
from .common_utils import default_preprocess_learn, set_noise_mode | ||
|
||
|
||
@POLICY_REGISTRY.register('dqn') | ||
|
@@ -248,6 +248,8 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: | |
.. note:: | ||
For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``. | ||
""" | ||
set_noise_mode(self._learn_model, True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use noisy_net to control this line Another question: how to deal with target_model in noisy net |
||
|
||
# Data preprocessing operations, such as stack data, cpu to cuda device | ||
data = default_preprocess_learn( | ||
data, | ||
|
@@ -384,6 +386,12 @@ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: | |
data = default_collate(list(data.values())) | ||
if self._cuda: | ||
data = to_device(data, self._device) | ||
# Use the add_noise parameter to decide noise mode. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename to |
||
# Default to True if the parameter is not provided. | ||
if self._cfg.collect.get("add_noise", True): | ||
set_noise_mode(self._collect_model, True) | ||
else: | ||
set_noise_mode(self._collect_model, False) | ||
self._collect_model.eval() | ||
with torch.no_grad(): | ||
output = self._collect_model.forward(data, eps=eps) | ||
|
@@ -476,6 +484,8 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: | |
data = default_collate(list(data.values())) | ||
if self._cuda: | ||
data = to_device(data, self._device) | ||
# Ensure that in evaluation mode noise is disabled. | ||
set_noise_mode(self._eval_model, False) | ||
self._eval_model.eval() | ||
with torch.no_grad(): | ||
output = self._eval_model.forward(data) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from easydict import EasyDict | ||
|
||
demon_attack_dqn_config = dict( | ||
exp_name='DemonAttack_dqn_collect-not-noise_seed0', | ||
env=dict( | ||
collector_env_num=8, | ||
evaluator_env_num=8, | ||
n_evaluator_episode=8, | ||
stop_value=1e6, | ||
env_id='DemonAttackNoFrameskip-v4', | ||
frame_stack=4, | ||
), | ||
policy=dict( | ||
cuda=True, | ||
priority=False, | ||
model=dict( | ||
obs_shape=[4, 84, 84], | ||
action_shape=6, | ||
encoder_hidden_size_list=[128, 128, 512], | ||
noise=True, | ||
), | ||
nstep=3, | ||
discount_factor=0.99, | ||
learn=dict( | ||
update_per_collect=10, | ||
batch_size=32, | ||
learning_rate=0.0001, | ||
target_update_freq=500, | ||
), | ||
# collect=dict(n_sample=96, add_noise=True), | ||
collect=dict(n_sample=96, add_noise=False), | ||
eval=dict(evaluator=dict(eval_freq=4000, )), | ||
other=dict( | ||
eps=dict( | ||
type='exp', | ||
start=1., | ||
end=0.05, | ||
decay=250000, | ||
), | ||
replay_buffer=dict(replay_buffer_size=100000, ), | ||
), | ||
), | ||
) | ||
demon_attack_dqn_config = EasyDict(demon_attack_dqn_config) | ||
main_config = demon_attack_dqn_config | ||
demon_attack_dqn_create_config = dict( | ||
env=dict( | ||
type='atari', | ||
import_names=['dizoo.atari.envs.atari_env'], | ||
), | ||
env_manager=dict(type='subprocess'), | ||
policy=dict(type='dqn'), | ||
) | ||
demon_attack_dqn_create_config = EasyDict(demon_attack_dqn_create_config) | ||
create_config = demon_attack_dqn_create_config | ||
|
||
if __name__ == '__main__': | ||
# or you can enter `ding -m serial -c demon_attack_dqn_config.py -s 0` | ||
from ding.entry import serial_pipeline | ||
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(10e6)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NoiseLinearLayer
to boost exploration