Skip to content

Add --namespace-packages to mypy for mlagents #3075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Dec 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ repos:

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.750
# Currently mypy may assert after logging one message. To get all the messages at once, change repo and rev to
# repo: https://github.com/chriselion/mypy
# rev: 3d0b6164a9487a6c5cf9d144110b86600fd85e25
# This is a fork with the assert disabled, although precommit has trouble installing it sometimes.
hooks:
- id: mypy
name: mypy-ml-agents
files: "ml-agents/.*"
args: [--ignore-missing-imports, --disallow-incomplete-defs]
args: [--ignore-missing-imports, --disallow-incomplete-defs, --namespace-packages]
- id: mypy
name: mypy-ml-agents-envs
files: "ml-agents-envs/.*"
Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/action_info.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import NamedTuple, Any, Dict, Optional
from typing import NamedTuple, Any, Dict
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


ActionInfoOutputs = Optional[Dict[str, Any]]
ActionInfoOutputs = Dict[str, Any]


class ActionInfo(NamedTuple):
Expand Down
4 changes: 4 additions & 0 deletions ml-agents/mlagents/trainers/demo_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,8 @@ def load_demonstration(
break
pos += next_pos
obs_decoded += 1
if not brain_params:
raise RuntimeError(
f"No BrainParameters found in demonstration file at {file_path}."
)
return brain_params, info_action_pairs, total_expected
9 changes: 4 additions & 5 deletions ml-agents/mlagents/trainers/env_manager.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from abc import ABC, abstractmethod
from typing import List, Dict, NamedTuple, Optional
from typing import List, Dict, NamedTuple
from mlagents.trainers.brain import AllBrainInfo, BrainParameters
from mlagents.trainers.policy import Policy
from mlagents.trainers.action_info import ActionInfo


class EnvironmentStep(NamedTuple):
previous_all_brain_info: Optional[AllBrainInfo]
previous_all_brain_info: AllBrainInfo
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

current_all_brain_info: AllBrainInfo
brain_name_to_action_info: Optional[Dict[str, ActionInfo]]
brain_name_to_action_info: Dict[str, ActionInfo]

def has_actions_for_brain(self, brain_name: str) -> bool:
return (
self.brain_name_to_action_info is not None
and brain_name in self.brain_name_to_action_info
brain_name in self.brain_name_to_action_info
and self.brain_name_to_action_info[brain_name].outputs is not None
)

Expand Down
7 changes: 4 additions & 3 deletions ml-agents/mlagents/trainers/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ def __init__(
self.check_param_keys()

if multi_gpu and len(get_devices()) > 1:
self.policy = MultiGpuPPOPolicy(
self.ppo_policy = MultiGpuPPOPolicy(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ervteng This is what we talked about offline

seed, brain, trainer_parameters, self.is_training, load
)
else:
self.policy = PPOPolicy(
self.ppo_policy = PPOPolicy(
seed, brain, trainer_parameters, self.is_training, load
)
self.policy = self.ppo_policy

for _reward_signal in self.policy.reward_signals.keys():
self.collected_rewards[_reward_signal] = {}
Expand Down Expand Up @@ -104,7 +105,7 @@ def process_experiences(
else:
bootstrapping_info = next_info
idx = l
value_next = self.policy.get_value_estimates(
value_next = self.ppo_policy.get_value_estimates(
bootstrapping_info,
idx,
next_info.local_done[l] and not next_info.max_reached[l],
Expand Down
12 changes: 8 additions & 4 deletions ml-agents/mlagents/trainers/sac/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def __init__(
if "save_replay_buffer" in trainer_parameters
else False
)
self.policy = SACPolicy(seed, brain, trainer_parameters, self.is_training, load)
self.sac_policy = SACPolicy(
seed, brain, trainer_parameters, self.is_training, load
)
self.policy = self.sac_policy

# Load the replay buffer if load
if load and self.checkpoint_replay_buffer:
Expand Down Expand Up @@ -293,8 +296,9 @@ def update_sac_policy(self) -> None:
for stat, stat_list in batch_update_stats.items():
self.stats[stat].append(np.mean(stat_list))

if self.policy.bc_module:
update_stats = self.policy.bc_module.update()
bc_module = self.sac_policy.bc_module
if bc_module:
update_stats = bc_module.update()
for stat, val in update_stats.items():
self.stats[stat].append(val)

Expand Down Expand Up @@ -325,7 +329,7 @@ def update_reward_signals(self) -> None:
self.trainer_parameters["batch_size"],
sequence_length=self.policy.sequence_length,
)
update_stats = self.policy.update_reward_signals(
update_stats = self.sac_policy.update_reward_signals(
reward_signal_minibatches, n_sequences
)
for stat_name, value in update_stats.items():
Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/simple_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, env: BaseEnv, float_prop_channel: FloatPropertiesChannel):
super().__init__()
self.shared_float_properties = float_prop_channel
self.env = env
self.previous_step: EnvironmentStep = EnvironmentStep(None, {}, None)
self.previous_step: EnvironmentStep = EnvironmentStep({}, {}, {})
self.previous_all_action_info: Dict[str, ActionInfo] = {}

def step(self) -> List[EnvironmentStep]:
Expand Down Expand Up @@ -51,7 +51,7 @@ def reset(
self.shared_float_properties.set_property(k, v)
self.env.reset()
all_brain_info = self._generate_all_brain_info()
self.previous_step = EnvironmentStep(None, all_brain_info, None)
self.previous_step = EnvironmentStep({}, all_brain_info, {})
return [self.previous_step]

@property
Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/subprocess_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, process: Process, worker_id: int, conn: Connection):
self.process = process
self.worker_id = worker_id
self.conn = conn
self.previous_step: EnvironmentStep = EnvironmentStep(None, {}, None)
self.previous_step: EnvironmentStep = EnvironmentStep({}, {}, {})
self.previous_all_action_info: Dict[str, ActionInfo] = {}
self.waiting = False

Expand Down Expand Up @@ -253,7 +253,7 @@ def reset(self, config: Optional[Dict] = None) -> List[EnvironmentStep]:
ew.send("reset", config)
# Next (synchronously) collect the reset observations from each worker in sequence
for ew in self.env_workers:
ew.previous_step = EnvironmentStep(None, ew.recv().payload, None)
ew.previous_step = EnvironmentStep({}, ew.recv().payload, {})
return list(map(lambda ew: ew.previous_step, self.env_workers))

@property
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/tests/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_take_action_returns_empty_with_no_agents():
policy = TFPolicy(test_seed, basic_mock_brain(), basic_params())
no_agent_brain_info = BrainInfo([], [], [], agents=[])
result = policy.get_action(no_agent_brain_info)
assert result == ActionInfo([], [], None)
assert result == ActionInfo([], [], {})


def test_take_action_returns_nones_on_missing_values():
Expand Down
3 changes: 2 additions & 1 deletion ml-agents/mlagents/trainers/tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, seed, brain, trainer_parameters):
self.brain = brain
self.use_recurrent = trainer_parameters["use_recurrent"]
self.memory_dict: Dict[str, np.ndarray] = {}
self.reward_signals: Dict[str, "RewardSignal"] = {}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used the string form here, because importing RewardSignal was leading to circular imports.

self.num_branches = len(self.brain.vector_action_space_size)
self.previous_action_dict: Dict[str, np.array] = {}
self.normalize = trainer_parameters.get("normalize", False)
Expand Down Expand Up @@ -126,7 +127,7 @@ def get_action(self, brain_info: BrainInfo) -> ActionInfo:
to be passed to add experiences
"""
if len(brain_info.agents) == 0:
return ActionInfo([], [], None)
return ActionInfo([], [], {})

agents_done = [
agent
Expand Down