Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 8f6b512

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Support specifying update interval in the parameter schedulers (#418)
Summary: Pull Request resolved: #418 Most parameter schedulers used the default "epoch" update interval and didn't allow users to override this. - This diff makes the update interval configurable by users - The schedulers use a "step" interval by default, except for "step" and "multistep" - Removed the default "epoch" update interval inside the base class - this isn't what most users would expect if they didn't specify the argument Reviewed By: vreis Differential Revision: D20256209 fbshipit-source-id: 1797d5d833dc20d173eed9ac53117e1e2de54ef1
1 parent 4b40c8d commit 8f6b512

11 files changed

+133
-52
lines changed

classy_vision/optim/param_scheduler/classy_vision_param_scheduler.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,25 @@ class UpdateInterval(Enum):
2020
EPOCH = "epoch"
2121
STEP = "step"
2222

23+
@classmethod
24+
def from_config(
25+
cls, config: Dict[str, Any], default: "UpdateInterval" = None
26+
) -> "UpdateInterval":
27+
"""Fetches the update interval from a config
28+
29+
Args:
30+
config: The config for the parameter scheduler
31+
default: The value to use if the config doesn't specify an update interval.
32+
If not set, STEP is used.
33+
"""
34+
if default is None:
35+
default = cls.STEP
36+
if "update_interval" not in config:
37+
return default
38+
if config.get("update_interval").lower() not in ["step", "epoch"]:
39+
raise ValueError("Choices for update interval are 'step' or 'epoch'")
40+
return cls[config["update_interval"].upper()]
41+
2342

2443
class ClassyParamScheduler(object):
2544
"""
@@ -33,7 +52,7 @@ class ClassyParamScheduler(object):
3352
# To be used for comparisons with where
3453
WHERE_EPSILON = 1e-6
3554

36-
def __init__(self, update_interval: UpdateInterval = UpdateInterval.EPOCH):
55+
def __init__(self, update_interval: UpdateInterval):
3756
"""
3857
Constructor for ClassyParamScheduler
3958

classy_vision/optim/param_scheduler/composite_scheduler.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,17 @@ class CompositeParamScheduler(ClassyParamScheduler):
4949
The parameter value will be 0.42 for the first [0%, 30%) of steps,
5050
and then will cosine decay from 0.42 to 0.0001 for [30%, 100%) of
5151
training.
52+
The schedule is updated after every train step by default.
5253
"""
5354

5455
def __init__(
5556
self,
5657
schedulers: Sequence[ClassyParamScheduler],
5758
lengths: Sequence[float],
58-
update_interval: UpdateInterval,
5959
interval_scaling: Sequence[IntervalScaling],
60+
update_interval: UpdateInterval = UpdateInterval.STEP,
6061
):
61-
super().__init__()
62-
self.update_interval = update_interval
62+
super().__init__(update_interval=update_interval)
6363
self._lengths = lengths
6464
self._schedulers = schedulers
6565
self._interval_scaling = interval_scaling
@@ -89,13 +89,6 @@ def from_config(cls, config: Dict[str, Any]) -> "CompositeParamScheduler":
8989
), "The sum of all values in lengths must be 1"
9090
if sum(config["lengths"]) != 1.0:
9191
config["lengths"][-1] = 1.0 - sum(config["lengths"][:-1])
92-
update_interval = UpdateInterval.STEP
93-
if "update_interval" in config:
94-
assert config["update_interval"] in {
95-
"step",
96-
"epoch",
97-
}, "Choices for update interval are 'step' or 'epoch'"
98-
update_interval = UpdateInterval[config["update_interval"].upper()]
9992
interval_scaling = []
10093
if "interval_scaling" in config:
10194
assert len(config["schedulers"]) == len(
@@ -119,7 +112,7 @@ def from_config(cls, config: Dict[str, Any]) -> "CompositeParamScheduler":
119112
build_param_scheduler(scheduler) for scheduler in config["schedulers"]
120113
],
121114
lengths=config["lengths"],
122-
update_interval=update_interval,
115+
update_interval=UpdateInterval.from_config(config, UpdateInterval.STEP),
123116
interval_scaling=interval_scaling,
124117
)
125118

classy_vision/optim/param_scheduler/constant_scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import Any, Dict
88

9-
from . import ClassyParamScheduler, register_param_scheduler
9+
from . import ClassyParamScheduler, UpdateInterval, register_param_scheduler
1010

1111

1212
@register_param_scheduler("constant")
@@ -16,7 +16,7 @@ class ConstantParamScheduler(ClassyParamScheduler):
1616
"""
1717

1818
def __init__(self, value: float):
19-
super().__init__()
19+
super().__init__(update_interval=UpdateInterval.EPOCH)
2020
self._value = value
2121

2222
@classmethod

classy_vision/optim/param_scheduler/cosine_scheduler.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import math
88
from typing import Any, Dict
99

10-
from . import ClassyParamScheduler, register_param_scheduler
10+
from . import ClassyParamScheduler, UpdateInterval, register_param_scheduler
1111

1212

1313
@register_param_scheduler("cosine")
@@ -17,6 +17,7 @@ class CosineParamScheduler(ClassyParamScheduler):
1717
//arxiv.org/pdf/1608.03983.pdf>`_.
1818
Can be used for either cosine decay or cosine warmup schedules based on
1919
start and end values.
20+
The schedule is updated after every train step by default.
2021
2122
Example:
2223
@@ -26,8 +27,13 @@ class CosineParamScheduler(ClassyParamScheduler):
2627
end_value: 0.0001
2728
"""
2829

29-
def __init__(self, start_value: float, end_value: float):
30-
super().__init__()
30+
def __init__(
31+
self,
32+
start_value: float,
33+
end_value: float,
34+
update_interval: UpdateInterval = UpdateInterval.STEP,
35+
):
36+
super().__init__(update_interval=update_interval)
3137
self._start_value = start_value
3238
self._end_value = end_value
3339

@@ -46,7 +52,11 @@ def from_config(cls, config: Dict[str, Any]) -> "CosineParamScheduler":
4652
"start_value" in config and "end_value" in config
4753
), "Cosine scheduler requires a start_value and a end_value"
4854

49-
return cls(start_value=config["start_value"], end_value=config["end_value"])
55+
return cls(
56+
start_value=config["start_value"],
57+
end_value=config["end_value"],
58+
update_interval=UpdateInterval.from_config(config, UpdateInterval.STEP),
59+
)
5060

5161
def __call__(self, where: float):
5262
return self._end_value + 0.5 * (self._start_value - self._end_value) * (

classy_vision/optim/param_scheduler/linear_scheduler.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66

77
from typing import Any, Dict
88

9-
from . import ClassyParamScheduler, register_param_scheduler
9+
from . import ClassyParamScheduler, UpdateInterval, register_param_scheduler
1010

1111

1212
@register_param_scheduler("linear")
1313
class LinearParamScheduler(ClassyParamScheduler):
1414
"""
1515
Linearly interpolates parameter between ``start_value`` and ``end_value``.
1616
Can be used for either warmup or decay based on start and end values.
17+
The schedule is updated after every train step by default.
1718
1819
Example:
1920
@@ -24,8 +25,13 @@ class LinearParamScheduler(ClassyParamScheduler):
2425
Corresponds to a linear increasing schedule with values in [0.0001, 0.01)
2526
"""
2627

27-
def __init__(self, start_value: float, end_value: float):
28-
super().__init__()
28+
def __init__(
29+
self,
30+
start_value: float,
31+
end_value: float,
32+
update_interval: UpdateInterval = UpdateInterval.STEP,
33+
):
34+
super().__init__(update_interval=update_interval)
2935
self._start_value = start_value
3036
self._end_value = end_value
3137

@@ -43,7 +49,12 @@ def from_config(cls, config: Dict[str, Any]) -> "LinearParamScheduler":
4349
assert (
4450
"start_value" in config and "end_value" in config
4551
), "Linear scheduler requires a start and a end"
46-
return cls(start_value=config["start_value"], end_value=config["end_value"])
52+
53+
return cls(
54+
start_value=config["start_value"],
55+
end_value=config["end_value"],
56+
update_interval=UpdateInterval.from_config(config, UpdateInterval.STEP),
57+
)
4758

4859
def __call__(self, where: float):
4960
# interpolate between start and end values

classy_vision/optim/param_scheduler/multi_step_scheduler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class MultiStepParamScheduler(ClassyParamScheduler):
1818
"""
1919
Takes a predefined schedule for a param value, and a list of epochs
2020
which stand for the upper boundary (excluded) of each range.
21+
The schedule is updated after every train epoch by default.
2122
2223
Example:
2324
@@ -37,10 +38,10 @@ def __init__(
3738
self,
3839
values,
3940
num_epochs: int,
40-
update_interval: UpdateInterval,
4141
milestones: Optional[List[int]] = None,
42+
update_interval: UpdateInterval = UpdateInterval.EPOCH,
4243
):
43-
super().__init__(update_interval)
44+
super().__init__(update_interval=update_interval)
4445
self._param_schedule = values
4546
self._num_epochs = num_epochs
4647
self._milestones = milestones
@@ -96,11 +97,12 @@ def from_config(cls, config: Dict[str, Any]) -> "MultiStepParamScheduler":
9697
"Non-Equi Step scheduler requires a list of %d epochs"
9798
% (len(config["values"]) - 1)
9899
)
100+
99101
return cls(
100102
values=config["values"],
101103
num_epochs=config["num_epochs"],
102-
update_interval=UpdateInterval(config.get("update_interval", "epoch")),
103104
milestones=milestones,
105+
update_interval=UpdateInterval.from_config(config, UpdateInterval.EPOCH),
104106
)
105107

