all.nn

class all.nn.Aggregation(*args: Any, **kwargs: Any)

Bases: torch.nn.

Aggregation layer for the Dueling architecture.

https://arxiv.org/abs/1511.06581 This layer computes a Q function by combining an estimate of V with an estimate of the advantage. The advantage is normalized by subtracting the average advantage so that we can properly

forward(value, advantages)
class all.nn.CategoricalDueling(*args: Any, **kwargs: Any)

Bases: torch.nn.

Dueling architecture for C51/Rainbow

forward(features)
class all.nn.Dueling(*args: Any, **kwargs: Any)

Bases: torch.nn.

Implementation of the head for the Dueling architecture.

https://arxiv.org/abs/1511.06581 This module computes a Q function by computing an estimate of V, and estimate of the advantage, and combining them with a special Aggregation layer.

forward(features)
class all.nn.Flatten(*args: Any, **kwargs: Any)

Bases: torch.nn.

Flatten a tensor, e.g., between conv2d and linear layers.

The maintainers FINALLY added this to torch.nn, but I am leaving it in for compatible for the moment.

forward(x)
class all.nn.Linear0(*args: Any, **kwargs: Any)

Bases: torch.nn.

reset_parameters()
class all.nn.NoisyFactorizedLinear(*args: Any, **kwargs: Any)

Bases: torch.nn.

NoisyNet layer with factorized gaussian noise

N.B. nn.Linear already initializes weight and bias to

forward(input)
reset_parameters()
class all.nn.NoisyLinear(*args: Any, **kwargs: Any)

Bases: torch.nn.

Implementation of Linear layer for NoisyNets

https://arxiv.org/abs/1706.10295 NoisyNets are a replacement for epsilon greedy exploration. Gaussian noise is added to the weights of the output layer, resulting in a stochastic policy. Exploration is implicitly learned at a per-state and per-action level, resulting in smarter exploration.

forward(x)
reset_parameters()
class all.nn.RLNetwork(*args: Any, **kwargs: Any)

Bases: torch.nn.

Wraps a network such that States can be given as input.

forward(state)
class all.nn.Scale(*args: Any, **kwargs: Any)

Bases: torch.nn.

forward(x)
class all.nn.TanhActionBound(*args: Any, **kwargs: Any)

Bases: torch.nn.

forward(x)
all.nn.td_loss(loss)
all.nn.weighted_mse_loss(input, target, weight, reduction='mean')
all.nn.weighted_smooth_l1_loss(input, target, weight, reduction='mean')