-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Improve mypy coverage by adding --namespace-packages #3049
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
8bd70f8
5066ec6
4a42120
5939823
821639e
ea216c7
b5ebe67
f4af166
b80ca34
a930c07
4e97a7c
36e8ec8
901b791
9e34716
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,13 +16,14 @@ | |
AgentId, | ||
) | ||
from mlagents.envs.timers import timed, hierarchical_timer | ||
from .exception import ( | ||
from mlagents.envs.exception import ( | ||
UnityEnvironmentException, | ||
UnityCommunicationException, | ||
UnityActionException, | ||
UnityTimeOutException, | ||
) | ||
|
||
from mlagents.envs.communicator_objects.command_pb2 import STEP, RESET | ||
from mlagents.envs.rpc_utils import ( | ||
agent_group_spec_from_proto, | ||
batched_step_result_from_proto, | ||
|
@@ -371,8 +372,8 @@ def set_action_for_agent( | |
action = action.astype(expected_type) | ||
|
||
if agent_group not in self._env_actions: | ||
self._env_actions[agent_group] = self._empty_action( | ||
spec, self._env_state[agent_group].n_agents() | ||
self._env_actions[agent_group] = spec.create_empty_action( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was a bug. |
||
self._env_state[agent_group].n_agents() | ||
) | ||
try: | ||
index = np.where(self._env_state[agent_group].agent_id == agent_id)[0][0] | ||
|
@@ -442,7 +443,7 @@ def _flatten(cls, arr: Any) -> List[float]: | |
|
||
@staticmethod | ||
def _parse_side_channel_message( | ||
side_channels: Dict[int, SideChannel], data: bytearray | ||
side_channels: Dict[int, SideChannel], data: bytes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
) -> None: | ||
offset = 0 | ||
while offset < len(data): | ||
|
@@ -493,13 +494,13 @@ def _generate_step_input( | |
for i in range(n_agents): | ||
action = AgentActionProto(vector_actions=vector_action[b][i]) | ||
rl_in.agent_actions[b].value.extend([action]) | ||
rl_in.command = 0 | ||
rl_in.command = STEP | ||
rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels)) | ||
return self.wrap_unity_input(rl_in) | ||
|
||
def _generate_reset_input(self) -> UnityInputProto: | ||
rl_in = UnityRLInputProto() | ||
rl_in.command = 1 | ||
rl_in.command = RESET | ||
rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels)) | ||
return self.wrap_unity_input(rl_in) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from mlagents.envs.side_channel.side_channel import SideChannel, SideChannelType | ||
import struct | ||
from typing import Tuple, Optional, List | ||
from typing import Dict, Tuple, Optional, List | ||
|
||
|
||
class FloatPropertiesChannel(SideChannel): | ||
|
@@ -10,15 +10,15 @@ class FloatPropertiesChannel(SideChannel): | |
set_property, get_property and list_properties. | ||
""" | ||
|
||
def __init__(self): | ||
self._float_properties = {} | ||
def __init__(self) -> None: | ||
self._float_properties: Dict[str, float] = {} | ||
super().__init__() | ||
|
||
@property | ||
def channel_type(self) -> int: | ||
return SideChannelType.FloatProperties | ||
|
||
def on_message_received(self, data: bytearray) -> None: | ||
def on_message_received(self, data: bytes) -> None: | ||
""" | ||
Is called by the environment to the side channel. Can be called | ||
multiple times per step if multiple messages are meant for that | ||
|
@@ -52,7 +52,14 @@ def list_properties(self) -> List[str]: | |
Returns a list of all the string identifiers of the properties | ||
currently present in the Unity Environment. | ||
""" | ||
return self._float_properties.keys() | ||
return list(self._float_properties.keys()) | ||
|
||
def get_property_dict(self) -> Dict[str, float]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add this since the usage pattern was always list_properties then get_property on each. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still need get and list properties then? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure we need list. I think get() is still useful. Would |
||
""" | ||
Returns a copy of the float properties. | ||
:return: | ||
""" | ||
return dict(self._float_properties) | ||
|
||
@staticmethod | ||
def serialize_float_prop(key: str, value: float) -> bytearray: | ||
|
@@ -64,7 +71,7 @@ def serialize_float_prop(key: str, value: float) -> bytearray: | |
return result | ||
|
||
@staticmethod | ||
def deserialize_float_prop(data: bytearray) -> Tuple[str, float]: | ||
def deserialize_float_prop(data: bytes) -> Tuple[str, float]: | ||
offset = 0 | ||
encoded_key_len = struct.unpack_from("<i", data, offset)[0] | ||
offset = offset + 4 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,6 +88,10 @@ def __init__( | |
self.running_variance: Optional[tf.Variable] = None | ||
self.update_normalization: Optional[tf.Operation] = None | ||
self.value: Optional[tf.Tensor] = None | ||
self.all_log_probs: Optional[tf.Tensor] = None | ||
self.output: Optional[tf.Tensor] = None | ||
self.selected_actions: Optional[tf.Tensor] = None | ||
self.action_holder: Optional[tf.Tensor] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you intend to add those? It does not seem related to mypy coverage. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, there were places where mypy was complaining (correctly) that e.g. all_log_probs wasn't a member of LearningModel. In practice it was OK because they were always defined in PPOModel or SACModel. This just defines them in the base class too. |
||
|
||
@staticmethod | ||
def create_global_steps(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was going to be more trouble that it was worth. This still aliases the types so that it's a little clearer than the raw int/str types.