106108
def __call__(self, where: float):

classy_vision/optim/param_scheduler/polynomial_decay_scheduler.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66

77
from typing import Any, Dict
88

9-
from . import ClassyParamScheduler, register_param_scheduler
9+
from . import ClassyParamScheduler, UpdateInterval, register_param_scheduler
1010

1111

1212
@register_param_scheduler("polynomial")
1313
class PolynomialDecayParamScheduler(ClassyParamScheduler):
1414
"""
1515
Decays the param value after every epoch according to a
1616
polynomial function with a fixed power.
17+
The schedule is updated after every train step by default.
1718
1819
Example:
1920
@@ -26,8 +27,13 @@ class PolynomialDecayParamScheduler(ClassyParamScheduler):
2627
so on.
2728
"""
2829

29-
def __init__(self, base_value, power):
30-
super().__init__()
30+
def __init__(
31+
self,
32+
base_value: float,
33+
power: float,
34+
update_interval: UpdateInterval = UpdateInterval.STEP,
35+
):
36+
super().__init__(update_interval=update_interval)
3137

3238
self._base_value = base_value
3339
self._power = power
@@ -46,7 +52,11 @@ def from_config(cls, config: Dict[str, Any]) -> "PolynomialDecayParamScheduler":
4652
assert (
4753
"base_value" in config and "power" in config
4854
), "Polynomial decay scheduler requires a base lr and a power of decay"
49-
return cls(base_value=config["base_value"], power=config["power"])
55+
return cls(
56+
base_value=config["base_value"],
57+
power=config["power"],
58+
update_interval=UpdateInterval.from_config(config, UpdateInterval.STEP),
59+
)
5060

