Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Remove local_variables from on_phase_start #416

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions classy_vision/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def _save_checkpoint(self, task, filename):
if checkpoint_file:
PathManager.copy(checkpoint_file, f"{self.checkpoint_folder}/{filename}")

def on_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_start(self, task: "tasks.ClassyTask") -> None:
if not is_master() or getattr(task, "test_only", False):
return
if not PathManager.exists(self.checkpoint_folder):
Expand Down
16 changes: 5 additions & 11 deletions classy_vision/hooks/classy_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, a, b):
def __init__(self):
self.state = ClassyHookState()

def _noop(self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]) -> None:
def _noop(self, *args, **kwargs) -> None:
"""Derived classes can set their hook functions to this.

This is useful if they want those hook functions to not do anything.
Expand All @@ -65,23 +65,17 @@ def name(cls) -> str:
return cls.__name__

@abstractmethod
def on_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_start(self, task: "tasks.ClassyTask") -> None:
"""Called at the start of training."""
pass

@abstractmethod
def on_phase_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_phase_start(self, task: "tasks.ClassyTask") -> None:
"""Called at the start of each phase."""
pass

@abstractmethod
def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""Called each time after parameters have been updated by the optimizer."""
pass

Expand All @@ -93,7 +87,7 @@ def on_phase_end(
pass

@abstractmethod
def on_end(self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]) -> None:
def on_end(self, task: "tasks.ClassyTask") -> None:
"""Called at the end of training."""
pass

