all.nn
- class all.nn.Aggregation(*args: Any, **kwargs: Any)
Bases:
Module
Aggregation layer for the Dueling architecture.
https://arxiv.org/abs/1511.06581 This layer computes a Q function by combining an estimate of V with an estimate of the advantage. The advantage is normalized by subtracting the average advantage to force action-independent value to be represented by value.
- forward(value, advantages)
- class all.nn.CategoricalDueling(*args: Any, **kwargs: Any)
Bases:
Module
Dueling architecture for C51/Rainbow
- forward(features)
- class all.nn.Dueling(*args: Any, **kwargs: Any)
Bases:
Module
Implementation of the head for the Dueling architecture.
https://arxiv.org/abs/1511.06581 This module computes a Q function by computing an estimate of V, and estimate of the advantage, and combining them with a special Aggregation layer.
- forward(features)
- class all.nn.Flatten(*args: Any, **kwargs: Any)
Bases:
Module
Flatten a tensor, e.g., between conv2d and linear layers.
The maintainers FINALLY added this to torch.nn, but I am leaving it in for compatible for the moment.
- forward(x)
- class all.nn.NoisyFactorizedLinear(*args: Any, **kwargs: Any)
Bases:
Linear
NoisyNet layer with factorized gaussian noise
N.B. nn.Linear already initializes weight and bias to
- forward(input)
- reset_parameters()
- class all.nn.NoisyLinear(*args: Any, **kwargs: Any)
Bases:
Linear
Implementation of Linear layer for NoisyNets
https://arxiv.org/abs/1706.10295 NoisyNets are a replacement for epsilon greedy exploration. Gaussian noise is added to the weights of the output layer, resulting in a stochastic policy. Exploration is implicitly learned at a per-state and per-action level, resulting in smarter exploration.
- forward(x)
- reset_parameters()
- class all.nn.RLNetwork(*args: Any, **kwargs: Any)
Bases:
Module
Wraps a network such that States can be given as input.
- forward(state)
- all.nn.td_loss(loss)
- all.nn.weighted_mse_loss(input, target, weight, reduction='mean')
- all.nn.weighted_smooth_l1_loss(input, target, weight, reduction='mean')