Skip to content

Commit c3fae3a

Browse files
Removing the experiment script from add fire (#4373)
* Removing the experiment script * Removing the script
1 parent d37960c commit c3fae3a

File tree

6 files changed

+17
-301
lines changed

6 files changed

+17
-301
lines changed

experiment_torch.py

Lines changed: 0 additions & 248 deletions
This file was deleted.

ml-agents/mlagents/trainers/learn.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222
from mlagents.trainers.cli_utils import parser
2323
from mlagents_envs.environment import UnityEnvironment
24-
from mlagents.trainers.settings import RunOptions, TestingConfiguration
24+
from mlagents.trainers.settings import RunOptions
2525

2626
from mlagents.trainers.training_status import GlobalTrainingStatus
2727
from mlagents_envs.base_env import BaseEnv
@@ -35,8 +35,6 @@
3535
)
3636
from mlagents_envs import logging_util
3737

38-
from mlagents_envs.registry import default_registry
39-
4038
logger = logging_util.get_logger(__name__)
4139

4240
TRAINING_STATUS_FILE_NAME = "training_status.json"
@@ -198,27 +196,16 @@ def create_unity_environment(
198196
) -> UnityEnvironment:
199197
# Make sure that each environment gets a different seed
200198
env_seed = seed + worker_id
201-
if TestingConfiguration.env_name == "":
202-
return UnityEnvironment(
203-
file_name=env_path,
204-
worker_id=worker_id,
205-
seed=env_seed,
206-
no_graphics=no_graphics,
207-
base_port=start_port,
208-
additional_args=env_args,
209-
side_channels=side_channels,
210-
log_folder=log_folder,
211-
)
212-
else:
213-
return default_registry[TestingConfiguration.env_name].make(
214-
seed=env_seed,
215-
no_graphics=no_graphics,
216-
base_port=start_port,
217-
worker_id=worker_id,
218-
additional_args=env_args,
219-
side_channels=side_channels,
220-
log_folder=log_folder,
221-
)
199+
return UnityEnvironment(
200+
file_name=env_path,
201+
worker_id=worker_id,
202+
seed=env_seed,
203+
no_graphics=no_graphics,
204+
base_port=start_port,
205+
additional_args=env_args,
206+
side_channels=side_channels,
207+
log_folder=log_folder,
208+
)
222209

223210
return create_unity_environment
224211

ml-agents/mlagents/trainers/policy/torch_policy.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec
99
from mlagents_envs.timers import timed
1010

11-
from mlagents.trainers.settings import TrainerSettings, TestingConfiguration
11+
from mlagents.trainers.settings import TrainerSettings
1212
from mlagents.trainers.trajectory import SplitObservations
1313
from mlagents.trainers.torch.networks import (
1414
SharedActorCritic,
@@ -57,10 +57,7 @@ def __init__(
5757
) # could be much simpler if TorchPolicy is nn.Module
5858
self.grads = None
5959

60-
if TestingConfiguration.device != "cpu":
61-
torch.set_default_tensor_type(torch.cuda.FloatTensor)
62-
else:
63-
torch.set_default_tensor_type(torch.FloatTensor)
60+
torch.set_default_tensor_type(torch.FloatTensor)
6461

6562
reward_signal_configs = trainer_settings.reward_signals
6663
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()]
@@ -83,7 +80,7 @@ def __init__(
8380
tanh_squash=tanh_squash,
8481
)
8582

86-
self.actor_critic.to(TestingConfiguration.device)
83+
self.actor_critic.to("cpu")
8784

8885
def split_decision_step(self, decision_requests):
8986
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)

ml-agents/mlagents/trainers/ppo/trainer.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,7 @@
1515
from mlagents.trainers.ppo.optimizer_tf import PPOOptimizer
1616
from mlagents.trainers.trajectory import Trajectory
1717
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
18-
from mlagents.trainers.settings import (
19-
TrainerSettings,
20-
PPOSettings,
21-
TestingConfiguration,
22-
FrameworkType,
23-
)
18+
from mlagents.trainers.settings import TrainerSettings, PPOSettings, FrameworkType
2419
from mlagents.trainers.components.reward_signals import RewardSignal
2520

2621
try:
@@ -64,8 +59,6 @@ def __init__(
6459
PPOSettings, self.trainer_settings.hyperparameters
6560
)
6661
self.seed = seed
67-
if TestingConfiguration.max_steps > 0:
68-
self.trainer_settings.max_steps = TestingConfiguration.max_steps
6962
self.policy: Policy = None # type: ignore
7063

7164
def _process_trajectory(self, trajectory: Trajectory) -> None:

ml-agents/mlagents/trainers/settings.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,6 @@ def defaultdict_to_dict(d: DefaultDict) -> Dict:
4646
return {key: cattr.unstructure(val) for key, val in d.items()}
4747

4848

49-
class TestingConfiguration:
50-
use_torch = True
51-
max_steps = 0
52-
env_name = ""
53-
device = "cpu"
54-
55-
5649
class SerializationSettings:
5750
convert_to_barracuda = True
5851
convert_to_onnx = True

0 commit comments

Comments
 (0)