Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,48 @@ jobs:
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max

build-and-push-sagemaker-image:
needs:
- build-and-push-image
runs-on: ubuntu-latest
steps:
- name: Initialize Docker Buildx
uses: docker/[email protected]
with:
install: true
- name: Checkout repository
uses: actions/checkout@v3
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4
- name: Login to internal Container Registry
uses: docker/[email protected]
with:
username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }}
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
registry: registry.internal.huggingface.tech
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/[email protected]
with:
flavor: |
latest=auto
images: |
registry.internal.huggingface.tech/api-inference/community/text-generation-inference/sagemaker
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}
- name: Build and push Docker image
uses: docker/build-push-action@v2
with:
context: .
file: Dockerfile
push: ${{ github.event_name != 'pull_request' }}
platforms: 'linux/amd64'
target: sagemaker
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
13 changes: 12 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ COPY router router
COPY launcher launcher
RUN cargo build --release

FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as base

ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
Expand Down Expand Up @@ -76,5 +76,16 @@ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bi
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher

# AWS Sagemaker compatbile image
FROM base as sagemaker

COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh

ENTRYPOINT ["./entrypoint.sh"]

# Original image
FROM base

ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]
10 changes: 9 additions & 1 deletion router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,11 +529,19 @@ pub async fn run(
// Create router
let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
// Base routes
.route("/", post(compat_generate))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/", get(health))
// AWS Sagemaker route
.route("/invocations", post(compat_generate))
// Base Health route
Comment on lines +536 to +538
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we an a cfg for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would complexify things for no real gains. Maybe the runtime RAM would be a bit lower but I don't think it matters too much compared to the python servers used RAM.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was more thinking about the future, what if add routes for X and Y. But feel free to keep it as it is :)

.route("/health", get(health))
// Inference API health route
.route("/", get(health))
// AWS Sagemaker health route
.route("/ping", get(health))
// Prometheus metrics route
.route("/metrics", get(metrics))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))
Expand Down
20 changes: 20 additions & 0 deletions sagemaker-entrypoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

if [[ -z "${HF_MODEL_ID}" ]]; then
echo "HF_MODEL_ID must be set"
exit 1
fi

if [[ -n "${HF_MODEL_REVISION}" ]]; then
export REVISION="${HF_MODEL_REVISION}"
fi

if [[ -n "${SM_NUM_GPUS}" ]]; then
export NUM_SHARD="${SM_NUM_GPUS}"
fi

if [[ -n "${HF_MODEL_QUANTIZE}" ]]; then
export QUANTIZE="${HF_MODEL_QUANTIZE}"
fi

text-generation-launcher --port 8080