Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 0b2368e

Browse files
vreisfacebook-github-bot
authored andcommitted
Remove local_variables from on_phase_start (#416)
Summary: Pull Request resolved: #416 This is part of a series of diffs to eliminate local_variables (see D20171981). Proceed removing local_variables from on_phase_start Reviewed By: mannatsingh Differential Revision: D20178268 fbshipit-source-id: 09f78810228b2fec9faa2205d92b108aea30aff9
1 parent ce31b99 commit 0b2368e

13 files changed

+18
-25
lines changed

classy_vision/hooks/classy_hook.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ def on_start(self, task: "tasks.ClassyTask") -> None:
7070
pass
7171

7272
@abstractmethod
73-
def on_phase_start(
74-
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
75-
) -> None:
73+
def on_phase_start(self, task: "tasks.ClassyTask") -> None:
7674
"""Called at the start of each phase."""
7775
pass
7876

classy_vision/hooks/exponential_moving_average_model_hook.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def on_start(self, task: ClassyTask) -> None:
9393
self._save_current_model_state(task.base_model, self.state.model_state)
9494
self._save_current_model_state(task.base_model, self.state.ema_model_state)
9595

96-
def on_phase_start(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None:
96+
def on_phase_start(self, task: ClassyTask) -> None:
9797
# restore the right state depending on the phase type
9898
self.set_model_state(task, use_ema=not task.train)
9999

classy_vision/hooks/progress_bar_hook.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ def __init__(self) -> None:
3636
self.bar_size: int = 0
3737
self.batches: int = 0
3838

39-
def on_phase_start(
40-
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
41-
) -> None:
39+
def on_phase_start(self, task: "tasks.ClassyTask") -> None:
4240
"""Create and display a progress bar with 0 progress."""
4341
if not progressbar_available:
4442
raise RuntimeError(

classy_vision/hooks/tensorboard_plot_hook.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ def __init__(self, tb_writer) -> None:
5656
self.wall_times: Optional[List[float]] = None
5757
self.num_steps_global: Optional[List[int]] = None
5858

59-
def on_phase_start(
60-
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
61-
) -> None:
59+
def on_phase_start(self, task: "tasks.ClassyTask") -> None:
6260
"""Initialize losses and learning_rates."""
6361
self.learning_rates = []
6462
self.wall_times = []

classy_vision/hooks/time_metrics_hook.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ def __init__(self, log_freq: Optional[int] = None) -> None:
3333
self.log_freq: Optional[int] = log_freq
3434
self.start_time: Optional[float] = None
3535

36-
def on_phase_start(
37-
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
38-
) -> None:
36+
def on_phase_start(self, task: "tasks.ClassyTask") -> None:
3937
"""
4038
Initialize start time and reset perf stats
4139
"""

classy_vision/tasks/classification_task.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -858,12 +858,13 @@ def on_start(self):
858858
for hook in self.hooks:
859859
hook.on_start(self)
860860

861-
def on_phase_start(self, local_variables):
861+
def on_phase_start(self):
862862
self.phase_start_time_total = time.perf_counter()
863863

864864
self.advance_phase()
865865

866-
self.run_hooks(local_variables, ClassyHookFunctions.on_phase_start.name)
866+
for hook in self.hooks:
867+
hook.on_phase_start(self)
867868

868869
self.phase_start_time_train = time.perf_counter()
869870

classy_vision/tasks/classy_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def on_start(self):
128128
pass
129129

130130
@abstractmethod
131-
def on_phase_start(self, local_variables):
131+
def on_phase_start(self):
132132
"""
133133
Epoch start.
134134

classy_vision/trainer/classy_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def train(self, task: ClassyTask):
7474

7575
task.on_start()
7676
while not task.done_training():
77-
task.on_phase_start(local_variables)
77+
task.on_phase_start()
7878
while True:
7979
try:
8080
task.step(self.use_gpu)

classy_vision/trainer/elastic_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _run_step(self, state, local_variables, use_gpu):
9696
if state.advance_to_next_phase:
9797
self.elastic_coordinator.barrier()
9898
self.elastic_coordinator._log_event("on_phase_start")
99-
state.task.on_phase_start(local_variables)
99+
state.task.on_phase_start()
100100

101101
state.advance_to_next_phase = False
102102

test/hooks_exponential_moving_average_model_hook_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _test_exponential_moving_average_hook(self, model_device, hook_device):
4848
)
4949

5050
exponential_moving_average_hook.on_start(task)
51-
exponential_moving_average_hook.on_phase_start(task, local_variables)
51+
exponential_moving_average_hook.on_phase_start(task)
5252
# set the weights to all ones and simulate 10 updates
5353
task.base_model.update_fc_weight()
5454
fc_weight = model.fc.weight.clone()
@@ -60,7 +60,7 @@ def _test_exponential_moving_average_hook(self, model_device, hook_device):
6060

6161
# simulate a test phase now
6262
task.train = False
63-
exponential_moving_average_hook.on_phase_start(task, local_variables)
63+
exponential_moving_average_hook.on_phase_start(task)
6464
exponential_moving_average_hook.on_phase_end(task, local_variables)
6565

6666
# the model weights should be updated to the ema weights
@@ -72,7 +72,7 @@ def _test_exponential_moving_average_hook(self, model_device, hook_device):
7272

7373
# simulate a train phase again
7474
task.train = True
75-
exponential_moving_average_hook.on_phase_start(task, local_variables)
75+
exponential_moving_average_hook.on_phase_start(task)
7676

7777
# the model weights should be back to the old value
7878
self.assertTrue(torch.allclose(model.fc.weight, fc_weight))

test/hooks_time_metrics_hook_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_time_metrics(
4747
# on_phase_start() should set the start time and perf_stats
4848
start_time = 1.2
4949
mock_time.return_value = start_time
50-
time_metrics_hook.on_phase_start(task, local_variables)
50+
time_metrics_hook.on_phase_start(task)
5151
self.assertEqual(time_metrics_hook.start_time, start_time)
5252
self.assertTrue(isinstance(task.perf_stats, PerfStats))
5353

test/manual/hooks_progress_bar_hook_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_progress_bar(
4040
progress_bar_hook = ProgressBarHook()
4141

4242
# progressbar.ProgressBar should be init-ed with num_batches
43-
progress_bar_hook.on_phase_start(task, local_variables)
43+
progress_bar_hook.on_phase_start(task)
4444
mock_progressbar_pkg.ProgressBar.assert_called_once_with(num_batches)
4545
mock_progress_bar.start.assert_called_once_with()
4646
mock_progress_bar.start.reset_mock()
@@ -80,7 +80,7 @@ def test_progress_bar(
8080
mock_is_master.return_value = False
8181
progress_bar_hook = ProgressBarHook()
8282
try:
83-
progress_bar_hook.on_phase_start(task, local_variables)
83+
progress_bar_hook.on_phase_start(task)
8484
progress_bar_hook.on_step(task)
8585
progress_bar_hook.on_phase_end(task, local_variables)
8686
except Exception as e:

test/manual/hooks_tensorboard_plot_hook_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_writer(self, mock_is_master_func: mock.MagicMock) -> None:
8484
summary_writer.add_scalar.reset_mock()
8585

8686
# run the hook in the correct order
87-
tensorboard_plot_hook.on_phase_start(task, local_variables)
87+
tensorboard_plot_hook.on_phase_start(task)
8888

8989
for loss in losses:
9090
task.losses.append(loss)

0 commit comments

Comments
 (0)