diff --git a/examples/hello-context/.dockerignore b/examples/hello-context/.dockerignore new file mode 100644 index 0000000000..1d4c71fdac --- /dev/null +++ b/examples/hello-context/.dockerignore @@ -0,0 +1,20 @@ +# The .dockerignore file excludes files from the container build process. +# +# https://docs.docker.com/engine/reference/builder/#dockerignore-file + +# Exclude Git files +**/.git +**/.github +**/.gitignore + +# Exclude Python tooling +.python-version + +# Exclude Python cache files +__pycache__ +.mypy_cache +.pytest_cache +.ruff_cache + +# Exclude Python virtual environment +/venv diff --git a/examples/hello-context/README.md b/examples/hello-context/README.md new file mode 100644 index 0000000000..dbf747fef5 --- /dev/null +++ b/examples/hello-context/README.md @@ -0,0 +1,4 @@ +hello-context +------------- + +A simple model that takes no inputs but will echo back any context provided with the prediction as the output. diff --git a/examples/hello-context/cog.yaml b/examples/hello-context/cog.yaml new file mode 100644 index 0000000000..0107326ca8 --- /dev/null +++ b/examples/hello-context/cog.yaml @@ -0,0 +1,28 @@ +# Configuration for Cog ⚙️ +# Reference: https://cog.run/yaml + +build: + # set to true if your model requires a GPU + gpu: false + + # a list of ubuntu apt packages to install + # system_packages: + # - "libgl1-mesa-glx" + # - "libglib2.0-0" + + # python version in the form '3.11' or '3.11.4' + python_version: "3.11" + + # path to a Python requirements.txt file + python_requirements: requirements.txt + + # enable fast boots + fast: true + + # commands run after the environment is setup + # run: + # - "echo env is ready!" + # - "echo another command if needed" + +# predict.py defines how predictions are run on your model +predict: "predict.py:run" diff --git a/examples/hello-context/predict.py b/examples/hello-context/predict.py new file mode 100644 index 0000000000..9e6031fd8a --- /dev/null +++ b/examples/hello-context/predict.py @@ -0,0 +1,5 @@ +from cog import current_scope + + +def run() -> dict[str, str]: + return current_scope().context diff --git a/examples/hello-context/requirements.txt b/examples/hello-context/requirements.txt new file mode 100644 index 0000000000..ea75294519 --- /dev/null +++ b/examples/hello-context/requirements.txt @@ -0,0 +1,23 @@ +# This is a normal Python requirements.txt file. + +# You can add dependencies directly from PyPI: +# +# numpy==1.26.4 +# torch==2.2.1 +# torchvision==0.17.1 + + +# You can also add Git repos as dependencies, but you'll need to add git to the system_packages list in cog.yaml: +# +# build: +# system_packages: +# - "git" +# +# Then you can use a URL like this: +# +# git+https://github.com/huggingface/transformers + + +# You can also pin Git repos to a specific commit: +# +# git+https://github.com/huggingface/transformers@2d1602a diff --git a/examples/hello-procedure/README.md b/examples/hello-procedure/README.md new file mode 100644 index 0000000000..5e879f5599 --- /dev/null +++ b/examples/hello-procedure/README.md @@ -0,0 +1,22 @@ +# hello + +A simple pipeline that transforms your text input by converting it to uppercase and prefixing it with "HELLO". + +https://replicate.com/pipelines-beta/hello + +## Features + +- Converts any text input to uppercase +- Adds a friendly "HELLO" prefix to your text +- Simple, single-input interface + +## Models + +Under the hood it uses these models: + +- [pipelines-beta/upcase](https://replicate.com/pipelines-beta/upcase): A utility model that converts text to uppercase + +## How it works + +The pipeline takes a text prompt as input, passes it to the `upcase` model to convert the text to uppercase, and then adds "HELLO" as a prefix to the transformed text. This creates a greeting-style output from any input text. +Edit model diff --git a/examples/hello-procedure/cog.yaml b/examples/hello-procedure/cog.yaml new file mode 100644 index 0000000000..b8fac5970a --- /dev/null +++ b/examples/hello-procedure/cog.yaml @@ -0,0 +1,4 @@ +predict: "function.py:run" +build: + system_packages: + - tini diff --git a/examples/hello-procedure/function.py b/examples/hello-procedure/function.py new file mode 100644 index 0000000000..49a1905a1c --- /dev/null +++ b/examples/hello-procedure/function.py @@ -0,0 +1,12 @@ +from cog import Input +from cog.ext.pipelines import include + +# with run_state("load"): +upcase = include("pipelines-beta/upcase") + + +def run( + prompt: str = Input(), +) -> str: + upcased_prompt = upcase(prompt=prompt) + return f"HELLO {upcased_prompt}" diff --git a/go.mod b/go.mod index ff267f6647..19ec2b4b69 100644 --- a/go.mod +++ b/go.mod @@ -7,18 +7,22 @@ require ( github.com/aws/aws-sdk-go-v2 v1.36.3 github.com/aws/aws-sdk-go-v2/credentials v1.17.67 github.com/aws/aws-sdk-go-v2/service/s3 v1.79.3 + github.com/containerd/containerd/api v1.8.0 github.com/creack/pty v1.1.24 + github.com/distribution/reference v0.6.0 github.com/docker/cli v28.1.1+incompatible github.com/docker/docker v28.1.1+incompatible github.com/docker/go-connections v0.5.0 github.com/getkin/kin-openapi v0.128.0 github.com/google/go-containerregistry v0.20.5 + github.com/google/uuid v1.6.0 github.com/hashicorp/go-version v1.7.0 github.com/logrusorgru/aurora v2.0.3+incompatible github.com/mattn/go-isatty v0.0.20 github.com/mitchellh/go-homedir v1.1.0 github.com/moby/buildkit v0.22.0 github.com/moby/term v0.5.2 + github.com/opencontainers/go-digest v1.0.0 github.com/opencontainers/image-spec v1.1.1 github.com/pkg/errors v0.9.1 github.com/replicate/go v0.0.0-20250205165008-b772d7cd506b @@ -28,6 +32,7 @@ require ( github.com/stretchr/testify v1.10.0 github.com/testcontainers/testcontainers-go v0.37.0 github.com/testcontainers/testcontainers-go/modules/registry v0.37.0 + github.com/tonistiigi/fsutil v0.0.0-20250417144416-3f76f8130144 github.com/tonistiigi/go-csvvalue v0.0.0-20240710180619-ddb21b71c0b4 github.com/vbauerster/mpb/v8 v8.10.1 github.com/vincent-petithory/dataurl v1.0.0 @@ -97,7 +102,6 @@ require ( github.com/chavacava/garif v0.1.0 // indirect github.com/ckaznocha/intrange v0.3.0 // indirect github.com/containerd/console v1.0.4 // indirect - github.com/containerd/containerd/api v1.8.0 // indirect github.com/containerd/containerd/v2 v2.0.5 // indirect github.com/containerd/continuity v0.4.5 // indirect github.com/containerd/errdefs v1.0.0 // indirect @@ -112,7 +116,6 @@ require ( github.com/daixiang0/gci v0.13.5 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/denis-tingaikin/go-header v0.5.0 // indirect - github.com/distribution/reference v0.6.0 // indirect github.com/dnephin/pflag v1.0.7 // indirect github.com/docker/distribution v2.8.3+incompatible // indirect github.com/docker/docker-credential-helpers v0.9.3 // indirect @@ -156,7 +159,6 @@ require ( github.com/golangci/unconvert v0.0.0-20240309020433-c5143eacb3ed // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/gordonklaus/ineffassign v0.1.0 // indirect github.com/gostaticanalysis/analysisutil v0.7.1 // indirect github.com/gostaticanalysis/comment v1.5.0 // indirect @@ -216,7 +218,6 @@ require ( github.com/nishanths/predeclared v0.2.2 // indirect github.com/nunnatsa/ginkgolinter v0.19.1 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect - github.com/opencontainers/go-digest v1.0.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect @@ -267,7 +268,6 @@ require ( github.com/tklauser/numcpus v0.6.1 // indirect github.com/tomarrell/wrapcheck/v2 v2.10.0 // indirect github.com/tommy-muehle/go-mnd/v2 v2.5.1 // indirect - github.com/tonistiigi/fsutil v0.0.0-20250417144416-3f76f8130144 // indirect github.com/tonistiigi/units v0.0.0-20180711220420-6950e57a87ea // indirect github.com/tonistiigi/vt100 v0.0.0-20240514184818-90bafcd6abab // indirect github.com/ultraware/funlen v0.2.0 // indirect diff --git a/pkg/cli/build.go b/pkg/cli/build.go index fceecbe47c..5401d9990c 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -12,8 +12,7 @@ import ( "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/http" - "github.com/replicate/cog/pkg/image" - "github.com/replicate/cog/pkg/registry" + "github.com/replicate/cog/pkg/model/factory" "github.com/replicate/cog/pkg/util/console" ) @@ -97,32 +96,22 @@ func buildCommand(cmd *cobra.Command, args []string) error { logClient.EndBuild(ctx, err, logCtx) return err } - registryClient := registry.NewRegistryClient() - if err := image.Build( - ctx, - cfg, - projectDir, - imageName, - buildSecrets, - buildNoCache, - buildSeparateWeights, - buildUseCudaBaseImage, - buildProgressOutput, - buildSchemaFile, - buildDockerfileFile, - DetermineUseCogBaseImage(cmd), - buildStrip, - buildPrecompile, - buildFast, - nil, - buildLocalImage, - dockerClient, - registryClient); err != nil { + + modelFactory, err := factory.New(dockerClient) + if err != nil { + logClient.EndBuild(ctx, err, logCtx) + return err + } + + settings := buildSettings(cmd, cfg, false, projectDir) + + model, _, err := modelFactory.Build(ctx, settings) + if err != nil { logClient.EndBuild(ctx, err, logCtx) return err } - console.Infof("\nImage built as %s", imageName) + console.Infof("\nImage built as %s", model.ImageRef()) logClient.EndBuild(ctx, nil, logCtx) return nil diff --git a/pkg/cli/image_helpers.go b/pkg/cli/image_helpers.go new file mode 100644 index 0000000000..4dce3a2428 --- /dev/null +++ b/pkg/cli/image_helpers.go @@ -0,0 +1,35 @@ +package cli + +import ( + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/spf13/cobra" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/model/factory" +) + +func buildSettings(cmd *cobra.Command, cfg *config.Config, isPredict bool, projectDir string) factory.BuildSettings { + return factory.BuildSettings{ + Tag: config.DockerImageName(projectDir), + WorkingDir: projectDir, + Config: cfg, + Platform: ocispec.Platform{ + Architecture: "amd64", + OS: "linux", + }, + Monobase: buildFast || cfg.Build.Fast, + NoCache: buildNoCache, + BuildSecrets: buildSecrets, + SeparateWeights: buildSeparateWeights, + UseCudaBaseImage: buildUseCudaBaseImage, + SchemaFile: buildSchemaFile, + DockerfileFile: buildDockerfileFile, + Precompile: buildPrecompile, + ProgressOutput: buildProgressOutput, + Strip: buildStrip, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + LocalImage: buildLocalImage, + PredictBuild: isPredict, + Annotations: map[string]string{}, + } +} diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index 2c0668484c..3a8a31f4b9 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -22,9 +22,10 @@ import ( "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/docker/command" - "github.com/replicate/cog/pkg/image" + "github.com/replicate/cog/pkg/model" + "github.com/replicate/cog/pkg/model/factory" "github.com/replicate/cog/pkg/predict" - "github.com/replicate/cog/pkg/registry" + "github.com/replicate/cog/pkg/util" "github.com/replicate/cog/pkg/util/console" "github.com/replicate/cog/pkg/util/mime" ) @@ -79,98 +80,67 @@ func cmdPredict(cmd *cobra.Command, args []string) error { return err } - imageName := "" volumes := []command.Volume{} gpus := gpusFlag + var predictModel *model.Model + if len(args) == 0 { // Build image - cfg, projectDir, err := config.GetConfig(configFilename) if err != nil { return err } - if cfg.Build.Fast { - buildFast = cfg.Build.Fast - } + settings := buildSettings(cmd, cfg, true, projectDir) - client := registry.NewRegistryClient() - if buildFast { - imageName = config.DockerImageName(projectDir) - if err := image.Build( - ctx, - cfg, - projectDir, - imageName, - buildSecrets, - buildNoCache, - buildSeparateWeights, - buildUseCudaBaseImage, - buildProgressOutput, - buildSchemaFile, - buildDockerfileFile, - DetermineUseCogBaseImage(cmd), - buildStrip, - buildPrecompile, - buildFast, - nil, - buildLocalImage, - dockerClient, - client); err != nil { - return err - } - } else { - if imageName, err = image.BuildBase(ctx, dockerClient, cfg, projectDir, buildUseCudaBaseImage, DetermineUseCogBaseImage(cmd), buildProgressOutput, client); err != nil { - return err - } + modelFactory, err := factory.New(dockerClient) + if err != nil { + return err + } + builtModel, buildInfo, err := modelFactory.Build(ctx, settings) + if err != nil { + return err + } + predictModel = builtModel - // Base image doesn't have /src in it, so mount as volume + // dockerfile images for predict don't have /src in it, so mount as volume + if buildInfo.BaseImageOnly { volumes = append(volumes, command.Volume{ Source: projectDir, Destination: "/src", }) - - if gpus == "" && cfg.Build.GPU { - gpus = "all" - } } - } else { // Use existing image - imageName = args[0] - - // If the image name contains '=', then it's probably a mistake - if strings.Contains(imageName, "=") { - return fmt.Errorf("Invalid image name '%s'. Did you forget `-i`?", imageName) - } + imageName := args[0] - inspectResp, err := dockerClient.Pull(ctx, imageName, false) + _, err := dockerClient.Pull(ctx, imageName, false) if err != nil { return fmt.Errorf("Failed to pull image %q: %w", imageName, err) } - conf, err := image.CogConfigFromManifest(ctx, inspectResp) + predictModel, err = model.Resolve(ctx, imageName, model.WithProvider(dockerClient), model.WithResolveMode(docker.ResolveModeLocal)) if err != nil { - return err - } - if gpus == "" && conf.Build.GPU { - gpus = "all" - } - if conf.Build.Fast { - buildFast = conf.Build.Fast + return fmt.Errorf("Failed to resolve model %q: %w", imageName, err) } } + util.PrettyPrintJSON(predictModel) + + if gpus == "" && predictModel.Config.Build.GPU { + gpus = "all" + } + console.Info("") - console.Infof("Starting Docker image %s and running setup()...", imageName) + console.Infof("Starting Docker image %s and running setup()...", predictModel.Name()) predictor, err := predict.NewPredictor(ctx, command.RunOptions{ GPUs: gpus, - Image: imageName, + Image: predictModel.ImageRef(), Volumes: volumes, Env: envFlags, - }, false, buildFast, dockerClient) + }, false, dockerClient) if err != nil { return err } @@ -196,10 +166,10 @@ func cmdPredict(cmd *cobra.Command, args []string) error { _ = predictor.Stop(ctx) predictor, err = predict.NewPredictor(ctx, command.RunOptions{ - Image: imageName, + Image: predictModel.ImageRef(), Volumes: volumes, Env: envFlags, - }, false, buildFast, dockerClient) + }, false, dockerClient) if err != nil { return err } diff --git a/pkg/cli/push.go b/pkg/cli/push.go index 3f74f607fd..d5350435a5 100644 --- a/pkg/cli/push.go +++ b/pkg/cli/push.go @@ -3,19 +3,15 @@ package cli import ( "fmt" "strings" - "time" "github.com/spf13/cobra" - "github.com/replicate/go/uuid" - "github.com/replicate/cog/pkg/coglog" "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/http" - "github.com/replicate/cog/pkg/image" - "github.com/replicate/cog/pkg/registry" + "github.com/replicate/cog/pkg/model/factory" "github.com/replicate/cog/pkg/util/console" ) @@ -90,50 +86,26 @@ func push(cmd *cobra.Command, args []string) error { return err } - annotations := map[string]string{} - buildID, err := uuid.NewV7() + buildSettings := buildSettings(cmd, cfg, false, projectDir) + + modelFactory, err := factory.New(dockerClient) if err != nil { - // Don't insert build ID but continue anyways - console.Debugf("Failed to create build ID %v", err) - } else { - annotations["run.cog.push_id"] = buildID.String() + return err } - startBuildTime := time.Now() - registryClient := registry.NewRegistryClient() - if err := image.Build( - ctx, - cfg, - projectDir, - imageName, - buildSecrets, - buildNoCache, - buildSeparateWeights, - buildUseCudaBaseImage, - buildProgressOutput, - buildSchemaFile, - buildDockerfileFile, - DetermineUseCogBaseImage(cmd), - buildStrip, - buildPrecompile, - buildFast, - annotations, - buildLocalImage, - dockerClient, - registryClient); err != nil { + model, buildInfo, err := modelFactory.Build(ctx, buildSettings) + if err != nil { return err } - buildDuration := time.Since(startBuildTime) - - console.Infof("\nPushing image '%s'...", imageName) + console.Infof("\nPushing image '%s'...", model.ImageRef()) if buildFast { console.Info("Fast push enabled.") } err = docker.Push(ctx, imageName, buildFast, projectDir, dockerClient, docker.BuildInfo{ - BuildTime: buildDuration, - BuildID: buildID.String(), + BuildTime: buildInfo.Duration, + BuildID: buildInfo.BuildID, Pipeline: pushPipeline, }, client, cfg) if err != nil { diff --git a/pkg/cli/train.go b/pkg/cli/train.go index 2e1212b95a..3b728d828c 100644 --- a/pkg/cli/train.go +++ b/pkg/cli/train.go @@ -13,9 +13,9 @@ import ( "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/docker/command" - "github.com/replicate/cog/pkg/image" + "github.com/replicate/cog/pkg/model" + "github.com/replicate/cog/pkg/model/factory" "github.com/replicate/cog/pkg/predict" - "github.com/replicate/cog/pkg/registry" "github.com/replicate/cog/pkg/util/console" ) @@ -63,7 +63,6 @@ func cmdTrain(cmd *cobra.Command, args []string) error { return err } - imageName := "" volumes := []command.Volume{} gpus := gpusFlag @@ -72,58 +71,59 @@ func cmdTrain(cmd *cobra.Command, args []string) error { return err } + var cogModel *model.Model + if len(args) == 0 { // Build image - if cfg.Build.Fast { - buildFast = cfg.Build.Fast - } + settings := buildSettings(cmd, cfg, true, projectDir) - client := registry.NewRegistryClient() - if imageName, err = image.BuildBase(ctx, dockerClient, cfg, projectDir, buildUseCudaBaseImage, DetermineUseCogBaseImage(cmd), buildProgressOutput, client); err != nil { + modelFactory, err := factory.New(dockerClient) + if err != nil { return err } - - // Base image doesn't have /src in it, so mount as volume - volumes = append(volumes, command.Volume{ - Source: projectDir, - Destination: "/src", - }) - - if gpus == "" && cfg.Build.GPU { - gpus = "all" + builtModel, buildInfo, err := modelFactory.Build(ctx, settings) + if err != nil { + return err + } + cogModel = builtModel + + // dockerfile images for predict don't have /src in it, so mount as volume + if buildInfo.BaseImageOnly { + volumes = append(volumes, command.Volume{ + Source: projectDir, + Destination: "/src", + }) } } else { // Use existing image - imageName = args[0] + imageName := args[0] - inspectResp, err := dockerClient.Pull(ctx, imageName, false) + _, err := dockerClient.Pull(ctx, imageName, false) if err != nil { return fmt.Errorf("Failed to pull image %q: %w", imageName, err) } - conf, err := image.CogConfigFromManifest(ctx, inspectResp) + cogModel, err = model.Resolve(ctx, imageName, model.WithProvider(dockerClient), model.WithResolveMode(docker.ResolveModeLocal)) if err != nil { - return err - } - if gpus == "" && conf.Build.GPU { - gpus = "all" - } - if conf.Build.Fast { - buildFast = conf.Build.Fast + return fmt.Errorf("Failed to resolve model %q: %w", imageName, err) } } + if gpus == "" && cogModel.Config.Build.GPU { + gpus = "all" + } + console.Info("") - console.Infof("Starting Docker image %s...", imageName) + console.Infof("Starting Docker image %s...", cogModel.Name()) predictor, err := predict.NewPredictor(ctx, command.RunOptions{ GPUs: gpus, - Image: imageName, + Image: cogModel.ImageRef(), Volumes: volumes, Env: trainEnvFlags, Args: []string{"python", "-m", "cog.server.http", "--x-mode", "train"}, - }, true, buildFast, dockerClient) + }, true, dockerClient) if err != nil { return err } diff --git a/pkg/docker/api_client.go b/pkg/docker/api_client.go index 7fa128e5ee..39c2028904 100644 --- a/pkg/docker/api_client.go +++ b/pkg/docker/api_client.go @@ -274,12 +274,7 @@ func (c *apiClient) ImageBuild(ctx context.Context, options command.ImageBuildOp } defer os.RemoveAll(buildDir) - bc, err := buildkitclient.New(ctx, "", - // Connect to Docker Engine's embedded Buildkit. - buildkitclient.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { - return c.client.DialHijack(ctx, "/grpc", "h2c", map[string][]string{}) - }), - ) + bc, err := c.BuildKitClient(ctx) if err != nil { return err } @@ -307,7 +302,7 @@ func (c *apiClient) ImageBuild(ctx context.Context, options command.ImageBuildOp } // run the display in a goroutine _after_ we've built SolveOpt - eg.Go(newDisplay(statusCh, displayMode)) + eg.Go(NewBuildKitSolveDisplay(statusCh, displayMode)) res, err = bc.Solve(ctx, nil, options, statusCh) if err != nil { @@ -538,3 +533,21 @@ func parseGPURequest(opts command.RunOptions) (container.DeviceRequest, error) { return deviceRequest, nil } + +func (c *apiClient) DockerClient() (client.APIClient, error) { + return c.client, nil +} + +func (c *apiClient) BuildKitClient(ctx context.Context) (*buildkitclient.Client, error) { + bc, err := buildkitclient.New(ctx, "", + // Connect to Docker Engine's embedded Buildkit. + buildkitclient.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { + return c.client.DialHijack(ctx, "/grpc", "h2c", map[string][]string{}) + }), + ) + if err != nil { + return nil, err + } + + return bc, nil +} diff --git a/pkg/docker/buildkit.go b/pkg/docker/buildkit.go index d961f0b777..fffc4c8b05 100644 --- a/pkg/docker/buildkit.go +++ b/pkg/docker/buildkit.go @@ -143,7 +143,7 @@ func solveOptFromImageOptions(buildDir string, opts command.ImageBuildOptions) ( return solveOpts, nil } -func newDisplay(statusCh chan *buildkitclient.SolveStatus, displayMode string) func() error { +func NewBuildKitSolveDisplay(statusCh chan *buildkitclient.SolveStatus, displayMode string) func() error { return func() error { display, err := progressui.NewDisplay( os.Stderr, diff --git a/pkg/docker/command/command.go b/pkg/docker/command/command.go index 6be650c0ee..39abb6a13a 100644 --- a/pkg/docker/command/command.go +++ b/pkg/docker/command/command.go @@ -6,6 +6,8 @@ import ( "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/image" + "github.com/docker/docker/client" + buildkitclient "github.com/moby/buildkit/client" ) type Command interface { @@ -68,3 +70,9 @@ type Volume struct { Source string Destination string } + +type ClientProvider interface { + Command + DockerClient() (client.APIClient, error) + BuildKitClient(ctx context.Context) (*buildkitclient.Client, error) +} diff --git a/pkg/docker/resolver.go b/pkg/docker/resolver.go new file mode 100644 index 0000000000..21d2483480 --- /dev/null +++ b/pkg/docker/resolver.go @@ -0,0 +1,129 @@ +package docker + +import ( + "context" + "errors" + "fmt" + + "github.com/google/go-containerregistry/pkg/name" + containerregistryv1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/daemon" + "github.com/google/go-containerregistry/pkg/v1/remote" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + + "github.com/replicate/cog/pkg/docker/command" +) + +type ImageSource string + +const ( + ImageSourceLocal ImageSource = "local" + ImageSourceRemote ImageSource = "remote" +) + +type ResolveMode int + +const ( + ResolveModeAuto ResolveMode = iota + ResolveModePreferLocal + ResolveModePreferRemote + ResolveModeLocal + ResolveModeRemote +) + +func ResolveImage(ctx context.Context, ref string, provider command.Command, platform *ocispec.Platform, mode ResolveMode) (containerregistryv1.Image, ImageSource, error) { + parsedRef, err := name.ParseReference(ref) + if err != nil { + return nil, ImageSourceRemote, err + } + + resolver := &resolver{ + mode: mode, + platform: platform, + provider: provider, + } + + return resolver.resolve(ctx, parsedRef) +} + +type resolver struct { + provider command.Command + mode ResolveMode + platform *ocispec.Platform +} + +func (resolver *resolver) resolve(ctx context.Context, ref name.Reference) (containerregistryv1.Image, ImageSource, error) { + var preferred, fallback func(ctx context.Context, ref name.Reference) (containerregistryv1.Image, error) + var preferredSource, fallbackSource ImageSource + + switch resolver.mode { + case ResolveModePreferLocal, ResolveModeAuto: + preferred, fallback = resolver.localImage, resolver.remoteImage + preferredSource, fallbackSource = ImageSourceLocal, ImageSourceRemote + case ResolveModePreferRemote: + preferred, fallback = resolver.remoteImage, resolver.localImage + preferredSource, fallbackSource = ImageSourceRemote, ImageSourceLocal + case ResolveModeLocal: + preferred = resolver.localImage + preferredSource = ImageSourceLocal + case ResolveModeRemote: + preferred = resolver.remoteImage + preferredSource = ImageSourceRemote + } + + img, err := preferred(ctx, ref) + if err == nil { + return img, preferredSource, nil + } + if fallback == nil { + return nil, preferredSource, err + } + + img, fallbackErr := fallback(ctx, ref) + if fallbackErr == nil { + return img, fallbackSource, nil + } + return nil, fallbackSource, errors.Join(fallbackErr, err) +} + +func (resolver *resolver) localImage(ctx context.Context, ref name.Reference) (containerregistryv1.Image, error) { + opts := []daemon.Option{ + daemon.WithContext(ctx), + } + + if clientProvider, ok := resolver.provider.(command.ClientProvider); ok { + client, err := clientProvider.DockerClient() + if err != nil { + return nil, err + } + opts = append(opts, daemon.WithClient(client)) + } + + img, err := daemon.Image(ref, opts...) + if err != nil { + return nil, fmt.Errorf("failed to get local image: %w", err) + } + + return img, nil +} + +func (resolver *resolver) remoteImage(ctx context.Context, ref name.Reference) (containerregistryv1.Image, error) { + opts := []remote.Option{ + remote.WithContext(ctx), + } + if resolver.platform != nil { + opts = append(opts, remote.WithPlatform(containerregistryv1.Platform{ + Architecture: resolver.platform.Architecture, + OS: resolver.platform.OS, + OSVersion: resolver.platform.OSVersion, + OSFeatures: resolver.platform.OSFeatures, + Variant: resolver.platform.Variant, + })) + } + + img, err := remote.Image(ref, opts...) + if err != nil { + return nil, fmt.Errorf("failed to get remote image: %w", err) + } + return img, nil +} diff --git a/pkg/model/factory/buildkit_factory.go b/pkg/model/factory/buildkit_factory.go new file mode 100644 index 0000000000..95a8ee8802 --- /dev/null +++ b/pkg/model/factory/buildkit_factory.go @@ -0,0 +1,249 @@ +package factory + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + + "github.com/containerd/containerd/api/services/content/v1" + "github.com/google/go-containerregistry/pkg/name" + buildkitclient "github.com/moby/buildkit/client" + "github.com/moby/buildkit/frontend/gateway/client" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/tonistiigi/fsutil" + "golang.org/x/sync/errgroup" + + "github.com/replicate/cog/pkg/docker" + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/global" + "github.com/replicate/cog/pkg/model" + "github.com/replicate/cog/pkg/model/factory/state" + "github.com/replicate/cog/pkg/model/factory/types" + "github.com/replicate/cog/pkg/util" +) + +func newBuildkitFactory(provider command.ClientProvider) (*buildkitFactory, error) { + return &buildkitFactory{ + provider: provider, + }, nil +} + +type buildkitFactory struct { + provider command.ClientProvider +} + +func (f *buildkitFactory) Build(ctx context.Context, settings BuildSettings) (*model.Model, BuildInfo, error) { + buildInfo := BuildInfo{ + FactoryBackend: "buildkit", + } + + bkClient, err := f.provider.BuildKitClient(ctx) + if err != nil { + return nil, buildInfo, err + } + defer bkClient.Close() + + contextFS, err := fsutil.NewFS(settings.WorkingDir) + if err != nil { + return nil, buildInfo, fmt.Errorf("failed to create context FS: %w", err) + } + + // define the root solve options + solveOpt := buildkitclient.SolveOpt{ + Exports: []buildkitclient.ExportEntry{ + { + Type: "moby", + Attrs: map[string]string{ + "name": settings.Tag, + }, + }, + }, + LocalMounts: map[string]fsutil.FS{ + "context": contextFS, + }, + } + + productID := fmt.Sprintf("cog-model:%s", settings.Tag) + + // Create a status channel for build progress + statusCh := make(chan *buildkitclient.SolveStatus) + + eg, egctx := errgroup.WithContext(ctx) + eg.Go(docker.NewBuildKitSolveDisplay(statusCh, "plain")) + + var solveResp *buildkitclient.SolveResponse + + eg.Go(func() error { + resp, err := bkClient.Build( + egctx, + solveOpt, + productID, + func(ctx context.Context, c client.Client) (*client.Result, error) { + buildCtx := types.Context{ + Context: ctx, + Config: settings.Config, + WorkingDir: settings.WorkingDir, + Platform: settings.Platform, + Client: c, + } + + stack := PythonStack{} + buildInfo.Builder = "python" + + finalState, err := stack.Solve(buildCtx, c) + if err != nil { + return nil, err + } + + def, err := finalState.Marshal(ctx) + if err != nil { + return nil, err + } + + result, err := c.Solve(ctx, client.SolveRequest{ + Definition: def.ToPB(), + }) + if err != nil { + return nil, err + } + + outputMeta, err := state.GetMeta(buildCtx, finalState) + if err != nil { + return nil, err + } + outputMeta.Labels[types.LabelVersion] = global.Version + + configJSON, err := json.Marshal(buildCtx.Config) + if err != nil { + return nil, fmt.Errorf("Failed to convert config to JSON: %w", err) + } + outputMeta.Labels[types.LabelConfig] = string(configJSON) + + fmt.Println("outputMeta") + util.PrettyPrintJSON(outputMeta) + + outputImage := ocispec.Image{ + Config: outputMeta.ToImageConfig(), + Platform: settings.Platform, + Author: "cog", + } + + iamgeBlob, err := json.Marshal(outputImage) + if err != nil { + return nil, fmt.Errorf("failed to marshal image config: %w", err) + } + + out := &client.Result{} + // out.AddMeta("yo", []byte("yo")) + out.SetRef(result.Ref) // filesystem + out.AddMeta("containerimage.config", iamgeBlob) // config blob + + result.AddMeta("containerimage.config", iamgeBlob) + + return out, nil + }, + statusCh, + ) + if err != nil { + return fmt.Errorf("failed to solve build: %w", err) + } + solveResp = resp + return nil + }) + + if err := eg.Wait(); err != nil { + return nil, buildInfo, err + } + + fmt.Println("solveResp") + util.PrettyPrintJSON(solveResp) + + descriptor, manifest, image, err := imageFromExporterResp(ctx, bkClient, solveResp.ExporterResponse) + if err != nil { + return nil, buildInfo, err + } + util.PrettyPrintJSON(descriptor) + util.PrettyPrintJSON(manifest) + util.PrettyPrintJSON(image) + + ref, err := name.ParseReference(settings.Tag) + if err != nil { + return nil, buildInfo, fmt.Errorf("failed to parse reference: %w", err) + } + + return &model.Model{ + Ref: ref, + Source: model.ModelSourceLocal, + Config: settings.Config, + Manifest: manifest, + Image: *image, + }, buildInfo, nil +} + +func imageFromExporterResp(ctx context.Context, bkClient *buildkitclient.Client, exporterResp map[string]string) (*ocispec.Descriptor, *ocispec.Manifest, *ocispec.Image, error) { + manifestDesc := exporterResp["containerimage.descriptor"] + if manifestDesc == "" { + return nil, nil, nil, fmt.Errorf("no manifest descriptor found in response") + } + + data, err := base64.StdEncoding.DecodeString(manifestDesc) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to decode manifest descriptor: %w", err) + } + + var descriptor ocispec.Descriptor + if err := json.Unmarshal(data, &descriptor); err != nil { + return nil, nil, nil, fmt.Errorf("failed to parse manifest descriptor: %w", err) + } + + manifestContent, err := readContent(ctx, bkClient, descriptor.Digest.String()) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to read manifest content: %w", err) + } + + var manifest ocispec.Manifest + if err := json.Unmarshal(manifestContent, &manifest); err != nil { + return nil, nil, nil, fmt.Errorf("failed to parse manifest: %w", err) + } + + // Get the config digest from the response + configDigest := exporterResp["containerimage.config.digest"] + if configDigest == "" { + return nil, nil, nil, fmt.Errorf("no config digest found in response") + } + + imageConfigData, err := readContent(ctx, bkClient, configDigest) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to read image config: %w", err) + } + + var imageConfig ocispec.Image + if err := json.Unmarshal(imageConfigData, &imageConfig); err != nil { + return nil, nil, nil, fmt.Errorf("failed to parse image config: %w", err) + } + + return &descriptor, &manifest, &imageConfig, nil +} + +func readContent(ctx context.Context, bkClient *buildkitclient.Client, digest string) ([]byte, error) { + // Read the config content + readClient, err := bkClient.ContentClient().Read(ctx, &content.ReadContentRequest{Digest: digest}) + if err != nil { + return nil, fmt.Errorf("failed to read content: %w", err) + } + + var buf bytes.Buffer + + // Read the config content + for { + msg, err := readClient.Recv() + if err != nil { + break + } + buf.Write(msg.Data) + } + + return buf.Bytes(), nil +} diff --git a/pkg/model/factory/dockerfile_factory.go b/pkg/model/factory/dockerfile_factory.go new file mode 100644 index 0000000000..4bbf56f087 --- /dev/null +++ b/pkg/model/factory/dockerfile_factory.go @@ -0,0 +1,88 @@ +package factory + +import ( + "context" + + "github.com/replicate/cog/pkg/docker" + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/image" + "github.com/replicate/cog/pkg/model" + "github.com/replicate/cog/pkg/registry" +) + +func newDockerfileFactory(provider command.Command) *dockerfileFactory { + return &dockerfileFactory{ + provider: provider, + } +} + +type dockerfileFactory struct { + provider command.Command +} + +func (f *dockerfileFactory) Build(ctx context.Context, settings BuildSettings) (*model.Model, BuildInfo, error) { + buildInfo := BuildInfo{ + FactoryBackend: "dockerfile", + } + + resolveOpts := []model.ResolveOption{ + model.WithProvider(f.provider), + model.WithResolveMode(docker.ResolveModeLocal), + } + registryClient := registry.NewRegistryClient() + + var imageName string + // if we're building for predict, build a base image instead of a full image + if settings.PredictBuild { + buildInfo.BaseImageOnly = true + + // base images don't have a config, so pass it through to the model resolver + resolveOpts = append(resolveOpts, model.WithConfig(settings.Config)) + + baseImageName, err := image.BuildBase(ctx, + f.provider, + settings.Config, + settings.WorkingDir, + settings.UseCudaBaseImage, + settings.UseCogBaseImage, + settings.ProgressOutput, + registryClient, + ) + if err != nil { + return nil, buildInfo, err + } + imageName = baseImageName + } else { + imageName = settings.Tag + err := image.Build(ctx, + settings.Config, + settings.WorkingDir, + imageName, + settings.BuildSecrets, + settings.NoCache, + settings.SeparateWeights, + settings.UseCudaBaseImage, + settings.ProgressOutput, + settings.SchemaFile, + settings.DockerfileFile, + settings.UseCogBaseImage, + settings.Strip, + settings.Precompile, + settings.Monobase, + settings.Annotations, + settings.LocalImage, + f.provider, + registryClient, + ) + if err != nil { + return nil, buildInfo, err + } + } + + model, err := model.Resolve(ctx, imageName, resolveOpts...) + if err != nil { + return nil, buildInfo, err + } + + return model, buildInfo, nil +} diff --git a/pkg/model/factory/factory.go b/pkg/model/factory/factory.go new file mode 100644 index 0000000000..3f037c0fdd --- /dev/null +++ b/pkg/model/factory/factory.go @@ -0,0 +1,113 @@ +package factory + +import ( + "context" + "crypto/sha256" + "fmt" + "time" + + "github.com/google/uuid" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/model" + "github.com/replicate/cog/pkg/util" + "github.com/replicate/cog/pkg/util/console" +) + +type BuildSettings struct { + Tag string + WorkingDir string + Config *config.Config + Platform ocispec.Platform + + // dockerfile factory settings, many will get moved to the top section once buildkit factory supports it + Monobase bool + SeparateWeights bool + UseCudaBaseImage string + SchemaFile string + DockerfileFile string + Precompile bool + Strip bool + UseCogBaseImage *bool + LocalImage bool + NoCache bool + BuildSecrets []string + ProgressOutput string + PredictBuild bool // tell the dockerfile factory that this model is for predict so it can cut corners + Annotations map[string]string +} + +// type BuildInfo map[string]any + +type BuildInfo struct { + Duration time.Duration + BuildID string + + Builder string + + FactoryBackend string + + // does the build include the model source? this only applies to dockerfile builds for predict/train + BaseImageOnly bool +} + +type Factory struct { + impl factoryImpl +} + +func (f *Factory) Build(ctx context.Context, settings BuildSettings) (*model.Model, BuildInfo, error) { + // buildInfo := BuildInfo{} + + startTime := time.Now() + buildID := f.newBuildID(settings) + + // TODO[md]: not sure we want this in every label since it changes the image even though the underlying layers are identical... + if settings.Annotations == nil { + settings.Annotations = map[string]string{} + } + settings.Annotations["build_id"] = buildID + + model, buildInfo, err := f.impl.Build(ctx, settings) + buildInfo.Duration = time.Since(startTime) + buildInfo.BuildID = buildID + + return model, buildInfo, err +} + +func (f *Factory) newBuildID(settings BuildSettings) string { + // generating a uuid v7 only errors in extreme cases, like system clock issues, + // resource exhaustion, or entropy exhaustion. + if id, err := uuid.NewV7(); err == nil { + return id.String() + } + + // fallback to a uuid v4 which is even less likely to fail + if id, err := uuid.NewRandom(); err == nil { + return id.String() + } + + // finally, return a best-effort unique string from the build context & timestamp + hash := sha256.Sum256([]byte(settings.WorkingDir + settings.Tag)) + return fmt.Sprintf("build-%x-%d", hash[:8], time.Now().UnixNano()) +} + +func New(provider command.Command) (Factory, error) { + if util.EnvIsTruthy("COG_BUILDKIT_FACTORY") { + if clientProvider, ok := provider.(command.ClientProvider); ok { + impl, err := newBuildkitFactory(clientProvider) + if err != nil { + return Factory{}, err + } + return Factory{impl: impl}, nil + } + console.Warnf("COG_BUILDKIT_FACTORY is set, but provider does not implement command.ClientProvider. Falling back to dockerfile factory.") + } + return Factory{impl: newDockerfileFactory(provider)}, nil +} + +// factoryImpl is the interface that the buildkit & dockerfile factories implement +type factoryImpl interface { + Build(ctx context.Context, settings BuildSettings) (*model.Model, BuildInfo, error) +} diff --git a/pkg/model/factory/ops/apt.go b/pkg/model/factory/ops/apt.go new file mode 100644 index 0000000000..996101ce07 --- /dev/null +++ b/pkg/model/factory/ops/apt.go @@ -0,0 +1,98 @@ +package ops + +import ( + "fmt" + "strings" + + "github.com/moby/buildkit/client/llb" + + "github.com/replicate/cog/pkg/model/factory/types" +) + +// AptInstall installs system packages using apt-get +type AptInstall struct { + packages []string +} + +// NewAptInstall creates a new apt install operation +func NewAptInstall(packages ...string) *AptInstall { + return &AptInstall{packages: packages} +} + +// // NewAptInstallFromConfig creates apt install operation from cog config +// func NewAptInstallFromConfig() *AptInstall { +// return &AptInstall{} +// } + +func (op *AptInstall) Name() string { + if len(op.packages) > 0 { + return fmt.Sprintf("apt-install %s", strings.Join(op.packages, " ")) + } + return "apt-install-from-config" +} + +func (op *AptInstall) ShouldRun(ctx types.Context, state types.State) bool { + if len(op.packages) > 0 { + return true + } + // Check config for packages + return len(ctx.Config.Build.SystemPackages) > 0 +} + +func (op *AptInstall) Apply(ctx types.Context, state llb.State) (llb.State, error) { + packages := op.packages + if len(packages) == 0 { + packages = ctx.Config.Build.SystemPackages + } + + if len(packages) == 0 { + return state, nil + } + + aptCache := llb.AsPersistentCacheDir("apt-cache", llb.CacheMountLocked) + pkgList := strings.Join(packages, " ") + + // 1. apt-get update + intermediate := state.Run( + llb.Shlex("apt-get update -qq"), + llb.AddMount("/var/cache/apt", llb.Scratch(), aptCache), + llb.WithCustomName("apt-update"), + ).Root() + + // 2. apt-get install + intermediate = intermediate.Run( + llb.Shlex(fmt.Sprintf("apt-get install -qqy --no-install-recommends %s", pkgList)), + llb.AddMount("/var/cache/apt", llb.Scratch(), aptCache), + llb.WithCustomNamef("apt-install %s", pkgList), + ).Root() + + // 3. cleanup + intermediate = intermediate.Run( + llb.Shlex("apt-get clean"), + llb.WithCustomName("apt-clean"), + ).Root() + + removeDirs := []string{ + "/var/lib/apt/lists/*", + // docker for mac appears to add /root/.cache/rosetta directory, kill it and the cache directory + "/root/.cache", + "/var/log/*", + "/var/cache/apt/*", + "/var/lib/apt/lists/*", + "/var/cache/debconf/*", + "/usr/share/doc-base/*", + "/usr/share/common-licenses", + } + + intermediate = intermediate.Run( + llb.Shlex(fmt.Sprintf("sh -c 'rm -rf %s'", strings.Join(removeDirs, " "))), + llb.WithCustomName(fmt.Sprintf("remove %s", strings.Join(removeDirs, " "))), + ).Root() + + flattened := state.File( + llb.Copy(llb.Diff(state, intermediate), "/", "/"), + llb.WithCustomName("install apt dependencies"), + ) + + return flattened, nil +} diff --git a/pkg/model/factory/ops/base.go b/pkg/model/factory/ops/base.go new file mode 100644 index 0000000000..5e9358e9e9 --- /dev/null +++ b/pkg/model/factory/ops/base.go @@ -0,0 +1,44 @@ +package ops + +import ( + "encoding/json" + "fmt" + + "github.com/distribution/reference" + "github.com/moby/buildkit/client/llb" + "github.com/moby/buildkit/client/llb/sourceresolver" + gatewayClient "github.com/moby/buildkit/frontend/gateway/client" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + + "github.com/replicate/cog/pkg/model/factory/state" + "github.com/replicate/cog/pkg/model/factory/types" +) + +func ResolveBaseImage(ctx types.Context, feClient gatewayClient.Client, platform ocispec.Platform, ref string) (llb.State, error) { + named, err := reference.ParseNormalizedNamed(ref) + if err != nil { + return llb.State{}, fmt.Errorf("failed to parse reference: %w", err) + } + // TODO[md]: is this necessary??? + named = reference.TagNameOnly(named) + + resolvedRef, _, blob, err := feClient.ResolveImageConfig(ctx, named.String(), sourceresolver.Opt{ + Platform: &platform, + ImageOpt: &sourceresolver.ResolveImageOpt{ + ResolveMode: llb.ResolveModePreferLocal.String(), + }, + }) + if err != nil { + return llb.State{}, fmt.Errorf("failed to resolve base image: %w", err) + } + + var img ocispec.Image + if err := json.Unmarshal(blob, &img); err != nil { + return llb.State{}, fmt.Errorf("failed to unmarshal image config: %w", err) + } + + meta := state.MetaFromImage(&img) + + baseState := llb.Image(resolvedRef, llb.Platform(platform)) + return state.WithMeta(baseState, meta), nil +} diff --git a/pkg/model/factory/ops/download.go b/pkg/model/factory/ops/download.go new file mode 100644 index 0000000000..acdb13b309 --- /dev/null +++ b/pkg/model/factory/ops/download.go @@ -0,0 +1,27 @@ +package ops + +import ( + "github.com/moby/buildkit/client/llb" + + "github.com/replicate/cog/pkg/model/factory/types" +) + +func Download(source string, containerPath string) *downloader { + return &downloader{ + source: source, + containerPath: containerPath, + } +} + +type downloader struct { + source string + containerPath string +} + +func (op *downloader) Apply(ctx types.Context, base llb.State) (llb.State, error) { + intermediate := base + + target := llb.HTTP(op.source, llb.Filename("download.bin"), llb.Chmod(0x755)) + + return intermediate.File(llb.Copy(target, "/download.bin", op.containerPath)), nil +} diff --git a/pkg/model/factory/ops/env.go b/pkg/model/factory/ops/env.go new file mode 100644 index 0000000000..5af11542c6 --- /dev/null +++ b/pkg/model/factory/ops/env.go @@ -0,0 +1,41 @@ +package ops + +import ( + "fmt" + "strings" + + "github.com/moby/buildkit/client/llb" + + "github.com/replicate/cog/pkg/model/factory/types" +) + +// setEnv sets environment variables +type setEnv struct { + vars map[string]string +} + +func SetEnv(vars map[string]string) *setEnv { + return &setEnv{vars: vars} +} + +func (op *setEnv) Name() string { + var pairs []string + for k, v := range op.vars { + pairs = append(pairs, fmt.Sprintf("%s=%s", k, v)) + } + return fmt.Sprintf("set-env %s", strings.Join(pairs, " ")) +} + +func (op *setEnv) ShouldRun(ctx types.Context, state types.State) bool { + return len(op.vars) > 0 +} + +func (op *setEnv) Apply(ctx types.Context, state llb.State) (llb.State, error) { + intermediate := state + + for key, value := range op.vars { + intermediate = intermediate.AddEnv(key, value) + } + + return intermediate, nil +} diff --git a/pkg/model/factory/ops/layer.go b/pkg/model/factory/ops/layer.go new file mode 100644 index 0000000000..4ecde05d00 --- /dev/null +++ b/pkg/model/factory/ops/layer.go @@ -0,0 +1,54 @@ +package ops + +import ( + "github.com/moby/buildkit/client/llb" + + "github.com/replicate/cog/pkg/model/factory/types" +) + +func Layer(name string, ops ...Operation) layerOp { + return layerOp{ + role: name, + ops: ops, + } +} + +type layerOp struct { + role string + ops []Operation +} + +func (op layerOp) Apply(ctx types.Context, base llb.State) (llb.State, error) { + // meta, err := state.GetMeta(ctx, base) + // if err != nil { + // return llb.State{}, err[] + // } + + var err error + intermediate := base + + for _, op := range op.ops { + intermediate, err = op.Apply(ctx, intermediate) + if err != nil { + return llb.State{}, err + } + } + + diff := llb.Diff(base, intermediate, llb.WithCustomNamef("layer.diff:%s", op.role)) + + // return llb.Merge([]llb.State{base, diff}, llb.WithCustomNamef("merge.layer: %s", op.role)), nil + // return diff, nil + final := base.File( + llb.Copy(diff, "/", "/"), + llb.WithCustomNamef("layer: %s", op.role), + ) + + // return diff, nil + // merged := llb.Merge([]llb.State{base, diff}) + + // meta.Layers = append(meta.Layers, state.LayerInfo{ + // Role: op.role, + // }) + + return final, nil +} diff --git a/pkg/model/factory/ops/op.go b/pkg/model/factory/ops/op.go new file mode 100644 index 0000000000..04fbda438b --- /dev/null +++ b/pkg/model/factory/ops/op.go @@ -0,0 +1,47 @@ +package ops + +import ( + "github.com/moby/buildkit/client/llb" + + "github.com/replicate/cog/pkg/model/factory/types" +) + +type Operation interface { + Apply(ctx types.Context, state llb.State) (llb.State, error) +} + +type funcOp struct { + fn func(ctx types.Context, state llb.State) (llb.State, error) +} + +func (op funcOp) Apply(ctx types.Context, state llb.State) (llb.State, error) { + return op.fn(ctx, state) +} + +func OpFunc(f func(ctx types.Context, state llb.State) (llb.State, error)) Operation { + return funcOp{ + fn: f, + } +} + +func Do(ops ...Operation) Operation { + return doit{ + ops: ops, + } +} + +type doit struct { + ops []Operation +} + +func (op doit) Apply(ctx types.Context, state llb.State) (llb.State, error) { + var err error + intermediate := state + for _, op := range op.ops { + intermediate, err = op.Apply(ctx, intermediate) + if err != nil { + return llb.State{}, err + } + } + return intermediate, nil +} diff --git a/pkg/model/factory/python.go b/pkg/model/factory/python.go new file mode 100644 index 0000000000..f7d5ea69c9 --- /dev/null +++ b/pkg/model/factory/python.go @@ -0,0 +1,343 @@ +package factory + +import ( + "bytes" + "fmt" + "io" + + "github.com/moby/buildkit/client/llb" + "github.com/moby/buildkit/frontend/gateway/client" + gatewayClient "github.com/moby/buildkit/frontend/gateway/client" + "github.com/moby/buildkit/solver/pb" + + "github.com/replicate/cog/pkg/dockerfile" + "github.com/replicate/cog/pkg/model/factory/ops" + "github.com/replicate/cog/pkg/model/factory/state" + "github.com/replicate/cog/pkg/model/factory/types" +) + +type PythonStack struct { + *types.BuildEnv +} + +func (stack *PythonStack) Solve(ctx types.Context, feClient gatewayClient.Client) (llb.State, error) { + baseImg := "debian:bookworm-slim" + + baseState, err := ops.ResolveBaseImage(ctx, feClient, ctx.Platform, baseImg) + if err != nil { + return llb.State{}, err + } + + intermediate, err := state.WithConfig(ctx, baseState, + state.WithExposedPort("5000/tcp"), + state.WithEntrypoint([]string{"/usr/bin/tini", "--"}), + state.WithCmd([]string{"python", "-m", "cog.server.http"}), + state.WithWorkingDir("/model-src"), + ) + if err != nil { + return llb.State{}, err + } + + return ops.Do( + ops.Layer("sysdeps", + ops.NewAptInstall("tini"), + // don't install pget for these tests since pget in PATH _forces_ a download from github releases! + // ops.Download("https://github.com/replicate/pget/releases/latest/download/pget_Linux_x86_64", "/usr/local/bin/pget"), + ), + ops.Layer("python", + stack.installPython("3.12"), + ), + ops.Layer("venv+model-deps", + stack.initVENV(), + stack.installModelDeps(), + ), + ops.Layer("model", + stack.installModel(), + stack.installSchema(), + ), + ops.Layer("hacks", + stack.installFakePip(), + ), + ).Apply(ctx, intermediate) +} + +func (stack *PythonStack) installPython(version string) ops.Operation { + return ops.OpFunc(func(ctx types.Context, base llb.State) (llb.State, error) { + + intermediate := base + uvCache := llb.AsPersistentCacheDir("uv-cache", llb.CacheMountLocked) + + intermediate = intermediate.AddEnv("UV_COMPILE_BYTECODE", "1") + intermediate = intermediate.AddEnv("UV_LINK_MODE", "copy") + intermediate = intermediate.AddEnv("UV_PYTHON_INSTALL_DIR", "/python") + intermediate = intermediate.AddEnv("UV_PYTHON_PREFERENCE", "only-managed") + + intermediate = intermediate.Run( + llb.Shlexf("/uv/uv python install %s", version), + llb.AddMount("/uv", llb.Image("ghcr.io/astral-sh/uv:latest", llb.Platform(ctx.Platform), llb.ResolveModePreferLocal)), + llb.AddMount("/root/.cache/uv", llb.Scratch(), uvCache), + ).Root() + + diff := llb.Diff(base, intermediate) + final := base.File( + llb.Copy(diff, "/python", "/python"), + llb.WithCustomNamef("wat install python %s", version), + ) + + return final, nil + }) +} + +func (stack *PythonStack) initVENV() ops.Operation { + return ops.OpFunc(func(ctx types.Context, base llb.State) (llb.State, error) { + + intermediate := base + uvCache := llb.AsPersistentCacheDir("uv-cache", llb.CacheMountLocked) + + intermediate = intermediate.AddEnv("UV_COMPILE_BYTECODE", "1") + intermediate = intermediate.AddEnv("UV_LINK_MODE", "copy") + intermediate = intermediate.AddEnv("UV_PYTHON_INSTALL_DIR", "/python") + intermediate = intermediate.AddEnv("UV_PYTHON_PREFERENCE", "only-managed") + + intermediate = intermediate.Run( + llb.Shlexf("/uv/uv venv /venv --python %s", "3.12"), + llb.WithCustomName("init venv"), + llb.AddMount("/uv", llb.Image("ghcr.io/astral-sh/uv:latest", llb.Platform(ctx.Platform))), + llb.AddMount("/root/.cache/uv", llb.Scratch(), uvCache), + ).Root() + + intermediate, err := state.PrependPath(ctx, intermediate, "/venv/bin") + if err != nil { + return llb.State{}, err + } + + return intermediate, nil + }) +} + +func (stack *PythonStack) installModelDeps() ops.Operation { + return ops.OpFunc(func(ctx types.Context, base llb.State) (llb.State, error) { + intermediate := base + + intermediate = intermediate.AddEnv("UV_COMPILE_BYTECODE", "1") + intermediate = intermediate.AddEnv("UV_LINK_MODE", "copy") + intermediate = intermediate.AddEnv("UV_PYTHON_INSTALL_DIR", "/python") + intermediate = intermediate.AddEnv("UV_PYTHON_PREFERENCE", "only-managed") + + // Create UV cache mount for faster builds + uvCache := llb.AsPersistentCacheDir("uv-cache", llb.CacheMountLocked) + uvImage := llb.Image("ghcr.io/astral-sh/uv:latest", llb.Platform(ctx.Platform), llb.ResolveModePreferLocal) + + // Get the embedded cog wheel file + wheelData, wheelFilename, err := dockerfile.ReadWheelFile() + if err != nil { + return base, fmt.Errorf("failed to read embedded cog wheel: %w", err) + } + + // Copy the wheel file to the container + wheelPath := "/tmp/" + wheelFilename + intermediate = intermediate.File( + llb.Mkfile(wheelPath, 0x644, wheelData), + llb.WithCustomName("copy-cog-wheel"), + ) + + // Install the cog wheel file and pydantic dependency + intermediate = intermediate.Run( + llb.Shlexf("/uv/uv pip install --python /venv/bin/python %s 'pydantic>=1.9,<3'", wheelPath), + llb.AddMount("/root/.cache/uv", llb.Scratch(), uvCache), + llb.AddMount("/uv", uvImage), + llb.WithCustomName("uv-install-cog-wheel"), + ).Root() + + intermediate = intermediate.File(llb.Rm(wheelPath)) + + // If Python requirements are specified, install them as well + if ctx.Config.Build.PythonRequirements != "" { + intermediate = intermediate.Run( + llb.Shlexf("/uv/uv pip install --python /venv/bin/python -r %s", ctx.Config.Build.PythonRequirements), + llb.AddMount("/root/.cache/uv", llb.Scratch(), uvCache), + llb.AddMount("/uv", uvImage), + llb.WithCustomNamef("uv-install-requirements %s", ctx.Config.Build.PythonRequirements), + ).Root() + } + + diff := llb.Diff(base, intermediate) + final := base.File( + llb.Copy(diff, "/", "/", llb.WithExcludePatterns([]string{"/root/.cache"})), + llb.WithCustomName("install model deps"), + ) + + return final, nil + }) +} + +func (stack *PythonStack) installModel() ops.Operation { + return ops.OpFunc(func(ctx types.Context, base llb.State) (llb.State, error) { + intermediate := base + + // why do we need to do this twice? + intermediate = intermediate.Dir("/model-src") + + // Copy the context files + intermediate = intermediate.File( + llb.Copy( + llb.Local("context"), + ".", + ".", + llb.WithExcludePatterns([]string{".cog", "__pycache__"}), + ), + // llb.IgnoreCache, + llb.WithCustomName("copy context"), + ) + + return intermediate, nil + }) +} + +func (stack *PythonStack) installSchema() ops.Operation { + return ops.OpFunc(func(ctx types.Context, base llb.State) (llb.State, error) { + intermediate := base + + schemaData, err := stack.generateSchemaInContainer(ctx, intermediate) + if err != nil { + return llb.State{}, fmt.Errorf("failed to generate schema: %w", err) + } + + intermediate, err = state.SetLabel(ctx, intermediate, types.LabelOpenAPISchema, string(schemaData)) + if err != nil { + return llb.State{}, fmt.Errorf("failed to set label: %w", err) + } + intermediate = intermediate.File(llb.Mkfile("schema.json", 0x644, schemaData)) + + return intermediate, nil + }) +} + +func (stack *PythonStack) generateSchemaInContainer(ctx types.Context, base llb.State) ([]byte, error) { + def, err := base.Marshal(ctx) + if err != nil { + return nil, err + } + + // fmt.Println("generate Schema solve") + res, err := ctx.Client.Solve(ctx, client.SolveRequest{ + Definition: def.ToPB(), + }) + if err != nil { + return nil, fmt.Errorf("failed to solve build: %w", err) + } + + container, err := ctx.Client.NewContainer(ctx, client.NewContainerRequest{ + Platform: &pb.Platform{ + OS: ctx.Platform.OS, + Architecture: ctx.Platform.Architecture, + Variant: ctx.Platform.Variant, + OSVersion: ctx.Platform.OSVersion, + OSFeatures: ctx.Platform.OSFeatures, + }, + Mounts: []client.Mount{ + { + Dest: "/", + Ref: res.Ref, + }, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to create container: %w", err) + } + defer func() { + if err := container.Release(ctx); err != nil { + fmt.Printf("failed to release container: %T\n", err) + fmt.Printf("failed to release container: %s\n", err) + } + }() + + stdoutR, stdoutW := io.Pipe() + stderrR, stderrW := io.Pipe() + + var stdout bytes.Buffer + go func() { + _, _ = io.Copy(&stdout, stdoutR) + stdoutR.Close() + }() + var stderr bytes.Buffer + go func() { + _, _ = io.Copy(&stderr, stderrR) + stderrR.Close() + }() + + env, err := state.GetEnv(ctx, base) + if err != nil { + return nil, fmt.Errorf("failed to get env: %w", err) + } + + process, err := container.Start(ctx, client.StartRequest{ + Cwd: "/model-src", + Args: []string{"python", "-m", "cog.command.openapi_schema"}, + Stdout: stdoutW, + Stderr: stderrW, + Env: env, + }) + if err != nil { + return nil, fmt.Errorf("failed to start container: %w", err) + } + + if err := process.Wait(); err != nil { + return nil, fmt.Errorf("failed to wait for process: (STDOUT: %s, STDERR: %s) %w", stdout.String(), stderr.String(), err) + } + + stdoutStr := stdout.String() + stderrStr := stderr.String() + + if stderrStr != "" { + return nil, fmt.Errorf("stderr: %s", stderrStr) + } + + // fmt.Println("stdout", stdoutStr) + + return []byte(stdoutStr), nil +} + +// r8 runtime overrides the entrypoint with a script that may run `pip`, which these self contained images don't have. This +// operation injects a noop `pip` module into the python venv so boot won't fail. +func (stack *PythonStack) installFakePip() ops.Operation { + return ops.OpFunc(func(ctx types.Context, base llb.State) (llb.State, error) { + intermediate := base + + // Create the dummy pip __main__.py content + pipMainContent := []byte(`#!/usr/bin/env python3 +import sys +import os +# Dummy pip that handles all commands silently +if os.environ.get('DEBUG_DUMMY_PIP'): + print(f'dummy-pip: ignoring command: {" ".join(sys.argv)}', file=sys.stderr) +sys.exit(0) +`) + + // Create the __init__.py content (empty) + initContent := []byte("") + + intermediate = intermediate.File( + llb.Mkdir("/venv/lib/python3.12/site-packages/pip", 0x755), + llb.WithCustomName("create-pip-directory"), + ) + + // Create the directory and files + intermediate = intermediate.Run( + llb.Shlex("mkdir -p /venv/lib/python3.12/site-packages/pip"), + llb.WithCustomName("create-pip-directory"), + ).Root() + + intermediate = intermediate.File( + llb.Mkfile("/venv/lib/python3.12/site-packages/pip/__main__.py", 0x644, pipMainContent), + llb.WithCustomName("create-pip-main"), + ) + + intermediate = intermediate.File( + llb.Mkfile("/venv/lib/python3.12/site-packages/pip/__init__.py", 0x644, initContent), + llb.WithCustomName("create-pip-init"), + ) + + return intermediate, nil + }) +} diff --git a/pkg/model/factory/state/env.go b/pkg/model/factory/state/env.go new file mode 100644 index 0000000000..75352ef451 --- /dev/null +++ b/pkg/model/factory/state/env.go @@ -0,0 +1,44 @@ +package state + +import ( + "github.com/moby/buildkit/client/llb" + + "github.com/replicate/cog/pkg/model/factory/types" +) + +func SetEnvs(ctx types.Context, state llb.State, env map[string]string) (llb.State, error) { + meta, err := GetMeta(ctx, state) + if err != nil { + return llb.State{}, err + } + for k, v := range env { + meta.Env[k] = v + } + return state, nil +} + +func SetEnv(ctx types.Context, state llb.State, k, v string) (llb.State, error) { + meta, err := GetMeta(ctx, state) + if err != nil { + return llb.State{}, err + } + meta.Env[k] = v + return state, nil +} + +func UnsetEnv(ctx types.Context, state llb.State, k string) (llb.State, error) { + meta, err := GetMeta(ctx, state) + if err != nil { + return llb.State{}, err + } + delete(meta.Env, k) + return state, nil +} + +func GetEnv(ctx types.Context, state llb.State) ([]string, error) { + meta, err := GetMeta(ctx, state) + if err != nil { + return nil, err + } + return meta.GetEnv(), nil +} diff --git a/pkg/model/factory/state/fork.go b/pkg/model/factory/state/fork.go new file mode 100644 index 0000000000..13fd830770 --- /dev/null +++ b/pkg/model/factory/state/fork.go @@ -0,0 +1,50 @@ +package state + +import ( + "errors" + + "github.com/moby/buildkit/client/llb" + + "github.com/replicate/cog/pkg/model/factory/types" +) + +func Fork(ctx types.Context, state llb.State) (llb.State, error) { + meta, err := GetMeta(ctx, state) + if err != nil { + return llb.State{}, err + } + + return state.WithValue(metaKey, meta.Clone()), nil +} + +func Merge(ctx types.Context, states ...llb.State) (llb.State, error) { + if len(states) == 0 { + return llb.State{}, errors.New("no states to merge") + } + + lower := states[0] + + if len(states) == 1 { + return lower, nil + } + + var metas []*Meta + var diffs []llb.State + for idx, s := range states { + meta, err := GetMeta(ctx, s) + if err != nil { + return llb.State{}, err + } + metas = append(metas, meta) + if idx == 0 { + diffs = append(diffs, s) + } else { + diffs = append(diffs, llb.Diff(states[0], s)) + } + } + + mergedMeta := metas[0].Clone() + mergedMeta.Merge(metas[1:]...) + + return llb.Merge(diffs).WithValue(metaKey, mergedMeta), nil +} diff --git a/pkg/model/factory/state/labels.go b/pkg/model/factory/state/labels.go new file mode 100644 index 0000000000..162eef3591 --- /dev/null +++ b/pkg/model/factory/state/labels.go @@ -0,0 +1,22 @@ +package state + +import ( + "fmt" + + "github.com/moby/buildkit/client/llb" + + "github.com/replicate/cog/pkg/model/factory/types" +) + +func SetLabel(ctx types.Context, state llb.State, k, v string) (llb.State, error) { + fmt.Println("SetLabel", k, v) + meta, err := GetMeta(ctx, state) + if err != nil { + return llb.State{}, err + } + if meta.Labels == nil { + meta.Labels = make(map[string]string) + } + meta.Labels[k] = v + return state, nil +} diff --git a/pkg/model/factory/state/meta.go b/pkg/model/factory/state/meta.go new file mode 100644 index 0000000000..f2b272d43f --- /dev/null +++ b/pkg/model/factory/state/meta.go @@ -0,0 +1,261 @@ +package state + +import ( + "fmt" + "maps" + "slices" + "strings" + + "github.com/moby/buildkit/client/llb" + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + + "github.com/replicate/cog/pkg/model/factory/types" +) + +func MetaFromImage(img *ocispec.Image) *Meta { + meta := &Meta{ + User: img.Config.User, + WorkingDir: img.Config.WorkingDir, + Entrypoint: slices.Clone(img.Config.Entrypoint), + Cmd: slices.Clone(img.Config.Cmd), + Labels: maps.Clone(img.Config.Labels), + ExposedPorts: maps.Clone(img.Config.ExposedPorts), + Env: map[string]string{}, + } + meta.SetEnviron(img.Config.Env) + + return meta +} + +type Meta struct { + // BaseImage *ocispec.Image + + User string + Cmd []string + Entrypoint []string + Env map[string]string + ExposedPorts map[string]struct{} + Labels map[string]string + Layers []LayerInfo + path []string + WorkingDir string +} + +func (m *Meta) SetEnviron(env []string) { + fmt.Println("set environ", env) + for _, envVar := range env { + parts := strings.SplitN(envVar, "=", 2) + if len(parts) != 2 { + continue + } + m.SetEnv(parts[0], parts[1]) + } +} + +func (m *Meta) SetEnv(k, v string) { + fmt.Println("set env", k, v) + if k == "PATH" { + m.SetPath(v) + return + } + if m.Env == nil { + m.Env = map[string]string{} + } + m.Env[k] = v +} + +func (m *Meta) UnsetEnv(k string) { + if k == "PATH" { + m.path = []string{} + return + } + delete(m.Env, k) +} + +func (m *Meta) GetEnv() []string { + env := []string{} + for k, v := range m.Env { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + env = append(env, fmt.Sprintf("PATH=%s", strings.Join(m.path, ":"))) + return env +} + +func (m *Meta) SetPath(path string) { + m.path = mergePaths(m.path, strings.Split(path, ":")) +} + +func (m *Meta) AppendPath(path string) { + for _, part := range strings.Split(path, ":") { + // ignore $PATH since we're appending + if part == "$PATH" { + continue + } else { + m.path = append(m.path, part) + } + } +} + +func (m *Meta) PrependPath(path string) { + m.path = mergePaths(m.path, strings.Split(path, ":")) +} + +func mergePaths(base []string, incoming []string) []string { + if !slices.Contains(incoming, "$PATH") { + return incoming + } + + var newPath []string + var baseApplied bool + for _, part := range incoming { + if part == "$PATH" && !baseApplied && base != nil { + newPath = append(newPath, base...) + baseApplied = true + continue + } + newPath = append(newPath, part) + } + + return newPath +} + +func (m *Meta) UnsetPath() { + m.path = []string{} +} + +func (m *Meta) ExposePort(port string) { + fmt.Println("set exposed port") + if m.ExposedPorts == nil { + m.ExposedPorts = map[string]struct{}{} + } + m.ExposedPorts[port] = struct{}{} +} + +func (m *Meta) UnexposePort(port string) { + fmt.Println("unset exposed port") + delete(m.ExposedPorts, port) +} + +func (m *Meta) Clone() *Meta { + return &Meta{ + User: m.User, + Cmd: slices.Clone(m.Cmd), + Entrypoint: slices.Clone(m.Entrypoint), + Env: maps.Clone(m.Env), + ExposedPorts: maps.Clone(m.ExposedPorts), + Labels: maps.Clone(m.Labels), + Layers: slices.Clone(m.Layers), + path: slices.Clone(m.path), + WorkingDir: m.WorkingDir, + } +} + +func (m *Meta) Merge(others ...*Meta) { + for _, other := range others { + m.User = other.User + m.Cmd = slices.Clone(other.Cmd) + m.Entrypoint = slices.Clone(other.Entrypoint) + maps.Copy(m.ExposedPorts, other.ExposedPorts) + for k, v := range other.Env { + m.SetEnv(k, v) + } + maps.Copy(m.Env, other.Env) + m.path = slices.Clone(other.path) + m.WorkingDir = other.WorkingDir + // this is almost certainly wrong + m.Layers = append(m.Layers, other.Layers...) + } +} + +func (m *Meta) ToImageConfig() ocispec.ImageConfig { + return ocispec.ImageConfig{ + User: m.User, + Cmd: m.Cmd, + Entrypoint: m.Entrypoint, + Env: m.GetEnv(), + ExposedPorts: m.ExposedPorts, + Labels: m.Labels, + WorkingDir: m.WorkingDir, + } +} + +type LayerInfo struct { + Digest digest.Digest + Role string +} + +var metaKey = struct{}{} + +// func initMeta(ctx context.Context, state llb.State) (llb.State, error) { +// val, err := state.Value(ctx, metaKey) +// if err != nil { +// return state, err +// } + +// meta := val.(*Meta) +// return state.WithValue(metaKey, meta), nil +// } + +func WithMeta(state llb.State, meta *Meta) llb.State { + return state.WithValue(metaKey, meta) +} + +func GetMeta(ctx types.Context, state llb.State) (*Meta, error) { + val, err := state.Value(ctx, metaKey) + if err != nil { + return nil, err + } + return val.(*Meta), nil +} + +type MetaOption func(state llb.State, meta *Meta) llb.State + +func WithWorkingDir(dir string) MetaOption { + return func(state llb.State, meta *Meta) llb.State { + meta.WorkingDir = dir + return state.Dir(dir) + } +} + +func WithExposedPort(port string) MetaOption { + return func(state llb.State, meta *Meta) llb.State { + meta.ExposePort(port) + return state + } +} + +func WithEntrypoint(entrypoint []string) MetaOption { + return func(state llb.State, meta *Meta) llb.State { + meta.Entrypoint = entrypoint + return state + } +} + +func WithCmd(cmd []string) MetaOption { + return func(state llb.State, meta *Meta) llb.State { + meta.Cmd = cmd + return state + } +} + +func WithConfig(ctx types.Context, state llb.State, opts ...MetaOption) (llb.State, error) { + meta, err := GetMeta(ctx, state) + if err != nil { + return state, err + } + + for _, opt := range opts { + state = opt(state, meta) + } + + return state, nil +} + +// func InitMeta(state llb.State) (llb.State, error ) { +// if state.V + +// meta := val.(*Meta) + +// return state.WithValue(metaKey, meta) +// } diff --git a/pkg/model/factory/state/path.go b/pkg/model/factory/state/path.go new file mode 100644 index 0000000000..510f243049 --- /dev/null +++ b/pkg/model/factory/state/path.go @@ -0,0 +1,25 @@ +package state + +import ( + "github.com/moby/buildkit/client/llb" + + "github.com/replicate/cog/pkg/model/factory/types" +) + +func PrependPath(ctx types.Context, state llb.State, val string) (llb.State, error) { + meta, err := GetMeta(ctx, state) + if err != nil { + return llb.State{}, err + } + meta.path = append([]string{val}, meta.path...) + return state, nil +} + +func AppendPath(ctx types.Context, state llb.State, val string) (llb.State, error) { + meta, err := GetMeta(ctx, state) + if err != nil { + return llb.State{}, err + } + meta.path = append(meta.path, val) + return state, nil +} diff --git a/pkg/model/factory/state/run.go b/pkg/model/factory/state/run.go new file mode 100644 index 0000000000..b6cbf09fab --- /dev/null +++ b/pkg/model/factory/state/run.go @@ -0,0 +1,25 @@ +package state + +import ( + "strings" + + "github.com/moby/buildkit/client/llb" + + "github.com/replicate/cog/pkg/model/factory/types" +) + +func Run(ctx types.Context, state llb.State, opts ...llb.RunOption) (llb.State, error) { + meta, err := GetMeta(ctx, state) + if err != nil { + return llb.State{}, err + } + + path := strings.Join(meta.path, ":") + opts = append(opts, llb.AddEnv("PATH", path)) + + for k, v := range meta.Env { + opts = append(opts, llb.AddEnv(k, v)) + } + + return state.Run(opts...).Root(), nil +} diff --git a/pkg/model/factory/types/ctx.go b/pkg/model/factory/types/ctx.go new file mode 100644 index 0000000000..7ec95693c8 --- /dev/null +++ b/pkg/model/factory/types/ctx.go @@ -0,0 +1,19 @@ +package types + +import ( + "context" + + gatewayClient "github.com/moby/buildkit/frontend/gateway/client" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + + "github.com/replicate/cog/pkg/config" +) + +type Context struct { + context.Context + + Config *config.Config + WorkingDir string + Platform ocispec.Platform + Client gatewayClient.Client +} diff --git a/pkg/model/factory/types/env.go b/pkg/model/factory/types/env.go new file mode 100644 index 0000000000..99303cd33d --- /dev/null +++ b/pkg/model/factory/types/env.go @@ -0,0 +1,15 @@ +package types + +import ( + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + + "github.com/replicate/cog/pkg/config" +) + +// BuildEnv holds the build context and state for operations +type BuildEnv struct { + // BuildKitClient client.Client + Config *config.Config + WorkingDir string + Platform ocispec.Platform +} diff --git a/pkg/model/factory/types/labels.go b/pkg/model/factory/types/labels.go new file mode 100644 index 0000000000..26a455b311 --- /dev/null +++ b/pkg/model/factory/types/labels.go @@ -0,0 +1,7 @@ +package types + +const ( + LabelVersion = "run.cog.version" + LabelConfig = "run.cog.config" + LabelOpenAPISchema = "run.cog.openapi_schema" +) diff --git a/pkg/model/factory/types/op.go b/pkg/model/factory/types/op.go new file mode 100644 index 0000000000..324b2db1d1 --- /dev/null +++ b/pkg/model/factory/types/op.go @@ -0,0 +1,15 @@ +package types + +import "context" + +// Operation defines the interface for build operations +type Operation interface { + // Apply executes the operation on the given state and returns the new state + Apply(ctx context.Context, buildEnv *BuildEnv, state State) (State, error) + + // Name returns a human-readable name for this operation + Name() string + + // ShouldRun determines if this operation should execute based on config/state + ShouldRun(ctx context.Context, buildEnv *BuildEnv, state State) bool +} diff --git a/pkg/model/factory/types/stack.go b/pkg/model/factory/types/stack.go new file mode 100644 index 0000000000..4a37e1ba9e --- /dev/null +++ b/pkg/model/factory/types/stack.go @@ -0,0 +1,4 @@ +package types + +type Stack struct { +} diff --git a/pkg/model/factory/types/state.go b/pkg/model/factory/types/state.go new file mode 100644 index 0000000000..10aea57b24 --- /dev/null +++ b/pkg/model/factory/types/state.go @@ -0,0 +1,149 @@ +package types + +import ( + "context" + "fmt" + "maps" + "slices" + "strings" + + "github.com/moby/buildkit/client/llb" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" +) + +// State represents the current build state including filesystem, environment, and metadata +type State struct { + LLB llb.State + // Env []string + // Labels map[string]string + Layers []LayerInfo + // Config ocispec.ImageConfig + + Cmd []string + Entrypoint []string + + Env map[string]string + Labels map[string]string +} + +func (s *State) Fork() State { + return State{ + LLB: s.LLB, + Layers: slices.Clone(s.Layers), + Env: maps.Clone(s.Env), + Labels: maps.Clone(s.Labels), + Cmd: slices.Clone(s.Cmd), + Entrypoint: slices.Clone(s.Entrypoint), + } +} + +func (s *State) MergeLLB(ctx context.Context, llbState llb.State) { + merged := llb.Merge([]llb.State{s.LLB, llbState}) + s.LLB = merged +} + +// SetEnv adds or updates an environment variable +func (s *State) SetEnv(key, value string) { + if s.Env == nil { + s.Env = make(map[string]string) + } + s.Env[key] = value + + // envVar := key + "=" + value + + // // Remove existing env var with same key + // for i, env := range s.env { + // if len(env) > len(key) && env[:len(key)+1] == key+"=" { + // s.env[i] = envVar + // return + // } + // } + + // // Add new env var + // s.env = append(s.env, envVar) +} + +func (s *State) UnsetEnv(key string) { + delete(s.Env, key) +} + +// SetLabel adds or updates a label +func (s *State) SetLabel(key, value string) { + if s.Labels == nil { + s.Labels = make(map[string]string) + } + s.Labels[key] = value +} + +func (s *State) UnsetLabel(key string) { + delete(s.Labels, key) +} + +func (s *State) ToPB(ctx context.Context) (*llb.Definition, error) { + return s.LLB.Marshal(ctx) +} + +func (s *State) ToImage() ocispec.ImageConfig { + cfg := ocispec.ImageConfig{ + Env: make([]string, 0, len(s.Env)), + Labels: make(map[string]string), + Cmd: slices.Clone(s.Cmd), + Entrypoint: slices.Clone(s.Entrypoint), + } + + for k, v := range s.Env { + cfg.Env = append(cfg.Env, fmt.Sprintf("%s=%s", k, v)) + } + + maps.Copy(cfg.Labels, s.Labels) + + return cfg +} + +// func StateFromBaseImage(img *docker.ResolvedImage) (State, error) { +// env := map[string]string{} + +// for _, envVar := range img.GetEnvironment() { +// k, v, err := parseEnv(envVar) +// if err != nil { +// return State{}, err +// } +// env[k] = v +// } + +// state := State{ +// LLB: llb.Image(img.Source, llb.Platform(img.Config.Platform)), +// Env: env, +// Labels: img.GetLabels(), +// Layers: []LayerInfo{ +// { +// Role: "base", +// Description: img.Source, +// }, +// }, +// } + +// if len(img.Config.Config.Cmd) > 0 { +// state.Cmd = slices.Clone(img.Config.Config.Cmd) +// } +// if len(img.Config.Config.Entrypoint) > 0 { +// state.Entrypoint = slices.Clone(img.Config.Config.Entrypoint) +// } + +// return state, nil +// } + +func parseEnv(kv string) (string, string, error) { + parts := strings.SplitN(kv, "=", 2) + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid env var: %s", kv) + } + return parts[0], parts[1], nil +} + +// LayerInfo tracks information about each layer added to the image +type LayerInfo struct { + Role string // e.g., "base", "sys-deps", "model", "weights" + Description string + Size int64 +} diff --git a/pkg/model/factory/types/transaction.go b/pkg/model/factory/types/transaction.go new file mode 100644 index 0000000000..c860ddbeef --- /dev/null +++ b/pkg/model/factory/types/transaction.go @@ -0,0 +1,76 @@ +package types + +import ( + "context" + + "github.com/moby/buildkit/client/llb" +) + +// Transaction represents a group of operations that can be squashed together +type Transaction struct { + operations []Operation + role string + name string +} + +// NewTransaction creates a new transaction with the given role +func NewTransaction(role, name string) *Transaction { + return &Transaction{ + role: role, + name: name, + } +} + +// Add adds an operation to the transaction +func (t *Transaction) Add(op Operation) *Transaction { + t.operations = append(t.operations, op) + return t +} + +// Apply executes all operations in the transaction and squashes the result +func (t *Transaction) Apply(ctx context.Context, buildEnv *BuildEnv, baseState State) (State, error) { + if len(t.operations) == 0 { + return baseState, nil + } + + // Start with base state + currentState := baseState + + // Apply all operations + for _, op := range t.operations { + if !op.ShouldRun(ctx, buildEnv, currentState) { + continue + } + + newState, err := op.Apply(ctx, buildEnv, currentState) + if err != nil { + return State{}, err + } + currentState = newState + } + + // Squash the result back to base + squashedLLB, err := squash(baseState.LLB, currentState.LLB) + if err != nil { + return State{}, err + } + + // Create final state with squashed filesystem but accumulated metadata + finalState := State{ + LLB: squashedLLB, + Env: currentState.Env, + Labels: currentState.Labels, + Layers: append(baseState.Layers, LayerInfo{ + Role: t.role, + Description: t.name, + }), + } + + return finalState, nil +} + +// squash combines two LLB states using diff/copy pattern +func squash(base, target llb.State) (llb.State, error) { + diff := llb.Diff(base, target) + return base.File(llb.Copy(diff, "/", "/")), nil +} diff --git a/pkg/model/model.go b/pkg/model/model.go new file mode 100644 index 0000000000..c13ee6f6bb --- /dev/null +++ b/pkg/model/model.go @@ -0,0 +1,49 @@ +package model + +import ( + "github.com/google/go-containerregistry/pkg/name" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + + "github.com/replicate/cog/pkg/config" +) + +// Model holds the resolved model metadata for a Cog model. +type Model struct { + // Ref is the fully qualified image reference, eg "r8.im/username/modelname" + Ref name.Reference + + Source ModelSource + + Config *config.Config + + Image ocispec.Image + Manifest *ocispec.Manifest +} + +// Name returns the name of the model, typically the repository name, eg "username/modelname" +func (m Model) Name() string { + return m.Ref.Context().RepositoryStr() +} + +// ImageRef returns the fully qualified image reference +func (m Model) ImageRef() string { + return m.Ref.Name() +} + +func (m Model) Reference() name.Reference { + return m.Ref +} + +func (m Model) Size() (n int64) { + for _, layer := range m.Manifest.Layers { + n += layer.Size + } + return +} + +type ModelSource string + +const ( + ModelSourceLocal ModelSource = "local" + ModelSourceRemote ModelSource = "remote" +) diff --git a/pkg/model/model_test.go b/pkg/model/model_test.go new file mode 100644 index 0000000000..8e3e409df6 --- /dev/null +++ b/pkg/model/model_test.go @@ -0,0 +1,35 @@ +package model + +import ( + "testing" + + "github.com/google/go-containerregistry/pkg/name" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReference(t *testing.T) { + testCases := []struct { + input string + expectedName string + expectedRef string + }{ + {input: "test/model:latest", expectedName: "test/model", expectedRef: "r8.im/test/model:latest"}, + {input: "nousername/model:latest", expectedName: "nousername/model", expectedRef: "r8.im/nousername/model:latest"}, + } + + for _, testCase := range testCases { + t.Run(testCase.input, func(t *testing.T) { + + parsedRef, err := name.ParseReference(testCase.input, name.WithDefaultRegistry("r8.im")) + require.NoError(t, err) + + model := Model{ + Ref: parsedRef, + } + + assert.Equal(t, testCase.expectedName, model.Name()) + assert.Equal(t, testCase.expectedRef, model.ImageRef()) + }) + } +} diff --git a/pkg/model/resolver.go b/pkg/model/resolver.go new file mode 100644 index 0000000000..090ba5161d --- /dev/null +++ b/pkg/model/resolver.go @@ -0,0 +1,143 @@ +package model + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/google/go-containerregistry/pkg/name" + containerregistryv1 "github.com/google/go-containerregistry/pkg/v1" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/docker" + "github.com/replicate/cog/pkg/docker/command" +) + +type ResolveOption func(o *resolver) + +func WithPlatform(platform ocispec.Platform) ResolveOption { + return func(o *resolver) { + o.platform = &platform + } +} + +func WithResolveMode(mode docker.ResolveMode) ResolveOption { + return func(o *resolver) { + o.mode = mode + } +} + +func WithProvider(provider command.Command) ResolveOption { + return func(o *resolver) { + o.provider = provider + } +} + +func WithConfig(config *config.Config) ResolveOption { + return func(o *resolver) { + o.config = config + } +} + +func WithDefaultRegistry(registry string) ResolveOption { + return func(o *resolver) { + o.defaultRegistry = registry + } +} + +func Resolve(ctx context.Context, imageRef string, opts ...ResolveOption) (*Model, error) { + resolver := &resolver{ + mode: docker.ResolveModeAuto, + } + + for _, opt := range opts { + opt(resolver) + } + + return resolver.resolve(ctx, imageRef) +} + +type resolver struct { + provider command.Command + platform *ocispec.Platform + mode docker.ResolveMode + + // override "docker.io" as the default registry since it implies specific behavior (eg "library" namespace) + defaultRegistry string + + // overrides that builders can pass in to work around missing metadata in dev-time models + config *config.Config +} + +func (r *resolver) resolve(ctx context.Context, imageRef string) (*Model, error) { + parseOpts := []name.Option{} + if r.defaultRegistry != "" { + parseOpts = append(parseOpts, name.WithDefaultRegistry(r.defaultRegistry)) + } + + ref, err := name.ParseReference(imageRef, parseOpts...) + if err != nil { + return nil, err + } + + img, source, err := docker.ResolveImage(ctx, imageRef, r.provider, r.platform, r.mode) + if err != nil { + return nil, err + } + + model, err := r.modelFromImage(ref, source, img) + if err != nil { + return nil, err + } + + return model, nil +} + +func (r *resolver) modelFromImage(ref name.Reference, source docker.ImageSource, img containerregistryv1.Image) (*Model, error) { + model := &Model{ + Ref: ref, + Source: ModelSource(source), + } + + rawConfig, err := img.RawConfigFile() + if err != nil { + return nil, fmt.Errorf("failed to get image config: %w", err) + } + if err := json.Unmarshal(rawConfig, &model.Image); err != nil { + return nil, fmt.Errorf("failed to unmarshal image: %w", err) + } + + rawManifest, err := img.RawManifest() + if err != nil { + return nil, fmt.Errorf("failed to get image manifest: %w", err) + } + if err := json.Unmarshal(rawManifest, &model.Manifest); err != nil { + return nil, fmt.Errorf("failed to unmarshal manifest: %w", err) + } + + // this is a hack to allow base images built with the dockerfile factory to work without the run.cog.config label + if r.config != nil { + model.Config = r.config + } else { + model.Config, err = cogConfigFromImage(model.Image.Config) + if err != nil { + return nil, fmt.Errorf("failed to get cog config from image: %w", err) + } + } + + return model, nil +} + +func cogConfigFromImage(imageConfig ocispec.ImageConfig) (*config.Config, error) { + encodedConfig, ok := imageConfig.Labels[LabelCogConfig] + if !ok { + return nil, fmt.Errorf("cog config not found in image labels") + } + + cfg := &config.Config{} + if err := json.Unmarshal([]byte(encodedConfig), cfg); err != nil { + return nil, fmt.Errorf("failed to unmarshal cog config: %w", err) + } + return cfg, nil +} diff --git a/pkg/model/spec.go b/pkg/model/spec.go new file mode 100644 index 0000000000..bf558b71a1 --- /dev/null +++ b/pkg/model/spec.go @@ -0,0 +1,7 @@ +package model + +const ( + LabelVersion = "run.cog.version" + LabelCogConfig = "run.cog.config" + LabelOpenAPISchema = "run.cog.openapi_schema" +) diff --git a/pkg/predict/predictor.go b/pkg/predict/predictor.go index fbc52e227c..85c15e04d3 100644 --- a/pkg/predict/predictor.go +++ b/pkg/predict/predictor.go @@ -58,11 +58,7 @@ type Predictor struct { port int } -func NewPredictor(ctx context.Context, runOptions command.RunOptions, isTrain bool, fastFlag bool, dockerCommand command.Command) (*Predictor, error) { - if fastFlag { - console.Info("Fast predictor enabled.") - } - +func NewPredictor(ctx context.Context, runOptions command.RunOptions, isTrain bool, dockerCommand command.Command) (*Predictor, error) { if global.Debug { runOptions.Env = append(runOptions.Env, "COG_LOG_LEVEL=debug") } else { diff --git a/pkg/util/env.go b/pkg/util/env.go index 60dc4b7b6e..6a843b47e4 100644 --- a/pkg/util/env.go +++ b/pkg/util/env.go @@ -2,6 +2,7 @@ package util import ( "os" + "strconv" "github.com/replicate/cog/pkg/util/console" ) @@ -20,3 +21,9 @@ func GetEnvOrDefault[T any](key string, defaultVal T, conversionFunc func(string } return defaultVal } + +// EnvIsTruthy returns true if the environment variable is set to 1, t, T, TRUE, true, True. Any other value returns false. +func EnvIsTruthy(key string) bool { + ok, _ := strconv.ParseBool(os.Getenv(key)) + return ok +} diff --git a/pkg/util/print.go b/pkg/util/print.go new file mode 100644 index 0000000000..e00334c603 --- /dev/null +++ b/pkg/util/print.go @@ -0,0 +1,11 @@ +package util + +import ( + "encoding/json" + "fmt" +) + +func PrettyPrintJSON(thing any) { + json, _ := json.MarshalIndent(thing, "", " ") + fmt.Println(string(json)) +}