diff --git a/python/cog/server/eventtypes.py b/python/cog/server/eventtypes.py index da7925ea17..2ed5a69a89 100644 --- a/python/cog/server/eventtypes.py +++ b/python/cog/server/eventtypes.py @@ -32,7 +32,7 @@ class Log: @define class PredictionMetric: name: str - value: Union[float, int] + value: Optional[Union[bool, float, int, str]] @define diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index e7055abb27..dbb099a47e 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -390,10 +390,15 @@ def append_logs(self, logs: str) -> None: self._p.logs += logs self._send_webhook(schema.WebhookEvent.LOGS) - def set_metric(self, key: str, value: Union[float, int]) -> None: + def set_metric(self, key: str, value: Optional[Union[bool, float, int, str]]) -> None: if self._p.metrics is None: self._p.metrics = {} - self._p.metrics[key] = value + + if value is None: + if key in self._p.metrics: + del self._p.metrics[key] + else: + self._p.metrics[key] = value def succeeded(self) -> None: self._log.info(("prediction" if not self._is_train else "train") + " succeeded") diff --git a/python/cog/server/scope.py b/python/cog/server/scope.py index 54a1f8b014..9b6052d814 100644 --- a/python/cog/server/scope.py +++ b/python/cog/server/scope.py @@ -10,7 +10,7 @@ @frozen class Scope: - record_metric: Callable[[str, Union[float, int]], None] + record_metric: Callable[[str, Optional[Union[bool, float, int, str]]], None] context: Dict[str, str] = {} _tag: Optional[str] = None diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 963dd83a9d..94537f8a92 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -489,7 +489,7 @@ def send_cancel_signal(self) -> None: if self.is_alive() and self.pid: os.kill(self.pid, signal.SIGUSR1) - def record_metric(self, name: str, value: Union[float, int]) -> None: + def record_metric(self, name: str, value: Optional[Union[bool, float, int, str]]) -> None: self._events.send( Envelope(PredictionMetric(name, value), tag=self._current_tag) ) diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index ca5ce910b9..069f63c784 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -194,7 +194,12 @@ def handle_event(self, event: _PublicEventType): elif isinstance(event, PredictionMetric): if self.metrics is None: self.metrics = {} - self.metrics[event.name] = event.value + + if event.value is None: + if event.name in self.metrics: + del self.metrics[event.name] + else: + self.metrics[event.name] = event.value elif isinstance(event, PredictionOutput): assert self.output_type, "Should get output type before any output" if self.output_type.multi: