-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Improve mypy coverage by adding --namespace-packages #3049
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8bd70f8
5066ec6
4a42120
5939823
821639e
ea216c7
b5ebe67
f4af166
b80ca34
a930c07
4e97a7c
36e8ec8
901b791
9e34716
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,13 +16,14 @@ | |
AgentId, | ||
) | ||
from mlagents.envs.timers import timed, hierarchical_timer | ||
from .exception import ( | ||
from mlagents.envs.exception import ( | ||
UnityEnvironmentException, | ||
UnityCommunicationException, | ||
UnityActionException, | ||
UnityTimeOutException, | ||
) | ||
|
||
from mlagents.envs.communicator_objects.command_pb2 import STEP, RESET | ||
from mlagents.envs.rpc_utils import ( | ||
agent_group_spec_from_proto, | ||
batched_step_result_from_proto, | ||
|
@@ -371,8 +372,8 @@ def set_action_for_agent( | |
action = action.astype(expected_type) | ||
|
||
if agent_group not in self._env_actions: | ||
self._env_actions[agent_group] = self._empty_action( | ||
spec, self._env_state[agent_group].n_agents() | ||
self._env_actions[agent_group] = spec.create_empty_action( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was a bug. |
||
self._env_state[agent_group].n_agents() | ||
) | ||
try: | ||
index = np.where(self._env_state[agent_group].agent_id == agent_id)[0][0] | ||
|
@@ -442,7 +443,7 @@ def _flatten(cls, arr: Any) -> List[float]: | |
|
||
@staticmethod | ||
def _parse_side_channel_message( | ||
side_channels: Dict[int, SideChannel], data: bytearray | ||
side_channels: Dict[int, SideChannel], data: bytes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
) -> None: | ||
offset = 0 | ||
while offset < len(data): | ||
|
@@ -493,13 +494,13 @@ def _generate_step_input( | |
for i in range(n_agents): | ||
action = AgentActionProto(vector_actions=vector_action[b][i]) | ||
rl_in.agent_actions[b].value.extend([action]) | ||
rl_in.command = 0 | ||
rl_in.command = STEP | ||
rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels)) | ||
return self.wrap_unity_input(rl_in) | ||
|
||
def _generate_reset_input(self) -> UnityInputProto: | ||
rl_in = UnityRLInputProto() | ||
rl_in.command = 1 | ||
rl_in.command = RESET | ||
rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels)) | ||
return self.wrap_unity_input(rl_in) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,6 +88,10 @@ def __init__( | |
self.running_variance: Optional[tf.Variable] = None | ||
self.update_normalization: Optional[tf.Operation] = None | ||
self.value: Optional[tf.Tensor] = None | ||
self.all_log_probs: Optional[tf.Tensor] = None | ||
self.output: Optional[tf.Tensor] = None | ||
self.selected_actions: Optional[tf.Tensor] = None | ||
self.action_holder: Optional[tf.Tensor] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you intend to add those? It does not seem related to mypy coverage. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, there were places where mypy was complaining (correctly) that e.g. all_log_probs wasn't a member of LearningModel. In practice it was OK because they were always defined in PPOModel or SACModel. This just defines them in the base class too. |
||
|
||
@staticmethod | ||
def create_global_steps(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was going to be more trouble that it was worth. This still aliases the types so that it's a little clearer than the raw int/str types.