diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index a9f9a50a..92d71f8b 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -24,7 +24,7 @@ jobs: runs-on: ubuntu-latest services: postgres: - image: postgres:14.2 + image: ankane/pgvector:v0.4.4 env: POSTGRES_PASSWORD: postgres POSTGRES_USER: postgres diff --git a/generator/client/golang/templates/queryx/vector.go b/generator/client/golang/templates/queryx/vector.go new file mode 100644 index 00000000..e3596f2a --- /dev/null +++ b/generator/client/golang/templates/queryx/vector.go @@ -0,0 +1,92 @@ +package queryx + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "strconv" + "strings" +) + +type Vector struct { + Val []float32 + Null bool + Set bool +} + +func NewVector(v []float32) Vector { + return Vector{Val: v, Set: true} +} + +func NewNullableVector(v *[]float32) Vector { + if v != nil { + return Vector{Val: *v, Set: true} + } + return Vector{Null: true, Set: true} +} + +func (v Vector) String() string { + var buf strings.Builder + buf.WriteString("[") + + for i := 0; i < len(v.Val); i++ { + if i > 0 { + buf.WriteString(",") + } + buf.WriteString(strconv.FormatFloat(float64(v.Val[i]), 'f', -1, 32)) + } + + buf.WriteString("]") + return buf.String() +} + +func (v *Vector) Parse(s string) error { + v.Val = make([]float32, 0) + sp := strings.Split(s[1:len(s)-1], ",") + for i := 0; i < len(sp); i++ { + n, err := strconv.ParseFloat(sp[i], 32) + if err != nil { + return err + } + v.Val = append(v.Val, float32(n)) + } + return nil +} + +// Scan implements the Scanner interface. +func (v *Vector) Scan(src interface{}) (err error) { + switch src := src.(type) { + case []byte: + return v.Parse(string(src)) + case string: + return v.Parse(src) + default: + return fmt.Errorf("unsupported data type: %T", src) + } +} + +// Value implements the driver Valuer interface. +func (v Vector) Value() (driver.Value, error) { + return v.String(), nil +} + +// MarshalJSON implements the json.Marshaler interface. +func (v Vector) MarshalJSON() ([]byte, error) { + if v.Null { + return json.Marshal(nil) + } + return json.Marshal(v.Val) +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (v *Vector) UnmarshalJSON(data []byte) error { + v.Set = true + if string(data) == "null" { + v.Null = true + return nil + } + if err := json.Unmarshal(data, &v.Val); err != nil { + return err + } + return nil +} diff --git a/generator/client/golang/templates/queryx/vector_column.postgresql.go b/generator/client/golang/templates/queryx/vector_column.postgresql.go new file mode 100644 index 00000000..a3e10435 --- /dev/null +++ b/generator/client/golang/templates/queryx/vector_column.postgresql.go @@ -0,0 +1,13 @@ +package queryx + +type VectorColumn struct { + Name string + Table *Table +} + +func (t *Table) NewVectorColumn(name string) *VectorColumn { + return &VectorColumn{ + Table: t, + Name: name, + } +} diff --git a/generator/client/golang/templates/queryx/vector_test.go b/generator/client/golang/templates/queryx/vector_test.go new file mode 100644 index 00000000..415207c3 --- /dev/null +++ b/generator/client/golang/templates/queryx/vector_test.go @@ -0,0 +1,39 @@ +package queryx + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewVector(t *testing.T) { + v1 := NewVector([]float32{1, 2, 3}) + require.Equal(t, []float32{1, 2, 3}, v1.Val) + require.Equal(t, false, v1.Null) + + v2 := NewNullableVector(nil) + require.Equal(t, true, v2.Null) +} + +func TestVectorJSON(t *testing.T) { + type Foo struct { + X Vector `json:"x"` + Y Vector `json:"y"` + } + x := NewVector([]float32{1, 2, 3}) + y := NewNullableVector(nil) + s := `{"x":[1,2,3],"y":null}` + + f1 := Foo{X: x, Y: y} + b, err := json.Marshal(f1) + require.NoError(t, err) + require.Equal(t, s, string(b)) + + var f2 Foo + err = json.Unmarshal([]byte(s), &f2) + require.NoError(t, err) + require.Equal(t, x, f2.X) + require.Equal(t, y, f2.Y) + +} diff --git a/inflect/golang.go b/inflect/golang.go index aa4c1087..326bf4a2 100644 --- a/inflect/golang.go +++ b/inflect/golang.go @@ -1,8 +1,6 @@ package inflect import ( - "fmt" - "log" "strings" ) @@ -47,9 +45,10 @@ func goModelType(t string, null bool) string { return "queryx.Float" case "json", "jsonb": return "queryx.JSON" + case "vector": + return "queryx.Vector" default: - log.Fatal(fmt.Errorf("unhandled data type %s in goModelType", t)) - return "" + return t } } else { switch t { @@ -73,9 +72,10 @@ func goModelType(t string, null bool) string { return "float" case "json", "jsonb": return "queryx.JSON" + case "vector": + return "queryx.Vector" default: - log.Fatal(fmt.Errorf("unhandled data type %s in goModelType", t)) - return "" + return t } } } @@ -103,9 +103,10 @@ func goType(t string) string { return "Float" case "json", "jsonb": return "JSON" + case "vector": + return "Vector" default: - log.Fatal(fmt.Errorf("unhandled data type %s in goType", t)) - return "" + return t } } @@ -124,8 +125,9 @@ func goChangeSetType(t string) string { return "float64" case "json", "jsonb": return "map[string]interface{}" + case "vector": + return "[]float32" default: - log.Fatal(fmt.Errorf("unhandled data type %s in goChangeSetType", t)) - return "" + return t } } diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index 2e3474e6..35b95f43 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -1,4 +1,4 @@ -package main +package integration import ( "database/sql" diff --git a/internal/integration/postgresql.hcl b/internal/integration/postgresql.hcl index 80bda1c8..c6dcb173 100644 --- a/internal/integration/postgresql.hcl +++ b/internal/integration/postgresql.hcl @@ -153,4 +153,11 @@ database "db" { columns = ["id"] } } + + model "Item" { + column "embedding" { + type = vector + dimension = 1536 + } + } } diff --git a/internal/integration/postgresql_client_test.go b/internal/integration/postgresql_client_test.go new file mode 100644 index 00000000..9d8ebe0e --- /dev/null +++ b/internal/integration/postgresql_client_test.go @@ -0,0 +1,32 @@ +package integration + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/swiftcarrot/queryx/internal/integration/db/queryx" +) + +func TestVector(t *testing.T) { + _, err := c.QueryItem().DeleteAll() + require.NoError(t, err) + + item1, err := c.QueryItem().Create(c.ChangeItem().SetEmbedding([]float32{1, 2, 3})) + require.NoError(t, err) + require.Equal(t, item1.Embedding.Val, []float32{1, 2, 3}) + + item2, err := c.QueryItem().Create(c.ChangeItem().SetEmbedding([]float32{4, 5, 6})) + require.NoError(t, err) + require.Equal(t, item2.Embedding.Val, []float32{4, 5, 6}) + + type Foo struct { + embedding queryx.Vector `db:"embedding"` + } + var rows []Foo + err = c.Query("SELECT embedding FROM items ORDER BY embedding <-> '[3,1,2]'").Scan(&rows) + require.NoError(t, err) + require.Equal(t, []Foo{ + {queryx.NewVector([]float32{1, 2, 3})}, + {queryx.NewVector([]float32{4, 5, 6})}, + }, rows) +} diff --git a/schema/hcl.go b/schema/hcl.go index f86642cf..7dd8a198 100644 --- a/schema/hcl.go +++ b/schema/hcl.go @@ -73,6 +73,7 @@ var hclColumn = &hcl.BodySchema{ {Name: "null"}, {Name: "default"}, {Name: "unique"}, + {Name: "dimension"}, }, Blocks: []hcl.BlockHeaderSchema{}, } @@ -496,6 +497,8 @@ func columnFromBlock(block *hcl.Block, ctx *hcl.EvalContext) (*Column, error) { column.Array = valueAsBool(value) case "null": column.Null = valueAsBool(value) + case "dimension": + column.Dimension = valueAsInt(value) } } for name, attr := range content.Attributes { @@ -565,18 +568,20 @@ var env = function.New(&function.Spec{ func Parse(body hcl.Body) (*Schema, error) { ctx := &hcl.EvalContext{ Variables: map[string]cty.Value{ - "string": cty.StringVal("string"), - "text": cty.StringVal("text"), - "boolean": cty.StringVal("boolean"), - "date": cty.StringVal("date"), - "time": cty.StringVal("time"), - "datetime": cty.StringVal("datetime"), - "float": cty.StringVal("float"), - "integer": cty.StringVal("integer"), - "bigint": cty.StringVal("bigint"), - "json": cty.StringVal("json"), - "jsonb": cty.StringVal("jsonb"), - "uuid": cty.StringVal("uuid"), + "string": cty.StringVal("string"), + "text": cty.StringVal("text"), + "boolean": cty.StringVal("boolean"), + "date": cty.StringVal("date"), + "datetime": cty.StringVal("datetime"), + "float": cty.StringVal("float"), + "integer": cty.StringVal("integer"), + "bigint": cty.StringVal("bigint"), + "time": cty.StringVal("time"), + "timestamp": cty.StringVal("timestamp"), + "json": cty.StringVal("json"), + "jsonb": cty.StringVal("jsonb"), + "uuid": cty.StringVal("uuid"), + "vector": cty.StringVal("vector"), }, Functions: map[string]function.Function{ "env": env, diff --git a/schema/model.go b/schema/model.go index dac50a08..e7cf3196 100644 --- a/schema/model.go +++ b/schema/model.go @@ -51,6 +51,8 @@ type Column struct { // sql auto_increment AutoIncrement bool Default interface{} // TODO: support default + // vector dimension + Dimension int } type Type struct { diff --git a/schema/postgresql.go b/schema/postgresql.go index 64e735ab..9af0d080 100644 --- a/schema/postgresql.go +++ b/schema/postgresql.go @@ -2,6 +2,7 @@ package schema import ( "fmt" + "log" "strconv" "strings" @@ -20,6 +21,8 @@ func (d *Database) CreatePostgreSQLSchema(dbName string) *schema.Schema { for _, c := range model.Columns { col := schema.NewColumn(c.Name) + log.Println(c.Type) + switch c.Type { case "bigint": if c.AutoIncrement { @@ -101,6 +104,10 @@ func (d *Database) CreatePostgreSQLSchema(dbName string) *schema.Schema { col.SetType(&postgres.UUIDType{T: postgres.TypeUUID}).SetDefault(&schema.RawExpr{X: d}) } } + case "vector": + col.SetType(&postgres.UserDefinedType{T: fmt.Sprintf("vector(%d)", c.Dimension)}) + default: + col.SetType(&postgres.UserDefinedType{T: c.Type}) } col.SetNull(c.Null)