Skip to content

Improve mypy coverage by adding --namespace-packages #3049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 11, 2019
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repos:
)$

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.740
rev: v0.750
hooks:
- id: mypy
name: mypy-ml-agents
Expand All @@ -21,11 +21,11 @@ repos:
files: "ml-agents-envs/.*"
# Exclude protobuf files and don't follow them when imported
exclude: ".*_pb2.py"
args: [--ignore-missing-imports, --disallow-incomplete-defs]
args: [--ignore-missing-imports, --disallow-incomplete-defs, --namespace-packages]
- id: mypy
name: mypy-gym-unity
files: "gym-unity/.*"
args: [--ignore-missing-imports, --disallow-incomplete-defs]
args: [--ignore-missing-imports, --disallow-incomplete-defs, --namespace-packages]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.4.0
Expand Down
7 changes: 6 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,9 @@ disable =
# E0401: Unable to import...
# E0611: No name '...' in module '...'
# need to look into these, probably namespace packages
E0401, E0611
E0401, E0611,

# This was causing false positives
# Appears to be https://github.com/PyCQA/pylint/issues/2981
W0201,

6 changes: 3 additions & 3 deletions ml-agents-envs/mlagents/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
"""

from abc import ABC, abstractmethod
from typing import List, NamedTuple, Tuple, Optional, Union, Dict, NewType
from typing import List, NamedTuple, Tuple, Optional, Union, Dict
import numpy as np
from enum import Enum

AgentId = NewType("AgentId", int)
AgentGroup = NewType("AgentGroup", str)
AgentId = int
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was going to be more trouble that it was worth. This still aliases the types so that it's a little clearer than the raw int/str types.

AgentGroup = str


class StepResult(NamedTuple):
Expand Down
13 changes: 7 additions & 6 deletions ml-agents-envs/mlagents/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
AgentId,
)
from mlagents.envs.timers import timed, hierarchical_timer
from .exception import (
from mlagents.envs.exception import (
UnityEnvironmentException,
UnityCommunicationException,
UnityActionException,
UnityTimeOutException,
)

from mlagents.envs.communicator_objects.command_pb2 import STEP, RESET
from mlagents.envs.rpc_utils import (
agent_group_spec_from_proto,
batched_step_result_from_proto,
Expand Down Expand Up @@ -371,8 +372,8 @@ def set_action_for_agent(
action = action.astype(expected_type)

if agent_group not in self._env_actions:
self._env_actions[agent_group] = self._empty_action(
spec, self._env_state[agent_group].n_agents()
self._env_actions[agent_group] = spec.create_empty_action(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bug.

self._env_state[agent_group].n_agents()
)
try:
index = np.where(self._env_state[agent_group].agent_id == agent_id)[0][0]
Expand Down Expand Up @@ -442,7 +443,7 @@ def _flatten(cls, arr: Any) -> List[float]:

@staticmethod
def _parse_side_channel_message(
side_channels: Dict[int, SideChannel], data: bytearray
side_channels: Dict[int, SideChannel], data: bytes
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bytes is immutable - these seemed better for processing inputs. I left outputs as bytearrays though.

) -> None:
offset = 0
while offset < len(data):
Expand Down Expand Up @@ -493,13 +494,13 @@ def _generate_step_input(
for i in range(n_agents):
action = AgentActionProto(vector_actions=vector_action[b][i])
rl_in.agent_actions[b].value.extend([action])
rl_in.command = 0
rl_in.command = STEP
rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels))
return self.wrap_unity_input(rl_in)

def _generate_reset_input(self) -> UnityInputProto:
rl_in = UnityRLInputProto()
rl_in.command = 1
rl_in.command = RESET
rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels))
return self.wrap_unity_input(rl_in)

Expand Down
3 changes: 2 additions & 1 deletion ml-agents-envs/mlagents/envs/mock_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
NONE as COMPRESSION_TYPE_NONE,
PNG as COMPRESSION_TYPE_PNG,
)
from mlagents.envs.communicator_objects.space_type_pb2 import discrete, continuous


class MockCommunicator(Communicator):
Expand Down Expand Up @@ -43,7 +44,7 @@ def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
bp = BrainParametersProto(
vector_action_size=[2],
vector_action_descriptions=["", ""],
vector_action_space_type=int(not self.is_discrete),
vector_action_space_type=discrete if self.is_discrete else continuous,
brain_name=self.brain_name,
is_training=True,
)
Expand Down
25 changes: 19 additions & 6 deletions ml-agents-envs/mlagents/envs/rpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import numpy as np
import io
from typing import List, Tuple
from typing import cast, List, Tuple, Union, Collection
from PIL import Image

logger = logging.getLogger("mlagents.envs")
Expand All @@ -26,9 +26,10 @@ def agent_group_spec_from_proto(
if brain_param_proto.vector_action_space_type == 0
else ActionType.CONTINUOUS
)
action_shape = None
if action_type == ActionType.CONTINUOUS:
action_shape = brain_param_proto.vector_action_size[0]
action_shape: Union[
int, Tuple[int, ...]
] = brain_param_proto.vector_action_size[0]
else:
action_shape = tuple(brain_param_proto.vector_action_size)
return AgentGroupSpec(observation_shape, action_type, action_shape)
Expand Down Expand Up @@ -57,7 +58,11 @@ def process_pixels(image_bytes: bytes, gray_scale: bool) -> np.ndarray:

@timed
def _process_visual_observation(
obs_index: int, shape: Tuple[int, int, int], agent_info_list: List[AgentInfoProto]
obs_index: int,
shape: Tuple[int, int, int],
agent_info_list: Collection[
AgentInfoProto
], # pylint: disable=unsubscriptable-object
) -> np.ndarray:
if len(agent_info_list) == 0:
return np.zeros((0, shape[0], shape[1], shape[2]), dtype=np.float32)
Expand All @@ -72,7 +77,11 @@ def _process_visual_observation(

@timed
def _process_vector_observation(
obs_index: int, shape: Tuple[int, ...], agent_info_list: List[AgentInfoProto]
obs_index: int,
shape: Tuple[int, ...],
agent_info_list: Collection[
AgentInfoProto
], # pylint: disable=unsubscriptable-object
) -> np.ndarray:
if len(agent_info_list) == 0:
return np.zeros((0, shape[0]), dtype=np.float32)
Expand Down Expand Up @@ -104,12 +113,16 @@ def _process_vector_observation(

@timed
def batched_step_result_from_proto(
agent_info_list: List[AgentInfoProto], group_spec: AgentGroupSpec
agent_info_list: Collection[
AgentInfoProto
], # pylint: disable=unsubscriptable-object
group_spec: AgentGroupSpec,
) -> BatchedStepResult:
obs_list: List[np.ndarray] = []
for obs_index, obs_shape in enumerate(group_spec.observation_shapes):
is_visual = len(obs_shape) == 3
if is_visual:
obs_shape = cast(Tuple[int, int, int], obs_shape)
obs_list += [
_process_visual_observation(obs_index, obs_shape, agent_info_list)
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class EngineConfigurationChannel(SideChannel):
def channel_type(self) -> int:
return SideChannelType.EngineSettings

def on_message_received(self, data: bytearray) -> None:
def on_message_received(self, data: bytes) -> None:
"""
Is called by the environment to the side channel. Can be called
multiple times per step if multiple messages are meant for that
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from mlagents.envs.side_channel.side_channel import SideChannel, SideChannelType
import struct
from typing import Tuple, Optional, List
from typing import Dict, Tuple, Optional, List


