Skip to content

Convert List[np.ndarray] to np.ndarray before using torch.as_tensor #4183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion ml-agents/mlagents/trainers/models_torch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from enum import Enum
from typing import Callable, NamedTuple
from typing import Callable, NamedTuple, List, Optional

import torch
from torch import nn
import numpy as np

from mlagents.trainers.distributions_torch import (
GaussianDistribution,
Expand All @@ -19,6 +20,16 @@
EPSILON = 1e-7


def list_to_tensor(
ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
"""
Converts a list of numpy arrays into a tensor. MUCH faster than
calling as_tensor on the list directly.
"""
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype)


class ActionType(Enum):
DISCRETE = "discrete"
CONTINUOUS = "continuous"
Expand Down
7 changes: 4 additions & 3 deletions ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mlagents.trainers.optimizer import Optimizer
from mlagents.trainers.settings import TrainerSettings, RewardSignalType
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.models_torch import list_to_tensor


class TorchOptimizer(Optimizer): # pylint: disable=W0223
Expand Down Expand Up @@ -79,21 +80,21 @@ def get_value_estimates(
def get_trajectory_value_estimates(
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]:
vector_obs = [torch.as_tensor(batch["vector_obs"])]
vector_obs = [list_to_tensor(batch["vector_obs"])]
if self.policy.use_vis_obs:
visual_obs = []
for idx, _ in enumerate(
self.policy.actor_critic.network_body.visual_encoders
):
visual_ob = torch.as_tensor(batch["visual_obs%d" % idx])
visual_ob = list_to_tensor(batch["visual_obs%d" % idx])
visual_obs.append(visual_ob)
else:
visual_obs = []

memory = torch.zeros([1, len(vector_obs[0]), self.policy.m_size])

next_obs = np.concatenate(next_obs, axis=-1)
next_obs = [torch.as_tensor(next_obs).unsqueeze(0)]
next_obs = [list_to_tensor(next_obs).unsqueeze(0)]
next_memory = torch.zeros([1, 1, self.policy.m_size])

value_estimates, mean_value = self.policy.actor_critic.critic_pass(
Expand Down
1 change: 0 additions & 1 deletion ml-agents/mlagents/trainers/policy/nn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def __init__(
seed: int,
brain: BrainParameters,
trainer_settings: TrainerSettings,
is_training: bool,
model_path: str,
load: bool,
tanh_squash: bool = False,
Expand Down
23 changes: 12 additions & 11 deletions ml-agents/mlagents/trainers/ppo/optimizer_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents.trainers.settings import TrainerSettings, PPOSettings
from mlagents.trainers.models_torch import list_to_tensor


class TorchPPOOptimizer(TorchOptimizer):
Expand Down Expand Up @@ -91,18 +92,18 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
returns = {}
old_values = {}
for name in self.reward_signals:
old_values[name] = torch.as_tensor(batch["{}_value_estimates".format(name)])
returns[name] = torch.as_tensor(batch["{}_returns".format(name)])
old_values[name] = list_to_tensor(batch["{}_value_estimates".format(name)])
returns[name] = list_to_tensor(batch["{}_returns".format(name)])

vec_obs = [torch.as_tensor(batch["vector_obs"])]
act_masks = torch.as_tensor(batch["action_mask"])
vec_obs = [list_to_tensor(batch["vector_obs"])]
act_masks = list_to_tensor(batch["action_mask"])
if self.policy.use_continuous_act:
actions = torch.as_tensor(batch["actions"]).unsqueeze(-1)
actions = list_to_tensor(batch["actions"]).unsqueeze(-1)
else:
actions = torch.as_tensor(batch["actions"], dtype=torch.long)
actions = list_to_tensor(batch["actions"], dtype=torch.long)

memories = [
torch.as_tensor(batch["memory"][i])
list_to_tensor(batch["memory"][i])
for i in range(0, len(batch["memory"]), self.policy.sequence_length)
]
if len(memories) > 0:
Expand All @@ -113,7 +114,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
for idx, _ in enumerate(
self.policy.actor_critic.network_body.visual_encoders
):
vis_ob = torch.as_tensor(batch["visual_obs%d" % idx])
vis_ob = list_to_tensor(batch["visual_obs%d" % idx])
vis_obs.append(vis_ob)
else:
vis_obs = []
Expand All @@ -127,10 +128,10 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
)
value_loss = self.ppo_value_loss(values, old_values, returns)
policy_loss = self.ppo_policy_loss(
torch.as_tensor(batch["advantages"]),
list_to_tensor(batch["advantages"]),
log_probs,
torch.as_tensor(batch["action_probs"]),
torch.as_tensor(batch["masks"], dtype=torch.int32),
list_to_tensor(batch["action_probs"]),
list_to_tensor(batch["masks"], dtype=torch.int32),
)
loss = (
policy_loss
Expand Down