Skip to content

Commit 0d4836d

Browse files
author
Chris Elion
authored
Improve mypy coverage by adding --namespace-packages (#3049)
1 parent 1a240bb commit 0d4836d

23 files changed

+102
-54
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ repos:
1010
)$
1111
1212
- repo: https://github.com/pre-commit/mirrors-mypy
13-
rev: v0.740
13+
rev: v0.750
1414
hooks:
1515
- id: mypy
1616
name: mypy-ml-agents
@@ -21,11 +21,11 @@ repos:
2121
files: "ml-agents-envs/.*"
2222
# Exclude protobuf files and don't follow them when imported
2323
exclude: ".*_pb2.py"
24-
args: [--ignore-missing-imports, --disallow-incomplete-defs]
24+
args: [--ignore-missing-imports, --disallow-incomplete-defs, --namespace-packages]
2525
- id: mypy
2626
name: mypy-gym-unity
2727
files: "gym-unity/.*"
28-
args: [--ignore-missing-imports, --disallow-incomplete-defs]
28+
args: [--ignore-missing-imports, --disallow-incomplete-defs, --namespace-packages]
2929

3030
- repo: https://github.com/pre-commit/pre-commit-hooks
3131
rev: v2.4.0

.pylintrc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,9 @@ disable =
4040
# E0401: Unable to import...
4141
# E0611: No name '...' in module '...'
4242
# need to look into these, probably namespace packages
43-
E0401, E0611
43+
E0401, E0611,
44+
45+
# This was causing false positives
46+
# Appears to be https://github.com/PyCQA/pylint/issues/2981
47+
W0201,
48+