class FloatPropertiesChannel(SideChannel):
Expand All @@ -10,15 +10,15 @@ class FloatPropertiesChannel(SideChannel):
set_property, get_property and list_properties.
"""

def __init__(self):
self._float_properties = {}
def __init__(self) -> None:
self._float_properties: Dict[str, float] = {}
super().__init__()

@property
def channel_type(self) -> int:
return SideChannelType.FloatProperties

def on_message_received(self, data: bytearray) -> None:
def on_message_received(self, data: bytes) -> None:
"""
Is called by the environment to the side channel. Can be called
multiple times per step if multiple messages are meant for that
Expand Down Expand Up @@ -52,7 +52,14 @@ def list_properties(self) -> List[str]:
Returns a list of all the string identifiers of the properties
currently present in the Unity Environment.
"""
return self._float_properties.keys()
return list(self._float_properties.keys())

def get_property_dict(self) -> Dict[str, float]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this since the usage pattern was always list_properties then get_property on each.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need get and list properties then?
Also, it is not clear that modifying this dict will not modify the unity environment...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we need list. I think get() is still useful.

Would get_property_dict_copy() make it clearer? Unfortunately there's no clean way to return a read-only dictionary.

"""
Returns a copy of the float properties.
:return:
"""
return dict(self._float_properties)

@staticmethod
def serialize_float_prop(key: str, value: float) -> bytearray:
Expand All @@ -64,7 +71,7 @@ def serialize_float_prop(key: str, value: float) -> bytearray:
return result

@staticmethod
def deserialize_float_prop(data: bytearray) -> Tuple[str, float]:
def deserialize_float_prop(data: bytes) -> Tuple[str, float]:
offset = 0
encoded_key_len = struct.unpack_from("<i", data, offset)[0]
offset = offset + 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, channel_id=0):
def channel_type(self) -> int:
return SideChannelType.RawBytesChannelStart + self._channel_id

def on_message_received(self, data: bytearray) -> None:
def on_message_received(self, data: bytes) -> None:
"""
Is called by the environment to the side channel. Can be called
multiple times per step if multiple messages are meant for that
Expand Down
2 changes: 1 addition & 1 deletion ml-agents-envs/mlagents/envs/side_channel/side_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def queue_message_to_send(self, data: bytearray) -> None:
self.message_queue.append(data)

