Skip to content

Commit 589226e

Browse files
committed
add bpe interface
1 parent 5416484 commit 589226e

File tree

5 files changed

+27
-7
lines changed

5 files changed

+27
-7
lines changed

encoding.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ func initEncoding(encodingName string) (*Encoding, error) {
116116
}
117117

118118
func cl100k_base() (*Encoding, error) {
119-
ranks, err := loadTiktokenBpe("https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken")
119+
ranks, err := bpeLoader.LoadTiktokenBpe("https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken")
120120
if err != nil {
121121
return nil, err
122122
}
@@ -136,7 +136,7 @@ func cl100k_base() (*Encoding, error) {
136136
}
137137

138138
func p50k_edit() (*Encoding, error) {
139-
ranks, err := loadTiktokenBpe("https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken")
139+
ranks, err := bpeLoader.LoadTiktokenBpe("https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken")
140140
if err != nil {
141141
return nil, err
142142
}
@@ -150,7 +150,7 @@ func p50k_edit() (*Encoding, error) {
150150
}
151151

152152
func p50k_base() (*Encoding, error) {
153-
ranks, err := loadTiktokenBpe("https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken")
153+
ranks, err := bpeLoader.LoadTiktokenBpe("https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken")
154154
if err != nil {
155155
return nil, err
156156
}
@@ -173,7 +173,7 @@ func p50k_base() (*Encoding, error) {
173173
}
174174

175175
func r50k_base() (*Encoding, error) {
176-
ranks, err := loadTiktokenBpe("https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken")
176+
ranks, err := bpeLoader.LoadTiktokenBpe("https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken")
177177
if err != nil {
178178
return nil, err
179179
}

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module github.com/pkoukk/tiktoken-go
33
go 1.19
44

55
require (
6-
github.com/dlclark/regexp2 v1.8.1
6+
github.com/dlclark/regexp2 v1.10.0
77
github.com/google/uuid v1.3.0
88
github.com/stretchr/testify v1.8.2
99
)

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
22
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
33
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
4-
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
5-
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
4+
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
5+
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
66
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
77
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
88
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=

load.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ import (
1414
"github.com/google/uuid"
1515
)
1616

17+
type BpeLoader interface {
18+
LoadTiktokenBpe(tiktokenBpeFile string) (map[string]int, error)
19+
}
20+
1721
func readFile(blobpath string) ([]byte, error) {
1822
if !strings.HasPrefix(blobpath, "http://") && !strings.HasPrefix(blobpath, "https://") {
1923
file, err := os.Open(blobpath)
@@ -91,3 +95,13 @@ func loadTiktokenBpe(tiktokenBpeFile string) (map[string]int, error) {
9195
}
9296
return bpeRanks, nil
9397
}
98+
99+
type defaultBpeLoader struct{}
100+
101+
func (l *defaultBpeLoader) LoadTiktokenBpe(tiktokenBpeFile string) (map[string]int, error) {
102+
return loadTiktokenBpe(tiktokenBpeFile)
103+
}
104+
105+
func NewDefaultBpeLoader() BpeLoader {
106+
return &defaultBpeLoader{}
107+
}

tiktoken.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ import (
88
"github.com/dlclark/regexp2"
99
)
1010

11+
var bpeLoader BpeLoader = NewDefaultBpeLoader()
12+
13+
func SetBpeLoader(loader BpeLoader) {
14+
bpeLoader = loader
15+
}
16+
1117
func GetEncoding(encodingName string) (*Tiktoken, error) {
1218
enc, err := getEncoding(encodingName)
1319
if err != nil {

0 commit comments

Comments
 (0)