all.core

State(x[, device])

An environment State.

StateArray(x, shape[, device])

An n-dimensional array of environment State objects.

class all.core.State(x, device='cpu', **kwargs)

Bases: dict

An environment State.

An environment State represents all of the information available to an agent at a given timestep, including the observation, reward, and the done flag. The State object contains useful utilities for creating StateArray objects, constructing State objects for OpenAI gym environments, masking the output of networks based on the done flag, etc.

Parameters
  • x (dict) –

    A dictionary containing all state information. Any key/value can be included, but the following keys are standard:

    observation (torch.tensor) (required):

    A tensor representing the current observation available to the agent

    reward (float) (optional):

    The reward for the previous state/action. Defaults to 0.

    done (bool) (optional):

    Whether or not this is a terminal state. Defaults to False.

    mask (float) (optional):

    The mask (0 or 1) for the current state.

  • device (string) – The torch device on which component tensors are stored.

apply(model, *keys)

Apply a model to the state. Automatically selects the correct keys, reshapes the input/output as necessary and applies the mask.

Parameters
  • model (torch.nn.Module) – A torch Module which accepts the components corresponding to the given keys as args.

  • keys (string) – Strings corresponding to the desired components of the state. E.g., apply(model, ‘observation’, ‘reward’) would pass the observation and reward as arguments to the model.

Returns

The output of the model.

apply_mask(tensor)

Applies the mask to the given tensor, generally to prevent backpropagation through terminal states.

Parameters

tensor (torch.tensor) – The tensor to apply the mask to.

Returns

A torch.tensor with the mask applied.

classmethod array(list_of_states)

Construct a StateArray from a list of State or StateArray objects. The shape of the resulting StateArray is (N, …M), where N is the length of the input list and M is the shape of the component State or StateArray objects.

Parameters

list_of_states – A list of State or StateArray objects with a matching shape.

Returns

A StateArray object.

as_input(key)

Gets the value for a given key and reshapes it to a batch-style tensor suitable as input to a pytorch module.

Parameters

key (string) – The component of the state to select.

Returns

A torch.tensor containing the value of the component with a batch dimension added.

as_output(tensor)

Reshapes the output of a batch-style pytorch module to match the original shape of the state.

Parameters

tensor (torch.tensor) – The output of a batch-style pytorch module.

Returns

A torch.tensor containing the output in the appropriate shape.

property done

A boolean that is true if the state is a terminal state, and false otherwise.

classmethod from_gym(state, device='cpu', dtype=<class 'numpy.float32'>)

Constructs a State object given the return value of an OpenAI gym reset()/step(action) call.

Parameters
  • state (tuple) – The return value of an OpenAI gym reset()/step(action) call

  • device (string) – The device on which to store resulting tensors.

  • dtype – The type of the observation.

Returns

A State object.

property mask

A float that is 1. if the state is non-terminal, or 0. otherwise.

property observation

A tensor containing the current observation.

property reward

A float representing the reward for the previous state/action pair.

property shape

The shape of the State or StateArray. A State always has shape ().

update(key, value)

Adds a key/value pair to the state, or updates an existing key/value pair. Note that this is NOT an in-place operation, but returns a new State or StateArray.

Parameters
  • key (string) – The name of the state component to update.

  • value (any) – The value of the new state component.

Returns

A State or StateArray object with the given component added/updated.

class all.core.StateArray(x, shape, device='cpu', **kwargs)

Bases: all.core.state.State

An n-dimensional array of environment State objects.

Internally, all components of the states are represented as n-dimensional tensors. This allows for batch-style processing and easy manipulation of states. Usually, a StateArray should be constructed using the State.array() function.

Parameters
  • x (dict) –

    A dictionary containing all state information. Each value should be a tensor in which the first n-dimensions match the shape of the StateArray. The following keys are standard:

    observation (torch.tensor) (required):

    A tensor representing the observations for each state

    reward (torch.FloatTensor) (optional):

    A tensor representing rewards for the previous state/action pairs

    done (torch.BoolTensors) (optional):

    A tensor representing whether each state is terminal

    mask (torch.FloatTensor) (optional):

    A tensor representing the mask for each state.

  • device (string) – The torch device on which component tensors are stored.

apply_mask(tensor)

Applies the mask to the given tensor, generally to prevent backpropagation through terminal states.

Parameters

tensor (torch.tensor) – The tensor to apply the mask to.

Returns

A torch.tensor with the mask applied.

as_input(key)

Gets the value for a given key and reshapes it to a batch-style tensor suitable as input to a pytorch module.

Parameters

key (string) – The component of the state to select.

Returns

A torch.tensor containing the value of the component with a batch dimension added.

as_output(tensor)

Reshapes the output of a batch-style pytorch module to match the original shape of the state.

Parameters

tensor (torch.tensor) – The output of a batch-style pytorch module.

Returns

A torch.tensor containing the output in the appropriate shape.

property done

A boolean that is true if the state is a terminal state, and false otherwise.

flatten()

Converts an n-dimensional StateArray to a 1-dimensional StateArray

Returns

A 1-dimensional StateArray

property mask

A float that is 1. if the state is non-terminal, or 0. otherwise.

property observation

A tensor containing the current observation.

property reward

A float representing the reward for the previous state/action pair.

property shape

The shape of the StateArray

update(key, value)

Adds a key/value pair to the StateArray, or updates an existing key/value pair. The value should be a tensor whose first n-dimensions match the shape of the StateArray Note that this is NOT an in-place operation, but returns a StateArray.

Parameters
  • key (string) – The name of the state component to update.

  • value (any) – The value of the new state component.

Returns

A StateArray object with the given component added/updated.

view(shape)

Analogous to torch.tensor.view(), returns a new StateArray object containing the same data but with a different shape.

Returns

A StateArray with the given shape