Expand Down
6 changes: 3 additions & 3 deletions classy_vision/hooks/exponential_moving_average_model_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _save_current_model_state(self, model: nn.Module, model_state: Dict[str, Any
for name, param in self.get_model_state_iterator(model):
model_state[name] = param.detach().clone().to(device=self.device)

def on_start(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None:
def on_start(self, task: ClassyTask) -> None:
if self.state.model_state:
# loaded state from checkpoint, do not re-initialize, only move the state
# to the right device
Expand All @@ -93,7 +93,7 @@ def on_start(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None:
self._save_current_model_state(task.base_model, self.state.model_state)
self._save_current_model_state(task.base_model, self.state.ema_model_state)

def on_phase_start(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None:
def on_phase_start(self, task: ClassyTask) -> None:
# restore the right state depending on the phase type
self.set_model_state(task, use_ema=not task.train)

Expand All @@ -103,7 +103,7 @@ def on_phase_end(self, task: ClassyTask, local_variables: Dict[str, Any]) -> Non
# state in the test phase
self._save_current_model_state(task.base_model, self.state.model_state)

def on_step(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None:
def on_step(self, task: ClassyTask) -> None:
if not task.train:
return

Expand Down
20 changes: 7 additions & 13 deletions classy_vision/hooks/loss_lr_meter_logging_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,36 +45,30 @@ def on_phase_end(
# trainer to implement an unsynced end of phase meter or
# for meters to not provide a sync function.
logging.info("End of phase metric values:")
self._log_loss_meters(task, local_variables)
self._log_loss_meters(task)
if task.train:
self._log_lr(task, local_variables)
self._log_lr(task)

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""
Log the LR every log_freq batches, if log_freq is not None.
"""
if self.log_freq is None or not task.train:
return
batches = len(task.losses)
if batches and batches % self.log_freq == 0:
self._log_lr(task, local_variables)
self._log_lr(task)
logging.info("Local unsynced metric values:")
self._log_loss_meters(task, local_variables)
self._log_loss_meters(task)

def _log_lr(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def _log_lr(self, task: "tasks.ClassyTask") -> None:
"""
Compute and log the optimizer LR.
"""
optimizer_lr = task.optimizer.parameters.lr
logging.info("Learning Rate: {}\n".format(optimizer_lr))

def _log_loss_meters(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def _log_loss_meters(self, task: "tasks.ClassyTask") -> None:
"""
Compute and log the loss and meters.
"""
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/model_complexity_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ class ModelComplexityHook(ClassyHook):
on_phase_end = ClassyHook._noop
on_end = ClassyHook._noop

def on_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_start(self, task: "tasks.ClassyTask") -> None:
"""Measure number of parameters, FLOPs and activations."""
self.num_flops = 0
self.num_activations = 0
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/model_tensorboard_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def __init__(self, tb_writer) -> None:

self.tb_writer = tb_writer

def on_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_start(self, task: "tasks.ClassyTask") -> None:
"""
Plot the model on Tensorboard.
"""
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/profiler_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ class ProfilerHook(ClassyHook):
on_phase_end = ClassyHook._noop
on_end = ClassyHook._noop

def on_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_start(self, task: "tasks.ClassyTask") -> None:
"""Profile the forward pass."""
logging.info("Profiling forward pass...")
batchsize_per_replica = getattr(
Expand Down
8 changes: 2 additions & 6 deletions classy_vision/hooks/progress_bar_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def __init__(self) -> None:
self.bar_size: int = 0
self.batches: int = 0

def on_phase_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_phase_start(self, task: "tasks.ClassyTask") -> None:
"""Create and display a progress bar with 0 progress."""
if not progressbar_available:
raise RuntimeError(
Expand All @@ -51,9 +49,7 @@ def on_phase_start(
self.progress_bar = progressbar.ProgressBar(self.bar_size)
self.progress_bar.start()

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""Update the progress bar with the batch size."""
if task.train and is_master() and self.progress_bar is not None:
self.batches += 1
Expand Down
8 changes: 2 additions & 6 deletions classy_vision/hooks/tensorboard_plot_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,13 @@ def __init__(self, tb_writer) -> None:
self.wall_times: Optional[List[float]] = None
self.num_steps_global: Optional[List[int]] = None

def on_phase_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_phase_start(self, task: "tasks.ClassyTask") -> None:
"""Initialize losses and learning_rates."""
self.learning_rates = []
self.wall_times = []
self.num_steps_global = []

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""Store the observed learning rates."""
if self.learning_rates is None:
logging.warning("learning_rates is not initialized")
Expand Down
24 changes: 9 additions & 15 deletions classy_vision/hooks/time_metrics_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,22 @@ def __init__(self, log_freq: Optional[int] = None) -> None:
self.log_freq: Optional[int] = log_freq
self.start_time: Optional[float] = None

def on_phase_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_phase_start(self, task: "tasks.ClassyTask") -> None:
"""
Initialize start time and reset perf stats
"""
self.start_time = time.time()
local_variables["perf_stats"] = PerfStats()
task.perf_stats = PerfStats()

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""
Log metrics every log_freq batches, if log_freq is not None.
"""
if self.log_freq is None:
return
batches = len(task.losses)
if batches and batches % self.log_freq == 0:
self._log_performance_metrics(task, local_variables)
self._log_performance_metrics(task)

def on_phase_end(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
Expand All @@ -62,11 +58,9 @@ def on_phase_end(
"""
batches = len(task.losses)
if batches:
self._log_performance_metrics(task, local_variables)
self._log_performance_metrics(task)

def _log_performance_metrics(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def _log_performance_metrics(self, task: "tasks.ClassyTask") -> None:
"""
Compute and log performance metrics.
"""
Expand All @@ -85,11 +79,11 @@ def _log_performance_metrics(
)

# Train step time breakdown
if local_variables.get("perf_stats") is None:
logging.warning('"perf_stats" not set in local_variables')
if not hasattr(task, "perf_stats") or task.perf_stats is None:
logging.warning('"perf_stats" not set in task')
elif task.train:
logging.info(
"Train step time breakdown (rank {}):\n{}".format(
get_rank(), local_variables["perf_stats"].report_str()
get_rank(), task.perf_stats.report_str()
)
)
Loading