Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""A DDPG/NAF agent. | |
Implements the Deep Deterministic Policy Gradient (DDPG) algorithm from | |
"Continuous control with deep reinforcement learning" - Lilicrap et al. | |
https://arxiv.org/abs/1509.02971, and the Normalized Advantage Functions (NAF) | |
algorithm "Continuous Deep Q-Learning with Model-based Acceleration" - Gu et al. | |
https://arxiv.org/pdf/1603.00748. | |
""" | |
import tensorflow as tf | |
slim = tf.contrib.slim | |
import gin.tf | |
from utils import utils | |
from agents import ddpg_networks as networks | |
class DdpgAgent(object): | |
"""An RL agent that learns using the DDPG algorithm. | |
Example usage: | |
def critic_net(states, actions): | |
... | |
def actor_net(states, num_action_dims): | |
... | |
Given a tensorflow environment tf_env, | |
(of type learning.deepmind.rl.environments.tensorflow.python.tfpyenvironment) | |
obs_spec = tf_env.observation_spec() | |
action_spec = tf_env.action_spec() | |
ddpg_agent = agent.DdpgAgent(obs_spec, | |
action_spec, | |
actor_net=actor_net, | |
critic_net=critic_net) | |
we can perform actions on the environment as follows: | |
state = tf_env.observations()[0] | |
action = ddpg_agent.actor_net(tf.expand_dims(state, 0))[0, :] | |
transition_type, reward, discount = tf_env.step([action]) | |
Train: | |
critic_loss = ddpg_agent.critic_loss(states, actions, rewards, discounts, | |
next_states) | |
actor_loss = ddpg_agent.actor_loss(states) | |
critic_train_op = slim.learning.create_train_op( | |
critic_loss, | |
critic_optimizer, | |
variables_to_train=ddpg_agent.get_trainable_critic_vars(), | |
) | |
actor_train_op = slim.learning.create_train_op( | |
actor_loss, | |
actor_optimizer, | |
variables_to_train=ddpg_agent.get_trainable_actor_vars(), | |
) | |
""" | |
ACTOR_NET_SCOPE = 'actor_net' | |
CRITIC_NET_SCOPE = 'critic_net' | |
TARGET_ACTOR_NET_SCOPE = 'target_actor_net' | |
TARGET_CRITIC_NET_SCOPE = 'target_critic_net' | |
def __init__(self, | |
observation_spec, | |
action_spec, | |
actor_net=networks.actor_net, | |
critic_net=networks.critic_net, | |
td_errors_loss=tf.losses.huber_loss, | |
dqda_clipping=0., | |
actions_regularizer=0., | |
target_q_clipping=None, | |
residual_phi=0.0, | |
debug_summaries=False): | |
"""Constructs a DDPG agent. | |
Args: | |
observation_spec: A TensorSpec defining the observations. | |
action_spec: A BoundedTensorSpec defining the actions. | |
actor_net: A callable that creates the actor network. Must take the | |
following arguments: states, num_actions. Please see networks.actor_net | |
for an example. | |
critic_net: A callable that creates the critic network. Must take the | |
following arguments: states, actions. Please see networks.critic_net | |
for an example. | |
td_errors_loss: A callable defining the loss function for the critic | |
td error. | |
dqda_clipping: (float) clips the gradient dqda element-wise between | |
[-dqda_clipping, dqda_clipping]. Does not perform clipping if | |
dqda_clipping == 0. | |
actions_regularizer: A scalar, when positive penalizes the norm of the | |
actions. This can prevent saturation of actions for the actor_loss. | |
target_q_clipping: (tuple of floats) clips target q values within | |
(low, high) values when computing the critic loss. | |
residual_phi: (float) [0.0, 1.0] Residual algorithm parameter that | |
interpolates between Q-learning and residual gradient algorithm. | |
http://www.leemon.com/papers/1995b.pdf | |
debug_summaries: If True, add summaries to help debug behavior. | |
Raises: | |
ValueError: If 'dqda_clipping' is < 0. | |
""" | |
self._observation_spec = observation_spec[0] | |
self._action_spec = action_spec[0] | |
self._state_shape = tf.TensorShape([None]).concatenate( | |
self._observation_spec.shape) | |
self._action_shape = tf.TensorShape([None]).concatenate( | |
self._action_spec.shape) | |
self._num_action_dims = self._action_spec.shape.num_elements() | |
self._scope = tf.get_variable_scope().name | |
self._actor_net = tf.make_template( | |
self.ACTOR_NET_SCOPE, actor_net, create_scope_now_=True) | |
self._critic_net = tf.make_template( | |
self.CRITIC_NET_SCOPE, critic_net, create_scope_now_=True) | |
self._target_actor_net = tf.make_template( | |
self.TARGET_ACTOR_NET_SCOPE, actor_net, create_scope_now_=True) | |
self._target_critic_net = tf.make_template( | |
self.TARGET_CRITIC_NET_SCOPE, critic_net, create_scope_now_=True) | |
self._td_errors_loss = td_errors_loss | |
if dqda_clipping < 0: | |
raise ValueError('dqda_clipping must be >= 0.') | |
self._dqda_clipping = dqda_clipping | |
self._actions_regularizer = actions_regularizer | |
self._target_q_clipping = target_q_clipping | |
self._residual_phi = residual_phi | |
self._debug_summaries = debug_summaries | |
def _batch_state(self, state): | |
"""Convert state to a batched state. | |
Args: | |
state: Either a list/tuple with an state tensor [num_state_dims]. | |
Returns: | |
A tensor [1, num_state_dims] | |
""" | |
if isinstance(state, (tuple, list)): | |
state = state[0] | |
if state.get_shape().ndims == 1: | |
state = tf.expand_dims(state, 0) | |
return state | |
def action(self, state): | |
"""Returns the next action for the state. | |
Args: | |
state: A [num_state_dims] tensor representing a state. | |
Returns: | |
A [num_action_dims] tensor representing the action. | |
""" | |
return self.actor_net(self._batch_state(state), stop_gradients=True)[0, :] | |
def sample_action(self, state, stddev=1.0): | |
"""Returns the action for the state with additive noise. | |
Args: | |
state: A [num_state_dims] tensor representing a state. | |
stddev: stddev for the Ornstein-Uhlenbeck noise. | |
Returns: | |
A [num_action_dims] action tensor. | |
""" | |
agent_action = self.action(state) | |
agent_action += tf.random_normal(tf.shape(agent_action)) * stddev | |
return utils.clip_to_spec(agent_action, self._action_spec) | |
def actor_net(self, states, stop_gradients=False): | |
"""Returns the output of the actor network. | |
Args: | |
states: A [batch_size, num_state_dims] tensor representing a batch | |
of states. | |
stop_gradients: (boolean) if true, gradients cannot be propogated through | |
this operation. | |
Returns: | |
A [batch_size, num_action_dims] tensor of actions. | |
Raises: | |
ValueError: If `states` does not have the expected dimensions. | |
""" | |
self._validate_states(states) | |
actions = self._actor_net(states, self._action_spec) | |
if stop_gradients: | |
actions = tf.stop_gradient(actions) | |
return actions | |
def critic_net(self, states, actions, for_critic_loss=False): | |
"""Returns the output of the critic network. | |
Args: | |
states: A [batch_size, num_state_dims] tensor representing a batch | |
of states. | |
actions: A [batch_size, num_action_dims] tensor representing a batch | |
of actions. | |
Returns: | |
q values: A [batch_size] tensor of q values. | |
Raises: | |
ValueError: If `states` or `actions' do not have the expected dimensions. | |
""" | |
self._validate_states(states) | |
self._validate_actions(actions) | |
return self._critic_net(states, actions, | |
for_critic_loss=for_critic_loss) | |
def target_actor_net(self, states): | |
"""Returns the output of the target actor network. | |
The target network is used to compute stable targets for training. | |
Args: | |
states: A [batch_size, num_state_dims] tensor representing a batch | |
of states. | |
Returns: | |
A [batch_size, num_action_dims] tensor of actions. | |
Raises: | |
ValueError: If `states` does not have the expected dimensions. | |
""" | |
self._validate_states(states) | |
actions = self._target_actor_net(states, self._action_spec) | |
return tf.stop_gradient(actions) | |
def target_critic_net(self, states, actions, for_critic_loss=False): | |
"""Returns the output of the target critic network. | |
The target network is used to compute stable targets for training. | |
Args: | |
states: A [batch_size, num_state_dims] tensor representing a batch | |
of states. | |
actions: A [batch_size, num_action_dims] tensor representing a batch | |
of actions. | |
Returns: | |
q values: A [batch_size] tensor of q values. | |
Raises: | |
ValueError: If `states` or `actions' do not have the expected dimensions. | |
""" | |
self._validate_states(states) | |
self._validate_actions(actions) | |
return tf.stop_gradient( | |
self._target_critic_net(states, actions, | |
for_critic_loss=for_critic_loss)) | |
def value_net(self, states, for_critic_loss=False): | |
"""Returns the output of the critic evaluated with the actor. | |
Args: | |
states: A [batch_size, num_state_dims] tensor representing a batch | |
of states. | |
Returns: | |
q values: A [batch_size] tensor of q values. | |
""" | |
actions = self.actor_net(states) | |
return self.critic_net(states, actions, | |
for_critic_loss=for_critic_loss) | |
def target_value_net(self, states, for_critic_loss=False): | |
"""Returns the output of the target critic evaluated with the target actor. | |
Args: | |
states: A [batch_size, num_state_dims] tensor representing a batch | |
of states. | |
Returns: | |
q values: A [batch_size] tensor of q values. | |
""" | |
target_actions = self.target_actor_net(states) | |
return self.target_critic_net(states, target_actions, | |
for_critic_loss=for_critic_loss) | |
def critic_loss(self, states, actions, rewards, discounts, | |
next_states): | |
"""Computes a loss for training the critic network. | |
The loss is the mean squared error between the Q value predictions of the | |
critic and Q values estimated using TD-lambda. | |
Args: | |
states: A [batch_size, num_state_dims] tensor representing a batch | |
of states. | |
actions: A [batch_size, num_action_dims] tensor representing a batch | |
of actions. | |
rewards: A [batch_size, ...] tensor representing a batch of rewards, | |
broadcastable to the critic net output. | |
discounts: A [batch_size, ...] tensor representing a batch of discounts, | |
broadcastable to the critic net output. | |
next_states: A [batch_size, num_state_dims] tensor representing a batch | |
of next states. | |
Returns: | |
A rank-0 tensor representing the critic loss. | |
Raises: | |
ValueError: If any of the inputs do not have the expected dimensions, or | |
if their batch_sizes do not match. | |
""" | |
self._validate_states(states) | |
self._validate_actions(actions) | |
self._validate_states(next_states) | |
target_q_values = self.target_value_net(next_states, for_critic_loss=True) | |
td_targets = target_q_values * discounts + rewards | |
if self._target_q_clipping is not None: | |
td_targets = tf.clip_by_value(td_targets, self._target_q_clipping[0], | |
self._target_q_clipping[1]) | |
q_values = self.critic_net(states, actions, for_critic_loss=True) | |
td_errors = td_targets - q_values | |
if self._debug_summaries: | |
gen_debug_td_error_summaries( | |
target_q_values, q_values, td_targets, td_errors) | |
loss = self._td_errors_loss(td_targets, q_values) | |
if self._residual_phi > 0.0: # compute residual gradient loss | |
residual_q_values = self.value_net(next_states, for_critic_loss=True) | |
residual_td_targets = residual_q_values * discounts + rewards | |
if self._target_q_clipping is not None: | |
residual_td_targets = tf.clip_by_value(residual_td_targets, | |
self._target_q_clipping[0], | |
self._target_q_clipping[1]) | |
residual_td_errors = residual_td_targets - q_values | |
residual_loss = self._td_errors_loss( | |
residual_td_targets, residual_q_values) | |
loss = (loss * (1.0 - self._residual_phi) + | |
residual_loss * self._residual_phi) | |
return loss | |
def actor_loss(self, states): | |
"""Computes a loss for training the actor network. | |
Note that output does not represent an actual loss. It is called a loss only | |
in the sense that its gradient w.r.t. the actor network weights is the | |
correct gradient for training the actor network, | |
i.e. dloss/dweights = (dq/da)*(da/dweights) | |
which is the gradient used in Algorithm 1 of Lilicrap et al. | |
Args: | |
states: A [batch_size, num_state_dims] tensor representing a batch | |
of states. | |
Returns: | |
A rank-0 tensor representing the actor loss. | |
Raises: | |
ValueError: If `states` does not have the expected dimensions. | |
""" | |
self._validate_states(states) | |
actions = self.actor_net(states, stop_gradients=False) | |
critic_values = self.critic_net(states, actions) | |
q_values = self.critic_function(critic_values, states) | |
dqda = tf.gradients([q_values], [actions])[0] | |
dqda_unclipped = dqda | |
if self._dqda_clipping > 0: | |
dqda = tf.clip_by_value(dqda, -self._dqda_clipping, self._dqda_clipping) | |
actions_norm = tf.norm(actions) | |
if self._debug_summaries: | |
with tf.name_scope('dqda'): | |
tf.summary.scalar('actions_norm', actions_norm) | |
tf.summary.histogram('dqda', dqda) | |
tf.summary.histogram('dqda_unclipped', dqda_unclipped) | |
tf.summary.histogram('actions', actions) | |
for a in range(self._num_action_dims): | |
tf.summary.histogram('dqda_unclipped_%d' % a, dqda_unclipped[:, a]) | |
tf.summary.histogram('dqda_%d' % a, dqda[:, a]) | |
actions_norm *= self._actions_regularizer | |
return slim.losses.mean_squared_error(tf.stop_gradient(dqda + actions), | |
actions, | |
scope='actor_loss') + actions_norm | |
def critic_function(self, critic_values, states, weights=None): | |
"""Computes q values based on critic_net outputs, states, and weights. | |
Args: | |
critic_values: A tf.float32 [batch_size, ...] tensor representing outputs | |
from the critic net. | |
states: A [batch_size, num_state_dims] tensor representing a batch | |
of states. | |
weights: A list or Numpy array or tensor with a shape broadcastable to | |
`critic_values`. | |
Returns: | |
A tf.float32 [batch_size] tensor representing q values. | |
""" | |
del states # unused args | |
if weights is not None: | |
weights = tf.convert_to_tensor(weights, dtype=critic_values.dtype) | |
critic_values *= weights | |
if critic_values.shape.ndims > 1: | |
critic_values = tf.reduce_sum(critic_values, | |
range(1, critic_values.shape.ndims)) | |
critic_values.shape.assert_has_rank(1) | |
return critic_values | |
def update_targets(self, tau=1.0): | |
"""Performs a soft update of the target network parameters. | |
For each weight w_s in the actor/critic networks, and its corresponding | |
weight w_t in the target actor/critic networks, a soft update is: | |
w_t = (1- tau) x w_t + tau x ws | |
Args: | |
tau: A float scalar in [0, 1] | |
Returns: | |
An operation that performs a soft update of the target network parameters. | |
Raises: | |
ValueError: If `tau` is not in [0, 1]. | |
""" | |
if tau < 0 or tau > 1: | |
raise ValueError('Input `tau` should be in [0, 1].') | |
update_actor = utils.soft_variables_update( | |
slim.get_trainable_variables( | |
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)), | |
slim.get_trainable_variables( | |
utils.join_scope(self._scope, self.TARGET_ACTOR_NET_SCOPE)), | |
tau) | |
update_critic = utils.soft_variables_update( | |
slim.get_trainable_variables( | |
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)), | |
slim.get_trainable_variables( | |
utils.join_scope(self._scope, self.TARGET_CRITIC_NET_SCOPE)), | |
tau) | |
return tf.group(update_actor, update_critic, name='update_targets') | |
def get_trainable_critic_vars(self): | |
"""Returns a list of trainable variables in the critic network. | |
Returns: | |
A list of trainable variables in the critic network. | |
""" | |
return slim.get_trainable_variables( | |
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)) | |
def get_trainable_actor_vars(self): | |
"""Returns a list of trainable variables in the actor network. | |
Returns: | |
A list of trainable variables in the actor network. | |
""" | |
return slim.get_trainable_variables( | |
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)) | |
def get_critic_vars(self): | |
"""Returns a list of all variables in the critic network. | |
Returns: | |
A list of trainable variables in the critic network. | |
""" | |
return slim.get_model_variables( | |
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)) | |
def get_actor_vars(self): | |
"""Returns a list of all variables in the actor network. | |
Returns: | |
A list of trainable variables in the actor network. | |
""" | |
return slim.get_model_variables( | |
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)) | |
def _validate_states(self, states): | |
"""Raises a value error if `states` does not have the expected shape. | |
Args: | |
states: A tensor. | |
Raises: | |
ValueError: If states.shape or states.dtype are not compatible with | |
observation_spec. | |
""" | |
states.shape.assert_is_compatible_with(self._state_shape) | |
if not states.dtype.is_compatible_with(self._observation_spec.dtype): | |
raise ValueError('states.dtype={} is not compatible with' | |
' observation_spec.dtype={}'.format( | |
states.dtype, self._observation_spec.dtype)) | |
def _validate_actions(self, actions): | |
"""Raises a value error if `actions` does not have the expected shape. | |
Args: | |
actions: A tensor. | |
Raises: | |
ValueError: If actions.shape or actions.dtype are not compatible with | |
action_spec. | |
""" | |
actions.shape.assert_is_compatible_with(self._action_shape) | |
if not actions.dtype.is_compatible_with(self._action_spec.dtype): | |
raise ValueError('actions.dtype={} is not compatible with' | |
' action_spec.dtype={}'.format( | |
actions.dtype, self._action_spec.dtype)) | |
class TD3Agent(DdpgAgent): | |
"""An RL agent that learns using the TD3 algorithm.""" | |
ACTOR_NET_SCOPE = 'actor_net' | |
CRITIC_NET_SCOPE = 'critic_net' | |
CRITIC_NET2_SCOPE = 'critic_net2' | |
TARGET_ACTOR_NET_SCOPE = 'target_actor_net' | |
TARGET_CRITIC_NET_SCOPE = 'target_critic_net' | |
TARGET_CRITIC_NET2_SCOPE = 'target_critic_net2' | |
def __init__(self, | |
observation_spec, | |
action_spec, | |
actor_net=networks.actor_net, | |
critic_net=networks.critic_net, | |
td_errors_loss=tf.losses.huber_loss, | |
dqda_clipping=0., | |
actions_regularizer=0., | |
target_q_clipping=None, | |
residual_phi=0.0, | |
debug_summaries=False): | |
"""Constructs a TD3 agent. | |
Args: | |
observation_spec: A TensorSpec defining the observations. | |
action_spec: A BoundedTensorSpec defining the actions. | |
actor_net: A callable that creates the actor network. Must take the | |
following arguments: states, num_actions. Please see networks.actor_net | |
for an example. | |
critic_net: A callable that creates the critic network. Must take the | |
following arguments: states, actions. Please see networks.critic_net | |
for an example. | |
td_errors_loss: A callable defining the loss function for the critic | |
td error. | |
dqda_clipping: (float) clips the gradient dqda element-wise between | |
[-dqda_clipping, dqda_clipping]. Does not perform clipping if | |
dqda_clipping == 0. | |
actions_regularizer: A scalar, when positive penalizes the norm of the | |
actions. This can prevent saturation of actions for the actor_loss. | |
target_q_clipping: (tuple of floats) clips target q values within | |
(low, high) values when computing the critic loss. | |
residual_phi: (float) [0.0, 1.0] Residual algorithm parameter that | |
interpolates between Q-learning and residual gradient algorithm. | |
http://www.leemon.com/papers/1995b.pdf | |
debug_summaries: If True, add summaries to help debug behavior. | |
Raises: | |
ValueError: If 'dqda_clipping' is < 0. | |
""" | |
self._observation_spec = observation_spec[0] | |
self._action_spec = action_spec[0] | |
self._state_shape = tf.TensorShape([None]).concatenate( | |
self._observation_spec.shape) | |
self._action_shape = tf.TensorShape([None]).concatenate( | |
self._action_spec.shape) | |
self._num_action_dims = self._action_spec.shape.num_elements() | |
self._scope = tf.get_variable_scope().name | |
self._actor_net = tf.make_template( | |
self.ACTOR_NET_SCOPE, actor_net, create_scope_now_=True) | |
self._critic_net = tf.make_template( | |
self.CRITIC_NET_SCOPE, critic_net, create_scope_now_=True) | |
self._critic_net2 = tf.make_template( | |
self.CRITIC_NET2_SCOPE, critic_net, create_scope_now_=True) | |
self._target_actor_net = tf.make_template( | |
self.TARGET_ACTOR_NET_SCOPE, actor_net, create_scope_now_=True) | |
self._target_critic_net = tf.make_template( | |
self.TARGET_CRITIC_NET_SCOPE, critic_net, create_scope_now_=True) | |
self._target_critic_net2 = tf.make_template( | |
self.TARGET_CRITIC_NET2_SCOPE, critic_net, create_scope_now_=True) | |
self._td_errors_loss = td_errors_loss | |
if dqda_clipping < 0: | |
raise ValueError('dqda_clipping must be >= 0.') | |
self._dqda_clipping = dqda_clipping | |
self._actions_regularizer = actions_regularizer | |
self._target_q_clipping = target_q_clipping | |
self._residual_phi = residual_phi | |
self._debug_summaries = debug_summaries | |
def get_trainable_critic_vars(self): | |
"""Returns a list of trainable variables in the critic network. | |
NOTE: This gets the vars of both critic networks. | |
Returns: | |
A list of trainable variables in the critic network. | |
""" | |
return ( | |
slim.get_trainable_variables( | |
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE))) | |
def critic_net(self, states, actions, for_critic_loss=False): | |
"""Returns the output of the critic network. | |
Args: | |
states: A [batch_size, num_state_dims] tensor representing a batch | |
of states. | |
actions: A [batch_size, num_action_dims] tensor representing a batch | |
of actions. | |
Returns: | |
q values: A [batch_size] tensor of q values. | |
Raises: | |
ValueError: If `states` or `actions' do not have the expected dimensions. | |
""" | |
values1 = self._critic_net(states, actions, | |
for_critic_loss=for_critic_loss) | |
values2 = self._critic_net2(states, actions, | |
for_critic_loss=for_critic_loss) | |
if for_critic_loss: | |
return values1, values2 | |
return values1 | |
def target_critic_net(self, states, actions, for_critic_loss=False): | |
"""Returns the output of the target critic network. | |
The target network is used to compute stable targets for training. | |
Args: | |
states: A [batch_size, num_state_dims] tensor representing a batch | |
of states. | |
actions: A [batch_size, num_action_dims] tensor representing a batch | |
of actions. | |
Returns: | |
q values: A [batch_size] tensor of q values. | |
Raises: | |
ValueError: If `states` or `actions' do not have the expected dimensions. | |
""" | |
self._validate_states(states) | |
self._validate_actions(actions) | |
values1 = tf.stop_gradient( | |
self._target_critic_net(states, actions, | |
for_critic_loss=for_critic_loss)) | |
values2 = tf.stop_gradient( | |
self._target_critic_net2(states, actions, | |
for_critic_loss=for_critic_loss)) | |
if for_critic_loss: | |
return values1, values2 | |
return values1 | |
def value_net(self, states, for_critic_loss=False): | |
"""Returns the output of the critic evaluated with the actor. | |
Args: | |
states: A [batch_size, num_state_dims] tensor representing a batch | |
of states. | |
Returns: | |
q values: A [batch_size] tensor of q values. | |
""" | |
actions = self.actor_net(states) | |
return self.critic_net(states, actions, | |
for_critic_loss=for_critic_loss) | |
def target_value_net(self, states, for_critic_loss=False): | |
"""Returns the output of the target critic evaluated with the target actor. | |
Args: | |
states: A [batch_size, num_state_dims] tensor representing a batch | |
of states. | |
Returns: | |
q values: A [batch_size] tensor of q values. | |
""" | |
target_actions = self.target_actor_net(states) | |
noise = tf.clip_by_value( | |
tf.random_normal(tf.shape(target_actions), stddev=0.2), -0.5, 0.5) | |
values1, values2 = self.target_critic_net( | |
states, target_actions + noise, | |
for_critic_loss=for_critic_loss) | |
values = tf.minimum(values1, values2) | |
return values, values | |
def update_targets(self, tau=1.0): | |
"""Performs a soft update of the target network parameters. | |
For each weight w_s in the actor/critic networks, and its corresponding | |
weight w_t in the target actor/critic networks, a soft update is: | |
w_t = (1- tau) x w_t + tau x ws | |
Args: | |
tau: A float scalar in [0, 1] | |
Returns: | |
An operation that performs a soft update of the target network parameters. | |
Raises: | |
ValueError: If `tau` is not in [0, 1]. | |
""" | |
if tau < 0 or tau > 1: | |
raise ValueError('Input `tau` should be in [0, 1].') | |
update_actor = utils.soft_variables_update( | |
slim.get_trainable_variables( | |
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)), | |
slim.get_trainable_variables( | |
utils.join_scope(self._scope, self.TARGET_ACTOR_NET_SCOPE)), | |
tau) | |
# NOTE: This updates both critic networks. | |
update_critic = utils.soft_variables_update( | |
slim.get_trainable_variables( | |
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)), | |
slim.get_trainable_variables( | |
utils.join_scope(self._scope, self.TARGET_CRITIC_NET_SCOPE)), | |
tau) | |
return tf.group(update_actor, update_critic, name='update_targets') | |
def gen_debug_td_error_summaries( | |
target_q_values, q_values, td_targets, td_errors): | |
"""Generates debug summaries for critic given a set of batch samples. | |
Args: | |
target_q_values: set of predicted next stage values. | |
q_values: current predicted value for the critic network. | |
td_targets: discounted target_q_values with added next stage reward. | |
td_errors: the different between td_targets and q_values. | |
""" | |
with tf.name_scope('td_errors'): | |
tf.summary.histogram('td_targets', td_targets) | |
tf.summary.histogram('q_values', q_values) | |
tf.summary.histogram('target_q_values', target_q_values) | |
tf.summary.histogram('td_errors', td_errors) | |
with tf.name_scope('td_targets'): | |
tf.summary.scalar('mean', tf.reduce_mean(td_targets)) | |
tf.summary.scalar('max', tf.reduce_max(td_targets)) | |
tf.summary.scalar('min', tf.reduce_min(td_targets)) | |
with tf.name_scope('q_values'): | |
tf.summary.scalar('mean', tf.reduce_mean(q_values)) | |
tf.summary.scalar('max', tf.reduce_max(q_values)) | |
tf.summary.scalar('min', tf.reduce_min(q_values)) | |
with tf.name_scope('target_q_values'): | |
tf.summary.scalar('mean', tf.reduce_mean(target_q_values)) | |
tf.summary.scalar('max', tf.reduce_max(target_q_values)) | |
tf.summary.scalar('min', tf.reduce_min(target_q_values)) | |
with tf.name_scope('td_errors'): | |
tf.summary.scalar('mean', tf.reduce_mean(td_errors)) | |
tf.summary.scalar('max', tf.reduce_max(td_errors)) | |
tf.summary.scalar('min', tf.reduce_min(td_errors)) | |
tf.summary.scalar('mean_abs', tf.reduce_mean(tf.abs(td_errors))) | |