Skip to content

Commit 30dfaeb

Browse files
authored
Populate training destinations from output (#311)
Resolves #293 Signed-off-by: Mattt Zmuda <[email protected]>
1 parent d8d6a4e commit 30dfaeb

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

replicate/training.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,4 +418,14 @@ def _create_training_url_from_model_and_version(
418418
def _json_to_training(client: "Client", json: Dict[str, Any]) -> Training:
419419
training = Training(**json)
420420
training._client = client
421+
422+
# FIXME: This should be populated by the API
423+
if (
424+
training.output
425+
and isinstance(training.output, dict)
426+
and "version" in training.output
427+
):
428+
id = ModelVersionIdentifier.parse(training.output["version"])
429+
training.destination = f"{id.owner}/{id.name}"
430+
421431
return training

tests/test_training.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import httpx
12
import pytest
3+
import respx
24

35
import replicate
46
from replicate.exceptions import ReplicateException
@@ -188,3 +190,54 @@ async def test_trainings_cancel_instance_method(async_flag, mock_replicate_api_t
188190

189191
training.cancel()
190192
assert training.status == "canceled"
193+
194+
195+
router = respx.Router(base_url="https://api.replicate.com/v1")
196+
197+
router.route(
198+
method="GET",
199+
path="/trainings/zz4ibbonubfz7carwiefibzgga",
200+
name="trainings.get",
201+
).mock(
202+
return_value=httpx.Response(
203+
201,
204+
json={
205+
"completed_at": "2023-09-08T16:41:19.826523Z",
206+
"created_at": "2023-09-08T16:32:57.018467Z",
207+
"error": None,
208+
"id": "zz4ibbonubfz7carwiefibzgga",
209+
"input": {"input_images": "https://example.com/my-input-images.zip"},
210+
"logs": "...",
211+
"metrics": {"predict_time": 502.713876},
212+
"output": {
213+
"version": "replicate/my-app-image-generator:8a43525956ef4039702e509c789964a7ea873697be9033abf9fd2badfe68c9e3",
214+
"weights": "https://weights.replicate.com/example.tar",
215+
},
216+
"started_at": "2023-09-08T16:32:57.112647Z",
217+
"status": "succeeded",
218+
"urls": {
219+
"get": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga",
220+
"cancel": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel",
221+
},
222+
"model": "stability-ai/sdxl",
223+
"version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
224+
},
225+
)
226+
)
227+
228+
router.route(host="api.replicate.com").pass_through()
229+
230+
231+
@pytest.mark.asyncio
232+
@pytest.mark.parametrize("async_flag", [True, False])
233+
async def test_training_gets_destination_from_output(async_flag):
234+
client = replicate.Client(
235+
api_token="test-token", transport=httpx.MockTransport(router.handler)
236+
)
237+
238+
if async_flag:
239+
training = await client.trainings.async_get("zz4ibbonubfz7carwiefibzgga")
240+
else:
241+
training = client.trainings.get("zz4ibbonubfz7carwiefibzgga")
242+
243+
assert training.destination == "replicate/my-app-image-generator"

0 commit comments

Comments
 (0)