all.approximation¶
- class all.approximation.Approximation(model, optimizer=None, checkpointer=None, clip_grad=0, device=None, loss_scaling=1, name='approximation', scheduler=None, target=None, writer=<all.logging.DummyWriter object>)¶
Bases:
object
Base function approximation object.
This defines a Pytorch-based function approximation object that wraps key functionality useful for reinforcement learning, including decaying learning rates, model checkpointing, loss scaling, gradient clipping, target networks, and tensorboard logging. This enables increased code reusability and simpler Agent implementations.
- Parameters
model (torch.nn.Module) – A Pytorch module representing the model used to approximate the function. This could be a convolution network, a fully connected network, or any other Pytorch-compatible model.
optimizer (torch.optim.Optimizer) – A optimizer initialized with the model parameters, e.g. SGD, Adam, RMSprop, etc.
checkpointer (all.approximation.checkpointer.Checkpointer) – A Checkpointer object that periodically saves the model and its parameters to the disk. Default: A PeriodicCheckpointer that saves the model once every 200 updates.
clip_grad (float, optional) – If non-zero, clips the norm of the gradient to this value in order prevent large updates and improve stability. See torch.nn.utils.clip_grad.
device (string, optional) – The device that the model is on. If none is passed, the device will be automatically determined based on model.parameters()
loss_scaling (float, optional) – Multiplies the loss by this value before performing a backwards pass. Useful when used with multi-headed networks with shared feature layers.
name (str, optional) – The name of the function approximator used for logging.
( (scheduler) – torch.optim.lr_scheduler._LRScheduler:, optional): A learning rate scheduler initialized with the given optimizer. step() will be called after every update.
target (all.approximation.target.TargetNetwork, optional) – A target network object to be used during optimization. A target network updates more slowly than the base model that is being optimizing, allowing for a more stable optimization target.
(all.logging.Writer (writer) – , optional): A Writer object used for logging. The standard object logs to tensorboard, however, other types of Writer objects may be implemented by the user.
- eval(*inputs)¶
Run a forward pass of the model in eval mode with no_grad. The model is returned to its previous mode afer the forward pass is made.
- no_grad(*inputs)¶
Run a forward pass of the model in no_grad mode.
- reinforce(loss)¶
Backpropagate the loss through the model and make an update step. Internally, this will perform most of the activities associated with a control loop in standard machine learning environments, depending on the configuration of the object: Gradient clipping, learning rate schedules, logging, checkpointing, etc.
- Parameters
loss (torch.Tensor) – The loss computed for a batch of inputs.
- Returns
The current Approximation object
- Return type
self
- step()¶
Given that a backward pass has been made, run an optimization step Internally, this will perform most of the activities associated with a control loop in standard machine learning environments, depending on the configuration of the object: Gradient clipping, learning rate schedules, logging, checkpointing, etc.
- Returns
The current Approximation object
- Return type
self
- target(*inputs)¶
Run a forward pass of the target network.
- zero_grad()¶
Clears the gradients of all optimized tensors
- Returns
The current Approximation object
- Return type
self
- class all.approximation.DummyCheckpointer¶
Bases:
all.approximation.checkpointer.Checkpointer
- init(*inputs)¶
- class all.approximation.FeatureNetwork(model, optimizer=None, name='feature', **kwargs)¶
Bases:
all.approximation.approximation.Approximation
A special type of Approximation that accumulates gradients before backpropagating them. This is useful when features are shared between network heads.
The __call__ function caches the computation graph and detaches the output. Then, various functions approximators may backpropagate to the output. The reinforce() function will then backpropagate the accumulated gradients on the output through the original computation graph.
- reinforce()¶
Backward pass of the model.
- class all.approximation.FixedTarget(update_frequency)¶
Bases:
all.approximation.target.abstract.TargetNetwork
- init(model)¶
- update()¶
- class all.approximation.Identity(device, name='identity', **kwargs)¶
Bases:
all.approximation.approximation.Approximation
An Approximation that represents the identity function.
Because the model has no parameters, reinforce and step do nothing.
- reinforce()¶
Backpropagate the loss through the model and make an update step. Internally, this will perform most of the activities associated with a control loop in standard machine learning environments, depending on the configuration of the object: Gradient clipping, learning rate schedules, logging, checkpointing, etc.
- Parameters
loss (torch.Tensor) – The loss computed for a batch of inputs.
- Returns
The current Approximation object
- Return type
self
- step()¶
Given that a backward pass has been made, run an optimization step Internally, this will perform most of the activities associated with a control loop in standard machine learning environments, depending on the configuration of the object: Gradient clipping, learning rate schedules, logging, checkpointing, etc.
- Returns
The current Approximation object
- Return type
self
- class all.approximation.PeriodicCheckpointer(frequency)¶
Bases:
all.approximation.checkpointer.Checkpointer
- init(model, filename)¶
- class all.approximation.PolyakTarget(rate)¶
Bases:
all.approximation.target.abstract.TargetNetwork
TargetNetwork that updates using polyak averaging
- init(model)¶
- update()¶
- class all.approximation.QContinuous(model, optimizer, name='q', **kwargs)¶
- class all.approximation.QDist(model, optimizer, n_actions, n_atoms, v_min, v_max, name='q_dist', **kwargs)¶
Bases:
all.approximation.approximation.Approximation
- project(dist, support)¶
- class all.approximation.QNetwork(model, optimizer=None, name='q', **kwargs)¶
- class all.approximation.TrivialTarget¶
Bases:
all.approximation.target.abstract.TargetNetwork
- init(model)¶
- update()¶
- class all.approximation.VNetwork(model, optimizer, name='v', **kwargs)¶