all.approximation

class all.approximation.Approximation(model, optimizer=None, checkpointer=None, clip_grad=0, device=None, loss_scaling=1, name='approximation', scheduler=None, target=None, writer=<all.logging.DummyWriter object>)

Bases: object

Base function approximation object.

This defines a Pytorch-based function approximation object that wraps key functionality useful for reinforcement learning, including decaying learning rates, model checkpointing, loss scaling, gradient clipping, target networks, and tensorboard logging. This enables increased code reusability and simpler Agent implementations.

Parameters
  • model (torch.nn.Module) – A Pytorch module representing the model used to approximate the function. This could be a convolution network, a fully connected network, or any other Pytorch-compatible model.

  • optimizer (torch.optim.Optimizer) – A optimizer initialized with the model parameters, e.g. SGD, Adam, RMSprop, etc.

  • checkpointer (all.approximation.checkpointer.Checkpointer) – A Checkpointer object that periodically saves the model and its parameters to the disk. Default: A PeriodicCheckpointer that saves the model once every 200 updates.

  • clip_grad (float, optional) – If non-zero, clips the norm of the gradient to this value in order prevent large updates and improve stability. See torch.nn.utils.clip_grad.

  • device (string, optional) – The device that the model is on. If none is passed, the device will be automatically determined based on model.parameters()

  • loss_scaling (float, optional) – Multiplies the loss by this value before performing a backwards pass. Useful when used with multi-headed networks with shared feature layers.

  • name (str, optional) – The name of the function approximator used for logging.

  • ( (scheduler) – torch.optim.lr_scheduler._LRScheduler:, optional): A learning rate scheduler initialized with the given optimizer. step() will be called after every update.

  • target (all.approximation.target.TargetNetwork, optional) – A target network object to be used during optimization. A target network updates more slowly than the base model that is being optimizing, allowing for a more stable optimization target.

  • (all.logging.Writer (writer) – , optional): A Writer object used for logging. The standard object logs to tensorboard, however, other types of Writer objects may be implemented by the user.

eval(*inputs)

Run a forward pass of the model in eval mode with no_grad. The model is returned to its previous mode afer the forward pass is made.

no_grad(*inputs)

Run a forward pass of the model in no_grad mode.

reinforce(loss)

Backpropagate the loss through the model and make an update step. Internally, this will perform most of the activities associated with a control loop in standard machine learning environments, depending on the configuration of the object: Gradient clipping, learning rate schedules, logging, checkpointing, etc.

Parameters

loss (torch.Tensor) – The loss computed for a batch of inputs.

Returns

The current Approximation object

Return type

self

step()

Given that a backward pass has been made, run an optimization step Internally, this will perform most of the activities associated with a control loop in standard machine learning environments, depending on the configuration of the object: Gradient clipping, learning rate schedules, logging, checkpointing, etc.

Returns

The current Approximation object

Return type

self

target(*inputs)

Run a forward pass of the target network.

zero_grad()

Clears the gradients of all optimized tensors

Returns

The current Approximation object

Return type

self

class all.approximation.Checkpointer

Bases: abc.ABC

abstract init(model, filename)
class all.approximation.DummyCheckpointer

Bases: all.approximation.checkpointer.Checkpointer

init(*inputs)
class all.approximation.FeatureNetwork(model, optimizer=None, name='feature', **kwargs)

Bases: all.approximation.approximation.Approximation

A special type of Approximation that accumulates gradients before backpropagating them. This is useful when features are shared between network heads.

The __call__ function caches the computation graph and detaches the output. Then, various functions approximators may backpropagate to the output. The reinforce() function will then backpropagate the accumulated gradients on the output through the original computation graph.

reinforce()

Backward pass of the model.

class all.approximation.FixedTarget(update_frequency)

Bases: all.approximation.target.abstract.TargetNetwork

init(model)
update()
class all.approximation.Identity(device, name='identity', **kwargs)

Bases: all.approximation.approximation.Approximation

An Approximation that represents the identity function.

Because the model has no parameters, reinforce and step do nothing.

reinforce()

Backpropagate the loss through the model and make an update step. Internally, this will perform most of the activities associated with a control loop in standard machine learning environments, depending on the configuration of the object: Gradient clipping, learning rate schedules, logging, checkpointing, etc.

Parameters

loss (torch.Tensor) – The loss computed for a batch of inputs.

Returns

The current Approximation object

Return type

self

step()

Given that a backward pass has been made, run an optimization step Internally, this will perform most of the activities associated with a control loop in standard machine learning environments, depending on the configuration of the object: Gradient clipping, learning rate schedules, logging, checkpointing, etc.

Returns

The current Approximation object

Return type

self

class all.approximation.PeriodicCheckpointer(frequency)

Bases: all.approximation.checkpointer.Checkpointer

init(model, filename)
class all.approximation.PolyakTarget(rate)

Bases: all.approximation.target.abstract.TargetNetwork

TargetNetwork that updates using polyak averaging

init(model)
update()
class all.approximation.QContinuous(model, optimizer, name='q', **kwargs)

Bases: all.approximation.approximation.Approximation

class all.approximation.QDist(model, optimizer, n_actions, n_atoms, v_min, v_max, name='q_dist', **kwargs)

Bases: all.approximation.approximation.Approximation

project(dist, support)
class all.approximation.QNetwork(model, optimizer=None, name='q', **kwargs)

Bases: all.approximation.approximation.Approximation

class all.approximation.TargetNetwork

Bases: abc.ABC

abstract init(model)
abstract update()
class all.approximation.TrivialTarget

Bases: all.approximation.target.abstract.TargetNetwork

init(model)
update()
class all.approximation.VNetwork(model, optimizer, name='v', **kwargs)

Bases: all.approximation.approximation.Approximation