Skip to content

Commit 2488c44

Browse files
authored
feat: bert.cpp token embeddings (#241)
1 parent b4241d0 commit 2488c44

File tree

4 files changed

+38
-6
lines changed

4 files changed

+38
-6
lines changed

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ GOGPT2_VERSION?=92421a8cf61ed6e03babd9067af292b094cb1307
1010
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
1111
RWKV_VERSION?=07166da10cb2a9e8854395a4f210464dcea76e47
1212
WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993
13-
BERT_VERSION?=ec771ec715576ac050263bb7bb74bfd616a5ba13
13+
BERT_VERSION?=ac22f8f74aec5e31bc46242c17e7d511f127856b
1414
BLOOMZ_VERSION?=e9366e82abdfe70565644fbfae9651976714efd1
1515

1616

@@ -182,6 +182,7 @@ test-models/testmodel:
182182
mkdir test-dir
183183
wget https://huggingface.co/concedo/cerebras-111M-ggml/resolve/main/cerberas-111m-q4_0.bin -O test-models/testmodel
184184
wget https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
185+
wget https://huggingface.co/skeskinen/ggml/resolve/main/all-MiniLM-L6-v2/ggml-model-q4_0.bin -O test-models/bert
185186
wget https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav
186187
cp tests/fixtures/* test-models
187188

api/api_test.go

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ var _ = Describe("API test", func() {
4747
It("returns the models list", func() {
4848
models, err := client.ListModels(context.TODO())
4949
Expect(err).ToNot(HaveOccurred())
50-
Expect(len(models.Models)).To(Equal(5))
51-
Expect(models.Models[0].ID).To(Equal("testmodel"))
50+
Expect(len(models.Models)).To(Equal(7))
5251
})
5352
It("can generate completions", func() {
5453
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"})
@@ -97,6 +96,33 @@ var _ = Describe("API test", func() {
9796
Expect(err).ToNot(HaveOccurred())
9897
Expect(resp.Text).To(ContainSubstring("This is the Micro Machine Man presenting"))
9998
})
99+
100+
It("calculate embeddings", func() {
101+
if runtime.GOOS != "linux" {
102+
Skip("test supported only on linux")
103+
}
104+
resp, err := client.CreateEmbeddings(
105+
context.Background(),
106+
openai.EmbeddingRequest{
107+
Model: openai.AdaEmbeddingV2,
108+
Input: []string{"sun", "cat"},
109+
},
110+
)
111+
Expect(err).ToNot(HaveOccurred())
112+
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384))
113+
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384))
114+
115+
sunEmbedding := resp.Data[0].Embedding
116+
resp2, err := client.CreateEmbeddings(
117+
context.Background(),
118+
openai.EmbeddingRequest{
119+
Model: openai.AdaEmbeddingV2,
120+
Input: []string{"sun"},
121+
},
122+
)
123+
Expect(err).ToNot(HaveOccurred())
124+
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
125+
})
100126
})
101127

102128
Context("Config file", func() {
@@ -123,8 +149,7 @@ var _ = Describe("API test", func() {
123149

124150
models, err := client.ListModels(context.TODO())
125151
Expect(err).ToNot(HaveOccurred())
126-
Expect(len(models.Models)).To(Equal(7))
127-
Expect(models.Models[0].ID).To(Equal("testmodel"))
152+
Expect(len(models.Models)).To(Equal(9))
128153
})
129154
It("can generate chat completions from config file", func() {
130155
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}})

api/prediction.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config)
6868
case *bert.Bert:
6969
fn = func() ([]float32, error) {
7070
if len(tokens) > 0 {
71-
return nil, fmt.Errorf("embeddings endpoint for this model supports only string")
71+
return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads))
7272
}
7373
return model.Embeddings(s, bert.SetThreads(c.Threads))
7474
}

tests/fixtures/embeddings.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
name: text-embedding-ada-002
2+
parameters:
3+
model: bert
4+
threads: 14
5+
backend: bert-embeddings
6+
embeddings: true

0 commit comments

Comments
 (0)