diff --git a/docs/source/conf.py b/docs/source/conf.py index fa3558ab7968c..0e8af52262354 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -295,7 +295,8 @@ def setup(app): MOCK_REQUIRE_PACKAGES.append(pkg.rstrip()) # TODO: better parse from package since the import name and package name may differ -MOCK_MANUAL_PACKAGES = ['torch', 'torchvision', 'sklearn', 'test_tube', 'mlflow', 'comet_ml', 'wandb', 'neptune'] +MOCK_MANUAL_PACKAGES = ['torch', 'torchvision', 'sklearn', 'test_tube', 'mlflow', 'comet_ml', 'wandb', 'neptune', + 'sacred'] autodoc_mock_imports = MOCK_REQUIRE_PACKAGES + MOCK_MANUAL_PACKAGES # for mod_name in MOCK_REQUIRE_PACKAGES: # sys.modules[mod_name] = mock.Mock() diff --git a/pytorch_lightning/logging/__init__.py b/pytorch_lightning/logging/__init__.py index 5fbb93cddc14d..4fcdeab038202 100644 --- a/pytorch_lightning/logging/__init__.py +++ b/pytorch_lightning/logging/__init__.py @@ -114,4 +114,10 @@ def any_lightning_module_function_or_hook(...): except ImportError: pass +try: + from .sacred import SacredLogger + all.append("SacredLogger") +except ImportError: + pass + __all__ = all diff --git a/pytorch_lightning/logging/sacred.py b/pytorch_lightning/logging/sacred.py new file mode 100644 index 0000000000000..76e11a0f185b3 --- /dev/null +++ b/pytorch_lightning/logging/sacred.py @@ -0,0 +1,78 @@ +""" +Log using `sacred '_ +.. code-block:: python + from pytorch_lightning.logging import SacredLogger + ex = Experiment() # initialize however you like + ex.main(your_main_fct) + ex.observers.append( + # add any observer you like + ) + sacred_logger = SacredLogger(ex) + trainer = Trainer(logger=sacred_logger) +Use the logger anywhere in you LightningModule as follows: +.. code-block:: python + def train_step(...): + # example + self.logger.experiment.whatever_sacred_supports(...) + def any_lightning_module_function_or_hook(...): + self.logger.experiment.whatever_sacred_supports(...) +""" + +from logging import getLogger + +try: + import sacred +except ImportError: + raise ImportError('Missing sacred package. Run `pip install sacred`') + +from pytorch_lightning.logging.base import LightningLoggerBase, rank_zero_only + +logger = getLogger(__name__) + + +class SacredLogger(LightningLoggerBase): + def __init__(self, sacred_experiment): + """Initialize a sacred logger. + + :param sacred.experiment.Experiment sacred_experiment: Required. Experiment object with desired observers + already appended. + """ + super().__init__() + self.sacred_experiment = sacred_experiment + self.experiment_name = sacred_experiment.path + self._run_id = None + + @property + def experiment(self): + return self.sacred_experiment + + @property + def run_id(self): + if self._run_id is not None: + return self._run_id + + self._run_id = self.sacred_experiment.current_run._id + return self._run_id + + @rank_zero_only + def log_hyperparams(self, params): + # probably not needed bc. it is dealt with by sacred + pass + + @rank_zero_only + def log_metrics(self, metrics, step=None): + for k, v in metrics.items(): + if isinstance(v, str): + logger.warning( + f"Discarding metric with string value {k}={v}" + ) + continue + self.experiment.log_scalar(k, v, step) + + @property + def name(self): + return self.experiment_name + + @property + def version(self): + return self.run_id diff --git a/tests/test_logging.py b/tests/test_logging.py index 5467d0aab3c1c..a8602a0a9a626 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -386,3 +386,45 @@ def version(self): assert logger.hparams_logged == hparams assert logger.metrics_logged != {} assert logger.finalized_status == "success" + + +def test_sacred_logger(tmpdir): + """Verify that basic functionality of sacred logger works.""" + tutils.reset_seed() + + try: + from pytorch_lightning.logging import SacredLogger + except ModuleNotFoundError: + return + + try: + from sacred import Experiment + except ModuleNotFoundError: + return + + hparams = tutils.get_hparams() + model = LightningTestModel(hparams) + sacred_dir = os.path.join(tmpdir, "sacredruns") + + ex = Experiment() + ex_config = vars(hparams) + ex.add_config(ex_config) + + @ex.main + def run_fct(): + logger = SacredLogger(ex) + + trainer_options = dict( + default_save_path=sacred_dir, + max_epochs=1, + train_percent_check=0.01, + logger=logger + ) + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + return result + + result = ex.run() + + print('result finished') + assert result.status == "COMPLETED", "Training failed"