diff --git a/ml-agents/mlagents/trainers/torch/model_serialization.py b/ml-agents/mlagents/trainers/torch/model_serialization.py index dffc4677ed..ca2745f402 100644 --- a/ml-agents/mlagents/trainers/torch/model_serialization.py +++ b/ml-agents/mlagents/trainers/torch/model_serialization.py @@ -32,16 +32,8 @@ def __init__(self, policy): + ["action_masks", "memories"] ) - if self.policy.use_continuous_act: - action_name = "action" - action_prob_name = "action_probs" - else: - action_name = "action_unused" - action_prob_name = "action" - self.output_names = [ - action_name, - action_prob_name, + "action", "version_number", "memory_size", "is_continuous_control", @@ -49,7 +41,7 @@ def __init__(self, policy): ] self.dynamic_axes = {name: {0: "batch"} for name in self.input_names} - self.dynamic_axes.update({"action": {0: "batch"}, "action_probs": {0: "batch"}}) + self.dynamic_axes.update({"action": {0: "batch"}}) def export_policy_model(self, output_filepath: str) -> None: """ diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 27c8425a30..526407b2e2 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -193,7 +193,7 @@ def forward( vis_inputs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]: + ) -> Tuple[torch.Tensor, int, int, int, int]: """ Forward pass of the Actor for inference. This is required for export to ONNX, and the inputs and outputs of this method should not be changed without a respective change @@ -325,7 +325,7 @@ def forward( vis_inputs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]: + ) -> Tuple[torch.Tensor, int, int, int, int]: """ Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs. """ @@ -333,12 +333,11 @@ def forward( action_list = self.sample_action(dists) sampled_actions = torch.stack(action_list, dim=-1) if self.act_type == ActionType.CONTINUOUS: - log_probs = dists[0].log_prob(sampled_actions) + action_out = sampled_actions else: - log_probs = dists[0].all_log_prob() + action_out = dists[0].all_log_prob() return ( - sampled_actions, - log_probs, + action_out, self.version_number, torch.Tensor([self.network_body.memory_size]), self.is_continuous_int,