Skip to content

Commit d503e8f

Browse files
feat: aws sagemaker compatible image (#147)
The only difference is that now it pushes to registry.internal.huggingface.tech/api-inference/community/text-generation-inference/sagemaker:... instead of registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sagemaker-... --------- Co-authored-by: Philipp Schmid <[email protected]>
1 parent c9bdaa8 commit d503e8f

File tree

4 files changed

+86
-3
lines changed

4 files changed

+86
-3
lines changed

.github/workflows/build.yaml

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,48 @@ jobs:
8383
tags: ${{ steps.meta.outputs.tags }}
8484
labels: ${{ steps.meta.outputs.labels }}
8585
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
86-
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
86+
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
87+
88+
build-and-push-sagemaker-image:
89+
needs:
90+
- build-and-push-image
91+
runs-on: ubuntu-latest
92+
steps:
93+
- name: Initialize Docker Buildx
94+
uses: docker/[email protected]
95+
with:
96+
install: true
97+
- name: Checkout repository
98+
uses: actions/checkout@v3
99+
- name: Inject slug/short variables
100+
uses: rlespinasse/github-slug-action@v4
101+
- name: Login to internal Container Registry
102+
uses: docker/[email protected]
103+
with:
104+
username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }}
105+
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
106+
registry: registry.internal.huggingface.tech
107+
- name: Extract metadata (tags, labels) for Docker
108+
id: meta
109+
uses: docker/[email protected]
110+
with:
111+
flavor: |
112+
latest=auto
113+
images: |
114+
registry.internal.huggingface.tech/api-inference/community/text-generation-inference/sagemaker
115+
tags: |
116+
type=semver,pattern={{version}}
117+
type=semver,pattern={{major}}.{{minor}}
118+
type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
119+
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}
120+
- name: Build and push Docker image
121+
uses: docker/build-push-action@v2
122+
with:
123+
context: .
124+
file: Dockerfile
125+
push: ${{ github.event_name != 'pull_request' }}
126+
platforms: 'linux/amd64'
127+
target: sagemaker
128+
tags: ${{ steps.meta.outputs.tags }}
129+
labels: ${{ steps.meta.outputs.labels }}
130+
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max

Dockerfile

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ COPY router router
2727
COPY launcher launcher
2828
RUN cargo build --release
2929

30-
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
30+
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as base
3131

3232
ENV LANG=C.UTF-8 \
3333
LC_ALL=C.UTF-8 \
@@ -76,5 +76,16 @@ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bi
7676
# Install launcher
7777
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
7878

79+
# AWS Sagemaker compatbile image
80+
FROM base as sagemaker
81+
82+
COPY sagemaker-entrypoint.sh entrypoint.sh
83+
RUN chmod +x entrypoint.sh
84+
85+
ENTRYPOINT ["./entrypoint.sh"]
86+
87+
# Original image
88+
FROM base
89+
7990
ENTRYPOINT ["text-generation-launcher"]
8091
CMD ["--json-output"]

router/src/server.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,11 +529,19 @@ pub async fn run(
529529
// Create router
530530
let app = Router::new()
531531
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
532+
// Base routes
532533
.route("/", post(compat_generate))
533534
.route("/generate", post(generate))
534535
.route("/generate_stream", post(generate_stream))
535-
.route("/", get(health))
536+
// AWS Sagemaker route
537+
.route("/invocations", post(compat_generate))
538+
// Base Health route
536539
.route("/health", get(health))
540+
// Inference API health route
541+
.route("/", get(health))
542+
// AWS Sagemaker health route
543+
.route("/ping", get(health))
544+
// Prometheus metrics route
537545
.route("/metrics", get(metrics))
538546
.layer(Extension(compat_return_full_text))
539547
.layer(Extension(infer))

sagemaker-entrypoint.sh

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/bin/bash
2+
3+
if [[ -z "${HF_MODEL_ID}" ]]; then
4+
echo "HF_MODEL_ID must be set"
5+
exit 1
6+
fi
7+
8+
if [[ -n "${HF_MODEL_REVISION}" ]]; then
9+
export REVISION="${HF_MODEL_REVISION}"
10+
fi
11+
12+
if [[ -n "${SM_NUM_GPUS}" ]]; then
13+
export NUM_SHARD="${SM_NUM_GPUS}"
14+
fi
15+
16+
if [[ -n "${HF_MODEL_QUANTIZE}" ]]; then
17+
export QUANTIZE="${HF_MODEL_QUANTIZE}"
18+
fi
19+
20+
text-generation-launcher --port 8080

0 commit comments

Comments
 (0)