-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Don't drop multiple stats from the same step #4236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,7 +67,7 @@ in python, run: | |
| from mlagents_envs.environment import UnityEnvironment | ||
| # This is a non-blocking call that only loads the environment. | ||
| env = UnityEnvironment(file_name="3DBall", seed=1, side_channels=[]) | ||
| # Start interacting with the evironment. | ||
| # Start interacting with the environment. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Found this when I accidentally typo'd EnvironmentStats too. |
||
| env.reset() | ||
| behavior_names = env.behavior_specs.keys() | ||
| ... | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,9 @@ | ||
| from mlagents_envs.side_channel import SideChannel, IncomingMessage | ||
| import uuid | ||
| from typing import Dict, Tuple | ||
| from typing import Tuple, List, Mapping | ||
| from enum import Enum | ||
| from collections import defaultdict | ||
|
|
||
| from mlagents_envs.side_channel import SideChannel, IncomingMessage | ||
|
|
||
|
|
||
| # Determines the behavior of how multiple stats within the same summary period are combined. | ||
|
|
@@ -13,6 +15,10 @@ class StatsAggregationMethod(Enum): | |
| MOST_RECENT = 1 | ||
|
|
||
|
|
||
| StatList = List[Tuple[float, StatsAggregationMethod]] | ||
| EnvironmentStats = Mapping[str, StatList] | ||
|
Comment on lines
+18
to
+19
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I originally had DefaultDict here, but Pycharm was complaining about my unit test that I was passing a dict instead. |
||
|
|
||
|
|
||
| class StatsSideChannel(SideChannel): | ||
| """ | ||
| Side channel that receives (string, float) pairs from the environment, so that they can eventually | ||
|
|
@@ -24,7 +30,7 @@ def __init__(self) -> None: | |
| # UUID('a1d8f7b7-cec8-50f9-b78b-d3e165a78520') | ||
| super().__init__(uuid.UUID("a1d8f7b7-cec8-50f9-b78b-d3e165a78520")) | ||
|
|
||
| self.stats: Dict[str, Tuple[float, StatsAggregationMethod]] = {} | ||
| self.stats: EnvironmentStats = defaultdict(list) | ||
|
|
||
| def on_message_received(self, msg: IncomingMessage) -> None: | ||
| """ | ||
|
|
@@ -36,13 +42,13 @@ def on_message_received(self, msg: IncomingMessage) -> None: | |
| val = msg.read_float32() | ||
| agg_type = StatsAggregationMethod(msg.read_int32()) | ||
|
|
||
| self.stats[key] = (val, agg_type) | ||
| self.stats[key].append((val, agg_type)) | ||
|
|
||
| def get_and_reset_stats(self) -> Dict[str, Tuple[float, StatsAggregationMethod]]: | ||
| def get_and_reset_stats(self) -> EnvironmentStats: | ||
| """ | ||
| Returns the current stats, and resets the internal storage of the stats. | ||
| :return: | ||
| """ | ||
| s = self.stats | ||
| self.stats = {} | ||
| self.stats = defaultdict(list) | ||
| return s | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,7 +9,10 @@ | |
| TerminalSteps, | ||
| TerminalStep, | ||
| ) | ||
| from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod | ||
| from mlagents_envs.side_channel.stats_side_channel import ( | ||
| StatsAggregationMethod, | ||
| EnvironmentStats, | ||
| ) | ||
| from mlagents.trainers.trajectory import Trajectory, AgentExperience | ||
| from mlagents.trainers.policy.tf_policy import TFPolicy | ||
| from mlagents.trainers.policy import Policy | ||
|
|
@@ -306,7 +309,7 @@ def __init__( | |
| self.publish_trajectory_queue(self.trajectory_queue) | ||
|
|
||
| def record_environment_stats( | ||
| self, env_stats: Dict[str, Tuple[float, StatsAggregationMethod]], worker_id: int | ||
| self, env_stats: EnvironmentStats, worker_id: int | ||
| ) -> None: | ||
| """ | ||
| Pass stats from the environment to the StatsReporter. | ||
|
|
@@ -316,11 +319,12 @@ def record_environment_stats( | |
| :param worker_id: | ||
| :return: | ||
| """ | ||
| for stat_name, (val, agg_type) in env_stats.items(): | ||
| if agg_type == StatsAggregationMethod.AVERAGE: | ||
| self.stats_reporter.add_stat(stat_name, val) | ||
| elif agg_type == StatsAggregationMethod.MOST_RECENT: | ||
| # In order to prevent conflicts between multiple environments, | ||
| # only stats from the first environment are recorded. | ||
| if worker_id == 0: | ||
| self.stats_reporter.set_stat(stat_name, val) | ||
| for stat_name, value_list in env_stats.items(): | ||
| for val, agg_type in value_list: | ||
| if agg_type == StatsAggregationMethod.AVERAGE: | ||
| self.stats_reporter.add_stat(stat_name, val) | ||
| elif agg_type == StatsAggregationMethod.MOST_RECENT: | ||
| # In order to prevent conflicts between multiple environments, | ||
| # only stats from the first environment are recorded. | ||
| if worker_id == 0: | ||
| self.stats_reporter.set_stat(stat_name, val) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Kind of weird to set_stats for each element in value_list since we will only keep the last one. I would do the for loop There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's possible (but not recommended) to mix MOST_RECENT and AVERAGE for the same key. So if you passed I think this is the most sensible way to handle it. You'd end up appending 1, then dropping that and replacing with 2, and finally appending 3. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None of the agents had this set, so we never accumulate the score:
ml-agents/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs
Lines 257 to 260 in 2c1f04e