all.core
|
An environment State. |
|
An n-dimensional array of environment State objects. |
|
- class all.core.MultiagentState(x, device='cpu', **kwargs)
Bases:
State
- property agent
- classmethod from_zoo(agent, state, device='cpu', dtype=<class 'numpy.float32'>)
Constructs a State object given the return value of an OpenAI gym reset()/step(action) call.
- Parameters:
state (tuple) – The return value of an OpenAI gym reset()/step(action) call
device (string) – The device on which to store resulting tensors.
dtype – The type of the observation.
- Returns:
A State object.
- to(device)
- class all.core.State(x, device='cpu', **kwargs)
Bases:
dict
An environment State.
An environment State represents all of the information available to an agent at a given timestep, including the observation, reward, and the done flag. The State object contains useful utilities for creating StateArray objects, constructing State objects for OpenAI gym environments, masking the output of networks based on the done flag, etc.
- Parameters:
x (dict) –
A dictionary containing all state information. Any key/value can be included, but the following keys are standard:
- observation (torch.tensor) (required):
A tensor representing the current observation available to the agent
- reward (float) (optional):
The reward for the previous state/action. Defaults to 0.
- done (bool) (optional):
Whether or not this is a terminal state. Defaults to False.
- mask (float) (optional):
The mask (0 or 1) for the current state.
device (string) – The torch device on which component tensors are stored.
- apply(model, *keys)
Apply a model to the state. Automatically selects the correct keys, reshapes the input/output as necessary and applies the mask.
- Parameters:
model (torch.nn.Module) – A torch Module which accepts the components corresponding to the given keys as args.
keys (string) – Strings corresponding to the desired components of the state. E.g., apply(model, ‘observation’, ‘reward’) would pass the observation and reward as arguments to the model.
- Returns:
The output of the model.
- apply_mask(tensor)
Applies the mask to the given tensor, generally to prevent backpropagation through terminal states.
- Parameters:
tensor (torch.tensor) – The tensor to apply the mask to.
- Returns:
A torch.tensor with the mask applied.
- classmethod array(list_of_states)
Construct a StateArray from a list of State or StateArray objects. The shape of the resulting StateArray is (N, …M), where N is the length of the input list and M is the shape of the component State or StateArray objects.
- Parameters:
list_of_states – A list of State or StateArray objects with a matching shape.
- Returns:
A StateArray object.
- as_input(key)
Gets the value for a given key and reshapes it to a batch-style tensor suitable as input to a pytorch module.
- Parameters:
key (string) – The component of the state to select.
- Returns:
A torch.tensor containing the value of the component with a batch dimension added.
- as_output(tensor)
Reshapes the output of a batch-style pytorch module to match the original shape of the state.
- Parameters:
tensor (torch.tensor) – The output of a batch-style pytorch module.
- Returns:
A torch.tensor containing the output in the appropriate shape.
- property done
A boolean that is true if the state is a terminal state, and false otherwise.
- classmethod from_gym(gym_output, device='cpu', dtype=<class 'numpy.float32'>)
Constructs a State object given the return value of an OpenAI gym reset()/step(action) call.
- Parameters:
gym_output (tuple) – The output value of an OpenAI gym reset()/step(action) call
device (string) – The device on which to store resulting tensors.
dtype – The type of the observation.
- Returns:
A State object.
- property mask
A float that is 1. if the state is non-terminal, or 0. otherwise.
- property observation
A tensor containing the current observation.
- property reward
A float representing the reward for the previous state/action pair.
- property shape
The shape of the State or StateArray. A State always has shape ().
- to(device)
- update(key, value)
Adds a key/value pair to the state, or updates an existing key/value pair. Note that this is NOT an in-place operation, but returns a new State or StateArray.
- Parameters:
key (string) – The name of the state component to update.
value (any) – The value of the new state component.
- Returns:
A State or StateArray object with the given component added/updated.
- class all.core.StateArray(x, shape, device='cpu', **kwargs)
Bases:
State
An n-dimensional array of environment State objects.
Internally, all components of the states are represented as n-dimensional tensors. This allows for batch-style processing and easy manipulation of states. Usually, a StateArray should be constructed using the State.array() function.
- Parameters:
x (dict) –
A dictionary containing all state information. Each value should be a tensor in which the first n-dimensions match the shape of the StateArray. The following keys are standard:
- observation (torch.tensor) (required):
A tensor representing the observations for each state
- reward (torch.FloatTensor) (optional):
A tensor representing rewards for the previous state/action pairs
- done (torch.BoolTensors) (optional):
A tensor representing whether each state is terminal
- mask (torch.FloatTensor) (optional):
A tensor representing the mask for each state.
device (string) – The torch device on which component tensors are stored.
- apply_mask(tensor)
Applies the mask to the given tensor, generally to prevent backpropagation through terminal states.
- Parameters:
tensor (torch.tensor) – The tensor to apply the mask to.
- Returns:
A torch.tensor with the mask applied.
- as_input(key)
Gets the value for a given key and reshapes it to a batch-style tensor suitable as input to a pytorch module.
- Parameters:
key (string) – The component of the state to select.
- Returns:
A torch.tensor containing the value of the component with a batch dimension added.
- as_output(tensor)
Reshapes the output of a batch-style pytorch module to match the original shape of the state.
- Parameters:
tensor (torch.tensor) – The output of a batch-style pytorch module.
- Returns:
A torch.tensor containing the output in the appropriate shape.
- batch_execute(minibatch_size, fn)
execute in batches to reduce memory consumption
- classmethod cat(state_array_list, axis=0)
Concatenates along batch dimention
- property done
A boolean that is true if the state is a terminal state, and false otherwise.
- flatten()
Converts an n-dimensional StateArray to a 1-dimensional StateArray
- Returns:
A 1-dimensional StateArray
- property mask
A float that is 1. if the state is non-terminal, or 0. otherwise.
- property observation
A tensor containing the current observation.
- property reward
A float representing the reward for the previous state/action pair.
- property shape
The shape of the StateArray
- update(key, value)
Adds a key/value pair to the StateArray, or updates an existing key/value pair. The value should be a tensor whose first n-dimensions match the shape of the StateArray Note that this is NOT an in-place operation, but returns a StateArray.
- Parameters:
key (string) – The name of the state component to update.
value (any) – The value of the new state component.
- Returns:
A StateArray object with the given component added/updated.
- view(shape)
Analogous to torch.tensor.view(), returns a new StateArray object containing the same data but with a different shape.
- Returns:
A StateArray with the given shape