diff --git a/replicate/use.py b/replicate/use.py index 596674e..28d94c6 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -210,6 +210,12 @@ def _resolve_ref(obj: Any) -> Any: return result +def _log_prediction_url(id: str) -> None: + if os.environ.get("R8_LOG_PREDICTION_URL") != "1": + return + print(f"Running prediction https://replicate.com/p/{id}") + + T = TypeVar("T") @@ -436,6 +442,8 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]: model=self._model, input=processed_inputs ) + _log_prediction_url(prediction.id) + return Run( prediction=prediction, schema=self.openapi_schema(), @@ -649,6 +657,8 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu model=model, input=processed_inputs ) + _log_prediction_url(prediction.id) + return AsyncRun( prediction=prediction, schema=await self.openapi_schema(),