ml-agents-envs/mlagents/envs/base_env.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
"""
1919

2020
from abc import ABC, abstractmethod
21-
from typing import List, NamedTuple, Tuple, Optional, Union, Dict, NewType
21+
from typing import List, NamedTuple, Tuple, Optional, Union, Dict
2222
import numpy as np
2323
from enum import Enum
2424

25-
AgentId = NewType("AgentId", int)
26-
AgentGroup = NewType("AgentGroup", str)
25+
AgentId = int
26+
AgentGroup = str
2727

2828

2929
class StepResult(NamedTuple):

ml-agents-envs/mlagents/envs/environment.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
AgentId,
1717
)
1818
from mlagents.envs.timers import timed, hierarchical_timer
19-
from .exception import (
19+
from mlagents.envs.exception import (
2020
UnityEnvironmentException,
2121
UnityCommunicationException,
2222
UnityActionException,
2323
UnityTimeOutException,
2424
)
2525

26+
from mlagents.envs.communicator_objects.command_pb2 import STEP, RESET
2627
from mlagents.envs.rpc_utils import (
2728
agent_group_spec_from_proto,
2829
batched_step_result_from_proto,
@@ -371,8 +372,8 @@ def set_action_for_agent(
371372
action = action.astype(expected_type)
372373

373374
if agent_group not in self._env_actions:
374-
self._env_actions[agent_group] = self._empty_action(
375-
spec, self._env_state[agent_group].n_agents()
375+
self._env_actions[agent_group] = spec.create_empty_action(
376+
self._env_state[agent_group].n_agents()
376377
)
377378
try:
378379
index = np.where(self._env_state[agent_group].agent_id == agent_id)[0][0]
@@ -442,7 +443,7 @@ def _flatten(cls, arr: Any) -> List[float]:
442443

443444
@staticmethod
444445
def _parse_side_channel_message(
445-
side_channels: Dict[int, SideChannel], data: bytearray
446+
side_channels: Dict[int, SideChannel], data: bytes
446447
) -> None:
447448
offset = 0
448449
while offset < len(data):
@@ -493,13 +494,13 @@ def _generate_step_input(
493494
for i in range(n_agents):
494495
action = AgentActionProto(vector_actions=vector_action[b][i])
495496
rl_in.agent_actions[b].value.extend([action])
496-
rl_in.command = 0
497+
rl_in.command = STEP
497498
rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels))
498499
return self.wrap_unity_input(rl_in)
499500

500501
def _generate_reset_input(self) -> UnityInputProto:
501502
rl_in = UnityRLInputProto()
502-
rl_in.command = 1
503+
rl_in.command = RESET
503504
rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels))
504505
return self.wrap_unity_input(rl_in)
505506

ml-agents-envs/mlagents/envs/mock_communicator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
NONE as COMPRESSION_TYPE_NONE,
1414
PNG as COMPRESSION_TYPE_PNG,
1515
)
16+
from mlagents.envs.communicator_objects.space_type_pb2 import discrete, continuous
1617

1718

1819
class MockCommunicator(Communicator):
@@ -43,7 +44,7 @@ def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
4344
bp = BrainParametersProto(
4445
vector_action_size=[2],
4546
vector_action_descriptions=["", ""],
46-
vector_action_space_type=int(not self.is_discrete),
47+
vector_action_space_type=discrete if self.is_discrete else continuous,
4748
brain_name=self.brain_name,
4849
is_training=True,
4950
)

ml-agents-envs/mlagents/envs/rpc_utils.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import numpy as np
77
import io
8-
from typing import List, Tuple
8+
from typing import cast, List, Tuple, Union, Collection
99
from PIL import Image
1010

1111
logger = logging.getLogger("mlagents.envs")
@@ -26,9 +26,10 @@ def agent_group_spec_from_proto(
2626
if brain_param_proto.vector_action_space_type == 0
2727
else ActionType.CONTINUOUS
2828
)
29-
action_shape = None
3029
if action_type == ActionType.CONTINUOUS:
31-
action_shape = brain_param_proto.vector_action_size[0]
30+
action_shape: Union[
31+
int, Tuple[int, ...]
32+
] = brain_param_proto.vector_action_size[0]
3233
else:
3334
action_shape = tuple(brain_param_proto.vector_action_size)
3435
return AgentGroupSpec(observation_shape, action_type, action_shape)
@@ -57,7 +58,11 @@ def process_pixels(image_bytes: bytes, gray_scale: bool) -> np.ndarray:
5758

5859
@timed
5960
def _process_visual_observation(
60-
obs_index: int, shape: Tuple[int, int, int], agent_info_list: List[AgentInfoProto]
61+
obs_index: int,
62+
shape: Tuple[int, int, int],
63+
agent_info_list: Collection[
64+
AgentInfoProto
65+
], # pylint: disable=unsubscriptable-object
6166
) -> np.ndarray:
6267
if len(agent_info_list) == 0:
6368
return np.zeros((0, shape[0], shape[1], shape[2]), dtype=np.float32)
@@ -72,7 +77,11 @@ def _process_visual_observation(
7277

7378
@timed
7479
def _process_vector_observation(
75-
obs_index: int, shape: Tuple[int, ...], agent_info_list: List[AgentInfoProto]
80+
obs_index: int,
81+
shape: Tuple[int, ...],
82+
agent_info_list: Collection[
83+
AgentInfoProto
84+
], # pylint: disable=unsubscriptable-object
7685
) -> np.ndarray:
7786
if len(agent_info_list) == 0:
7887
return np.zeros((0, shape[0]), dtype=np.float32)
@@ -104,12 +113,16 @@ def _process_vector_observation(
104113

105114
@timed
106115
def batched_step_result_from_proto(
107-
agent_info_list: List[AgentInfoProto], group_spec: AgentGroupSpec
116+
agent_info_list: Collection[
117+
AgentInfoProto
118+
], # pylint: disable=unsubscriptable-object
119+
group_spec: AgentGroupSpec,
108120
) -> BatchedStepResult:
109121
obs_list: List[np.ndarray] = []
110122
for obs_index, obs_shape in enumerate(group_spec.observation_shapes):
111123
is_visual = len(obs_shape) == 3
112124
if is_visual:
125+
obs_shape = cast(Tuple[int, int, int], obs_shape)
113126
obs_list += [
114127
_process_visual_observation(obs_index, obs_shape, agent_info_list)
115128
]

ml-agents-envs/mlagents/envs/side_channel/engine_configuration_channel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class EngineConfigurationChannel(SideChannel):
3131
def channel_type(self) -> int:
3232
return SideChannelType.EngineSettings
3333

34-
def on_message_received(self, data: bytearray) -> None:
34+
def on_message_received(self, data: bytes) -> None:
3535
"""
3636
Is called by the environment to the side channel. Can be called
3737
multiple times per step if multiple messages are meant for that

ml-agents-envs/mlagents/envs/side_channel/float_properties_channel.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from mlagents.envs.side_channel.side_channel import SideChannel, SideChannelType
22
import struct
3-
from typing import Tuple, Optional, List
3+
from typing import Dict, Tuple, Optional, List
44

55

66
class FloatPropertiesChannel(SideChannel):
@@ -10,15 +10,15 @@ class FloatPropertiesChannel(SideChannel):
1010
set_property, get_property and list_properties.
1111
"""
1212

13-
def __init__(self):
14-
self._float_properties = {}
13+
def __init__(self) -> None:
14+
self._float_properties: Dict[str, float] = {}
1515
super().__init__()
1616

1717
@property
1818
def channel_type(self) -> int:
1919
return SideChannelType.FloatProperties
2020

21-
def on_message_received(self, data: bytearray) -> None:
21+
def on_message_received(self, data: bytes) -> None:
2222
"""
2323
Is called by the environment to the side channel. Can be called
2424
multiple times per step if multiple messages are meant for that
@@ -52,7 +52,14 @@ def list_properties(self) -> List[str]:
5252
Returns a list of all the string identifiers of the properties
5353
currently present in the Unity Environment.
5454
"""
55-
return self._float_properties.keys()
55+
return list(self._float_properties.keys())
56+
57+
def get_property_dict_copy(self) -> Dict[str, float]:
58+
"""
59+
Returns a copy of the float properties.
60+
:return:
61+
"""
62+
return dict(self._float_properties)
5663

5764
@staticmethod
5865
def serialize_float_prop(key: str, value: float) -> bytearray:
@@ -64,7 +71,7 @@ def serialize_float_prop(key: str, value: float) -> bytearray:
6471
return result
6572

6673
@staticmethod
67-
def deserialize_float_prop(data: bytearray) -> Tuple[str, float]:
74+
def deserialize_float_prop(data: bytes) -> Tuple[str, float]:
6875
offset = 0
6976
encoded_key_len = struct.unpack_from("<i", data, offset)[0]
7077
offset = offset + 4

ml-agents-envs/mlagents/envs/side_channel/raw_bytes_channel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, channel_id=0):
1717
def channel_type(self) -> int:
1818
return SideChannelType.RawBytesChannelStart + self._channel_id
1919

20-
def on_message_received(self, data: bytearray) -> None:
20+
def on_message_received(self, data: bytes) -> None:
2121
"""
2222
Is called by the environment to the side channel. Can be called
2323
multiple times per step if multiple messages are meant for that

ml-agents-envs/mlagents/envs/side_channel/side_channel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def queue_message_to_send(self, data: bytearray) -> None:
3333
self.message_queue.append(data)
3434

