Skip to content

Commit 19c9ff0

Browse files
author
Ervin T
authored
[feature] Fix TF tests, add --torch CLI option, allow run TF without torch installed (#4305)
1 parent 9d0fad2 commit 19c9ff0

File tree

13 files changed

+74
-24
lines changed

13 files changed

+74
-24
lines changed

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ jobs:
7070
. venv/bin/activate
7171
mkdir test-reports
7272
pip freeze > test-reports/pip_versions.txt
73-
pytest -n 2 --cov=ml-agents --cov=ml-agents-envs --cov=gym-unity --cov-report html --junitxml=test-reports/junit.xml -p no:warnings
73+
pytest --cov=ml-agents --cov=ml-agents-envs --cov=gym-unity --cov-report html --junitxml=test-reports/junit.xml -p no:warnings
7474
7575
- run:
7676
name: Verify there are no hidden/missing metafiles.

experiment_torch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def run_experiment(
9898
evaluate_count = evaluate["TorchPolicy.evaluate"]["count"]
9999
else:
100100
if algo == "ppo":
101-
update_total = update["TFPPOOptimizer.update"]["total"]
102-
update_count = update["TFPPOOptimizer.update"]["count"]
101+
update_total = update["PPOOptimizer.update"]["total"]
102+
update_count = update["PPOOptimizer.update"]["count"]
103103
else:
104104
update_total = update["SACTrainer._update_policy"]["total"]
105105
update_count = update["SACTrainer._update_policy"]["count"]

ml-agents/mlagents/trainers/cli_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,13 @@ def _create_parser() -> argparse.ArgumentParser:
168168
action=DetectDefaultStoreTrue,
169169
help="Forces training using CPU only",
170170
)
171+
argparser.add_argument(
172+
"--torch",
173+
default=False,
174+
action=DetectDefaultStoreTrue,
175+
help="(Experimental) Use the PyTorch framework instead of TensorFlow. Install PyTorch "
176+
"before using this option",
177+
)
171178

172179
eng_conf = argparser.add_argument_group(title="Engine Configuration")
173180
eng_conf.add_argument(

ml-agents/mlagents/trainers/ppo/optimizer_tf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from mlagents.trainers.settings import TrainerSettings, PPOSettings
1010

1111

12-
class TFPPOOptimizer(TFOptimizer):
12+
class PPOOptimizer(TFOptimizer):
1313
def __init__(self, policy: TFPolicy, trainer_params: TrainerSettings):
1414
"""
1515
Takes a Policy and a Dict of trainer parameters and creates an Optimizer around the policy.

ml-agents/mlagents/trainers/ppo/trainer.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,25 @@
1010
from mlagents_envs.logging_util import get_logger
1111
from mlagents_envs.base_env import BehaviorSpec
1212
from mlagents.trainers.trainer.rl_trainer import RLTrainer
13-
from mlagents.trainers.policy.torch_policy import TorchPolicy
1413
from mlagents.trainers.policy import Policy
1514
from mlagents.trainers.policy.tf_policy import TFPolicy
16-
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer
17-
from mlagents.trainers.ppo.optimizer_tf import TFPPOOptimizer
15+
from mlagents.trainers.ppo.optimizer_tf import PPOOptimizer
1816
from mlagents.trainers.trajectory import Trajectory
1917
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
2018
from mlagents.trainers.settings import (
2119
TrainerSettings,
2220
PPOSettings,
2321
TestingConfiguration,
22+
FrameworkType,
2423
)
2524

25+
try:
26+
from mlagents.trainers.policy.torch_policy import TorchPolicy
27+
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer
28+
except ModuleNotFoundError:
29+
TorchPolicy = None # type: ignore
30+
TorchPPOOptimizer = None # type: ignore
31+
2632

2733
logger = get_logger(__name__)
2834

@@ -58,7 +64,6 @@ def __init__(
5864
)
5965
self.load = load
6066
self.seed = seed
61-
self.framework = "torch" if TestingConfiguration.use_torch else "tf"
6267
if TestingConfiguration.max_steps > 0:
6368
self.trainer_settings.max_steps = TestingConfiguration.max_steps
6469
self.policy: Policy = None # type: ignore
@@ -254,12 +259,12 @@ def add_policy(
254259
)
255260
self.policy = policy
256261
self.policies[parsed_behavior_id.behavior_id] = policy
257-
if self.framework == "torch":
262+
if self.framework == FrameworkType.PYTORCH:
258263
self.optimizer = TorchPPOOptimizer( # type: ignore
259264
self.policy, self.trainer_settings # type: ignore
260265
) # type: ignore
261266
else:
262-
self.optimizer = TFPPOOptimizer( # type: ignore
267+
self.optimizer = PPOOptimizer( # type: ignore
263268
self.policy, self.trainer_settings # type: ignore
264269
) # type: ignore
265270
for _reward_signal in self.optimizer.reward_signals.keys():

ml-agents/mlagents/trainers/sac/trainer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@
1818
from mlagents.trainers.trainer.rl_trainer import RLTrainer
1919
from mlagents.trainers.trajectory import Trajectory, SplitObservations
2020
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
21-
from mlagents.trainers.policy.torch_policy import TorchPolicy
22-
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
23-
from mlagents.trainers.settings import TrainerSettings, SACSettings
21+
from mlagents.trainers.settings import TrainerSettings, SACSettings, FrameworkType
2422

23+
try:
24+
from mlagents.trainers.policy.torch_policy import TorchPolicy
25+
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
26+
except ModuleNotFoundError:
27+
TorchPolicy = None # type: ignore
28+
TorchSACOptimizer = None # type: ignore
2529

2630
logger = get_logger(__name__)
2731

@@ -353,7 +357,7 @@ def add_policy(
353357
)
354358
self.policy = policy
355359
self.policies[parsed_behavior_id.behavior_id] = policy
356-
if self.framework == "torch":
360+
if self.framework == FrameworkType.PYTORCH:
357361
self.optimizer = TorchSACOptimizer( # type: ignore
358362
self.policy, self.trainer_settings # type: ignore
359363
) # type: ignore

ml-agents/mlagents/trainers/settings.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,11 @@ def to_settings(self) -> type:
524524
return _mapping[self]
525525

526526

527+
class FrameworkType(Enum):
528+
TENSORFLOW: str = "tensorflow"
529+
PYTORCH: str = "pytorch"
530+
531+
527532
@attr.s(auto_attribs=True)
528533
class TrainerSettings(ExportableSettings):
529534
trainer_type: TrainerType = TrainerType.PPO
@@ -546,6 +551,7 @@ def _set_default_hyperparameters(self):
546551
threaded: bool = True
547552
self_play: Optional[SelfPlaySettings] = None
548553
behavioral_cloning: Optional[BehavioralCloningSettings] = None
554+
framework: FrameworkType = FrameworkType.TENSORFLOW
549555

550556
cattr.register_structure_hook(
551557
Dict[RewardSignalType, RewardSignalSettings], RewardSignalSettings.structure
@@ -713,7 +719,13 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions":
713719
configured_dict["engine_settings"][key] = val
714720
else: # Base options
715721
configured_dict[key] = val
716-
return RunOptions.from_dict(configured_dict)
722+
723+
# Apply --torch retroactively
724+
final_runoptions = RunOptions.from_dict(configured_dict)
725+
if "torch" in DetectDefault.non_default_args:
726+
for trainer_set in final_runoptions.behaviors.values():
727+
trainer_set.framework = FrameworkType.PYTORCH
728+
return final_runoptions
717729

718730
@staticmethod
719731
def from_dict(options_dict: Dict[str, Any]) -> "RunOptions":

ml-agents/mlagents/trainers/tests/test_ppo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
99

1010
from mlagents.trainers.ppo.trainer import PPOTrainer, discount_rewards
11-
from mlagents.trainers.ppo.optimizer_tf import TFPPOOptimizer
11+
from mlagents.trainers.ppo.optimizer_tf import PPOOptimizer
1212
from mlagents.trainers.policy.tf_policy import TFPolicy
1313
from mlagents.trainers.agent_processor import AgentManagerQueue
1414
from mlagents.trainers.tests import mock_brain as mb
@@ -52,7 +52,7 @@ def _create_ppo_optimizer_ops_mock(dummy_config, use_rnn, use_discrete, use_visu
5252
policy = TFPolicy(
5353
0, mock_specs, trainer_settings, "test", False, create_tf_graph=False
5454
)
55-
optimizer = TFPPOOptimizer(policy, trainer_settings)
55+
optimizer = PPOOptimizer(policy, trainer_settings)
5656
return optimizer
5757

5858

ml-agents/mlagents/trainers/tests/test_reward_signals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import mlagents.trainers.tests.mock_brain as mb
55
from mlagents.trainers.policy.tf_policy import TFPolicy
66
from mlagents.trainers.sac.optimizer import SACOptimizer
7-
from mlagents.trainers.ppo.optimizer_tf import TFPPOOptimizer
7+
from mlagents.trainers.ppo.optimizer_tf import PPOOptimizer
88
from mlagents.trainers.tests.test_simple_rl import PPO_CONFIG, SAC_CONFIG
99
from mlagents.trainers.settings import (
1010
GAILSettings,
@@ -75,7 +75,7 @@ def create_optimizer_mock(
7575
if trainer_settings.trainer_type == TrainerType.SAC:
7676
optimizer = SACOptimizer(policy, trainer_settings)
7777
else:
78-
optimizer = TFPPOOptimizer(policy, trainer_settings)
78+
optimizer = PPOOptimizer(policy, trainer_settings)
7979
return optimizer
8080

8181

ml-agents/mlagents/trainers/tests/test_rl_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ def _update_policy(self):
2525
def add_policy(self, mock_behavior_id, mock_policy):
2626
self.policies[mock_behavior_id] = mock_policy
2727

28-
def create_policy(self):
28+
def create_tf_policy(self):
29+
return mock.Mock()
30+
31+
def create_torch_policy(self):
2932
return mock.Mock()
3033

3134
def _process_trajectory(self, trajectory):

0 commit comments

Comments
 (0)