-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Sacred logger #656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sacred logger #656
Changes from all commits
5e1869d
688818c
62845e1
2d662c6
a4f709f
8bceb01
26386e7
e65defd
a63ae05
30eb667
bfd6a80
7bbf188
64fda62
d45b554
40343a2
fff9231
969c58c
91e8b9a
6b8b1f3
f6259e7
0cb66bb
8bf8402
a20f670
c743ea5
6b3702c
36c72f6
e84cbd9
630ddec
4974847
9fcb238
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
""" | ||
Log using `sacred <https://sacred.readthedocs.io/en/stable/index.html>'_ | ||
.. code-block:: python | ||
from pytorch_lightning.logging import SacredLogger | ||
ex = Experiment() # initialize however you like | ||
ex.main(your_main_fct) | ||
ex.observers.append( | ||
# add any observer you like | ||
) | ||
sacred_logger = SacredLogger(ex) | ||
trainer = Trainer(logger=sacred_logger) | ||
Use the logger anywhere in you LightningModule as follows: | ||
.. code-block:: python | ||
def train_step(...): | ||
# example | ||
self.logger.experiment.whatever_sacred_supports(...) | ||
def any_lightning_module_function_or_hook(...): | ||
self.logger.experiment.whatever_sacred_supports(...) | ||
""" | ||
|
||
from logging import getLogger | ||
|
||
try: | ||
import sacred | ||
except ImportError: | ||
raise ImportError('Missing sacred package. Run `pip install sacred`') | ||
|
||
from pytorch_lightning.logging.base import LightningLoggerBase, rank_zero_only | ||
|
||
logger = getLogger(__name__) | ||
|
||
|
||
class SacredLogger(LightningLoggerBase): | ||
expectopatronum marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, sacred_experiment): | ||
"""Initialize a sacred logger. | ||
|
||
:param sacred.experiment.Experiment sacred_experiment: Required. Experiment object with desired observers | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use the updated docs formating |
||
already appended. | ||
""" | ||
super().__init__() | ||
self.sacred_experiment = sacred_experiment | ||
self.experiment_name = sacred_experiment.path | ||
self._run_id = None | ||
|
||
@property | ||
def experiment(self): | ||
return self.sacred_experiment | ||
|
||
@property | ||
def run_id(self): | ||
if self._run_id is not None: | ||
return self._run_id | ||
|
||
self._run_id = self.sacred_experiment.current_run._id | ||
return self._run_id | ||
|
||
@rank_zero_only | ||
def log_hyperparams(self, params): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just wanted to share my workflow for your consideration. I use Sacred for general configuration, then Ax for hyperparameter optimization. There is one Sacred experiment and one logger during the whole process, but trainer.fit() is called many times by Ax with different hyperparameters. Logging these in log_hyperparams somehow may be beneficial. |
||
# probably not needed bc. it is dealt with by sacred | ||
pass | ||
|
||
@rank_zero_only | ||
def log_metrics(self, metrics, step=None): | ||
for k, v in metrics.items(): | ||
if isinstance(v, str): | ||
logger.warning( | ||
f"Discarding metric with string value {k}={v}" | ||
) | ||
continue | ||
self.experiment.log_scalar(k, v, step) | ||
|
||
@property | ||
def name(self): | ||
return self.experiment_name | ||
|
||
@property | ||
def version(self): | ||
return self.run_id |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -386,3 +386,45 @@ def version(self): | |
assert logger.hparams_logged == hparams | ||
assert logger.metrics_logged != {} | ||
assert logger.finalized_status == "success" | ||
|
||
|
||
def test_sacred_logger(tmpdir): | ||
"""Verify that basic functionality of sacred logger works.""" | ||
tutils.reset_seed() | ||
|
||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should not be here, the logger shall be tested |
||
from pytorch_lightning.logging import SacredLogger | ||
except ModuleNotFoundError: | ||
return | ||
|
||
try: | ||
from sacred import Experiment | ||
except ModuleNotFoundError: | ||
return | ||
|
||
hparams = tutils.get_hparams() | ||
model = LightningTestModel(hparams) | ||
sacred_dir = os.path.join(tmpdir, "sacredruns") | ||
|
||
ex = Experiment() | ||
ex_config = vars(hparams) | ||
ex.add_config(ex_config) | ||
|
||
@ex.main | ||
def run_fct(): | ||
logger = SacredLogger(ex) | ||
|
||
trainer_options = dict( | ||
default_save_path=sacred_dir, | ||
max_epochs=1, | ||
train_percent_check=0.01, | ||
logger=logger | ||
) | ||
trainer = Trainer(**trainer_options) | ||
result = trainer.fit(model) | ||
return result | ||
|
||
result = ex.run() | ||
|
||
print('result finished') | ||
assert result.status == "COMPLETED", "Training failed" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls check the dos visage