Skip to content

Commit 1071b75

Browse files
committed
Allow to install ollama models from CLI
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 25d3fb3 commit 1071b75

File tree

4 files changed

+51
-6
lines changed

4 files changed

+51
-6
lines changed

core/cli/models.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
cliContext "github.com/go-skynet/LocalAI/core/cli/context"
88

9+
"github.com/go-skynet/LocalAI/pkg/downloader"
910
"github.com/go-skynet/LocalAI/pkg/gallery"
1011
"github.com/go-skynet/LocalAI/pkg/startup"
1112
"github.com/rs/zerolog/log"
@@ -79,13 +80,15 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
7980
return err
8081
}
8182

82-
model := gallery.FindModel(models, modelName, mi.ModelsPath)
83-
if model == nil {
84-
log.Error().Str("model", modelName).Msg("model not found")
85-
return err
86-
}
83+
if !downloader.LooksLikeOCI(modelName) {
84+
model := gallery.FindModel(models, modelName, mi.ModelsPath)
85+
if model == nil {
86+
log.Error().Str("model", modelName).Msg("model not found")
87+
return err
88+
}
8789

88-
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
90+
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
91+
}
8992
err = startup.InstallModels(galleries, "", mi.ModelsPath, progressCallback, modelName)
9093
if err != nil {
9194
return err

core/cli/util.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@ package cli
22

33
import (
44
"fmt"
5+
"path/filepath"
56

67
"github.com/rs/zerolog/log"
78

89
cliContext "github.com/go-skynet/LocalAI/core/cli/context"
10+
"github.com/go-skynet/LocalAI/pkg/downloader"
11+
"github.com/go-skynet/LocalAI/pkg/utils"
912
gguf "github.com/thxcode/gguf-parser-go"
1013
)
1114

1215
type UtilCMD struct {
1316
GGUFInfo GGUFInfoCMD `cmd:"" name:"gguf-info" help:"Get information about a GGUF file"`
17+
Download DownloadCMD `cmd:"" name:"download" help:"Download a file or a model from an OCI registry"`
1418
}
1519

1620
type GGUFInfoCMD struct {
@@ -53,3 +57,16 @@ func (u *GGUFInfoCMD) Run(ctx *cliContext.Context) error {
5357

5458
return nil
5559
}
60+
61+
type DownloadCMD struct {
62+
Args []string `arg:"" optional:"" name:"args" help:"File URL and name to download"`
63+
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
64+
}
65+
66+
func (u *DownloadCMD) Run(ctx *cliContext.Context) error {
67+
if len(u.Args) < 2 {
68+
return fmt.Errorf("no URL or model name provided")
69+
}
70+
71+
return downloader.DownloadFile(u.Args[0], filepath.Join(u.ModelsPath, u.Args[1]), "", 1, 1, utils.DisplayDownloadFunction)
72+
}

pkg/downloader/uri.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,11 @@ func DownloadFile(url string, filePath, sha string, fileN, total int, downloadSt
173173
}
174174

175175
if strings.HasPrefix(url, OllamaPrefix) {
176+
url = strings.TrimPrefix(url, OllamaPrefix)
176177
return oci.OllamaFetchModel(url, filePath, progressStatus)
177178
}
178179

180+
url = strings.TrimPrefix(url, OCIPrefix)
179181
img, err := oci.GetImage(url, "", nil, nil)
180182
if err != nil {
181183
return fmt.Errorf("failed to get image %q: %v", url, err)

pkg/startup/model_preload.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"os"
77
"path/filepath"
8+
"strings"
89

910
"github.com/go-skynet/LocalAI/embedded"
1011
"github.com/go-skynet/LocalAI/pkg/downloader"
@@ -52,6 +53,28 @@ func InstallModels(galleries []gallery.Gallery, modelLibraryURL string, modelPat
5253
log.Error().Err(e).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition")
5354
err = errors.Join(err, e)
5455
}
56+
case downloader.LooksLikeOCI(url):
57+
log.Debug().Msgf("[startup] resolved OCI model to download: %s", url)
58+
59+
// convert OCI image name to a file name.
60+
ociName := strings.TrimPrefix(url, downloader.OCIPrefix)
61+
ociName = strings.TrimPrefix(ociName, downloader.OllamaPrefix)
62+
ociName = strings.ReplaceAll(ociName, "/", "__")
63+
ociName = strings.ReplaceAll(ociName, ":", "__")
64+
65+
// check if file exists
66+
if _, e := os.Stat(filepath.Join(modelPath, ociName)); errors.Is(e, os.ErrNotExist) {
67+
modelDefinitionFilePath := filepath.Join(modelPath, ociName)
68+
e := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) {
69+
utils.DisplayDownloadFunction(fileName, current, total, percent)
70+
})
71+
if e != nil {
72+
log.Error().Err(e).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model")
73+
err = errors.Join(err, e)
74+
}
75+
}
76+
77+
log.Info().Msgf("[startup] installed model from OCI repository: %s", ociName)
5578
case downloader.LooksLikeURL(url):
5679
log.Debug().Msgf("[startup] resolved model to download: %s", url)
5780

0 commit comments

Comments
 (0)