@abstractmethod
def on_message_received(self, data: bytearray) -> None:
def on_message_received(self, data: bytes) -> None:
"""
Is called by the environment to the side channel. Can be called
multiple times per step if multiple messages are meant for that
Expand Down
10 changes: 7 additions & 3 deletions ml-agents-envs/mlagents/envs/tests/test_rpc_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import List, Tuple
from mlagents.envs.communicator_objects.agent_info_pb2 import AgentInfoProto
from mlagents.envs.communicator_objects.observation_pb2 import ObservationProto
from mlagents.envs.communicator_objects.observation_pb2 import (
ObservationProto,
NONE,
PNG,
)
from mlagents.envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
import numpy as np
from mlagents.envs.base_env import AgentGroupSpec, ActionType
Expand Down Expand Up @@ -30,7 +34,7 @@ def generate_list_agent_proto(
for obs_index in range(len(shape)):
obs_proto = ObservationProto()
obs_proto.shape.extend(list(shape[obs_index]))
obs_proto.compression_type = 0
obs_proto.compression_type = NONE
obs_proto.float_data.data.extend([0.1] * np.prod(shape[obs_index]))
obs_proto_list.append(obs_proto)
ap.observations.extend(obs_proto_list)
Expand All @@ -49,7 +53,7 @@ def generate_compressed_data(in_array: np.ndarray) -> bytes:
def generate_compressed_proto_obs(in_array: np.ndarray) -> ObservationProto:
obs_proto = ObservationProto()
obs_proto.compressed_data = generate_compressed_data(in_array)
obs_proto.compression_type = 1
obs_proto.compression_type = PNG
obs_proto.shape.extend(in_array.shape)
return obs_proto

Expand Down
3 changes: 3 additions & 0 deletions ml-agents-envs/mlagents/envs/tests/test_side_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def test_float_properties():
val = sender.get_property("prop1")
assert val == 1.0

assert receiver.get_property_dict() == {"prop1": 1.0, "prop2": 2.0}
assert receiver.get_property_dict() == sender.get_property_dict()


def test_raw_bytes():
sender = RawBytesChannel()
Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/agent_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Union

from mlagents.trainers.buffer import AgentBuffer, BufferException

Expand Down Expand Up @@ -28,7 +28,7 @@ def reset_local_buffers(self) -> None:
def append_to_update_buffer(
self,
update_buffer: AgentBuffer,
agent_id: str,
agent_id: Union[int, str],
key_list: List[str] = None,
batch_size: int = None,
training_length: int = None,
Expand Down
16 changes: 12 additions & 4 deletions ml-agents/mlagents/trainers/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mlagents.envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
from mlagents.envs.communicator_objects.observation_pb2 import ObservationProto
from mlagents.envs.timers import hierarchical_timer, timed
from typing import Dict, List, NamedTuple
from typing import Dict, List, NamedTuple, Collection
from PIL import Image

logger = logging.getLogger("mlagents.envs")
Expand Down Expand Up @@ -144,7 +144,9 @@ def process_pixels(image_bytes: bytes, gray_scale: bool) -> np.ndarray:
@timed
def from_agent_proto(
worker_id: int,
agent_info_list: List[AgentInfoProto],
agent_info_list: Collection[
AgentInfoProto
], # pylint: disable=unsubscriptable-object
brain_params: BrainParameters,
) -> "BrainInfo":
"""
Expand Down Expand Up @@ -186,7 +188,10 @@ def from_agent_proto(

@staticmethod
def _process_visual_observations(
brain_params: BrainParameters, agent_info_list: List[AgentInfoProto]
brain_params: BrainParameters,
agent_info_list: Collection[
AgentInfoProto
], # pylint: disable=unsubscriptable-object
) -> List[np.ndarray]:

visual_observation_protos: List[List[ObservationProto]] = []
Expand Down Expand Up @@ -215,7 +220,10 @@ def _process_visual_observations(

@staticmethod
def _process_vector_observations(
brain_params: BrainParameters, agent_info_list: List[AgentInfoProto]
brain_params: BrainParameters,
agent_info_list: Collection[
AgentInfoProto
], # pylint: disable=unsubscriptable-object
) -> np.ndarray:
if len(agent_info_list) == 0:
vector_obs = np.zeros(
Expand Down
4 changes: 4 additions & 0 deletions ml-agents/mlagents/trainers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def __init__(
self.running_variance: Optional[tf.Variable] = None
self.update_normalization: Optional[tf.Operation] = None
self.value: Optional[tf.Tensor] = None
self.all_log_probs: Optional[tf.Tensor] = None
self.output: Optional[tf.Tensor] = None
self.selected_actions: Optional[tf.Tensor] = None
self.action_holder: Optional[tf.Tensor] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you intend to add those? It does not seem related to mypy coverage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there were places where mypy was complaining (correctly) that e.g. all_log_probs wasn't a member of LearningModel. In practice it was OK because they were always defined in PPOModel or SACModel. This just defines them in the base class too.


@staticmethod
def create_global_steps():
Expand Down
6 changes: 6 additions & 0 deletions ml-agents/mlagents/trainers/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,9 @@ def add_rewards_outputs(
raise UnityTrainerException(
"The add_rewards_outputs method was not implemented."
)

def advance(self):
"""
Eventually logic from TrainerController.advance() will live here.
"""
self.clear_update_buffer()
Loading