all.approximation

class all.approximation.Approximation(model, optimizer=None, checkpointer=None, clip_grad=0, device=None, loss_scaling=1, name='approximation', scheduler=None, target=None, logger=<all.logging.dummy.DummyLogger object>)

Bases: object

Base function approximation object.

This defines a Pytorch-based function approximation object that wraps key functionality useful for reinforcement learning, including decaying learning rates, model checkpointing, loss scaling, gradient clipping, target networks, and tensorboard logging. This enables increased code reusability and simpler Agent implementations.

Parameters:
  • model (torch.nn.Module) – A Pytorch module representing the model used to approximate the function. This could be a convolution network, a fully connected network, or any other Pytorch-compatible model.

  • optimizer (torch.optim.Optimizer) – A optimizer initialized with the model parameters, e.g. SGD, Adam, RMSprop, etc.

  • checkpointer (all.approximation.checkpointer.Checkpointer) – A Checkpointer object that periodically saves the model and its parameters to the disk. Default: A PeriodicCheckpointer that saves the model once every 200 updates.

  • clip_grad (float, optional) – If non-zero, clips the norm of the gradient to this value in order prevent large updates and improve stability. See torch.nn.utils.clip_grad.

  • device (string, optional) – The device that the model is on. If none is passed, the device will be automatically determined based on model.parameters()

  • loss_scaling (float, optional) – Multiplies the loss by this value before performing a backwards pass. Useful when used with multi-headed networks with shared feature layers.

  • name (str, optional) – The name of the function approximator used for logging.

  • ( (scheduler) – torch.optim.lr_scheduler._LRScheduler:, optional): A learning rate scheduler initialized with the given optimizer. step() will be called after every update.

  • target (all.approximation.target.TargetNetwork, optional) – A target network object to be used during optimization. A target network updates more slowly than the base model that is being optimizing, allowing for a more stable optimization target.

  • (all.logging.Logger (logger) – , optional): A Logger object used for logging. The standard object logs to tensorboard, however, other types of Logger objects may be implemented by the user.

eval(*inputs)

Run a forward pass of the model in eval mode with no_grad. The model is returned to its previous mode afer the forward pass is made.

no_grad(*inputs)

Run a forward pass of the model in no_grad mode.

reinforce(loss)

Backpropagate the loss through the model and make an update step. Internally, this will perform most of the activities associated with a control loop in standard machine learning environments, depending on the configuration of the object: Gradient clipping, learning rate schedules, logging, checkpointing, etc.

Parameters:

loss (torch.Tensor) – The loss computed for a batch of inputs.

Returns:

The current Approximation object

Return type:

self

step(loss=None)

Given that a backward pass has been made, run an optimization step. Internally, this will perform most of the activities associated with a control loop in standard machine learning environments, depending on the configuration of the object: Gradient clipping, learning rate schedules, logging, checkpointing, etc.

Parameters:

loss (torch.Tensor, optional) – The loss to log for this opdate step.

Returns:

The current Approximation object

Return type:

self

target(*inputs)

Run a forward pass of the target network.

zero_grad()

Clears the gradients of all optimized tensors

Returns:

The current Approximation object

Return type:

self

class all.approximation.Checkpointer

Bases: ABC

abstract init(model, filename)
class all.approximation.DummyCheckpointer

Bases: Checkpointer

init(*inputs)
class all.approximation.FeatureNetwork(model, optimizer=None, name='feature', **kwargs)

Bases: Approximation

An Approximation that accepts a state updates the observation key based on the given model.

class all.approximation.FixedTarget(update_frequency)

Bases: TargetNetwork

init(model)
update()
class all.approximation.Identity(device, name='identity', **kwargs)

Bases: Approximation

An Approximation that represents the identity function.

Because the model has no parameters, reinforce and step do nothing.

reinforce()

Backpropagate the loss through the model and make an update step. Internally, this will perform most of the activities associated with a control loop in standard machine learning environments, depending on the configuration of the object: Gradient clipping, learning rate schedules, logging, checkpointing, etc.

Parameters:

loss (torch.Tensor) – The loss computed for a batch of inputs.

Returns:

The current Approximation object

Return type:

self

step()

Given that a backward pass has been made, run an optimization step. Internally, this will perform most of the activities associated with a control loop in standard machine learning environments, depending on the configuration of the object: Gradient clipping, learning rate schedules, logging, checkpointing, etc.

Parameters:

loss (torch.Tensor, optional) – The loss to log for this opdate step.

Returns:

The current Approximation object

Return type:

self

class all.approximation.PeriodicCheckpointer(frequency)

Bases: Checkpointer

init(model, filename)
class all.approximation.PolyakTarget(rate)

Bases: TargetNetwork

TargetNetwork that updates using polyak averaging

init(model)
update()
class all.approximation.QContinuous(model, optimizer, name='q', **kwargs)

Bases: Approximation

class all.approximation.QDist(model, optimizer, n_actions, n_atoms, v_min, v_max, name='q_dist', **kwargs)

Bases: Approximation

project(dist, support)
class all.approximation.QNetwork(model, optimizer=None, name='q', **kwargs)

Bases: Approximation

class all.approximation.TargetNetwork

Bases: ABC

abstract init(model)
abstract update()
class all.approximation.TrivialTarget

Bases: TargetNetwork

init(model)
update()
class all.approximation.VNetwork(model, optimizer, name='v', **kwargs)

Bases: Approximation