5161
def __call__(self, where: float):
5262
return self._base_value * (1 - where) ** self._power

classy_vision/optim/param_scheduler/step_scheduler.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import Any, Dict, List, NamedTuple, Optional, Union
88

9-
from . import ClassyParamScheduler, register_param_scheduler
9+
from . import ClassyParamScheduler, UpdateInterval, register_param_scheduler
1010

1111

1212
@register_param_scheduler("step")
@@ -15,6 +15,7 @@ class StepParamScheduler(ClassyParamScheduler):
1515
Takes a fixed schedule for a param value. If the length of the
1616
fixed schedule is less than the number of epochs, then the epochs
1717
are divided evenly among the param schedule.
18+
The schedule is updated after every train epoch by default.
1819
1920
Example:
2021
@@ -27,8 +28,13 @@ class StepParamScheduler(ClassyParamScheduler):
2728
epochs 30-59, 0.001 for epoch 60-89, 0.0001 for epochs 90-119.
2829
"""
2930

30-
def __init__(self, num_epochs: Union[int, float], values: List[float]):
31-
super().__init__()
31+
def __init__(
32+
self,
33+
num_epochs: Union[int, float],
34+
values: List[float],
35+
update_interval: UpdateInterval = UpdateInterval.EPOCH,
36+
):
37+
super().__init__(update_interval=update_interval)
3238

3339
self._param_schedule = values
3440

@@ -50,7 +56,11 @@ def from_config(cls, config: Dict[str, Any]) -> "StepParamScheduler":
5056
), "Step scheduler requires a list of at least one param value"
5157
assert config["num_epochs"] > 0, "Num epochs must be greater than 0"
5258

53-
return cls(num_epochs=config["num_epochs"], values=config["values"])
59+
return cls(
60+
num_epochs=config["num_epochs"],
61+
values=config["values"],
62+
update_interval=UpdateInterval.from_config(config, UpdateInterval.EPOCH),
63+
)
5464

5565
def __call__(self, where: float):
5666
ind = int((where + self.WHERE_EPSILON) * len(self._param_schedule))

classy_vision/optim/param_scheduler/step_with_fixed_gamma_scheduler.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class StepWithFixedGammaParamScheduler(ClassyParamScheduler):
1515
"""
1616
Decays the param value by gamma at equal number of steps so as to have the
1717
specified total number of decays.
18+
The schedule is updated after every train step by default.
1819
1920
Example:
2021
@@ -29,6 +30,31 @@ class StepWithFixedGammaParamScheduler(ClassyParamScheduler):
2930
epochs 30-59, 0.001 for epoch 60-89, 0.0001 for epochs 90-119.
3031
"""
3132

33+
def __init__(
34+
self,
35+
base_value: float,
36+
num_decays: int,
37+
gamma: float,
38+
num_epochs: int,
39+
update_interval: UpdateInterval = UpdateInterval.STEP,
40+
):
41+
super().__init__(update_interval=update_interval)
42+
43+
self.base_value = base_value
44+
self.num_decays = num_decays
45+
self.gamma = gamma
46+
self.num_epochs = num_epochs
47+
values = [base_value]
48+
for _ in range(num_decays):
49+
values.append(values[-1] * gamma)
50+
51+
self._step_param_scheduler = StepParamScheduler(
52+
num_epochs=num_epochs, values=values
53+
)
54+
55+
# make this a STEP scheduler
56+
self.update_interval = UpdateInterval.STEP
57+
3258
@classmethod
3359
def from_config(cls, config: Dict[str, Any]) -> "StepWithFixedGammaParamScheduler":
3460
"""Instantiates a StepWithFixedGammaParamScheduler from a configuration.
@@ -56,25 +82,8 @@ def from_config(cls, config: Dict[str, Any]) -> "StepWithFixedGammaParamSchedule
5682
num_decays=config["num_decays"],
5783
gamma=config["gamma"],
5884
num_epochs=config["num_epochs"],
85+
update_interval=UpdateInterval.from_config(config, UpdateInterval.STEP),
5986
)
6087

61-
def __init__(self, base_value, num_decays, gamma, num_epochs):
62-
super().__init__()
63-
64-
self.base_value = base_value
65-
self.num_decays = num_decays
66-
self.gamma = gamma
67-
self.num_epochs = num_epochs
68-
values = [base_value]
69-
for _ in range(num_decays):
70-
values.append(values[-1] * gamma)
71-
72-
self._step_param_scheduler = StepParamScheduler(
73-
num_epochs=num_epochs, values=values
74-
)
75-
76-
# make this a STEP scheduler
77-
self.update_interval = UpdateInterval.STEP
78-
7988
def __call__(self, where: float) -> float:
8089
return self._step_param_scheduler(where)

test/optim_param_scheduler_composite_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_invalid_config(self):
9393
# Bad value for update_interval
9494
bad_config["lengths"] = copy.deepcopy(config["lengths"])
9595
bad_config["update_interval"] = "epochs"
96-
with self.assertRaises(AssertionError):
96+
with self.assertRaises(Exception):
9797
CompositeParamScheduler.from_config(bad_config)
9898

9999
# Bad value for composition_mode

0 commit comments

Comments
 (0)