From 2ddcca1172a623422fc682d1b4e89d8fc005e79e Mon Sep 17 00:00:00 2001 From: dysrama Date: Wed, 30 Apr 2025 16:06:12 +0200 Subject: [PATCH 1/2] Support for streaming formatters --- go/ai/format_array.go | 12 +++ go/ai/format_enum.go | 12 +++ go/ai/format_json.go | 55 +++++++++- go/ai/format_jsonl.go | 12 +++ go/ai/format_text.go | 15 +++ go/ai/formatter.go | 4 + go/ai/formatter_test.go | 216 ++++++++++++++++++++++++++++++++++++++- go/ai/generate.go | 5 + go/ai/generator_test.go | 4 +- go/go.mod | 1 + go/go.sum | 2 + go/internal/base/json.go | 21 +++- 12 files changed, 344 insertions(+), 15 deletions(-) diff --git a/go/ai/format_array.go b/go/ai/format_array.go index 2dd6d46781..977684af08 100644 --- a/go/ai/format_array.go +++ b/go/ai/format_array.go @@ -15,6 +15,7 @@ package ai import ( + "context" "encoding/json" "errors" "fmt" @@ -68,6 +69,12 @@ func (a arrayHandler) Config() ModelOutputConfig { return a.config } +func (a arrayHandler) StreamCallback(cb ModelStreamCallback) ModelStreamCallback { + return func(ctx context.Context, mrc *ModelResponseChunk) error { + return cb(ctx, mrc) + } +} + // ParseMessage parses the message and returns the formatted message. func (a arrayHandler) ParseMessage(m *Message) (*Message, error) { if a.config.Format == OutputFormatArray { @@ -103,3 +110,8 @@ func (a arrayHandler) ParseMessage(m *Message) (*Message, error) { return m, nil } + +// ParseChunk parse the chunk and returns a new formatted chunk. +func (a arrayHandler) ParseChunk(c *ModelResponseChunk) (*ModelResponseChunk, error) { + return c, nil +} diff --git a/go/ai/format_enum.go b/go/ai/format_enum.go index f984d62193..ee1c3d0c58 100644 --- a/go/ai/format_enum.go +++ b/go/ai/format_enum.go @@ -15,6 +15,7 @@ package ai import ( + "context" "errors" "fmt" "regexp" @@ -67,6 +68,12 @@ func (e enumHandler) Config() ModelOutputConfig { return e.config } +func (e enumHandler) StreamCallback(cb ModelStreamCallback) ModelStreamCallback { + return func(ctx context.Context, mrc *ModelResponseChunk) error { + return cb(ctx, mrc) + } +} + // ParseMessage parses the message and returns the formatted message. func (e enumHandler) ParseMessage(m *Message) (*Message, error) { if e.config.Format == OutputFormatEnum { @@ -100,6 +107,11 @@ func (e enumHandler) ParseMessage(m *Message) (*Message, error) { return m, nil } +// ParseChunk parse the chunk and returns a new formatted chunk. +func (e enumHandler) ParseChunk(c *ModelResponseChunk) (*ModelResponseChunk, error) { + return c, nil +} + // Get enum strings from json schema func objectEnums(schema map[string]any) []string { var enums []string diff --git a/go/ai/format_json.go b/go/ai/format_json.go index 4bef3ed0ea..3bbefc9f85 100644 --- a/go/ai/format_json.go +++ b/go/ai/format_json.go @@ -15,10 +15,13 @@ package ai import ( + "context" "encoding/json" "errors" "fmt" + partialparser "github.com/blaze2305/partial-json-parser" + "github.com/blaze2305/partial-json-parser/options" "github.com/firebase/genkit/go/internal/base" ) @@ -55,8 +58,9 @@ func (j jsonFormatter) Handler(schema map[string]any) (FormatHandler, error) { // jsonHandler is a handler for the JSON formatter. type jsonHandler struct { - instructions string - config ModelOutputConfig + instructions string + config ModelOutputConfig + previousParts []*Part } // Instructions returns the instructions for the formatter. @@ -69,6 +73,21 @@ func (j jsonHandler) Config() ModelOutputConfig { return j.config } +// StreamCallback handler for streaming formatted responses +func (j jsonHandler) StreamCallback(cb ModelStreamCallback) ModelStreamCallback { + return func(ctx context.Context, mrc *ModelResponseChunk) error { + j.previousParts = append(j.previousParts, mrc.Content...) + mrc.Content = j.previousParts + + parsed, err := j.ParseChunk(mrc) + if err != nil { + return err + } + + return cb(ctx, parsed) + } +} + // ParseMessage parses the message and returns the formatted message. func (j jsonHandler) ParseMessage(m *Message) (*Message, error) { if j.config.Format == OutputFormatJSON { @@ -85,7 +104,6 @@ func (j jsonHandler) ParseMessage(m *Message) (*Message, error) { } text := base.ExtractJSONFromMarkdown(part.Text) - if j.config.Schema != nil { var schemaBytes []byte schemaBytes, err := json.Marshal(j.config.Schema) @@ -107,3 +125,34 @@ func (j jsonHandler) ParseMessage(m *Message) (*Message, error) { return m, nil } + +// ParseChunk parse the chunk and returns a new formatted chunk. +func (j jsonHandler) ParseChunk(c *ModelResponseChunk) (*ModelResponseChunk, error) { + if j.config.Format == OutputFormatJSON { + if c == nil { + return nil, errors.New("chunk is empty") + } + + if len(c.Content) == 0 { + return nil, errors.New("message has no content") + } + + // Get all chunks streamed so far + text := c.Text() + text = base.ExtractJSONFromMarkdown(text) + // Try and extract a json object + text = base.GetJsonObject(text) + if text != "" { + var err error + text, err = partialparser.ParseMalformedString(text, options.ALL, false) + if err != nil { + return nil, errors.New("message is not a valid JSON") + } + } else { + return nil, nil + } + + c.Content = []*Part{NewJSONPart(text)} + } + return c, nil +} diff --git a/go/ai/format_jsonl.go b/go/ai/format_jsonl.go index 576ec6ec4c..91a37a0ab6 100644 --- a/go/ai/format_jsonl.go +++ b/go/ai/format_jsonl.go @@ -15,6 +15,7 @@ package ai import ( + "context" "encoding/json" "errors" "fmt" @@ -69,6 +70,12 @@ func (j jsonlHandler) Config() ModelOutputConfig { return j.config } +func (j jsonlHandler) StreamCallback(cb ModelStreamCallback) ModelStreamCallback { + return func(ctx context.Context, mrc *ModelResponseChunk) error { + return cb(ctx, mrc) + } +} + // ParseMessage parses the message and returns the formatted message. func (j jsonlHandler) ParseMessage(m *Message) (*Message, error) { if j.config.Format == OutputFormatJSONL { @@ -106,3 +113,8 @@ func (j jsonlHandler) ParseMessage(m *Message) (*Message, error) { return m, nil } + +// ParseChunk parse the chunk and returns a new formatted chunk. +func (j jsonlHandler) ParseChunk(c *ModelResponseChunk) (*ModelResponseChunk, error) { + return c, nil +} diff --git a/go/ai/format_text.go b/go/ai/format_text.go index b9b85d9b8a..79ad1cbfac 100644 --- a/go/ai/format_text.go +++ b/go/ai/format_text.go @@ -14,6 +14,10 @@ package ai +import ( + "context" +) + type textFormatter struct{} // Name returns the name of the formatter. @@ -47,7 +51,18 @@ func (t textHandler) Instructions() string { return t.instructions } +func (t textHandler) StreamCallback(cb ModelStreamCallback) ModelStreamCallback { + return func(ctx context.Context, mrc *ModelResponseChunk) error { + return cb(ctx, mrc) + } +} + // ParseMessage parses the message and returns the formatted message. func (t textHandler) ParseMessage(m *Message) (*Message, error) { return m, nil } + +// ParseChunk parse the chunk and returns a new formatted chunk. +func (t textHandler) ParseChunk(c *ModelResponseChunk) (*ModelResponseChunk, error) { + return c, nil +} diff --git a/go/ai/formatter.go b/go/ai/formatter.go index b6ec7c81f2..7546f41321 100644 --- a/go/ai/formatter.go +++ b/go/ai/formatter.go @@ -50,10 +50,14 @@ type Formatter interface { type FormatHandler interface { // ParseMessage parses the message and returns a new formatted message. ParseMessage(message *Message) (*Message, error) + // ParseChunk parse the chunk and returns a new formatted chunk. + ParseChunk(chunk *ModelResponseChunk) (*ModelResponseChunk, error) // Instructions returns the formatter instructions to embed in the prompt. Instructions() string // Config returns the output config for the model request. Config() ModelOutputConfig + // Stream callback returns a ModelStreamCallback + StreamCallback(cb ModelStreamCallback) ModelStreamCallback } // ConfigureFormats registers default formats in the registry diff --git a/go/ai/formatter_test.go b/go/ai/formatter_test.go index 27ecb31d37..8839b817eb 100644 --- a/go/ai/formatter_test.go +++ b/go/ai/formatter_test.go @@ -41,7 +41,7 @@ func TestConstrainedGenerate(t *testing.T) { formatModel := DefineModel(r, "test", "format", &modelInfo, func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) { if msc != nil { msc(ctx, &ModelResponseChunk{ - Content: []*Part{NewTextPart("stream!")}, + Content: []*Part{NewJSONPart("{\"Foo\": \"bar\"}")}, }) } @@ -53,7 +53,7 @@ func TestConstrainedGenerate(t *testing.T) { t.Run("doesn't inject instructions when model supports native contrained generation", func(t *testing.T) { wantText := JSON - wantStreamText := "stream!" + wantStreamText := "{\"Foo\": \"bar\"}" wantRequest := &ModelRequest{ Messages: []*Message{ { @@ -113,7 +113,7 @@ func TestConstrainedGenerate(t *testing.T) { t.Run("doesn't use format instructions when explicitly instructed not to", func(t *testing.T) { wantText := JSON - wantStreamText := "stream!" + wantStreamText := "{\"Foo\": \"bar\"}" wantRequest := &ModelRequest{ Messages: []*Message{ { @@ -166,7 +166,7 @@ func TestConstrainedGenerate(t *testing.T) { t.Run("uses format instructions given by user", func(t *testing.T) { customInstructions := "The generated output should be in JSON format and conform to the following schema:\n\n```{\"additionalProperties\":false,\"properties\":{\"foo\":{\"type\":\"string\"}},\"required\":[\"foo\"],\"type\":\"object\"}```" wantText := JSON - wantStreamText := "stream!" + wantStreamText := "{\"Foo\": \"bar\"}" wantRequest := &ModelRequest{ Messages: []*Message{ { @@ -224,7 +224,7 @@ func TestConstrainedGenerate(t *testing.T) { t.Run("uses simulated constrained generation when explicitly told to do so", func(t *testing.T) { wantText := JSON - wantStreamText := "stream!" + wantStreamText := "{\"Foo\": \"bar\"}" wantRequest := &ModelRequest{ Messages: []*Message{ { @@ -608,6 +608,212 @@ func TestJsonParser(t *testing.T) { } } +func TestJsonStreamingParser(t *testing.T) { + tests := []struct { + name string + schema map[string]any + response []*ModelResponseChunk + want []*ModelResponseChunk + wantErr bool + }{ + { + name: "parses complete JSON object", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "integer"}, + }, + "additionalProperties": false, + }, + response: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewTextPart(JSONMarkdown(`{"name": "John", "age": 19}`)), + }, + }, + }, + want: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewJSONPart("{\"name\": \"John\", \"age\": 19}"), + }, + }, + }, + wantErr: false, + }, + { + name: "handles partial JSON", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "integer"}, + "name": map[string]any{"type": "string"}, + }, + "additionalProperties": false, + }, + response: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewTextPart("{\"id\": 1"), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewTextPart(", \"name\": \"test\"}"), + }, + }, + }, + want: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewJSONPart("{\"id\": 1}"), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewJSONPart("{\"id\": 1, \"name\": \"test\"}"), + }, + }, + }, + wantErr: false, + }, + { + name: "handles object with array JSON", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "Countries": map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "string", + }, + }, + }, + "additionalProperties": false, + }, + response: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewTextPart("{\""), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewTextPart("Countries\": [\"China\", \"India\", \"United States\", \"Indonesia\", \""), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewTextPart("Pakistan\", \"Brazil\", \"Nigeria\", \"Bangladesh\", \"Russia\", \"Mexico"), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewTextPart("\"]}"), + }, + }, + }, + want: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewJSONPart("{}"), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewJSONPart("{\"Countries\": [\"China\", \"India\", \"United States\", \"Indonesia\", \"\"]}"), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewJSONPart("{\"Countries\": [\"China\", \"India\", \"United States\", \"Indonesia\", \"Pakistan\", \"Brazil\", \"Nigeria\", \"Bangladesh\", \"Russia\", \"Mexico\"]}"), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewJSONPart("{\"Countries\": [\"China\", \"India\", \"United States\", \"Indonesia\", \"Pakistan\", \"Brazil\", \"Nigeria\", \"Bangladesh\", \"Russia\", \"Mexico\"]}"), + }, + }, + }, + wantErr: false, + }, + { + name: "handles preamble with code fence", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "integer"}, + }, + "additionalProperties": false, + }, + response: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewTextPart("Here is the JSON:\n\n```json\n"), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewTextPart("{\"id\": 1}\n```"), + }, + }, + }, + want: []*ModelResponseChunk{ + nil, + { + Role: RoleModel, + Content: []*Part{ + NewJSONPart("{\"id\": 1}"), + }, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + formatter := jsonFormatter{} + handler, err := formatter.Handler(tt.schema) + if err != nil { + t.Fatalf("Handler() error = %v", err) + } + + var previousParts []*Part + for i, chunk := range tt.response { + previousParts = append(previousParts, chunk.Content...) + chunk.Content = previousParts + + messageChunk, err := handler.ParseChunk(chunk) + if err != nil { + t.Errorf("ParseChunk() error = %v, wantErr %v", err, tt.wantErr) + } + + if diff := cmp.Diff(tt.want[i], messageChunk); diff != "" { + t.Errorf("Request msgs diff (+got -want):\n%s", diff) + } + } + }) + } +} + func TestTextParser(t *testing.T) { tests := []struct { name string diff --git a/go/ai/generate.go b/go/ai/generate.go index d6b7573e2a..ae3ca77dec 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -252,6 +252,11 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera // This is optional to make the output config internally consistent. outputCfg.Schema = nil } + + // Handle output format parsing for chunks + if cb != nil { + cb = formatHandler.StreamCallback(cb) + } } req := &ModelRequest{ diff --git a/go/ai/generator_test.go b/go/ai/generator_test.go index c9e3d1aec9..15d7227380 100644 --- a/go/ai/generator_test.go +++ b/go/ai/generator_test.go @@ -246,7 +246,7 @@ func TestGenerate(t *testing.T) { bananaModel := DefineModel(r, "test", "banana", &metadata, func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) { if msc != nil { msc(ctx, &ModelResponseChunk{ - Content: []*Part{NewTextPart("stream!")}, + Content: []*Part{NewJSONPart(JSON)}, }) } @@ -258,7 +258,7 @@ func TestGenerate(t *testing.T) { t.Run("constructs request", func(t *testing.T) { wantText := JSON - wantStreamText := "stream!" + wantStreamText := "{\"subject\": \"bananas\", \"location\": \"tropics\"}" wantRequest := &ModelRequest{ Messages: []*Message{ { diff --git a/go/go.mod b/go/go.mod index 98dd9ba812..c9663a972f 100644 --- a/go/go.mod +++ b/go/go.mod @@ -54,6 +54,7 @@ require ( github.com/MicahParks/keyfunc v1.9.0 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/blaze2305/partial-json-parser v0.1.1 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20250326154945-ae57f3c0d45f // indirect diff --git a/go/go.sum b/go/go.sum index 046d519209..9b54d74e65 100644 --- a/go/go.sum +++ b/go/go.sum @@ -49,6 +49,8 @@ github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3d github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/blaze2305/partial-json-parser v0.1.1 h1:CAakvirNBTZVAW0iLEInUGaKAiRDB3ECYNCllLaUVPw= +github.com/blaze2305/partial-json-parser v0.1.1/go.mod h1:bpALLUZPippHpE3gJJguWnpIXDsoqBY1KJN2hKJjBrw= github.com/blues/jsonata-go v1.5.4 h1:XCsXaVVMrt4lcpKeJw6mNJHqQpWU751cnHdCFUq3xd8= github.com/blues/jsonata-go v1.5.4/go.mod h1:uns2jymDrnI7y+UFYCqsRTEiAH22GyHnNXrkupAVFWI= github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= diff --git a/go/internal/base/json.go b/go/internal/base/json.go index 81e0b12214..d2f9954eac 100644 --- a/go/internal/base/json.go +++ b/go/internal/base/json.go @@ -148,11 +148,10 @@ func GetJsonObjectLines(text string) []string { continue } - // Trim leading and trailing whitespace from the current line. - trimmedLine := strings.TrimSpace(line) - // Check if the trimmed line starts with the character '{'. - if strings.HasPrefix(trimmedLine, "{") { - // If it does, append the trimmed line to our result slice. + // Check the given string for a json object + trimmedLine := GetJsonObject(line) + if trimmedLine != "" { + // If it exists, append the trimmed line to our result slice. result = append(result, trimmedLine) } } @@ -160,3 +159,15 @@ func GetJsonObjectLines(text string) []string { // Return the slice containing the filtered and trimmed lines. return result } + +// Looks for a JsonObject in the given string, returns empty string if none found +func GetJsonObject(text string) (result string) { + // Trim leading and trailing whitespace from the current line. + trimmedLine := strings.TrimSpace(text) + // Check if the trimmed line starts with the character '{'. + if strings.HasPrefix(trimmedLine, "{") { + result = trimmedLine + } + + return result +} From 0e9663f7c902a73267061e3f25125624d28c68ae Mon Sep 17 00:00:00 2001 From: dysrama Date: Fri, 9 May 2025 14:42:51 +0200 Subject: [PATCH 2/2] Support for streaming text and jsonl --- go/ai/format_jsonl.go | 68 +++++++++- go/ai/formatter_test.go | 283 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 348 insertions(+), 3 deletions(-) diff --git a/go/ai/format_jsonl.go b/go/ai/format_jsonl.go index 91a37a0ab6..fbdd29ea3a 100644 --- a/go/ai/format_jsonl.go +++ b/go/ai/format_jsonl.go @@ -19,7 +19,10 @@ import ( "encoding/json" "errors" "fmt" + "strings" + partialparser "github.com/blaze2305/partial-json-parser" + "github.com/blaze2305/partial-json-parser/options" "github.com/firebase/genkit/go/internal/base" ) @@ -56,8 +59,9 @@ func (j jsonlFormatter) Handler(schema map[string]any) (FormatHandler, error) { } type jsonlHandler struct { - instructions string - config ModelOutputConfig + instructions string + config ModelOutputConfig + previousParts []*Part } // Instructions returns the instructions for the formatter. @@ -72,7 +76,15 @@ func (j jsonlHandler) Config() ModelOutputConfig { func (j jsonlHandler) StreamCallback(cb ModelStreamCallback) ModelStreamCallback { return func(ctx context.Context, mrc *ModelResponseChunk) error { - return cb(ctx, mrc) + j.previousParts = append(j.previousParts, mrc.Content...) + mrc.Content = j.previousParts + + parsed, err := j.ParseChunk(mrc) + if err != nil { + return err + } + + return cb(ctx, parsed) } } @@ -116,5 +128,55 @@ func (j jsonlHandler) ParseMessage(m *Message) (*Message, error) { // ParseChunk parse the chunk and returns a new formatted chunk. func (j jsonlHandler) ParseChunk(c *ModelResponseChunk) (*ModelResponseChunk, error) { + if j.config.Format == OutputFormatJSONL { + if c == nil { + return nil, errors.New("message is empty") + } + if len(c.Content) == 0 { + return nil, errors.New("message has no content") + } + + // Get all chunks streamed so far + text := c.Text() + + startIndex := 0 + // If there are previous chunks, adjust startIndex based on the last newline + // in the previous text to ensure complete lines are processed from the accumulatedText. + noParts := len(c.Content) + if c.Content != nil && noParts > 1 { + var sb strings.Builder + i := 0 + for i < noParts-1 { + sb.WriteString(c.Content[i].Text) + i++ + } + + previousText := sb.String() + lastNewline := strings.LastIndex(previousText, `\n`) + + if lastNewline != -1 { + // Exclude the newline + startIndex = lastNewline + 2 + } + } + + text = text[startIndex:] + + var newParts []*Part + lines := base.GetJsonObjectLines(text) + for _, line := range lines { + if line != "" { + var err error + line, err = partialparser.ParseMalformedString(line, options.ALL, false) + if err != nil { + return nil, errors.New("message is not a valid JSON") + } + + newParts = append(newParts, NewJSONPart(line)) + } + } + + c.Content = newParts + } return c, nil } diff --git a/go/ai/formatter_test.go b/go/ai/formatter_test.go index 8839b817eb..a2c3d3230d 100644 --- a/go/ai/formatter_test.go +++ b/go/ai/formatter_test.go @@ -880,6 +880,89 @@ func TestTextParser(t *testing.T) { } } +func TestTextStreamingParser(t *testing.T) { + tests := []struct { + name string + response []*ModelResponseChunk + want []*ModelResponseChunk + wantErr bool + }{ + { + name: "emits text chunks as they arrive", + response: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewTextPart("Hello"), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewTextPart(" world"), + }, + }, + }, + want: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewTextPart("Hello"), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewTextPart(" world"), + }, + }, + }, + wantErr: false, + }, + { + name: "handles empty chunks", + response: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewTextPart(""), + }, + }, + }, + want: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewTextPart(""), + }, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + formatter := textFormatter{} + handler, err := formatter.Handler(nil) + if err != nil { + t.Fatalf("Handler() error = %v", err) + } + + for i, chunk := range tt.response { + messageChunk, err := handler.ParseChunk(chunk) + if err != nil { + t.Errorf("ParseChunk() error = %v, wantErr %v", err, tt.wantErr) + } + + if diff := cmp.Diff(tt.want[i], messageChunk); diff != "" { + t.Errorf("Request msgs diff (+got -want):\n%s", diff) + } + } + }) + } +} + func TestJsonlParser(t *testing.T) { tests := []struct { name string @@ -993,6 +1076,206 @@ func TestJsonlParser(t *testing.T) { } } +func TestJsonlStreamingParser(t *testing.T) { + tests := []struct { + name string + schema map[string]any + response []*ModelResponseChunk + want []*ModelResponseChunk + wantErr bool + }{ + { + name: "emits complete JSON objects as they arrive", + schema: map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "integer"}, + "name": map[string]any{"type": "string"}, + }, + }, + "additionalProperties": false, + }, + response: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewTextPart(`{"id": 1, "name": "first"}\n`), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewTextPart(`{"id": 2, "name": "second"}\n{"id": 3`), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewTextPart(`, "name": "third"}\n`), + }, + }, + }, + want: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewJSONPart(`{"id": 1, "name": "first"}`), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewJSONPart(`{"id": 2, "name": "second"}`), + NewJSONPart(`{"id": 3}`), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewJSONPart(`{"id": 3, "name": "third"}`), + }, + }, + }, + wantErr: false, + }, + { + name: "handles single object", + schema: map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "integer"}, + "name": map[string]any{"type": "string"}, + }, + }, + "additionalProperties": false, + }, + response: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewTextPart(JSONMarkdown(`{"id": 1, "name": "single"}\n`)), + }, + }, + }, + want: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewJSONPart(`{"id": 1, "name": "single"}`), + }, + }, + }, + wantErr: false, + }, + { + name: "handles preamble with code fence", + schema: map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "integer"}, + "name": map[string]any{"type": "string"}, + }, + }, + "additionalProperties": false, + }, + response: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewTextPart("Here are the objects:\n\n```\n"), + }, + }, + { + Role: RoleModel, + Content: []*Part{ + NewTextPart("{\"id\": 1, \"name\": \"item\"}\n```"), + }, + }, + }, + want: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: nil, + }, + { + Role: RoleModel, + Content: []*Part{ + NewJSONPart(`{"id": 1, "name": "item"}`), + }, + }, + }, + wantErr: false, + }, + { + name: "ignores non-object lines", + schema: map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "integer"}, + }, + }, + "additionalProperties": false, + }, + response: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewTextPart(JSONMarkdown(`First object:\n{"id": 1}\nSecond object:\n{"id": 2}\n`)), + }, + }, + }, + want: []*ModelResponseChunk{ + { + Role: RoleModel, + Content: []*Part{ + NewJSONPart(`{"id": 1}`), + NewJSONPart(`{"id": 2}`), + }, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + formatter := jsonlFormatter{} + handler, err := formatter.Handler(tt.schema) + if err != nil { + t.Fatalf("Handler() error = %v", err) + } + + var previousParts []*Part + var receivedChunks []*ModelResponseChunk + for _, chunk := range tt.response { + previousParts = append(previousParts, chunk.Content...) + chunk.Content = previousParts + + messageChunk, err := handler.ParseChunk(chunk) + if err != nil { + t.Errorf("ParseChunk() error = %v, wantErr %v", err, tt.wantErr) + } + + if messageChunk != nil { + receivedChunks = append(receivedChunks, messageChunk) + } + } + + if diff := cmp.Diff(tt.want, receivedChunks); diff != "" { + t.Errorf("Request msgs diff (+got -want):\n%s", diff) + } + }) + } +} + func TestArrayParser(t *testing.T) { tests := []struct { name string