|
| 1 | +import httpx |
1 | 2 | import pytest
|
| 3 | +import respx |
2 | 4 |
|
3 | 5 | import replicate
|
4 | 6 | from replicate.exceptions import ReplicateException
|
@@ -188,3 +190,54 @@ async def test_trainings_cancel_instance_method(async_flag, mock_replicate_api_t
|
188 | 190 |
|
189 | 191 | training.cancel()
|
190 | 192 | assert training.status == "canceled"
|
| 193 | + |
| 194 | + |
| 195 | +router = respx.Router(base_url="https://api.replicate.com/v1") |
| 196 | + |
| 197 | +router.route( |
| 198 | + method="GET", |
| 199 | + path="/trainings/zz4ibbonubfz7carwiefibzgga", |
| 200 | + name="trainings.get", |
| 201 | +).mock( |
| 202 | + return_value=httpx.Response( |
| 203 | + 201, |
| 204 | + json={ |
| 205 | + "completed_at": "2023-09-08T16:41:19.826523Z", |
| 206 | + "created_at": "2023-09-08T16:32:57.018467Z", |
| 207 | + "error": None, |
| 208 | + "id": "zz4ibbonubfz7carwiefibzgga", |
| 209 | + "input": {"input_images": "https://example.com/my-input-images.zip"}, |
| 210 | + "logs": "...", |
| 211 | + "metrics": {"predict_time": 502.713876}, |
| 212 | + "output": { |
| 213 | + "version": "replicate/my-app-image-generator:8a43525956ef4039702e509c789964a7ea873697be9033abf9fd2badfe68c9e3", |
| 214 | + "weights": "https://weights.replicate.com/example.tar", |
| 215 | + }, |
| 216 | + "started_at": "2023-09-08T16:32:57.112647Z", |
| 217 | + "status": "succeeded", |
| 218 | + "urls": { |
| 219 | + "get": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga", |
| 220 | + "cancel": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel", |
| 221 | + }, |
| 222 | + "model": "stability-ai/sdxl", |
| 223 | + "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", |
| 224 | + }, |
| 225 | + ) |
| 226 | +) |
| 227 | + |
| 228 | +router.route(host="api.replicate.com").pass_through() |
| 229 | + |
| 230 | + |
| 231 | +@pytest.mark.asyncio |
| 232 | +@pytest.mark.parametrize("async_flag", [True, False]) |
| 233 | +async def test_training_gets_destination_from_output(async_flag): |
| 234 | + client = replicate.Client( |
| 235 | + api_token="test-token", transport=httpx.MockTransport(router.handler) |
| 236 | + ) |
| 237 | + |
| 238 | + if async_flag: |
| 239 | + training = await client.trainings.async_get("zz4ibbonubfz7carwiefibzgga") |
| 240 | + else: |
| 241 | + training = client.trainings.get("zz4ibbonubfz7carwiefibzgga") |
| 242 | + |
| 243 | + assert training.destination == "replicate/my-app-image-generator" |
0 commit comments