Skip to content

Commit 78ef045

Browse files
committed
wip
1 parent f0e265a commit 78ef045

File tree

21 files changed

+485
-336
lines changed

21 files changed

+485
-336
lines changed

api/backend/llm.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ type TokenUsage struct {
2626
Completion int
2727
}
2828

29-
func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
29+
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
3030
modelFile := c.Model
3131

3232
grpcOpts := gRPCModelOpts(c)
@@ -72,6 +72,7 @@ func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c
7272
fn := func() (LLMResponse, error) {
7373
opts := gRPCPredictOpts(c, loader.ModelPath)
7474
opts.Prompt = s
75+
opts.Images = images
7576

7677
tokenUsage := TokenUsage{}
7778

api/backend/options.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ func gRPCModelOpts(c config.Config) *pb.ModelOptions {
4545
DraftModel: c.DraftModel,
4646
AudioPath: c.VallE.AudioPath,
4747
Quantization: c.Quantization,
48+
MMProj: c.MMProj,
4849
LoraAdapter: c.LoraAdapter,
4950
LoraBase: c.LoraBase,
5051
NGQA: c.NGQA,

api/config/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ type LLMConfig struct {
104104
DraftModel string `yaml:"draft_model"`
105105
NDraft int32 `yaml:"n_draft"`
106106
Quantization string `yaml:"quantization"`
107+
MMProj string `yaml:"mmproj"`
107108
}
108109

109110
type AutoGPTQ struct {

api/openai/chat.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
8181
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
8282
}
8383

84+
if input.ResponseFormat == "json_object" {
85+
input.Grammar = grammar.JSONBNF
86+
}
87+
8488
// process functions if we have any defined or if we have a function call string
8589
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
8690
log.Debug().Msgf("Response needs to process functions")
@@ -140,14 +144,14 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
140144
}
141145
}
142146
r := config.Roles[role]
143-
contentExists := i.Content != nil && *i.Content != ""
147+
contentExists := i.Content != nil && i.StringContent != ""
144148
// First attempt to populate content via a chat message specific template
145149
if config.TemplateConfig.ChatMessage != "" {
146150
chatMessageData := model.ChatMessageTemplateData{
147151
SystemPrompt: config.SystemPrompt,
148152
Role: r,
149153
RoleName: role,
150-
Content: *i.Content,
154+
Content: i.StringContent,
151155
MessageIndex: messageIndex,
152156
}
153157
templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
@@ -166,7 +170,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
166170
if content == "" {
167171
if r != "" {
168172
if contentExists {
169-
content = fmt.Sprint(r, " ", *i.Content)
173+
content = fmt.Sprint(r, " ", i.StringContent)
170174
}
171175
if i.FunctionCall != nil {
172176
j, err := json.Marshal(i.FunctionCall)
@@ -180,7 +184,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
180184
}
181185
} else {
182186
if contentExists {
183-
content = fmt.Sprint(*i.Content)
187+
content = fmt.Sprint(i.StringContent)
184188
}
185189
if i.FunctionCall != nil {
186190
j, err := json.Marshal(i.FunctionCall)
@@ -334,7 +338,11 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
334338
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
335339
// Note: This costs (in term of CPU) another computation
336340
config.Grammar = ""
337-
predFunc, err := backend.ModelInference(input.Context, predInput, o.Loader, *config, o, nil)
341+
images := []string{}
342+
for _, m := range input.Messages {
343+
images = append(images, m.StringImages...)
344+
}
345+
predFunc, err := backend.ModelInference(input.Context, predInput, images, o.Loader, *config, o, nil)
338346
if err != nil {
339347
log.Error().Msgf("inference error: %s", err.Error())
340348
return

api/openai/completion.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
config "github.com/go-skynet/LocalAI/api/config"
1313
"github.com/go-skynet/LocalAI/api/options"
1414
"github.com/go-skynet/LocalAI/api/schema"
15+
"github.com/go-skynet/LocalAI/pkg/grammar"
1516
model "github.com/go-skynet/LocalAI/pkg/model"
1617
"github.com/gofiber/fiber/v2"
1718
"github.com/google/uuid"
@@ -64,6 +65,10 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
6465
return fmt.Errorf("failed reading parameters from request:%w", err)
6566
}
6667

68+
if input.ResponseFormat == "json_object" {
69+
input.Grammar = grammar.JSONBNF
70+
}
71+
6772
log.Debug().Msgf("Parameter Config: %+v", config)
6873

6974
if input.Stream {

api/openai/inference.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,13 @@ func ComputeChoices(
2323
n = 1
2424
}
2525

26+
images := []string{}
27+
for _, m := range req.Messages {
28+
images = append(images, m.StringImages...)
29+
}
30+
2631
// get the model function to call for the result
27-
predFunc, err := backend.ModelInference(req.Context, predInput, loader, *config, o, tokenCallback)
32+
predFunc, err := backend.ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback)
2833
if err != nil {
2934
return result, backend.TokenUsage{}, err
3035
}

api/openai/request.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ package openai
22

33
import (
44
"context"
5+
"encoding/base64"
56
"encoding/json"
67
"fmt"
8+
"io/ioutil"
9+
"net/http"
710
"os"
811
"path/filepath"
912
"strings"
@@ -61,6 +64,37 @@ func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *sche
6164
return modelFile, input, nil
6265
}
6366

67+
// this function check if the string is an URL, if it's an URL downloads the image in memory
68+
// encodes it in base64 and returns the base64 string
69+
func getBase64Image(s string) (string, error) {
70+
if strings.HasPrefix(s, "http") {
71+
// download the image
72+
resp, err := http.Get(s)
73+
if err != nil {
74+
return "", err
75+
}
76+
defer resp.Body.Close()
77+
78+
// read the image data into memory
79+
data, err := ioutil.ReadAll(resp.Body)
80+
if err != nil {
81+
return "", err
82+
}
83+
84+
// encode the image data in base64
85+
encoded := base64.StdEncoding.EncodeToString(data)
86+
87+
// return the base64 string
88+
return encoded, nil
89+
}
90+
91+
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
92+
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
93+
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
94+
}
95+
return "", fmt.Errorf("not valid string")
96+
}
97+
6498
func updateConfig(config *config.Config, input *schema.OpenAIRequest) {
6599
if input.Echo {
66100
config.Echo = input.Echo
@@ -129,6 +163,35 @@ func updateConfig(config *config.Config, input *schema.OpenAIRequest) {
129163
}
130164
}
131165

166+
// Decode each request's message content
167+
index := 0
168+
for _, m := range input.Messages {
169+
switch content := m.Content.(type) {
170+
case string:
171+
m.StringContent = content
172+
case []interface{}:
173+
dat, _ := json.Marshal(content)
174+
c := []schema.Content{}
175+
json.Unmarshal(dat, &c)
176+
for _, pp := range c {
177+
if pp.Type == "text" {
178+
m.StringContent = pp.Text
179+
} else if pp.Type == "image_url" {
180+
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
181+
base64, err := getBase64Image(pp.ImageURL)
182+
if err == nil {
183+
m.StringImages = append(m.StringImages, base64) // TODO: make sure that we only return base64 stuff
184+
// set a placeholder for each image
185+
m.StringContent = m.StringContent + fmt.Sprintf("[img-%d]", index)
186+
index++
187+
} else {
188+
fmt.Print("Failed encoding image", err)
189+
}
190+
}
191+
}
192+
}
193+
}
194+
132195
if input.RepeatPenalty != 0 {
133196
config.RepeatPenalty = input.RepeatPenalty
134197
}

api/schema/openai.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,21 @@ type Choice struct {
5555
Text string `json:"text,omitempty"`
5656
}
5757

58+
type Content struct {
59+
Type string `json:"type" yaml:"type"`
60+
Text string `json:"text" yaml:"text"`
61+
ImageURL string `json:"image_url" yaml:"image_url"`
62+
}
63+
5864
type Message struct {
5965
// The message role
6066
Role string `json:"role,omitempty" yaml:"role"`
6167
// The message content
62-
Content *string `json:"content" yaml:"content"`
68+
Content interface{} `json:"content" yaml:"content"`
69+
70+
StringContent string `json:"string_content,omitempty" yaml:"string_content,omitempty"`
71+
StringImages []string `json:"string_images,omitempty" yaml:"string_images,omitempty"`
72+
6373
// A result of a function call
6474
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`
6575
}

backend/cpp/llama/grpc-server.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
// llama.cpp gRPC C++ backend server
22
//
3-
// Ettore Di Giacinto <[email protected]>
3+
// Ettore Di Giacinto <[email protected]> and llama.cpp authors
44
//
55
// This is a gRPC server for llama.cpp compatible with the LocalAI proto
6-
// Note: this is a re-adaptation of the original llama.cpp example/server.cpp for HTTP,
6+
// Note: this is a re-adaptation of the original llama.cpp example/server.cpp for HTTP (https://github.com/ggerganov/llama.cpp/tree/master/examples/server),
77
// but modified to work with gRPC
88
//
99

@@ -39,7 +39,7 @@ using grpc::Status;
3939
using backend::HealthMessage;
4040

4141

42-
///// LLAMA.CPP server
42+
///// LLAMA.CPP server code below
4343

4444
using json = nlohmann::json;
4545

@@ -1809,7 +1809,9 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
18091809

18101810
/////////////////////////////////
18111811
////////////////////////////////
1812-
//////// LOCALAI
1812+
//////// LOCALAI code starts below here
1813+
/////////////////////////////////
1814+
////////////////////////////////
18131815

18141816
bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model
18151817

@@ -1880,6 +1882,16 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
18801882
data["prompt"] = predict->prompt();
18811883
data["ignore_eos"] = predict->ignoreeos();
18821884

1885+
// for each image in the request, add the image data
1886+
//
1887+
for (int i = 0; i < predict->images_size(); i++) {
1888+
data["image_data"].push_back(json
1889+
{
1890+
{"id", i},
1891+
{"data", predict->images(i)},
1892+
});
1893+
}
1894+
18831895
data["stop"] = predict->stopprompts();
18841896
// data["n_probs"] = predict->nprobs();
18851897
//TODO: images,
@@ -1953,14 +1965,17 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
19531965
// }
19541966
// }
19551967

1956-
1957-
19581968
static void params_parse(const backend::ModelOptions* request,
19591969
gpt_params & params) {
19601970

19611971
// this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809
19621972

19631973
params.model = request->modelfile();
1974+
if (!request->mmproj().empty()) {
1975+
// get the directory of modelfile
1976+
std::string model_dir = params.model.substr(0, params.model.find_last_of("/\\"));
1977+
params.mmproj = model_dir + request->mmproj();
1978+
}
19641979
// params.model_alias ??
19651980
params.model_alias = request->modelfile();
19661981
params.n_ctx = request->contextsize();
@@ -2071,16 +2086,6 @@ class BackendServiceImpl final : public backend::Backend::Service {
20712086
break;
20722087
}
20732088
}
2074-
return grpc::Status::OK;
2075-
2076-
2077-
// auto on_complete = [task_id, &llama] (bool)
2078-
// {
2079-
// // cancel
2080-
// llama.request_cancel(task_id);
2081-
// };
2082-
2083-
20842089

20852090
return grpc::Status::OK;
20862091
}

0 commit comments

Comments
 (0)