3535
@abstractmethod
36-
def on_message_received(self, data: bytearray) -> None:
36+
def on_message_received(self, data: bytes) -> None:
3737
"""
3838
Is called by the environment to the side channel. Can be called
3939
multiple times per step if multiple messages are meant for that

ml-agents-envs/mlagents/envs/tests/test_rpc_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from typing import List, Tuple
22
from mlagents.envs.communicator_objects.agent_info_pb2 import AgentInfoProto
3-
from mlagents.envs.communicator_objects.observation_pb2 import ObservationProto
3+
from mlagents.envs.communicator_objects.observation_pb2 import (
4+
ObservationProto,
5+
NONE,
6+
PNG,
7+
)
48
from mlagents.envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
59
import numpy as np
610
from mlagents.envs.base_env import AgentGroupSpec, ActionType
@@ -30,7 +34,7 @@ def generate_list_agent_proto(
3034
for obs_index in range(len(shape)):
3135
obs_proto = ObservationProto()
3236
obs_proto.shape.extend(list(shape[obs_index]))
33-
obs_proto.compression_type = 0
37+
obs_proto.compression_type = NONE
3438
obs_proto.float_data.data.extend([0.1] * np.prod(shape[obs_index]))
3539
obs_proto_list.append(obs_proto)
3640
ap.observations.extend(obs_proto_list)
@@ -49,7 +53,7 @@ def generate_compressed_data(in_array: np.ndarray) -> bytes:
4953
def generate_compressed_proto_obs(in_array: np.ndarray) -> ObservationProto:
5054
obs_proto = ObservationProto()
5155
obs_proto.compressed_data = generate_compressed_data(in_array)
52-
obs_proto.compression_type = 1
56+
obs_proto.compression_type = PNG
5357
obs_proto.shape.extend(in_array.shape)
5458
return obs_proto
5559

ml-agents-envs/mlagents/envs/tests/test_side_channel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def test_float_properties():
6969
val = sender.get_property("prop1")
7070
assert val == 1.0
7171

72+
assert receiver.get_property_dict_copy() == {"prop1": 1.0, "prop2": 2.0}
73+
assert receiver.get_property_dict_copy() == sender.get_property_dict_copy()
74+
7275

7376
def test_raw_bytes():
7477
sender = RawBytesChannel()

ml-agents/mlagents/trainers/agent_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Union
22

33
from mlagents.trainers.buffer import AgentBuffer, BufferException
44

@@ -28,7 +28,7 @@ def reset_local_buffers(self) -> None:
2828
def append_to_update_buffer(
2929
self,
3030
update_buffer: AgentBuffer,
31-
agent_id: str,
31+
agent_id: Union[int, str],
3232
key_list: List[str] = None,
3333
batch_size: int = None,
3434
training_length: int = None,

ml-agents/mlagents/trainers/brain.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mlagents.envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
77
from mlagents.envs.communicator_objects.observation_pb2 import ObservationProto
88
from mlagents.envs.timers import hierarchical_timer, timed
9-
from typing import Dict, List, NamedTuple
9+
from typing import Dict, List, NamedTuple, Collection
1010
from PIL import Image
1111

1212
logger = logging.getLogger("mlagents.envs")
@@ -144,7 +144,9 @@ def process_pixels(image_bytes: bytes, gray_scale: bool) -> np.ndarray:
144144
@timed
145145
def from_agent_proto(
146146
worker_id: int,
147-
agent_info_list: List[AgentInfoProto],
147+
agent_info_list: Collection[
148+
AgentInfoProto
149+
], # pylint: disable=unsubscriptable-object
148150
brain_params: BrainParameters,
149151
) -> "BrainInfo":
150152
"""
@@ -186,7 +188,10 @@ def from_agent_proto(
186188

187189
@staticmethod
188190
def _process_visual_observations(
189-
brain_params: BrainParameters, agent_info_list: List[AgentInfoProto]
191+
brain_params: BrainParameters,
192+
agent_info_list: Collection[
193+
AgentInfoProto
194+
], # pylint: disable=unsubscriptable-object
190195
) -> List[np.ndarray]:
191196

192197
visual_observation_protos: List[List[ObservationProto]] = []
@@ -215,7 +220,10 @@ def _process_visual_observations(
215220

216221
@staticmethod
217222
def _process_vector_observations(
218-
brain_params: BrainParameters, agent_info_list: List[AgentInfoProto]
223+
brain_params: BrainParameters,
224+
agent_info_list: Collection[
225+
AgentInfoProto
226+
], # pylint: disable=unsubscriptable-object
219227
) -> np.ndarray:
220228
if len(agent_info_list) == 0:
221229
vector_obs = np.zeros(

ml-agents/mlagents/trainers/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ def __init__(
8888
self.running_variance: Optional[tf.Variable] = None
8989
self.update_normalization: Optional[tf.Operation] = None
9090
self.value: Optional[tf.Tensor] = None
91+
self.all_log_probs: Optional[tf.Tensor] = None
92+
self.output: Optional[tf.Tensor] = None
93+
self.selected_actions: Optional[tf.Tensor] = None
94+
self.action_holder: Optional[tf.Tensor] = None
9195

9296
@staticmethod
9397
def create_global_steps():

ml-agents/mlagents/trainers/rl_trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,9 @@ def add_rewards_outputs(
264264
raise UnityTrainerException(
265265
"The add_rewards_outputs method was not implemented."
266266
)
267+
268+
def advance(self):
269+
"""
270+
Eventually logic from TrainerController.advance() will live here.
271+
"""
272+
self.clear_update_buffer()

0 commit comments

Comments
 (0)