diff --git a/documentation/src/main/asciidoc/userguide/chapters/query/extensions/Vector.adoc b/documentation/src/main/asciidoc/userguide/chapters/query/extensions/Vector.adoc index 7c4c93b18f8a..cc317634217a 100644 --- a/documentation/src/main/asciidoc/userguide/chapters/query/extensions/Vector.adoc +++ b/documentation/src/main/asciidoc/userguide/chapters/query/extensions/Vector.adoc @@ -12,10 +12,21 @@ The Hibernate ORM Vector module contains support for mathematical vector types a This is useful for AI/ML topics like vector similarity search and Retrieval-Augmented Generation (RAG). The module comes with support for a special `vector` data type that essentially represents an array of bytes, floats, or doubles. -So far, both the PostgreSQL extension `pgvector` and the Oracle database 23ai+ `AI Vector Search` feature are supported, but in theory, -the vector specific functions could be implemented to work with every database that supports arrays. +Currently, the following databases are supported: -For further details, refer to the https://github.com/pgvector/pgvector#querying[pgvector documentation] or the https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/overview-node.html[AI Vector Search documentation]. +* PostgreSQL 13+ through the https://github.com/pgvector/pgvector#querying[`pgvector` extension] +* https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/overview-node.html[Oracle database 23ai+] +* https://mariadb.com/docs/server/reference/sql-structure/vectors/vector-overview[MariaDB 11.7+] +* https://dev.mysql.com/doc/refman/9.4/en/vector-functions.html[MySQL 9.0+] + +In theory, the vector-specific functions could be implemented to work with every database that supports arrays. + +[WARNING] +==== +Per the https://dev.mysql.com/doc/refman/9.4/en/vector-functions.html#function_distance[MySQL documentation], +the various vector distance functions for MySQL only work on MySQL cloud offerings like +https://dev.mysql.com/doc/heatwave/en/mys-hw-about-heatwave.html[HeatWave MySQL on OCI]. +==== [[vector-module-setup]] === Setup @@ -57,7 +68,7 @@ As Oracle AI Vector Search supports different types of elements (to ensure bette ==== [source, java, indent=0] ---- -include::{example-dir-vector}/PGVectorTest.java[tags=usage-example] +include::{example-dir-vector}/FloatVectorTest.java[tags=usage-example] ---- ==== @@ -113,7 +124,7 @@ which is `1 - inner_product( v1, v2 ) / ( vector_norm( v1 ) * vector_norm( v2 ) ==== [source, java, indent=0] ---- -include::{example-dir-vector}/PGVectorTest.java[tags=cosine-distance-example] +include::{example-dir-vector}/FloatVectorTest.java[tags=cosine-distance-example] ---- ==== @@ -128,7 +139,7 @@ The `l2_distance()` function is an alias. ==== [source, java, indent=0] ---- -include::{example-dir-vector}/PGVectorTest.java[tags=euclidean-distance-example] +include::{example-dir-vector}/FloatVectorTest.java[tags=euclidean-distance-example] ---- ==== @@ -143,7 +154,7 @@ The `l1_distance()` function is an alias. ==== [source, java, indent=0] ---- -include::{example-dir-vector}/PGVectorTest.java[tags=taxicab-distance-example] +include::{example-dir-vector}/FloatVectorTest.java[tags=taxicab-distance-example] ---- ==== @@ -158,7 +169,7 @@ and the `inner_product()` function as well, but multiplies the result time `-1`. ==== [source, java, indent=0] ---- -include::{example-dir-vector}/PGVectorTest.java[tags=inner-product-example] +include::{example-dir-vector}/FloatVectorTest.java[tags=inner-product-example] ---- ==== @@ -171,7 +182,7 @@ Determines the dimensions of a vector. ==== [source, java, indent=0] ---- -include::{example-dir-vector}/PGVectorTest.java[tags=vector-dims-example] +include::{example-dir-vector}/FloatVectorTest.java[tags=vector-dims-example] ---- ==== @@ -185,7 +196,7 @@ which is `sqrt( sum( v_i^2 ) )`. ==== [source, java, indent=0] ---- -include::{example-dir-vector}/PGVectorTest.java[tags=vector-norm-example] +include::{example-dir-vector}/FloatVectorTest.java[tags=vector-norm-example] ---- ==== diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/OracleTypes.java b/hibernate-core/src/main/java/org/hibernate/dialect/OracleTypes.java index 89f8c76690fc..722122dfd546 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/OracleTypes.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/OracleTypes.java @@ -15,4 +15,5 @@ public class OracleTypes { public static final int VECTOR_INT8 = -106; public static final int VECTOR_FLOAT32 = -107; public static final int VECTOR_FLOAT64 = -108; + public static final int VECTOR_BINARY = -109; } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/CastFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/CastFunction.java index cb0dde411195..d99c750b3491 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/CastFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/CastFunction.java @@ -77,8 +77,14 @@ public void render( renderCastArrayToString( sqlAppender, arguments.get( 0 ), dialect, walker ); } else { - new PatternRenderer( dialect.castPattern( sourceType, targetType ) ) - .render( sqlAppender, arguments, walker ); + String castPattern = targetJdbcMapping.getJdbcType().castFromPattern( sourceMapping ); + if ( castPattern == null ) { + castPattern = sourceMapping.getJdbcType().castToPattern( targetJdbcMapping ); + if ( castPattern == null ) { + castPattern = dialect.castPattern( sourceType, targetType ); + } + } + new PatternRenderer( castPattern ).render( sqlAppender, arguments, walker ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/SumReturnTypeResolver.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/SumReturnTypeResolver.java index 2cee65aeed1c..99354744dd9d 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/SumReturnTypeResolver.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/SumReturnTypeResolver.java @@ -90,6 +90,14 @@ public ReturnableType resolveFunctionReturnType( case NUMERIC: return BigInteger.class.isAssignableFrom( basicType.getJavaType() ) ? bigIntegerType : bigDecimalType; case VECTOR: + case VECTOR_BINARY: + case VECTOR_INT8: + case VECTOR_FLOAT16: + case VECTOR_FLOAT32: + case VECTOR_FLOAT64: + case SPARSE_VECTOR_INT8: + case SPARSE_VECTOR_FLOAT32: + case SPARSE_VECTOR_FLOAT64: return basicType; } return bigDecimalType; @@ -123,6 +131,14 @@ public BasicValuedMapping resolveFunctionReturnType( final Class argTypeClass = jdbcMapping.getJavaTypeDescriptor().getJavaTypeClass(); return BigInteger.class.isAssignableFrom( argTypeClass ) ? bigIntegerType : bigDecimalType; case VECTOR: + case VECTOR_BINARY: + case VECTOR_INT8: + case VECTOR_FLOAT16: + case VECTOR_FLOAT32: + case VECTOR_FLOAT64: + case SPARSE_VECTOR_INT8: + case SPARSE_VECTOR_FLOAT32: + case SPARSE_VECTOR_FLOAT64: return (BasicValuedMapping) jdbcMapping; } return bigDecimalType; diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/type/OracleArrayJdbcType.java b/hibernate-core/src/main/java/org/hibernate/dialect/type/OracleArrayJdbcType.java index acbe11e3b481..7712cd7e08d2 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/type/OracleArrayJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/type/OracleArrayJdbcType.java @@ -10,6 +10,7 @@ import java.sql.SQLException; import java.sql.Types; import java.util.Locale; +import java.util.Objects; import org.hibernate.HibernateException; import org.hibernate.boot.model.relational.Database; @@ -288,4 +289,16 @@ public String getFriendlyName() { public String toString() { return "OracleArrayTypeDescriptor(" + typeName + ")"; } + + @Override + public boolean equals(Object that) { + return super.equals( that ) + && that instanceof OracleArrayJdbcType jdbcType + && Objects.equals( typeName, jdbcType.typeName ); + } + + @Override + public int hashCode() { + return Objects.hashCode( typeName ) + 31 * super.hashCode(); + } } diff --git a/hibernate-core/src/main/java/org/hibernate/type/BasicCollectionType.java b/hibernate-core/src/main/java/org/hibernate/type/BasicCollectionType.java index 70160c445fc5..ad34e08ec35c 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/BasicCollectionType.java +++ b/hibernate-core/src/main/java/org/hibernate/type/BasicCollectionType.java @@ -35,6 +35,16 @@ public BasicCollectionType( this.name = determineName( collectionTypeDescriptor, baseDescriptor ); } + public BasicCollectionType( + BasicType baseDescriptor, + JdbcType arrayJdbcType, + JavaType collectionTypeDescriptor, + String typeName) { + super( arrayJdbcType, collectionTypeDescriptor ); + this.baseDescriptor = baseDescriptor; + this.name = typeName; + } + private static String determineName(BasicCollectionJavaType collectionTypeDescriptor, BasicType baseDescriptor) { final String elementTypeName = determineElementTypeName( baseDescriptor ); switch ( collectionTypeDescriptor.getSemantics().getCollectionClassification() ) { diff --git a/hibernate-core/src/main/java/org/hibernate/type/BasicTypeRegistry.java b/hibernate-core/src/main/java/org/hibernate/type/BasicTypeRegistry.java index 3ce06f230a0f..e6fcbc3ec032 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/BasicTypeRegistry.java +++ b/hibernate-core/src/main/java/org/hibernate/type/BasicTypeRegistry.java @@ -16,11 +16,13 @@ import org.hibernate.internal.CoreMessageLogger; import org.hibernate.internal.util.StringHelper; import org.hibernate.internal.util.collections.CollectionHelper; +import org.hibernate.tool.schema.extract.spi.ColumnTypeInformation; import org.hibernate.type.descriptor.converter.spi.BasicValueConverter; import org.hibernate.type.descriptor.java.BasicPluralJavaType; import org.hibernate.type.descriptor.java.ImmutableMutabilityPlan; import org.hibernate.type.descriptor.java.JavaType; import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.DelegatingJdbcTypeIndicators; import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.internal.BasicTypeImpl; import org.hibernate.type.internal.ConvertedBasicTypeImpl; @@ -166,8 +168,48 @@ private BasicType resolvedType(ArrayJdbcType arrayType, BasicPluralJavaTy typeConfiguration, typeConfiguration.getCurrentBaseSqlTypeIndicators().getDialect(), elementType, - null, - typeConfiguration.getCurrentBaseSqlTypeIndicators() + new ColumnTypeInformation() { + @Override + public Boolean getNullable() { + return null; + } + + @Override + public int getTypeCode() { + return arrayType.getDefaultSqlTypeCode(); + } + + @Override + public String getTypeName() { + return null; + } + + @Override + public int getColumnSize() { + return 0; + } + + @Override + public int getDecimalDigits() { + return 0; + } + }, + new DelegatingJdbcTypeIndicators( typeConfiguration.getCurrentBaseSqlTypeIndicators() ) { + @Override + public Integer getExplicitJdbcTypeCode() { + return arrayType.getDefaultSqlTypeCode(); + } + + @Override + public int getPreferredSqlTypeCodeForArray() { + return arrayType.getDefaultSqlTypeCode(); + } + + @Override + public int getPreferredSqlTypeCodeForArray(int elementSqlTypeCode) { + return arrayType.getDefaultSqlTypeCode(); + } + } ); if ( resolvedType instanceof BasicPluralType ) { register( resolvedType ); diff --git a/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java b/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java index f2354b799d16..a465f4bee052 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java +++ b/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java @@ -681,10 +681,10 @@ public class SqlTypes { /** - * A type code representing an {@code embedding vector} type for databases + * A type code representing a {@code vector} type for databases * like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL}, * {@link org.hibernate.dialect.OracleDialect Oracle 23ai} and {@link org.hibernate.dialect.MariaDBDialect MariaDB}. - * An embedding vector essentially is a {@code float[]} with a fixed size. + * A vector essentially is a {@code float[]} with a fixed length. * * @since 6.4 */ @@ -701,10 +701,39 @@ public class SqlTypes { public static final int VECTOR_FLOAT32 = 10_002; /** - * A type code representing a double-precision floating-point type for Oracle 23ai database. + * A type code representing a double-precision floating-point vector type for Oracle 23ai database. */ public static final int VECTOR_FLOAT64 = 10_003; + /** + * A type code representing a bit precision vector type for databases + * like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL} and + * {@link org.hibernate.dialect.OracleDialect Oracle 23ai}. + */ + public static final int VECTOR_BINARY = 10_004; + + /** + * A type code representing a half-precision floating-point vector type for databases + * like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL}. + */ + public static final int VECTOR_FLOAT16 = 10_005; + + /** + * A type code representing a sparse single-byte integer vector type for Oracle 23ai database. + */ + public static final int SPARSE_VECTOR_INT8 = 10_006; + + /** + * A type code representing a sparse single-precision floating-point vector type for Oracle 23ai database. + */ + public static final int SPARSE_VECTOR_FLOAT32 = 10_007; + + /** + * A type code representing a sparse double-precision floating-point vector type for Oracle 23ai database. + */ + public static final int SPARSE_VECTOR_FLOAT64 = 10_008; + + private SqlTypes() { } diff --git a/hibernate-core/src/main/java/org/hibernate/type/StandardBasicTypes.java b/hibernate-core/src/main/java/org/hibernate/type/StandardBasicTypes.java index fc521d5146f6..8c16770b3caf 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/StandardBasicTypes.java +++ b/hibernate-core/src/main/java/org/hibernate/type/StandardBasicTypes.java @@ -749,6 +749,14 @@ private StandardBasicTypes() { "byte_vector", byte[].class, SqlTypes.VECTOR_INT8 ); + /** + * The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR_FLOAT16 VECTOR_FLOAT16}, + * specifically for embedding half-precision floating-point (16-bits) vectors like provided by the PostgreSQL extension pgvector. + */ + public static final BasicTypeReference VECTOR_FLOAT16 = new BasicTypeReference<>( + "float16_vector", float[].class, SqlTypes.VECTOR_FLOAT16 + ); + /** * The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR}, * specifically for embedding single-precision floating-point (32-bits) vectors like provided by Oracle 23ai. @@ -765,6 +773,38 @@ private StandardBasicTypes() { "double_vector", double[].class, SqlTypes.VECTOR_FLOAT64 ); + /** + * The standard Hibernate type for mapping {@code byte[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR_BINARY VECTOR_BIT}, + * specifically for embedding bit vectors like provided by Oracle 23ai. + */ + public static final BasicTypeReference VECTOR_BINARY = new BasicTypeReference<>( + "binary_vector", byte[].class, SqlTypes.VECTOR_BINARY + ); + +// /** +// * The standard Hibernate type for mapping {@code byte[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR_INT8 VECTOR_INT8}, +// * specifically for embedding integer vectors (8-bits) like provided by Oracle 23ai. +// */ +// public static final BasicTypeReference SPARSE_VECTOR_INT8 = new BasicTypeReference<>( +// "sparse_byte_vector", byte[].class, SqlTypes.SPARSE_VECTOR_INT8 +// ); +// +// /** +// * The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR}, +// * specifically for embedding single-precision floating-point (32-bits) vectors like provided by Oracle 23ai. +// */ +// public static final BasicTypeReference SPARSE_VECTOR_FLOAT32 = new BasicTypeReference<>( +// "sparse_float_vector", float[].class, SqlTypes.SPARSE_VECTOR_FLOAT32 +// ); +// +// /** +// * The standard Hibernate type for mapping {@code double[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR}, +// * specifically for embedding double-precision floating-point (64-bits) vectors like provided by Oracle 23ai. +// */ +// public static final BasicTypeReference SPARSE_VECTOR_FLOAT64 = new BasicTypeReference<>( +// "sparse_double_vector", double[].class, SqlTypes.SPARSE_VECTOR_FLOAT64 +// ); + public static void prime(TypeConfiguration typeConfiguration) { BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); @@ -1286,6 +1326,34 @@ public static void prime(TypeConfiguration typeConfiguration) { "byte_vector" ); + handle( + VECTOR_BINARY, + null, + basicTypeRegistry, + "bit_vector" + ); + +// handle( +// SPARSE_VECTOR_FLOAT32, +// null, +// basicTypeRegistry, +// "sparse_float_vector" +// ); +// +// handle( +// SPARSE_VECTOR_FLOAT64, +// null, +// basicTypeRegistry, +// "sparse_double_vector" +// ); +// +// handle( +// SPARSE_VECTOR_INT8, +// null, +// basicTypeRegistry, +// "sparse_byte_vector" +// ); + // Specialized version handlers diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/JdbcType.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/JdbcType.java index c3bb08aca6cc..3ec393193838 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/JdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/JdbcType.java @@ -9,10 +9,12 @@ import java.sql.SQLException; import java.sql.Types; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.Incubating; import org.hibernate.boot.model.relational.Database; import org.hibernate.dialect.Dialect; import org.hibernate.engine.jdbc.Size; +import org.hibernate.metamodel.mapping.JdbcMapping; import org.hibernate.query.sqm.CastType; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.spi.StringBuilderSqlAppender; @@ -367,6 +369,30 @@ default String getExtraCreateTableInfo(JavaType javaType, String columnName, return ""; } + /** + * Returns the cast pattern from the given source type to this type, or {@code null} if not possible. + * + * @param sourceMapping The source type + * @return The cast pattern or null + * @since 7.1 + */ + @Incubating + default @Nullable String castFromPattern(JdbcMapping sourceMapping) { + return null; + } + + /** + * Returns the cast pattern from this type to the given target type, or {@code null} if not possible. + * + * @param targetJdbcMapping The target type + * @return The cast pattern or null + * @since 7.1 + */ + @Incubating + default @Nullable String castToPattern(JdbcMapping targetJdbcMapping) { + return null; + } + @Incubating default boolean isComparable() { final int code = getDefaultSqlTypeCode(); diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/XmlAsStringArrayJdbcType.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/XmlAsStringArrayJdbcType.java index 4cdc0ad40d1b..337d38af5135 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/XmlAsStringArrayJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/XmlAsStringArrayJdbcType.java @@ -174,4 +174,20 @@ protected X doExtract(CallableStatement statement, String name, WrapperOptions o }; } + + @Override + public boolean equals(Object that) { + return super.equals( that ) + && that instanceof XmlAsStringArrayJdbcType jdbcType + && ddlTypeCode == jdbcType.ddlTypeCode + && nationalized == jdbcType.nationalized; + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + Boolean.hashCode( nationalized ); + result = 31 * result + ddlTypeCode; + return result; + } } diff --git a/hibernate-testing/src/main/java/org/hibernate/testing/orm/junit/DialectFeatureChecks.java b/hibernate-testing/src/main/java/org/hibernate/testing/orm/junit/DialectFeatureChecks.java index 7bb583f3801f..2fdd8bcc2090 100644 --- a/hibernate-testing/src/main/java/org/hibernate/testing/orm/junit/DialectFeatureChecks.java +++ b/hibernate-testing/src/main/java/org/hibernate/testing/orm/junit/DialectFeatureChecks.java @@ -13,9 +13,11 @@ import org.hibernate.boot.internal.MetadataBuilderImpl; import org.hibernate.boot.internal.NamedProcedureCallDefinitionImpl; import org.hibernate.boot.model.FunctionContributions; +import org.hibernate.boot.model.FunctionContributor; import org.hibernate.boot.model.IdentifierGeneratorDefinition; import org.hibernate.boot.model.NamedEntityGraphDefinition; import org.hibernate.boot.model.TypeContributions; +import org.hibernate.boot.model.TypeContributor; import org.hibernate.boot.model.TypeDefinition; import org.hibernate.boot.model.TypeDefinitionRegistry; import org.hibernate.boot.model.convert.spi.ConverterAutoApplyHandler; @@ -97,6 +99,7 @@ import org.hibernate.type.descriptor.java.StringJavaType; import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.descriptor.jdbc.VarcharJdbcType; +import org.hibernate.type.descriptor.sql.spi.DdlTypeRegistry; import org.hibernate.type.internal.BasicTypeImpl; import org.hibernate.type.spi.TypeConfiguration; import org.hibernate.usertype.CompositeUserType; @@ -105,6 +108,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.ServiceLoader; import java.util.Set; import java.util.UUID; import java.util.function.Consumer; @@ -1076,6 +1080,132 @@ public boolean apply(Dialect dialect) { } } + public static class SupportsVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.VECTOR ); + } + } + + public static class SupportsFloat16VectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.VECTOR_FLOAT16 ); + } + } + + public static class SupportsFloatVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.VECTOR_FLOAT32 ); + } + } + + public static class SupportsDoubleVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.VECTOR_FLOAT64 ); + } + } + + public static class SupportsByteVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.VECTOR_INT8 ); + } + } + + public static class SupportsBinaryVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.VECTOR_BINARY ); + } + } + + public static class SupportsSparseFloatVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.SPARSE_VECTOR_FLOAT32 ); + } + } + + public static class SupportsSparseDoubleVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.SPARSE_VECTOR_FLOAT64 ); + } + } + + public static class SupportsSparseByteVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.SPARSE_VECTOR_INT8 ); + } + } + + public static class SupportsCosineDistance implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "cosine_distance" ); + } + } + + public static class SupportsEuclideanDistance implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "euclidean_distance" ); + } + } + + public static class SupportsTaxicabDistance implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "taxicab_distance" ); + } + } + + public static class SupportsHammingDistance implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "hamming_distance" ); + } + } + + public static class SupportsJaccardDistance implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "jaccard_distance" ); + } + } + + public static class SupportsInnerProduct implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "inner_product" ); + } + } + + public static class SupportsVectorDims implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "vector_dims" ); + } + } + + public static class SupportsVectorNorm implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "vector_norm" ); + } + } + + public static class SupportsL2Norm implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "l2_norm" ); + } + } + + public static class SupportsL2Normalize implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "l2_normalize" ); + } + } + + public static class SupportsSubvector implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "subvector" ); + } + } + + public static class SupportsBinaryQuantize implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "binary_quantize" ); + } + } + public static class IsJtds implements DialectFeatureCheck { public boolean apply(Dialect dialect) { return dialect instanceof SybaseDialect && ( (SybaseDialect) dialect ).getDriverKind() == SybaseDriverKind.JTDS; @@ -1141,7 +1271,7 @@ public boolean apply(Dialect dialect) { } } - private static final HashMap FUNCTION_REGISTRIES = new HashMap<>(); + private static final HashMap FUNCTION_CONTRIBUTIONS = new HashMap<>(); public static boolean definesFunction(Dialect dialect, String functionName) { return getSqmFunctionRegistry( dialect ).findFunctionDescriptor( functionName ) != null; @@ -1151,6 +1281,11 @@ public static boolean definesSetReturningFunction(Dialect dialect, String functi return getSqmFunctionRegistry( dialect ).findSetReturningFunctionDescriptor( functionName ) != null; } + public static boolean definesDdlType(Dialect dialect, int typeCode) { + final DdlTypeRegistry ddlTypeRegistry = getFunctionContributions( dialect ).typeConfiguration.getDdlTypeRegistry(); + return ddlTypeRegistry.getDescriptor( typeCode ) != null; + } + public static class SupportsSubqueryInSelect implements DialectFeatureCheck { @Override public boolean apply(Dialect dialect) { @@ -1172,24 +1307,33 @@ public boolean apply(Dialect dialect) { } } - private static SqmFunctionRegistry getSqmFunctionRegistry(Dialect dialect) { - SqmFunctionRegistry sqmFunctionRegistry = FUNCTION_REGISTRIES.get( dialect ); - if ( sqmFunctionRegistry == null ) { + return getFunctionContributions( dialect ).functionRegistry; + } + + private static FakeFunctionContributions getFunctionContributions(Dialect dialect) { + FakeFunctionContributions functionContributions = FUNCTION_CONTRIBUTIONS.get( dialect ); + if ( functionContributions == null ) { final TypeConfiguration typeConfiguration = new TypeConfiguration(); final SqmFunctionRegistry functionRegistry = new SqmFunctionRegistry(); typeConfiguration.scope( new FakeMetadataBuildingContext( typeConfiguration, functionRegistry ) ); final FakeTypeContributions typeContributions = new FakeTypeContributions( typeConfiguration ); - final FakeFunctionContributions functionContributions = new FakeFunctionContributions( + functionContributions = new FakeFunctionContributions( dialect, typeConfiguration, functionRegistry ); dialect.contribute( typeContributions, typeConfiguration.getServiceRegistry() ); dialect.initializeFunctionRegistry( functionContributions ); - FUNCTION_REGISTRIES.put( dialect, sqmFunctionRegistry = functionContributions.functionRegistry ); + for ( TypeContributor typeContributor : ServiceLoader.load( TypeContributor.class ) ) { + typeContributor.contribute( typeContributions, typeConfiguration.getServiceRegistry() ); + } + for ( FunctionContributor functionContributor : ServiceLoader.load( FunctionContributor.class ) ) { + functionContributor.contributeFunctions( functionContributions ); + } + FUNCTION_CONTRIBUTIONS.put( dialect, functionContributions ); } - return sqmFunctionRegistry; + return functionContributions; } public static class FakeTypeContributions implements TypeContributions { diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/AbstractSparseVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/AbstractSparseVector.java new file mode 100644 index 000000000000..abfee8832ed9 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/AbstractSparseVector.java @@ -0,0 +1,138 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import org.hibernate.internal.util.collections.ArrayHelper; + +import java.util.AbstractList; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * Base class for sparse vectors. + * + * @since 7.1 + */ +public abstract class AbstractSparseVector extends AbstractList { + + protected static final int[] EMPTY_INT_ARRAY = new int[0]; + + protected interface ElementParser { + V parse(String string, int start, int end); + } + protected record ParsedVector(int size, int[] indices, List elements) { + } + + protected static ParsedVector parseSparseVector(String string, ElementParser parser) { + if ( string == null || !string.startsWith( "[" ) || !string.endsWith( "]" ) ) { + throw invalidVector( string ); + } + final int lengthEndIndex = string.indexOf( ',', 2 ); + if ( lengthEndIndex == -1 ) { + throw invalidVector( string ); + } + final int indicesStartIndex = lengthEndIndex + 1; + if ( string.charAt( indicesStartIndex ) != '[' ) { + throw invalidVector( string ); + } + final int indicesEndIndex = string.indexOf( ']', indicesStartIndex + 1 ); + if ( indicesEndIndex == -1 ) { + throw invalidVector( string ); + } + final int commaIndex = indicesEndIndex + 1; + if ( string.charAt( commaIndex ) != ',' ) { + throw invalidVector( string ); + } + final int elementsStartIndex = commaIndex + 1; + if ( string.charAt( elementsStartIndex ) != '[' ) { + throw invalidVector( string ); + } + final int elementsEndIndex = string.indexOf( ']', elementsStartIndex + 1 ); + if ( elementsEndIndex == -1 ) { + throw invalidVector( string ); + } + if ( elementsEndIndex != string.length() - 2 ) { + throw invalidVector( string ); + } + final int size = Integer.parseInt( string, 1, lengthEndIndex, 10 ); + int start = indicesStartIndex + 1; + final List indicesList = new ArrayList<>(); + if ( start < indicesEndIndex ) { + for ( int i = start; i < indicesEndIndex; i++ ) { + if ( string.charAt( i ) == ',' ) { + indicesList.add( Integer.parseInt( string, start, i, 10 ) ); + start = i + 1; + } + } + indicesList.add( Integer.parseInt( string, start, indicesEndIndex, 10 ) ); + } + final int[] indices = ArrayHelper.toIntArray( indicesList ); + final List elements = new ArrayList<>( indices.length ); + start = elementsStartIndex + 1; + if ( start < elementsEndIndex ) { + for ( int i = start; i < elementsEndIndex; i++ ) { + if ( string.charAt( i ) == ',' ) { + elements.add( parser.parse( string, start, i ) ); + start = i + 1; + } + } + elements.add( parser.parse( string, start, elementsEndIndex ) ); + } + return new ParsedVector<>( size, indices, elements ); + } + + private static IllegalArgumentException invalidVector(String string) { + return new IllegalArgumentException( "Invalid sparse vector string: " + string ); + } + + protected static int[] validateIndices(int[] indices, int dataLength, int size) { + if ( indices == null ) { + throw new IllegalArgumentException( "indices cannot be null" ); + } + if ( indices.length != dataLength ) { + throw new IllegalArgumentException( "indices length does not match data length" ); + } + int previousIndex = -1; + for ( int i = 0; i < indices.length; i++ ) { + if ( indices[i] < 0 ) { + throw new IllegalArgumentException( "indices[" + i + "] < 0" ); + } + else if ( indices[i] < previousIndex ) { + throw new IllegalArgumentException( "Indices array is not sorted ascendingly." ); + } + previousIndex = indices[i]; + } + if ( previousIndex >= size ) { + throw new IllegalArgumentException( "Indices array contains index " + previousIndex + " that is greater than or equal to size: " + size ); + } + return indices; + } + + @Override + public void clear() { + throw new UnsupportedOperationException( "Cannot remove from sparse vector" ); + } + + @Override + public E remove(int index) { + throw new UnsupportedOperationException( "Cannot remove from sparse vector" ); + } + + @Override + public boolean add(E aByte) { + throw new UnsupportedOperationException( "Cannot add to sparse vector" ); + } + + @Override + public void add(int index, E element) { + throw new UnsupportedOperationException( "Cannot add to sparse vector" ); + } + + @Override + public boolean addAll(int index, Collection c) { + throw new UnsupportedOperationException( "Cannot add to sparse vector" ); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/PGVectorFunctionContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/PGVectorFunctionContributor.java deleted file mode 100644 index 7b52fcafa592..000000000000 --- a/hibernate-vector/src/main/java/org/hibernate/vector/PGVectorFunctionContributor.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright Red Hat Inc. and Hibernate Authors - */ -package org.hibernate.vector; - -import org.hibernate.boot.model.FunctionContributions; -import org.hibernate.boot.model.FunctionContributor; -import org.hibernate.dialect.CockroachDialect; -import org.hibernate.dialect.Dialect; -import org.hibernate.dialect.PostgreSQLDialect; - -public class PGVectorFunctionContributor implements FunctionContributor { - - @Override - public void contributeFunctions(FunctionContributions functionContributions) { - final Dialect dialect = functionContributions.getDialect(); - if (dialect instanceof PostgreSQLDialect || dialect instanceof CockroachDialect) { - final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions ); - - vectorFunctionFactory.cosineDistance( "?1<=>?2" ); - vectorFunctionFactory.euclideanDistance( "?1<->?2" ); - vectorFunctionFactory.l1Distance( "l1_distance(?1,?2)" ); - - vectorFunctionFactory.innerProduct( "(?1<#>?2)*-1" ); - vectorFunctionFactory.negativeInnerProduct( "?1<#>?2" ); - - vectorFunctionFactory.vectorDimensions(); - vectorFunctionFactory.vectorNorm(); - } - } - - @Override - public int ordinal() { - return 200; - } -} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/SparseByteVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/SparseByteVector.java new file mode 100644 index 000000000000..b47a2c0fb372 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/SparseByteVector.java @@ -0,0 +1,204 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import java.util.Arrays; +import java.util.List; + +/** + * {@link java.util.List} implementation for a sparse byte vector. + * + * @since 7.1 + */ +public class SparseByteVector extends AbstractSparseVector { + + private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; + + private final int size; + private int[] indices = EMPTY_INT_ARRAY; + private byte[] data = EMPTY_BYTE_ARRAY; + + public SparseByteVector(int size) { + if ( size <= 0 ) { + throw new IllegalArgumentException( "size must be greater than zero" ); + } + this.size = size; + } + + public SparseByteVector(List list) { + if ( list instanceof SparseByteVector sparseVector ) { + size = sparseVector.size; + indices = sparseVector.indices.clone(); + data = sparseVector.data.clone(); + } + else { + if ( list == null ) { + throw new IllegalArgumentException( "list cannot be null" ); + } + if ( list.isEmpty() ) { + throw new IllegalArgumentException( "list cannot be empty" ); + } + int size = 0; + int[] indices = new int[list.size()]; + byte[] data = new byte[list.size()]; + for ( int i = 0; i < list.size(); i++ ) { + final Byte b = list.get( i ); + if ( b != null && b != 0 ) { + indices[size] = i; + data[size] = b; + size++; + } + } + this.size = list.size(); + this.indices = Arrays.copyOf( indices, size ); + this.data = Arrays.copyOf( data, size ); + } + } + + public SparseByteVector(byte[] denseVector) { + if ( denseVector == null ) { + throw new IllegalArgumentException( "denseVector cannot be null" ); + } + if ( denseVector.length == 0 ) { + throw new IllegalArgumentException( "denseVector cannot be empty" ); + } + int size = 0; + int[] indices = new int[denseVector.length]; + byte[] data = new byte[denseVector.length]; + for ( int i = 0; i < denseVector.length; i++ ) { + final byte b = denseVector[i]; + if ( b != 0 ) { + indices[size] = i; + data[size] = b; + size++; + } + } + this.size = denseVector.length; + this.indices = Arrays.copyOf( indices, size ); + this.data = Arrays.copyOf( data, size ); + } + + public SparseByteVector(int size, int[] indices, byte[] data) { + this( validateData( data, size ), validateIndices( indices, data.length, size ), size ); + } + + private SparseByteVector(byte[] data, int[] indices, int size) { + this.size = size; + this.indices = indices; + this.data = data; + } + + public SparseByteVector(String string) { + final ParsedVector parsedVector = + parseSparseVector( string, (s, start, end) -> Byte.parseByte( s.substring( start, end ) ) ); + this.size = parsedVector.size(); + this.indices = parsedVector.indices(); + this.data = toByteArray( parsedVector.elements() ); + } + + private static byte[] toByteArray(List elements) { + final byte[] result = new byte[elements.size()]; + for ( int i = 0; i < elements.size(); i++ ) { + result[i] = elements.get(i); + } + return result; + } + + private static byte[] validateData(byte[] data, int size) { + if ( size == 0 ) { + throw new IllegalArgumentException( "size cannot be 0" ); + } + if ( data == null ) { + throw new IllegalArgumentException( "data cannot be null" ); + } + if ( size < data.length ) { + throw new IllegalArgumentException( "size cannot be smaller than data size" ); + } + for ( int i = 0; i < data.length; i++ ) { + if ( data[i] == 0 ) { + throw new IllegalArgumentException( "data[" + i + "] == 0" ); + } + } + return data; + } + + @Override + public SparseByteVector clone() { + return new SparseByteVector( data.clone(), indices.clone(), size ); + } + + @Override + public Byte get(int index) { + final int foundIndex = Arrays.binarySearch( indices, index ); + return foundIndex < 0 ? 0 : data[foundIndex]; + } + + @Override + public Byte set(int index, Byte element) { + final int foundIndex = Arrays.binarySearch( indices, index ); + if ( foundIndex < 0 ) { + if ( element != null && element != 0 ) { + final int[] newIndices = new int[indices.length + 1]; + final byte[] newData = new byte[data.length + 1]; + final int insertionPoint = -foundIndex - 1; + System.arraycopy( indices, 0, newIndices, 0, insertionPoint ); + System.arraycopy( data, 0, newData, 0, insertionPoint ); + newIndices[insertionPoint] = index; + newData[insertionPoint] = element; + System.arraycopy( indices, insertionPoint, newIndices, insertionPoint + 1, indices.length - insertionPoint ); + System.arraycopy( data, insertionPoint, newData, insertionPoint + 1, data.length - insertionPoint ); + this.indices = newIndices; + this.data = newData; + } + return null; + } + else { + final byte oldValue = data[foundIndex]; + if ( element != null && element != 0 ) { + data[foundIndex] = element; + } + else { + final int[] newIndices = new int[indices.length - 1]; + final byte[] newData = new byte[data.length - 1]; + System.arraycopy( indices, 0, newIndices, 0, foundIndex ); + System.arraycopy( data, 0, newData, 0, foundIndex ); + System.arraycopy( indices, foundIndex + 1, newIndices, foundIndex, indices.length - foundIndex - 1 ); + System.arraycopy( data, foundIndex + 1, newData, foundIndex, data.length - foundIndex - 1 ); + this.indices = newIndices; + this.data = newData; + } + return oldValue; + } + } + + public byte[] toDenseVector() { + final byte[] result = new byte[this.size]; + for ( int i = 0; i < indices.length; i++ ) { + result[indices[i]] = data[i]; + } + return result; + } + + public int[] indices() { + return indices; + } + + public byte[] bytes() { + return data; + } + + @Override + public int size() { + return size; + } + + @Override + public String toString() { + return '[' + size + + ',' + Arrays.toString( indices ) + + ',' + Arrays.toString( data ) + + ']'; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/SparseDoubleVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/SparseDoubleVector.java new file mode 100644 index 000000000000..d5991dd8ed29 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/SparseDoubleVector.java @@ -0,0 +1,184 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import java.util.Arrays; +import java.util.List; + +/** + * {@link List} implementation for a sparse byte vector. + * + * @since 7.1 + */ +public class SparseDoubleVector extends AbstractSparseVector { + + private static final double[] EMPTY_FLOAT_ARRAY = new double[0]; + + private final int size; + private int[] indices = EMPTY_INT_ARRAY; + private double[] data = EMPTY_FLOAT_ARRAY; + + public SparseDoubleVector(int size) { + this.size = size; + } + + public SparseDoubleVector(List list) { + if ( list instanceof SparseDoubleVector sparseVector ) { + size = sparseVector.size; + indices = sparseVector.indices.clone(); + data = sparseVector.data.clone(); + } + else { + int size = 0; + int[] indices = new int[list.size()]; + double[] data = new double[list.size()]; + for ( int i = 0; i < list.size(); i++ ) { + final Double b = list.get( i ); + if ( b != null && b != 0 ) { + indices[size] = i; + data[size] = b; + size++; + } + } + this.size = list.size(); + this.indices = Arrays.copyOf( indices, size ); + this.data = Arrays.copyOf( data, size ); + } + } + + public SparseDoubleVector(double[] denseVector) { + int size = 0; + int[] indices = new int[denseVector.length]; + double[] data = new double[denseVector.length]; + for ( int i = 0; i < denseVector.length; i++ ) { + final double b = denseVector[i]; + if ( b != 0 ) { + indices[size] = i; + data[size] = b; + size++; + } + } + this.size = denseVector.length; + this.indices = Arrays.copyOf( indices, size ); + this.data = Arrays.copyOf( data, size ); + } + + public SparseDoubleVector(int size, int[] indices, double[] data) { + this( validateData( data ), validateIndices( indices, data.length, size ), size ); + } + + private SparseDoubleVector(double[] data, int[] indices, int size) { + this.size = size; + this.indices = indices; + this.data = data; + } + + public SparseDoubleVector(String string) { + final ParsedVector parsedVector = + parseSparseVector( string, (s, start, end) -> Double.parseDouble( s.substring( start, end ) ) ); + this.size = parsedVector.size(); + this.indices = parsedVector.indices(); + this.data = toDoubleArray( parsedVector.elements() ); + } + + private static double[] toDoubleArray(List elements) { + final double[] result = new double[elements.size()]; + for ( int i = 0; i < elements.size(); i++ ) { + result[i] = elements.get(i); + } + return result; + } + + private static double[] validateData(double[] data) { + if ( data == null ) { + throw new IllegalArgumentException( "data cannot be null" ); + } + for ( int i = 0; i < data.length; i++ ) { + if ( data[i] == 0 ) { + throw new IllegalArgumentException( "data[" + i + "] == 0" ); + } + } + return data; + } + + @Override + public SparseDoubleVector clone() { + return new SparseDoubleVector( data.clone(), indices.clone(), size ); + } + + @Override + public Double get(int index) { + final int foundIndex = Arrays.binarySearch( indices, index ); + return foundIndex < 0 ? 0 : data[foundIndex]; + } + + @Override + public Double set(int index, Double element) { + final int foundIndex = Arrays.binarySearch( indices, index ); + if ( foundIndex < 0 ) { + if ( element != null && element != 0 ) { + final int[] newIndices = new int[indices.length + 1]; + final double[] newData = new double[data.length + 1]; + final int insertionPoint = -foundIndex - 1; + System.arraycopy( indices, 0, newIndices, 0, insertionPoint ); + System.arraycopy( data, 0, newData, 0, insertionPoint ); + newIndices[insertionPoint] = index; + newData[insertionPoint] = element; + System.arraycopy( indices, insertionPoint, newIndices, insertionPoint + 1, indices.length - insertionPoint ); + System.arraycopy( data, insertionPoint, newData, insertionPoint + 1, data.length - insertionPoint ); + this.indices = newIndices; + this.data = newData; + } + return null; + } + else { + final double oldValue = data[foundIndex]; + if ( element != null && element != 0 ) { + data[foundIndex] = element; + } + else { + final int[] newIndices = new int[indices.length - 1]; + final double[] newData = new double[data.length - 1]; + System.arraycopy( indices, 0, newIndices, 0, foundIndex ); + System.arraycopy( data, 0, newData, 0, foundIndex ); + System.arraycopy( indices, foundIndex + 1, newIndices, foundIndex, indices.length - foundIndex - 1 ); + System.arraycopy( data, foundIndex + 1, newData, foundIndex, data.length - foundIndex - 1 ); + this.indices = newIndices; + this.data = newData; + } + return oldValue; + } + } + + public double[] toDenseVector() { + final double[] result = new double[this.size]; + for ( int i = 0; i < indices.length; i++ ) { + result[indices[i]] = data[i]; + } + return result; + } + + public int[] indices() { + return indices; + } + + public double[] doubles() { + return data; + } + + @Override + public int size() { + return size; + } + + @Override + public String toString() { + return '[' + size + + ',' + Arrays.toString( indices ) + + ',' + Arrays.toString( data ) + + ']'; + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/SparseFloatVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/SparseFloatVector.java new file mode 100644 index 000000000000..f683aec38332 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/SparseFloatVector.java @@ -0,0 +1,184 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import java.util.Arrays; +import java.util.List; + +/** + * {@link List} implementation for a sparse byte vector. + * + * @since 7.1 + */ +public class SparseFloatVector extends AbstractSparseVector { + + private static final float[] EMPTY_FLOAT_ARRAY = new float[0]; + + private final int size; + private int[] indices = EMPTY_INT_ARRAY; + private float[] data = EMPTY_FLOAT_ARRAY; + + public SparseFloatVector(int size) { + this.size = size; + } + + public SparseFloatVector(List list) { + if ( list instanceof SparseFloatVector sparseVector ) { + size = sparseVector.size; + indices = sparseVector.indices.clone(); + data = sparseVector.data.clone(); + } + else { + int size = 0; + int[] indices = new int[list.size()]; + float[] data = new float[list.size()]; + for ( int i = 0; i < list.size(); i++ ) { + final Float b = list.get( i ); + if ( b != null && b != 0 ) { + indices[size] = i; + data[size] = b; + size++; + } + } + this.size = list.size(); + this.indices = Arrays.copyOf( indices, size ); + this.data = Arrays.copyOf( data, size ); + } + } + + public SparseFloatVector(float[] denseVector) { + int size = 0; + int[] indices = new int[denseVector.length]; + float[] data = new float[denseVector.length]; + for ( int i = 0; i < denseVector.length; i++ ) { + final float b = denseVector[i]; + if ( b != 0 ) { + indices[size] = i; + data[size] = b; + size++; + } + } + this.size = denseVector.length; + this.indices = Arrays.copyOf( indices, size ); + this.data = Arrays.copyOf( data, size ); + } + + public SparseFloatVector(int size, int[] indices, float[] data) { + this( validateData( data ), validateIndices( indices, data.length, size ), size ); + } + + private SparseFloatVector(float[] data, int[] indices, int size) { + this.size = size; + this.indices = indices; + this.data = data; + } + + public SparseFloatVector(String string) { + final ParsedVector parsedVector = + parseSparseVector( string, (s, start, end) -> Float.parseFloat( s.substring( start, end ) ) ); + this.size = parsedVector.size(); + this.indices = parsedVector.indices(); + this.data = toFloatArray( parsedVector.elements() ); + } + + private static float[] toFloatArray(List elements) { + final float[] result = new float[elements.size()]; + for ( int i = 0; i < elements.size(); i++ ) { + result[i] = elements.get(i); + } + return result; + } + + private static float[] validateData(float[] data) { + if ( data == null ) { + throw new IllegalArgumentException( "data cannot be null" ); + } + for ( int i = 0; i < data.length; i++ ) { + if ( data[i] == 0 ) { + throw new IllegalArgumentException( "data[" + i + "] == 0" ); + } + } + return data; + } + + @Override + public SparseFloatVector clone() { + return new SparseFloatVector( data.clone(), indices.clone(), size ); + } + + @Override + public Float get(int index) { + final int foundIndex = Arrays.binarySearch( indices, index ); + return foundIndex < 0 ? 0 : data[foundIndex]; + } + + @Override + public Float set(int index, Float element) { + final int foundIndex = Arrays.binarySearch( indices, index ); + if ( foundIndex < 0 ) { + if ( element != null && element != 0 ) { + final int[] newIndices = new int[indices.length + 1]; + final float[] newData = new float[data.length + 1]; + final int insertionPoint = -foundIndex - 1; + System.arraycopy( indices, 0, newIndices, 0, insertionPoint ); + System.arraycopy( data, 0, newData, 0, insertionPoint ); + newIndices[insertionPoint] = index; + newData[insertionPoint] = element; + System.arraycopy( indices, insertionPoint, newIndices, insertionPoint + 1, indices.length - insertionPoint ); + System.arraycopy( data, insertionPoint, newData, insertionPoint + 1, data.length - insertionPoint ); + this.indices = newIndices; + this.data = newData; + } + return null; + } + else { + final float oldValue = data[foundIndex]; + if ( element != null && element != 0 ) { + data[foundIndex] = element; + } + else { + final int[] newIndices = new int[indices.length - 1]; + final float[] newData = new float[data.length - 1]; + System.arraycopy( indices, 0, newIndices, 0, foundIndex ); + System.arraycopy( data, 0, newData, 0, foundIndex ); + System.arraycopy( indices, foundIndex + 1, newIndices, foundIndex, indices.length - foundIndex - 1 ); + System.arraycopy( data, foundIndex + 1, newData, foundIndex, data.length - foundIndex - 1 ); + this.indices = newIndices; + this.data = newData; + } + return oldValue; + } + } + + public float[] toDenseVector() { + final float[] result = new float[this.size]; + for ( int i = 0; i < indices.length; i++ ) { + result[indices[i]] = data[i]; + } + return result; + } + + public int[] indices() { + return indices; + } + + public float[] floats() { + return data; + } + + @Override + public int size() { + return size; + } + + @Override + public String toString() { + return '[' + size + + ',' + Arrays.toString( indices ) + + ',' + Arrays.toString( data ) + + ']'; + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractOracleSparseVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractOracleSparseVectorJdbcType.java new file mode 100644 index 000000000000..cab0cf11516d --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractOracleSparseVectorJdbcType.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.ValueBinder; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.BasicBinder; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.vector.AbstractSparseVector; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.SQLException; + +public abstract class AbstractOracleSparseVectorJdbcType extends AbstractOracleVectorJdbcType { + + public AbstractOracleSparseVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { + super( elementJdbcType, isVectorSupported ); + } + + @Override + public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaTypeDescriptor) { + return new OracleJdbcLiteralFormatterSparseVector<>( javaTypeDescriptor, getVectorParameters() ); + } + + @Override + public ValueBinder getBinder(final JavaType javaTypeDescriptor) { + return new BasicBinder<>( javaTypeDescriptor, this ) { + @Override + protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) + throws SQLException { + if ( isVectorSupported ) { + st.setObject( index, getBindValue( value, options ) ); + } + else { + st.setString( index, stringVector( value, options ) ); + } + } + + @Override + protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) + throws SQLException { + if ( isVectorSupported ) { + st.setObject( name, getBindValue( value, options ) ); + } + else { + st.setString( name, stringVector( value, options ) ); + } + } + + private String stringVector(X value, WrapperOptions options) { + return ((AbstractOracleSparseVectorJdbcType) getJdbcType()).getStringVector( value, getJavaType(), options ); + } + + @Override + public Object getBindValue(X value, WrapperOptions options) { + return ((AbstractOracleSparseVectorJdbcType) getJdbcType()).getBindValue( getJavaType(), value, options ); + } + }; + } + + protected abstract Object getBindValue(JavaType javaType, X value, WrapperOptions options); + + @Override + protected String getStringVector(T vector, JavaType javaTypeDescriptor, WrapperOptions options) { + return javaTypeDescriptor.unwrap( vector, AbstractSparseVector.class, options ).toString(); + } + + @Override + protected Class getNativeJavaType() { + return Object.class; + } + + @Override + protected int getNativeTypeCode() { + return SqlTypes.OTHER; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/AbstractOracleVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractOracleVectorJdbcType.java similarity index 82% rename from hibernate-vector/src/main/java/org/hibernate/vector/AbstractOracleVectorJdbcType.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractOracleVectorJdbcType.java index 48407feab48f..23d839f428ff 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/AbstractOracleVectorJdbcType.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractOracleVectorJdbcType.java @@ -2,14 +2,16 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import java.sql.CallableStatement; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.metamodel.mapping.JdbcMapping; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.ValueBinder; @@ -24,7 +26,6 @@ import org.hibernate.type.descriptor.jdbc.BasicExtractor; import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; import org.hibernate.type.descriptor.jdbc.JdbcType; -import org.hibernate.type.descriptor.jdbc.internal.JdbcLiteralFormatterArray; /** * Specialized type mapping for generic vector {@link SqlTypes#VECTOR} SQL data type for Oracle. @@ -43,13 +44,32 @@ public AbstractOracleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSu this.isVectorSupported = isVectorSupported; } - public abstract void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect); + @Override + public @Nullable String castToPattern(JdbcMapping targetJdbcMapping) { + return targetJdbcMapping.getJdbcType().isStringLike() ? "from_vector(?1 returning ?2)" : null; + } + + @Override + public @Nullable String castFromPattern(JdbcMapping sourceMapping) { + return sourceMapping.getJdbcType().isStringLike() ? "to_vector(?1," + getVectorParameters() + ")" : null; + } @Override - public int getDefaultSqlTypeCode() { - return SqlTypes.VECTOR; + public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { + if ( isVectorSupported) { + appender.append( writeExpression ); + } + else { + appender.append( "to_vector(" ); + appender.append( writeExpression ); + appender.append( ',' ); + appender.append( getVectorParameters() ); + appender.append( ')' ); + } } + public abstract String getVectorParameters(); + @Override public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaTypeDescriptor) { final JavaType elementJavaType; @@ -65,9 +85,10 @@ else if ( javaTypeDescriptor instanceof BasicPluralJavaType ) { else { throw new IllegalArgumentException( "not a BasicPluralJavaType" ); } - return new JdbcLiteralFormatterArray<>( + return new OracleJdbcLiteralFormatterVector<>( javaTypeDescriptor, - getElementJdbcType().getJdbcLiteralFormatter( elementJavaType ) + getElementJdbcType().getJdbcLiteralFormatter( elementJavaType ), + getVectorParameters().replace( ",sparse", "" ) ); } @@ -76,7 +97,6 @@ public String toString() { return "OracleVectorTypeDescriptor"; } - @Override public ValueBinder getBinder(final JavaType javaTypeDescriptor) { return new BasicBinder<>( javaTypeDescriptor, this ) { @@ -142,7 +162,7 @@ protected X doExtract(CallableStatement statement, String name, WrapperOptions o }; } - protected abstract T getVectorArray(String string); + protected abstract Object getVectorArray(String string); protected abstract String getStringVector(T vector, JavaType javaTypeDescriptor, WrapperOptions options); diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/MariaDBFunctionContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBFunctionContributor.java similarity index 86% rename from hibernate-vector/src/main/java/org/hibernate/vector/MariaDBFunctionContributor.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBFunctionContributor.java index ac14aa3d48cd..f707dbb69d1e 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/MariaDBFunctionContributor.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBFunctionContributor.java @@ -2,7 +2,7 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import org.hibernate.boot.model.FunctionContributions; import org.hibernate.boot.model.FunctionContributor; @@ -13,7 +13,7 @@ public class MariaDBFunctionContributor implements FunctionContributor { @Override public void contributeFunctions(FunctionContributions functionContributions) { final Dialect dialect = functionContributions.getDialect(); - if ( dialect instanceof MariaDBDialect ) { + if ( dialect instanceof MariaDBDialect && dialect.getVersion().isSameOrAfter( 11, 7 ) ) { final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions ); vectorFunctionFactory.cosineDistance( "vec_distance_cosine(?1,?2)" ); diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/MariaDBTypeContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBTypeContributor.java similarity index 63% rename from hibernate-vector/src/main/java/org/hibernate/vector/MariaDBTypeContributor.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBTypeContributor.java index 78a3540db69d..bddcdfb415b1 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/MariaDBTypeContributor.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBTypeContributor.java @@ -2,13 +2,12 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import org.hibernate.boot.model.TypeContributions; import org.hibernate.boot.model.TypeContributor; import org.hibernate.dialect.Dialect; import org.hibernate.dialect.MariaDBDialect; -import org.hibernate.engine.jdbc.Size; import org.hibernate.engine.jdbc.spi.JdbcServices; import org.hibernate.service.ServiceRegistry; import org.hibernate.type.BasicArrayType; @@ -19,7 +18,6 @@ import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; -import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl; import org.hibernate.type.spi.TypeConfiguration; import java.lang.reflect.Type; @@ -34,35 +32,45 @@ public class MariaDBTypeContributor implements TypeContributor { @Override public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) { final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect(); - if ( dialect instanceof MariaDBDialect ) { + if ( dialect instanceof MariaDBDialect && dialect.getVersion().isSameOrAfter( 11, 7 ) ) { final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration(); final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry(); final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry(); final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); final BasicType floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ); - final ArrayJdbcType vectorJdbcType = new BinaryVectorJdbcType( jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ) ); - jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, vectorJdbcType ); + final ArrayJdbcType genericVectorJdbcType = new MariaDBVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType ); + final ArrayJdbcType floatVectorJdbcType = new MariaDBVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR_FLOAT32 + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType ); for ( Type vectorJavaType : VECTOR_JAVA_TYPES ) { basicTypeRegistry.register( new BasicArrayType<>( floatBasicType, - vectorJdbcType, + genericVectorJdbcType, javaTypeRegistry.getDescriptor( vectorJavaType ) ), StandardBasicTypes.VECTOR.getName() ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + floatVectorJdbcType, + javaTypeRegistry.getDescriptor( vectorJavaType ) + ), + StandardBasicTypes.VECTOR_FLOAT32.getName() + ); } typeConfiguration.getDdlTypeRegistry().addDescriptor( - new DdlTypeImpl( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) { - @Override - public String getTypeName(Size size) { - return getTypeName( - size.getArrayLength() == null ? null : size.getArrayLength().longValue(), - null, - null - ); - } - } + new VectorDdlType( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_FLOAT32, "vector($l)", "vector", dialect ) ); } } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/BinaryVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBVectorJdbcType.java similarity index 73% rename from hibernate-vector/src/main/java/org/hibernate/vector/BinaryVectorJdbcType.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBVectorJdbcType.java index 2f25d70edbd8..52b6db8fdc6c 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/BinaryVectorJdbcType.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBVectorJdbcType.java @@ -2,11 +2,12 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.metamodel.mapping.JdbcMapping; import org.hibernate.sql.ast.spi.SqlAppender; -import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.ValueBinder; import org.hibernate.type.descriptor.ValueExtractor; import org.hibernate.type.descriptor.WrapperOptions; @@ -22,15 +23,18 @@ import java.sql.ResultSet; import java.sql.SQLException; -public class BinaryVectorJdbcType extends ArrayJdbcType { +public class MariaDBVectorJdbcType extends ArrayJdbcType { - public BinaryVectorJdbcType(JdbcType elementJdbcType) { + private final int sqlType; + + public MariaDBVectorJdbcType(JdbcType elementJdbcType, int sqlType) { super( elementJdbcType ); + this.sqlType = sqlType; } @Override public int getDefaultSqlTypeCode() { - return SqlTypes.VECTOR; + return sqlType; } @Override @@ -46,6 +50,16 @@ public void appendWriteExpression(String writeExpression, SqlAppender appender, appender.append( writeExpression ); } + @Override + public @Nullable String castFromPattern(JdbcMapping sourceMapping) { + return sourceMapping.getJdbcType().isStringLike() ? "vec_fromtext(?1)" : null; + } + + @Override + public @Nullable String castToPattern(JdbcMapping targetJdbcMapping) { + return targetJdbcMapping.getJdbcType().isStringLike() ? "vec_totext(?1)" : null; + } + @Override public ValueExtractor getExtractor(JavaType javaTypeDescriptor) { return new BasicExtractor<>( javaTypeDescriptor, this ) { @@ -88,4 +102,16 @@ public Object getBindValue(X value, WrapperOptions options) { } }; } + + @Override + public boolean equals(Object that) { + return super.equals( that ) + && that instanceof MariaDBVectorJdbcType vectorJdbcType + && sqlType == vectorJdbcType.sqlType; + } + + @Override + public int hashCode() { + return sqlType + 31 * super.hashCode(); + } } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLFunctionContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLFunctionContributor.java new file mode 100644 index 000000000000..0d93fb55128f --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLFunctionContributor.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.boot.model.FunctionContributions; +import org.hibernate.boot.model.FunctionContributor; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.MySQLDialect; + +public class MySQLFunctionContributor implements FunctionContributor { + @Override + public void contributeFunctions(FunctionContributions functionContributions) { + final Dialect dialect = functionContributions.getDialect(); + if ( dialect instanceof MySQLDialect mySQLDialect && mySQLDialect.getMySQLVersion().isSameOrAfter( 9, 0 ) ) { + final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions ); + + vectorFunctionFactory.cosineDistance( "distance(?1,?2,'cosine')" ); + vectorFunctionFactory.euclideanDistance( "distance(?1,?2,'euclidean')" ); + vectorFunctionFactory.innerProduct( "distance(?1,?2,'dot')*-1" ); + vectorFunctionFactory.negativeInnerProduct( "distance(?1,?2,'dot')" ); + + vectorFunctionFactory.registerNamedVectorFunction( + "vector_dim", + functionContributions.getTypeConfiguration().getBasicTypeForJavaType( Integer.class ), + 1 + ); + functionContributions.getFunctionRegistry().registerAlternateKey( "vector_dims", "vector_dim" ); + } + } + + @Override + public int ordinal() { + return 200; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/PGVectorTypeContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLTypeContributor.java similarity index 59% rename from hibernate-vector/src/main/java/org/hibernate/vector/PGVectorTypeContributor.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLTypeContributor.java index 1cfcf5257487..da0ff31a52d6 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/PGVectorTypeContributor.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLTypeContributor.java @@ -2,16 +2,12 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; - -import java.lang.reflect.Type; +package org.hibernate.vector.internal; import org.hibernate.boot.model.TypeContributions; import org.hibernate.boot.model.TypeContributor; -import org.hibernate.dialect.CockroachDialect; import org.hibernate.dialect.Dialect; -import org.hibernate.dialect.PostgreSQLDialect; -import org.hibernate.engine.jdbc.Size; +import org.hibernate.dialect.MySQLDialect; import org.hibernate.engine.jdbc.spi.JdbcServices; import org.hibernate.service.ServiceRegistry; import org.hibernate.type.BasicArrayType; @@ -22,10 +18,11 @@ import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; -import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl; import org.hibernate.type.spi.TypeConfiguration; -public class PGVectorTypeContributor implements TypeContributor { +import java.lang.reflect.Type; + +public class MySQLTypeContributor implements TypeContributor { private static final Type[] VECTOR_JAVA_TYPES = { Float[].class, @@ -35,36 +32,45 @@ public class PGVectorTypeContributor implements TypeContributor { @Override public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) { final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect(); - if ( dialect instanceof PostgreSQLDialect || - dialect instanceof CockroachDialect ) { + if ( dialect instanceof MySQLDialect mySQLDialect && mySQLDialect.getMySQLVersion().isSameOrAfter( 9, 0 ) ) { final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration(); final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry(); final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry(); final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); final BasicType floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ); - final ArrayJdbcType vectorJdbcType = new VectorJdbcType( jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ) ); - jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, vectorJdbcType ); + final ArrayJdbcType genericVectorJdbcType = new MySQLVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType ); + final ArrayJdbcType floatVectorJdbcType = new MySQLVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR_FLOAT32 + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType ); for ( Type vectorJavaType : VECTOR_JAVA_TYPES ) { basicTypeRegistry.register( new BasicArrayType<>( floatBasicType, - vectorJdbcType, + genericVectorJdbcType, javaTypeRegistry.getDescriptor( vectorJavaType ) ), StandardBasicTypes.VECTOR.getName() ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + floatVectorJdbcType, + javaTypeRegistry.getDescriptor( vectorJavaType ) + ), + StandardBasicTypes.VECTOR_FLOAT32.getName() + ); } typeConfiguration.getDdlTypeRegistry().addDescriptor( - new DdlTypeImpl( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) { - @Override - public String getTypeName(Size size) { - return getTypeName( - size.getArrayLength() == null ? null : size.getArrayLength().longValue(), - null, - null - ); - } - } + new VectorDdlType( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_FLOAT32, "vector($l)", "vector", dialect ) ); } } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLVectorJdbcType.java new file mode 100644 index 000000000000..121244bb6903 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLVectorJdbcType.java @@ -0,0 +1,122 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hibernate.dialect.Dialect; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.ValueBinder; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.BasicBinder; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.spi.TypeConfiguration; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; + +import static org.hibernate.vector.internal.VectorHelper.parseFloatVector; + +public class MySQLVectorJdbcType extends ArrayJdbcType { + + private final int sqlType; + + public MySQLVectorJdbcType(JdbcType elementJdbcType, int sqlType) { + super( elementJdbcType ); + this.sqlType = sqlType; + } + + @Override + public int getDefaultSqlTypeCode() { + return sqlType; + } + + @Override + public JavaType getJdbcRecommendedJavaTypeMapping( + Integer precision, + Integer scale, + TypeConfiguration typeConfiguration) { + return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class ); + } + + @Override + public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { + appender.append( "string_to_vector(" ); + appender.append( writeExpression ); + appender.append( ')' ); + } + + @Override + public @Nullable String castFromPattern(JdbcMapping sourceMapping) { + return sourceMapping.getJdbcType().isStringLike() ? "string_to_vector(?1)" : null; + } + + @Override + public @Nullable String castToPattern(JdbcMapping targetJdbcMapping) { + return targetJdbcMapping.getJdbcType().isStringLike() ? "vector_to_string(?1)" : null; + } + + @Override + public ValueExtractor getExtractor(JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseFloatVector( rs.getBytes( paramIndex ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseFloatVector( statement.getBytes( index ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseFloatVector( statement.getBytes( name ) ), options ); + } + + }; + } + + @Override + public ValueBinder getBinder(final JavaType javaTypeDescriptor) { + return new BasicBinder<>( javaTypeDescriptor, this ) { + + @Override + protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) throws SQLException { + st.setString( index, getBindValue( value, options ) ); + } + + @Override + protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) + throws SQLException { + st.setString( name, getBindValue( value, options ) ); + } + + @Override + public String getBindValue(X value, WrapperOptions options) { + return Arrays.toString( getJavaType().unwrap( value, float[].class, options ) ); + } + }; + } + + @Override + public boolean equals(Object that) { + return super.equals( that ) + && that instanceof MySQLVectorJdbcType vectorJdbcType + && sqlType == vectorJdbcType.sqlType; + } + + @Override + public int hashCode() { + return sqlType + 31 * super.hashCode(); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleBinaryVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleBinaryVectorJdbcType.java new file mode 100644 index 000000000000..65067976cb50 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleBinaryVectorJdbcType.java @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.OracleTypes; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.JdbcType; + +import java.util.Arrays; + +/** + * Specialized type mapping for binary vector {@link SqlTypes#VECTOR_BINARY} SQL data type for Oracle. + */ +public class OracleBinaryVectorJdbcType extends AbstractOracleVectorJdbcType { + + public OracleBinaryVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { + super( elementJdbcType, isVectorSupported ); + } + + @Override + public String getVectorParameters() { + return "*,binary"; + } + + @Override + public String getFriendlyName() { + return "VECTOR_BINARY"; + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.VECTOR_BINARY; + } + + @Override + protected byte[] getVectorArray(String string) { + return VectorHelper.parseByteVector( string ); + } + + @Override + protected String getStringVector(T vector, JavaType javaTypeDescriptor, WrapperOptions options) { + return Arrays.toString( javaTypeDescriptor.unwrap( vector, byte[].class, options ) ); + } + + protected Class getNativeJavaType(){ + return byte[].class; + } + + protected int getNativeTypeCode(){ + return OracleTypes.VECTOR_BINARY; + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleByteVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleByteVectorJdbcType.java similarity index 53% rename from hibernate-vector/src/main/java/org/hibernate/vector/OracleByteVectorJdbcType.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleByteVectorJdbcType.java index 76379fed45bc..6521cae13346 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/OracleByteVectorJdbcType.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleByteVectorJdbcType.java @@ -2,14 +2,11 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import java.util.Arrays; -import java.util.BitSet; -import org.hibernate.dialect.Dialect; import org.hibernate.dialect.OracleTypes; -import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.WrapperOptions; import org.hibernate.type.descriptor.java.JavaType; @@ -22,18 +19,13 @@ */ public class OracleByteVectorJdbcType extends AbstractOracleVectorJdbcType { - - private static final byte[] EMPTY = new byte[0]; - public OracleByteVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { super( elementJdbcType, isVectorSupported ); } @Override - public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { - appender.append( "to_vector(" ); - appender.append( writeExpression ); - appender.append( ", *, INT8)" ); + public String getVectorParameters() { + return "*,int8"; } @Override @@ -48,31 +40,7 @@ public int getDefaultSqlTypeCode() { @Override protected byte[] getVectorArray(String string) { - if ( string == null ) { - return null; - } - if ( string.length() == 2 ) { - return EMPTY; - } - final BitSet commaPositions = new BitSet(); - int size = 1; - for ( int i = 1; i < string.length(); i++ ) { - final char c = string.charAt( i ); - if ( c == ',' ) { - commaPositions.set( i ); - size++; - } - } - final byte[] result = new byte[size]; - int doubleStartIndex = 1; - int commaIndex; - int index = 0; - while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) { - result[index++] = Byte.parseByte( string.substring( doubleStartIndex, commaIndex ) ); - doubleStartIndex = commaIndex + 1; - } - result[index] = Byte.parseByte( string.substring( doubleStartIndex, string.length() - 1 ) ); - return result; + return VectorHelper.parseByteVector( string ); } @Override @@ -82,10 +50,10 @@ protected String getStringVector(T vector, JavaType javaTypeDescriptor, W protected Class getNativeJavaType(){ return byte[].class; - }; + } protected int getNativeTypeCode(){ return OracleTypes.VECTOR_INT8; - }; + } } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleDoubleVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleDoubleVectorJdbcType.java similarity index 53% rename from hibernate-vector/src/main/java/org/hibernate/vector/OracleDoubleVectorJdbcType.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleDoubleVectorJdbcType.java index 9a2c07318ffb..c32ffaf4da58 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/OracleDoubleVectorJdbcType.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleDoubleVectorJdbcType.java @@ -2,14 +2,11 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import java.util.Arrays; -import java.util.BitSet; -import org.hibernate.dialect.Dialect; import org.hibernate.dialect.OracleTypes; -import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.WrapperOptions; import org.hibernate.type.descriptor.java.JavaType; @@ -22,18 +19,13 @@ */ public class OracleDoubleVectorJdbcType extends AbstractOracleVectorJdbcType { - private static final double[] EMPTY = new double[0]; - public OracleDoubleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { super( elementJdbcType, isVectorSupported ); } - @Override - public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { - appender.append( "to_vector(" ); - appender.append( writeExpression ); - appender.append( ", *, FLOAT64)" ); + public String getVectorParameters() { + return "*,float64"; } @Override @@ -48,31 +40,7 @@ public int getDefaultSqlTypeCode() { @Override protected double[] getVectorArray(String string) { - if ( string == null ) { - return null; - } - if ( string.length() == 2 ) { - return EMPTY; - } - final BitSet commaPositions = new BitSet(); - int size = 1; - for ( int i = 1; i < string.length(); i++ ) { - final char c = string.charAt( i ); - if ( c == ',' ) { - commaPositions.set( i ); - size++; - } - } - final double[] result = new double[size]; - int doubleStartIndex = 1; - int commaIndex; - int index = 0; - while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) { - result[index++] = Double.parseDouble( string.substring( doubleStartIndex, commaIndex ) ); - doubleStartIndex = commaIndex + 1; - } - result[index] = Double.parseDouble( string.substring( doubleStartIndex, string.length() - 1 ) ); - return result; + return VectorHelper.parseDoubleVector( string ); } @Override diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleFloatVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleFloatVectorJdbcType.java similarity index 53% rename from hibernate-vector/src/main/java/org/hibernate/vector/OracleFloatVectorJdbcType.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleFloatVectorJdbcType.java index acb06905c4b9..5fea12a422d1 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/OracleFloatVectorJdbcType.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleFloatVectorJdbcType.java @@ -2,14 +2,11 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import java.util.Arrays; -import java.util.BitSet; -import org.hibernate.dialect.Dialect; import org.hibernate.dialect.OracleTypes; -import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.WrapperOptions; import org.hibernate.type.descriptor.java.JavaType; @@ -23,18 +20,13 @@ public class OracleFloatVectorJdbcType extends AbstractOracleVectorJdbcType { - - private static final float[] EMPTY = new float[0]; - public OracleFloatVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { super( elementJdbcType, isVectorSupported ); } @Override - public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { - appender.append( "to_vector(" ); - appender.append( writeExpression ); - appender.append( ", *, FLOAT32)" ); + public String getVectorParameters() { + return "*,float32"; } @Override @@ -49,31 +41,7 @@ public int getDefaultSqlTypeCode() { @Override protected float[] getVectorArray(String string) { - if ( string == null ) { - return null; - } - if ( string.length() == 2 ) { - return EMPTY; - } - final BitSet commaPositions = new BitSet(); - int size = 1; - for ( int i = 1; i < string.length(); i++ ) { - final char c = string.charAt( i ); - if ( c == ',' ) { - commaPositions.set( i ); - size++; - } - } - final float[] result = new float[size]; - int doubleStartIndex = 1; - int commaIndex; - int index = 0; - while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) { - result[index++] = Float.parseFloat( string.substring( doubleStartIndex, commaIndex ) ); - doubleStartIndex = commaIndex + 1; - } - result[index] = Float.parseFloat( string.substring( doubleStartIndex, string.length() - 1 ) ); - return result; + return VectorHelper.parseFloatVector( string ); } @Override diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleJdbcLiteralFormatterSparseVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleJdbcLiteralFormatterSparseVector.java new file mode 100644 index 000000000000..7fafe6e48997 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleJdbcLiteralFormatterSparseVector.java @@ -0,0 +1,32 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter; +import org.hibernate.vector.AbstractSparseVector; + +public class OracleJdbcLiteralFormatterSparseVector extends BasicJdbcLiteralFormatter { + + private final String vectorParameters; + + public OracleJdbcLiteralFormatterSparseVector(JavaType javaType, String vectorParameters) { + super( javaType ); + this.vectorParameters = vectorParameters; + } + + @Override + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { + appender.append( "to_vector(" ); + appender.append( getJavaType().unwrap( value, AbstractSparseVector.class, wrapperOptions ).toString() ); + appender.append( "," ); + appender.append( vectorParameters ); + appender.append( ')' ); + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleJdbcLiteralFormatterVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleJdbcLiteralFormatterVector.java new file mode 100644 index 000000000000..e463735df099 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleJdbcLiteralFormatterVector.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter; + +public class OracleJdbcLiteralFormatterVector extends BasicJdbcLiteralFormatter { + + private final JdbcLiteralFormatter elementFormatter; + private final String vectorParameters; + + public OracleJdbcLiteralFormatterVector(JavaType javaType, JdbcLiteralFormatter elementFormatter, String vectorParameters) { + super( javaType ); + //noinspection unchecked + this.elementFormatter = (JdbcLiteralFormatter) elementFormatter; + this.vectorParameters = vectorParameters; + } + + @Override + public void appendJdbcLiteral(SqlAppender appender, Object value, Dialect dialect, WrapperOptions wrapperOptions) { + final Object[] objects = unwrapArray( value, wrapperOptions ); + appender.append( "to_vector('" ); + char separator = '['; + for ( Object o : objects ) { + appender.append( separator ); + elementFormatter.appendJdbcLiteral( appender, o, dialect, wrapperOptions ); + separator = ','; + } + appender.append( "]'," ); + appender.append( vectorParameters ); + appender.append( ')' ); + } + + private Object[] unwrapArray(Object value, WrapperOptions wrapperOptions) { + return unwrap( value, Object[].class, wrapperOptions ); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseByteVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseByteVectorJdbcType.java new file mode 100644 index 000000000000..9c1f77c19f30 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseByteVectorJdbcType.java @@ -0,0 +1,115 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import oracle.sql.VECTOR; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.vector.SparseByteVector; + +import java.sql.CallableStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +/** + * Specialized type mapping for sparse single-byte integer vector {@link SqlTypes#SPARSE_VECTOR_INT8} SQL data type for Oracle. + */ +public class OracleSparseByteVectorJdbcType extends AbstractOracleSparseVectorJdbcType { + + public OracleSparseByteVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { + super( elementJdbcType, isVectorSupported ); + } + + @Override + public String getVectorParameters() { + return "*,int8,sparse"; + } + + @Override + public String getFriendlyName() { + return "SPARSE_VECTOR_INT8"; + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.SPARSE_VECTOR_INT8; + } + + @Override + protected Object getBindValue(JavaType javaType, X value, WrapperOptions options) { + if ( isVectorSupported ) { + final SparseByteVector sparseVector = javaType.unwrap( value, SparseByteVector.class, options ); + return VECTOR.SparseByteArray.of( sparseVector.size(), sparseVector.indices(), sparseVector.bytes() ); + } + else { + return getStringVector( value, javaType, options ); + } + } + + @Override + public ValueExtractor getExtractor(final JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( rs.getObject( paramIndex, VECTOR.SparseByteArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( rs.getString( paramIndex ) ), options ); + } + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( statement.getObject( index, VECTOR.SparseByteArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( statement.getString( index ) ), options ); + } + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) + throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( statement.getObject( name, VECTOR.SparseByteArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( statement.getString( name ) ), options ); + } + } + + private Object wrapNativeValue(VECTOR.SparseByteArray nativeValue) { + return nativeValue == null + ? null + : new SparseByteVector( nativeValue.length(), nativeValue.indices(), nativeValue.values() ); + } + + private Object wrapStringValue(String value) { + return ((AbstractOracleVectorJdbcType) getJdbcType() ).getVectorArray( value ); + } + + }; + } + + @Override + protected SparseByteVector getVectorArray(String string) { + if ( string == null ) { + return null; + } + return new SparseByteVector( string ); + } + + @Override + protected Class getNativeJavaType() { + return VECTOR.SparseByteArray.class; + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseDoubleVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseDoubleVectorJdbcType.java new file mode 100644 index 000000000000..74f859e91cef --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseDoubleVectorJdbcType.java @@ -0,0 +1,115 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import oracle.sql.VECTOR; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.vector.SparseDoubleVector; + +import java.sql.CallableStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +/** + * Specialized type mapping for sparse double-precision floating-point vector {@link SqlTypes#SPARSE_VECTOR_FLOAT64} SQL data type for Oracle. + */ +public class OracleSparseDoubleVectorJdbcType extends AbstractOracleSparseVectorJdbcType { + + public OracleSparseDoubleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { + super( elementJdbcType, isVectorSupported ); + } + + @Override + public String getVectorParameters() { + return "*,double64,sparse"; + } + + @Override + public String getFriendlyName() { + return "SPARSE_VECTOR_FLOAT64"; + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.SPARSE_VECTOR_FLOAT64; + } + + @Override + protected Object getBindValue(JavaType javaType, X value, WrapperOptions options) { + if ( isVectorSupported ) { + final SparseDoubleVector sparseVector = javaType.unwrap( value, SparseDoubleVector.class, options ); + return VECTOR.SparseDoubleArray.of( sparseVector.size(), sparseVector.indices(), sparseVector.doubles() ); + } + else { + return getStringVector( value, javaType, options ); + } + } + + @Override + public ValueExtractor getExtractor(final JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( rs.getObject( paramIndex, VECTOR.SparseDoubleArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( rs.getString( paramIndex ) ), options ); + } + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( statement.getObject( index, VECTOR.SparseDoubleArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( statement.getString( index ) ), options ); + } + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) + throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( statement.getObject( name, VECTOR.SparseDoubleArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( statement.getString( name ) ), options ); + } + } + + private Object wrapNativeValue(VECTOR.SparseDoubleArray nativeValue) { + return nativeValue == null + ? null + : new SparseDoubleVector( nativeValue.length(), nativeValue.indices(), nativeValue.values() ); + } + + private Object wrapStringValue(String value) { + return ((AbstractOracleVectorJdbcType) getJdbcType() ).getVectorArray( value ); + } + + }; + } + + @Override + protected SparseDoubleVector getVectorArray(String string) { + if ( string == null ) { + return null; + } + return new SparseDoubleVector( string ); + } + + @Override + protected Class getNativeJavaType() { + return VECTOR.SparseDoubleArray.class; + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseFloatVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseFloatVectorJdbcType.java new file mode 100644 index 000000000000..2425c15f9bf7 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseFloatVectorJdbcType.java @@ -0,0 +1,115 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import oracle.sql.VECTOR; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.vector.SparseFloatVector; + +import java.sql.CallableStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +/** + * Specialized type mapping for sparse single-precision floating-point vector {@link SqlTypes#SPARSE_VECTOR_FLOAT32} SQL data type for Oracle. + */ +public class OracleSparseFloatVectorJdbcType extends AbstractOracleSparseVectorJdbcType { + + public OracleSparseFloatVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { + super( elementJdbcType, isVectorSupported ); + } + + @Override + public String getVectorParameters() { + return "*,float32,sparse"; + } + + @Override + public String getFriendlyName() { + return "SPARSE_VECTOR_FLOAT32"; + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.SPARSE_VECTOR_FLOAT32; + } + + @Override + protected Object getBindValue(JavaType javaType, X value, WrapperOptions options) { + if ( isVectorSupported ) { + final SparseFloatVector sparseVector = javaType.unwrap( value, SparseFloatVector.class, options ); + return VECTOR.SparseFloatArray.of( sparseVector.size(), sparseVector.indices(), sparseVector.floats() ); + } + else { + return getStringVector( value, javaType, options ); + } + } + + @Override + public ValueExtractor getExtractor(final JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( rs.getObject( paramIndex, VECTOR.SparseFloatArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( rs.getString( paramIndex ) ), options ); + } + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( statement.getObject( index, VECTOR.SparseFloatArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( statement.getString( index ) ), options ); + } + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) + throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( statement.getObject( name, VECTOR.SparseFloatArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( statement.getString( name ) ), options ); + } + } + + private Object wrapNativeValue(VECTOR.SparseFloatArray nativeValue) { + return nativeValue == null + ? null + : new SparseFloatVector( nativeValue.length(), nativeValue.indices(), nativeValue.values() ); + } + + private Object wrapStringValue(String value) { + return ((AbstractOracleVectorJdbcType) getJdbcType() ).getVectorArray( value ); + } + + }; + } + + @Override + protected SparseFloatVector getVectorArray(String string) { + if ( string == null ) { + return null; + } + return new SparseFloatVector( string ); + } + + @Override + protected Class getNativeJavaType() { + return VECTOR.SparseFloatArray.class; + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorFunctionContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorFunctionContributor.java similarity index 80% rename from hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorFunctionContributor.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorFunctionContributor.java index 69572ac79c7d..52b55feefbc4 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorFunctionContributor.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorFunctionContributor.java @@ -2,7 +2,7 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import org.hibernate.boot.model.FunctionContributions; import org.hibernate.boot.model.FunctionContributor; @@ -14,19 +14,21 @@ public class OracleVectorFunctionContributor implements FunctionContributor { @Override public void contributeFunctions(FunctionContributions functionContributions) { final Dialect dialect = functionContributions.getDialect(); - if ( dialect instanceof OracleDialect ) { + if ( dialect instanceof OracleDialect && dialect.getVersion().isSameOrAfter( 23, 4 ) ) { final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions ); vectorFunctionFactory.cosineDistance( "vector_distance(?1,?2,COSINE)" ); vectorFunctionFactory.euclideanDistance( "vector_distance(?1,?2,EUCLIDEAN)" ); vectorFunctionFactory.l1Distance( "vector_distance(?1,?2,MANHATTAN)" ); vectorFunctionFactory.hammingDistance( "vector_distance(?1,?2,HAMMING)" ); + vectorFunctionFactory.jaccardDistance( "vector_distance(?1,?2,JACCARD)" ); vectorFunctionFactory.innerProduct( "vector_distance(?1,?2,DOT)*-1" ); vectorFunctionFactory.negativeInnerProduct( "vector_distance(?1,?2,DOT)" ); vectorFunctionFactory.vectorDimensions(); vectorFunctionFactory.vectorNorm(); + functionContributions.getFunctionRegistry().registerAlternateKey( "l2_norm", "vector_norm" ); } } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorJdbcType.java similarity index 75% rename from hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorJdbcType.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorJdbcType.java index cdfa8dd219f9..f4b9e74cceac 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorJdbcType.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorJdbcType.java @@ -2,11 +2,9 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; -import org.hibernate.dialect.Dialect; import org.hibernate.dialect.OracleTypes; -import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.jdbc.JdbcType; @@ -20,7 +18,6 @@ */ public class OracleVectorJdbcType extends OracleFloatVectorJdbcType { - public OracleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { super( elementJdbcType, isVectorSupported ); } @@ -31,10 +28,8 @@ public String getFriendlyName() { } @Override - public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { - appender.append( "to_vector(" ); - appender.append( writeExpression ); - appender.append( ", *, *)" ); + public String getVectorParameters() { + return "*,*"; } @Override diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorTypeContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorTypeContributor.java similarity index 57% rename from hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorTypeContributor.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorTypeContributor.java index 480bb58ef372..0dbf14b296b9 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorTypeContributor.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorTypeContributor.java @@ -2,24 +2,23 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import org.hibernate.boot.model.TypeContributions; import org.hibernate.boot.model.TypeContributor; import org.hibernate.dialect.Dialect; import org.hibernate.dialect.OracleDialect; -import org.hibernate.engine.jdbc.Size; import org.hibernate.engine.jdbc.spi.JdbcServices; import org.hibernate.internal.util.StringHelper; import org.hibernate.service.ServiceRegistry; import org.hibernate.type.BasicArrayType; +import org.hibernate.type.BasicCollectionType; import org.hibernate.type.BasicTypeRegistry; import org.hibernate.type.SqlTypes; import org.hibernate.type.StandardBasicTypes; import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; -import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl; import org.hibernate.type.spi.TypeConfiguration; public class OracleVectorTypeContributor implements TypeContributor { @@ -57,7 +56,30 @@ public void contribute(TypeContributions typeContributions, ServiceRegistry serv isVectorSupported ); jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_INT8, byteVectorJdbcType ); + final JdbcType bitVectorJdbcType = new OracleBinaryVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.TINYINT ), + isVectorSupported + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_BINARY, bitVectorJdbcType ); + final JdbcType sparseByteVectorJdbcType = new OracleSparseByteVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.TINYINT ), + isVectorSupported + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.SPARSE_VECTOR_INT8, sparseByteVectorJdbcType ); + final JdbcType sparseFloatVectorJdbcType = new OracleSparseFloatVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + isVectorSupported + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.SPARSE_VECTOR_FLOAT32, sparseFloatVectorJdbcType ); + final JdbcType sparseDoubleVectorJdbcType = new OracleSparseDoubleVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.DOUBLE ), + isVectorSupported + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.SPARSE_VECTOR_FLOAT64, sparseDoubleVectorJdbcType ); + javaTypeRegistry.addDescriptor( SparseByteVectorJavaType.INSTANCE ); + javaTypeRegistry.addDescriptor( SparseFloatVectorJavaType.INSTANCE ); + javaTypeRegistry.addDescriptor( SparseDoubleVectorJavaType.INSTANCE ); // Resolving basic types after jdbc types are registered. basicTypeRegistry.register( @@ -92,50 +114,62 @@ public void contribute(TypeContributions typeContributions, ServiceRegistry serv ), StandardBasicTypes.VECTOR_INT8.getName() ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.BYTE ), + bitVectorJdbcType, + javaTypeRegistry.getDescriptor( byte[].class ) + ), + StandardBasicTypes.VECTOR_BINARY.getName() + ); + basicTypeRegistry.register( + new BasicCollectionType<>( + basicTypeRegistry.resolve( StandardBasicTypes.BYTE ), + sparseByteVectorJdbcType, + SparseByteVectorJavaType.INSTANCE, + "sparse_byte_vector" + ) + ); + basicTypeRegistry.register( + new BasicCollectionType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + sparseFloatVectorJdbcType, + SparseFloatVectorJavaType.INSTANCE, + "sparse_float_vector" + ) + ); + basicTypeRegistry.register( + new BasicCollectionType<>( + basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE ), + sparseDoubleVectorJdbcType, + SparseDoubleVectorJavaType.INSTANCE, + "sparse_double_vector" + ) + ); typeConfiguration.getDdlTypeRegistry().addDescriptor( - new DdlTypeImpl( SqlTypes.VECTOR, "vector($l, *)", "vector", dialect ) { - @Override - public String getTypeName(Size size) { - return OracleVectorTypeContributor.replace( - "vector($l, *)", - size.getArrayLength() == null ? null : size.getArrayLength().longValue() - ); - } - } + new VectorDdlType( SqlTypes.VECTOR, "vector($l,*)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_INT8, "vector($l,int8)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_FLOAT32, "vector($l,float32)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_FLOAT64, "vector($l,float64)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_BINARY, "vector($l,binary)", "vector", dialect ) ); typeConfiguration.getDdlTypeRegistry().addDescriptor( - new DdlTypeImpl( SqlTypes.VECTOR_INT8, "vector($l, INT8)", "vector", dialect ) { - @Override - public String getTypeName(Size size) { - return OracleVectorTypeContributor.replace( - "vector($l, INT8)", - size.getArrayLength() == null ? null : size.getArrayLength().longValue() - ); - } - } + new VectorDdlType( SqlTypes.SPARSE_VECTOR_INT8, "vector($l,int8,sparse)", "vector", dialect ) ); typeConfiguration.getDdlTypeRegistry().addDescriptor( - new DdlTypeImpl( SqlTypes.VECTOR_FLOAT32, "vector($l, FLOAT32)", "vector", dialect ) { - @Override - public String getTypeName(Size size) { - return OracleVectorTypeContributor.replace( - "vector($l, FLOAT32)", - size.getArrayLength() == null ? null : size.getArrayLength().longValue() - ); - } - } + new VectorDdlType( SqlTypes.SPARSE_VECTOR_FLOAT32, "vector($l,float32,sparse)", "vector", dialect ) ); typeConfiguration.getDdlTypeRegistry().addDescriptor( - new DdlTypeImpl( SqlTypes.VECTOR_FLOAT64, "vector($l, FLOAT64)", "vector", dialect ) { - @Override - public String getTypeName(Size size) { - return OracleVectorTypeContributor.replace( - "vector($l, FLOAT64)", - size.getArrayLength() == null ? null : size.getArrayLength().longValue() - ); - } - } + new VectorDdlType( SqlTypes.SPARSE_VECTOR_FLOAT64, "vector($l,float64,sparse)", "vector", dialect ) ); } } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGBinaryVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGBinaryVectorJdbcType.java new file mode 100644 index 000000000000..b264ce049e5f --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGBinaryVectorJdbcType.java @@ -0,0 +1,103 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.ValueBinder; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.BasicBinder; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.spi.TypeConfiguration; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Types; + +import static org.hibernate.vector.internal.VectorHelper.parseBitString; +import static org.hibernate.vector.internal.VectorHelper.toBitString; + +public class PGBinaryVectorJdbcType extends ArrayJdbcType { + + public PGBinaryVectorJdbcType(JdbcType elementJdbcType) { + super( elementJdbcType ); + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.VECTOR_BINARY; + } + + @Override + public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaTypeDescriptor) { + return null; + } + + @Override + public JavaType getJdbcRecommendedJavaTypeMapping( + Integer precision, + Integer scale, + TypeConfiguration typeConfiguration) { + return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( byte[].class ); + } + +// @Override +// public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { +// appender.append( "cast(" ); +// appender.append( writeExpression ); +// appender.append( " as varbit)" ); +// } + + @Override + public @Nullable String castFromPattern(JdbcMapping sourceMapping) { + return sourceMapping.getJdbcType().isStringLike() ? "cast(?1 as varbit)" : null; + } + + @Override + public ValueBinder getBinder(final JavaType javaTypeDescriptor) { + return new BasicBinder<>( javaTypeDescriptor, this ) { + @Override + protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) + throws SQLException { + st.setObject( index, toBitString( getJavaType().unwrap( value, byte[].class, options ) ), Types.OTHER ); + } + + @Override + protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) + throws SQLException { + st.setObject( name, toBitString( getJavaType().unwrap( value, byte[].class, options ) ), Types.OTHER ); + } + + }; + } + + @Override + public ValueExtractor getExtractor(JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseBitString( rs.getString( paramIndex ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseBitString( statement.getString( index ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseBitString( statement.getString( name ) ), options ); + } + }; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGSparseFloatVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGSparseFloatVectorJdbcType.java new file mode 100644 index 000000000000..4c21ba7bcac2 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGSparseFloatVectorJdbcType.java @@ -0,0 +1,168 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hibernate.dialect.Dialect; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.ValueBinder; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.BasicBinder; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.spi.TypeConfiguration; +import org.hibernate.vector.SparseFloatVector; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +public class PGSparseFloatVectorJdbcType extends ArrayJdbcType { + + public PGSparseFloatVectorJdbcType(JdbcType elementJdbcType) { + super( elementJdbcType ); + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.SPARSE_VECTOR_FLOAT32; + } + + @Override + public JavaType getJdbcRecommendedJavaTypeMapping( + Integer precision, + Integer scale, + TypeConfiguration typeConfiguration) { + return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class ); + } + + @Override + public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { + appender.append( "cast(" ); + appender.append( writeExpression ); + appender.append( " as sparsevec)" ); + } + + @Override + public @Nullable String castFromPattern(JdbcMapping sourceMapping) { + return sourceMapping.getJdbcType().isStringLike() ? "cast(?1 as sparsevec)" : null; + } + + @Override + public ValueBinder getBinder(final JavaType javaTypeDescriptor) { + return new BasicBinder<>( javaTypeDescriptor, this ) { + + @Override + protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) throws SQLException { + st.setString( index, getString( value, options ) ); + } + + @Override + protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) + throws SQLException { + st.setString( name, getString( value, options ) ); + } + + @Override + public Object getBindValue(X value, WrapperOptions options) { + return getString( value, options ); + } + + private String getString(X value, WrapperOptions options) { + final SparseFloatVector vector = getJavaType().unwrap( value, SparseFloatVector.class, options ); + final int size = vector.size(); + final int[] indices = vector.indices(); + final float[] floats = vector.floats(); + final StringBuilder sb = new StringBuilder( indices.length * 50 ); + char separator = '{'; + for ( int i = 0; i < indices.length; i++ ) { + sb.append( separator ); + // The sparvec format is 1 based + sb.append( indices[i] + 1 ); + sb.append( ':' ); + sb.append( floats[i] ); + separator = ','; + } + sb.append("}/"); + sb.append( size ); + return sb.toString(); + } + }; + } + + @Override + public ValueExtractor getExtractor(JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseSparseFloatVector( rs.getString( paramIndex ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseSparseFloatVector( statement.getString( index ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseSparseFloatVector( statement.getString( name ) ), options ); + } + }; + } + + /** + * Parses the pgvector sparsevec format `{idx1:val1,idx2:val2}/size`. + */ + private static @Nullable SparseFloatVector parseSparseFloatVector(@Nullable String string) { + if ( string == null ) { + return null; + } + + final int slashIndex = string.lastIndexOf( '/' ); + if ( string.charAt( 0 ) != '{' || slashIndex == -1 || string.charAt( slashIndex - 1 ) != '}' ) { + throw new IllegalArgumentException( "Invalid sparse vector string: " + string ); + } + final int size = Integer.parseInt( string, slashIndex + 1, string.length(), 10 ); + final int end = slashIndex - 1; + final int count = countValues( string, end ); + final int[] indices = new int[count]; + final float[] values = new float[count]; + int start = 1; + int index = 0; + for ( int i = start; i < end; i++ ) { + final char c = string.charAt( i ); + if ( c == ':' ) { + // Indices are 1 based in this format, but we need a zero base + indices[index] = Integer.parseInt( string, start, i, 10 ) - 1; + start = i + 1; + } + else if ( c == ',' ) { + values[index++] = Float.parseFloat( string.substring( start, i ) ); + start = i + 1; + } + } + if ( start != end ) { + values[index] = Float.parseFloat( string.substring( start, end ) ); + assert count == index + 1; + } + return new SparseFloatVector( size, indices, values ); + } + + private static int countValues(String string, int end) { + int count = 0; + for ( int i = 1; i < end; i++ ) { + if ( string.charAt( i ) == ':' ) { + count++; + } + } + return count; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorDimsFunction.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorDimsFunction.java new file mode 100644 index 000000000000..d6d8cb0119ce --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorDimsFunction.java @@ -0,0 +1,55 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.metamodel.model.domain.ReturnableType; +import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor; +import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; +import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.sql.ast.tree.SqlAstNode; +import org.hibernate.sql.ast.tree.expression.Expression; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.spi.TypeConfiguration; + +import java.util.List; + +public class PGVectorDimsFunction extends AbstractSqmSelfRenderingFunctionDescriptor { + public PGVectorDimsFunction(TypeConfiguration typeConfiguration) { + super( + "vector_dims", + StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 1 ), + VectorArgumentValidator.INSTANCE + ), + StandardFunctionReturnTypeResolvers.invariant( typeConfiguration.getBasicTypeForJavaType( Integer.class ) ), + VectorArgumentTypeResolver.INSTANCE + ); + } + + @Override + public void render(SqlAppender sqlAppender, List sqlAstArguments, ReturnableType returnType, SqlAstTranslator walker) { + final Expression expression = (Expression) sqlAstArguments.get( 0 ); + final int sqlTypeCode = + expression.getExpressionType().getSingleJdbcMapping().getJdbcType().getDefaultSqlTypeCode(); + if ( sqlTypeCode == SqlTypes.SPARSE_VECTOR_FLOAT32 ) { + sqlAppender.append( "cast(split_part(cast(" ); + expression.accept( walker ); + sqlAppender.append( " as text),'/',2) as integer)" ); + } + else { + if ( sqlTypeCode == SqlTypes.VECTOR_BINARY ) { + sqlAppender.append( "length" ); + } + else { + sqlAppender.append( "vector_dims" ); + } + sqlAppender.append( '(' ); + expression.accept( walker ); + sqlAppender.append( ')' ); + } + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorFunctionContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorFunctionContributor.java new file mode 100644 index 000000000000..72d3da9a5913 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorFunctionContributor.java @@ -0,0 +1,78 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.boot.model.FunctionContributions; +import org.hibernate.boot.model.FunctionContributor; +import org.hibernate.dialect.CockroachDialect; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.PostgreSQLDialect; +import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; +import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers; +import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; +import org.hibernate.type.StandardBasicTypes; +import org.hibernate.type.spi.TypeConfiguration; + +import static org.hibernate.query.sqm.produce.function.FunctionParameterType.INTEGER; + +public class PGVectorFunctionContributor implements FunctionContributor { + + @Override + public void contributeFunctions(FunctionContributions functionContributions) { + final Dialect dialect = functionContributions.getDialect(); + if (dialect instanceof PostgreSQLDialect || dialect instanceof CockroachDialect) { + final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions ); + + vectorFunctionFactory.cosineDistance( "?1<=>?2" ); + vectorFunctionFactory.euclideanDistance( "?1<->?2" ); + vectorFunctionFactory.l1Distance( "l1_distance(?1,?2)" ); + vectorFunctionFactory.hammingDistance( "?1<~>?2" ); + vectorFunctionFactory.jaccardDistance( "?1<%>?2" ); + + vectorFunctionFactory.innerProduct( "(?1<#>?2)*-1" ); + vectorFunctionFactory.negativeInnerProduct( "?1<#>?2" ); + + final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration(); + functionContributions.getFunctionRegistry() + .register( "vector_dims", new PGVectorDimsFunction( typeConfiguration ) ); + functionContributions.getFunctionRegistry() + .register( "vector_norm", new PGVectorNormFunction( typeConfiguration ) ); + + functionContributions.getFunctionRegistry().namedDescriptorBuilder( "binary_quantize" ) + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 1 ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( + typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.VECTOR_BINARY ) + ) ) + .register(); + functionContributions.getFunctionRegistry().namedDescriptorBuilder( "subvector" ) + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 3 ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( StandardFunctionArgumentTypeResolvers.byArgument( + VectorArgumentTypeResolver.INSTANCE, + StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, INTEGER ), + StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, INTEGER ) + ) ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.useArgType( 1 ) ) + .register(); + functionContributions.getFunctionRegistry().registerAlternateKey( "l2_norm", "vector_norm" ); + functionContributions.getFunctionRegistry().namedDescriptorBuilder( "l2_normalize" ) + .setArgumentsValidator( VectorArgumentValidator.INSTANCE ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.useArgType( 1 ) ) + .register(); + } + } + + @Override + public int ordinal() { + return 200; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/VectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcType.java similarity index 54% rename from hibernate-vector/src/main/java/org/hibernate/vector/VectorJdbcType.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcType.java index b50144fdbbbd..98bae86fcea5 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/VectorJdbcType.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcType.java @@ -2,16 +2,12 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; - -import java.sql.CallableStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.BitSet; +package org.hibernate.vector.internal; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.metamodel.mapping.JdbcMapping; import org.hibernate.sql.ast.spi.SqlAppender; -import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.ValueExtractor; import org.hibernate.type.descriptor.WrapperOptions; import org.hibernate.type.descriptor.java.JavaType; @@ -20,16 +16,26 @@ import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.spi.TypeConfiguration; -public class VectorJdbcType extends ArrayJdbcType { +import java.sql.CallableStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +import static org.hibernate.vector.internal.VectorHelper.parseFloatVector; + +public class PGVectorJdbcType extends ArrayJdbcType { - private static final float[] EMPTY = new float[0]; - public VectorJdbcType(JdbcType elementJdbcType) { + private final int sqlType; + private final String typeName; + + public PGVectorJdbcType(JdbcType elementJdbcType, int sqlType, String typeName) { super( elementJdbcType ); + this.sqlType = sqlType; + this.typeName = typeName; } @Override public int getDefaultSqlTypeCode() { - return SqlTypes.VECTOR; + return sqlType; } @Override @@ -44,7 +50,14 @@ public JavaType getJdbcRecommendedJavaTypeMapping( public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { appender.append( "cast(" ); appender.append( writeExpression ); - appender.append( " as vector)" ); + appender.append( " as " ); + appender.append( typeName ); + appender.append( ')' ); + } + + @Override + public @Nullable String castFromPattern(JdbcMapping sourceMapping) { + return sourceMapping.getJdbcType().isStringLike() ? "cast(?1 as " + typeName + ")" : null; } @Override @@ -52,46 +65,30 @@ public ValueExtractor getExtractor(JavaType javaTypeDescriptor) { return new BasicExtractor<>( javaTypeDescriptor, this ) { @Override protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { - return javaTypeDescriptor.wrap( getFloatArray( rs.getString( paramIndex ) ), options ); + return javaTypeDescriptor.wrap( parseFloatVector( rs.getString( paramIndex ) ), options ); } @Override protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { - return javaTypeDescriptor.wrap( getFloatArray( statement.getString( index ) ), options ); + return javaTypeDescriptor.wrap( parseFloatVector( statement.getString( index ) ), options ); } @Override protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException { - return javaTypeDescriptor.wrap( getFloatArray( statement.getString( name ) ), options ); - } - - private float[] getFloatArray(String string) { - if ( string == null ) { - return null; - } - if ( string.length() == 2 ) { - return EMPTY; - } - final BitSet commaPositions = new BitSet(); - int size = 1; - for ( int i = 1; i < string.length(); i++ ) { - final char c = string.charAt( i ); - if ( c == ',' ) { - commaPositions.set( i ); - size++; - } - } - final float[] result = new float[size]; - int floatStartIndex = 1; - int commaIndex; - int index = 0; - while ( ( commaIndex = commaPositions.nextSetBit( floatStartIndex ) ) != -1 ) { - result[index++] = Float.parseFloat( string.substring( floatStartIndex, commaIndex ) ); - floatStartIndex = commaIndex + 1; - } - result[index] = Float.parseFloat( string.substring( floatStartIndex, string.length() - 1 ) ); - return result; + return javaTypeDescriptor.wrap( parseFloatVector( statement.getString( name ) ), options ); } }; } + + @Override + public boolean equals(Object that) { + return super.equals( that ) + && that instanceof PGVectorJdbcType vectorJdbcType + && sqlType == vectorJdbcType.sqlType; + } + + @Override + public int hashCode() { + return sqlType + 31 * super.hashCode(); + } } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorNormFunction.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorNormFunction.java new file mode 100644 index 000000000000..814a2b07c318 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorNormFunction.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.metamodel.model.domain.ReturnableType; +import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor; +import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; +import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.sql.ast.tree.SqlAstNode; +import org.hibernate.sql.ast.tree.expression.Expression; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.spi.TypeConfiguration; + +import java.util.List; + +public class PGVectorNormFunction extends AbstractSqmSelfRenderingFunctionDescriptor { + public PGVectorNormFunction(TypeConfiguration typeConfiguration) { + super( + "vector_norm", + StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 1 ), + VectorArgumentValidator.INSTANCE + ), + StandardFunctionReturnTypeResolvers.invariant( typeConfiguration.getBasicTypeForJavaType( Double.class ) ), + VectorArgumentTypeResolver.INSTANCE + ); + } + + @Override + public void render(SqlAppender sqlAppender, List sqlAstArguments, ReturnableType returnType, SqlAstTranslator walker) { + final Expression expression = (Expression) sqlAstArguments.get( 0 ); + sqlAppender.append( + switch ( expression.getExpressionType().getSingleJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) { + case SqlTypes.SPARSE_VECTOR_FLOAT32, SqlTypes.VECTOR_FLOAT16 -> "l2_norm"; + default -> "vector_norm"; + } + ); + sqlAppender.append( '(' ); + expression.accept( walker ); + sqlAppender.append( ')' ); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorTypeContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorTypeContributor.java new file mode 100644 index 000000000000..ff559be72423 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorTypeContributor.java @@ -0,0 +1,133 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import java.lang.reflect.Type; + +import org.hibernate.boot.model.TypeContributions; +import org.hibernate.boot.model.TypeContributor; +import org.hibernate.dialect.CockroachDialect; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.PostgreSQLDialect; +import org.hibernate.engine.jdbc.spi.JdbcServices; +import org.hibernate.service.ServiceRegistry; +import org.hibernate.type.BasicArrayType; +import org.hibernate.type.BasicCollectionType; +import org.hibernate.type.BasicType; +import org.hibernate.type.BasicTypeRegistry; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.StandardBasicTypes; +import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; +import org.hibernate.type.spi.TypeConfiguration; + +public class PGVectorTypeContributor implements TypeContributor { + + private static final Type[] VECTOR_JAVA_TYPES = { + Float[].class, + float[].class + }; + + @Override + public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) { + final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect(); + if ( dialect instanceof PostgreSQLDialect || + dialect instanceof CockroachDialect ) { + final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration(); + final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry(); + final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry(); + final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); + final BasicType floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ); + final ArrayJdbcType genericVectorJdbcType = new PGVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR, + "vector" + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType ); + final ArrayJdbcType floatVectorJdbcType = new PGVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR_FLOAT32, + "vector" + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType ); + final ArrayJdbcType float16VectorJdbcType = new PGVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR_FLOAT16, + "halfvec" + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT16, float16VectorJdbcType ); + final JdbcType bitVectorJdbcType = new PGBinaryVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.TINYINT ) + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_BINARY, bitVectorJdbcType ); + final JdbcType sparseFloatVectorJdbcType = new PGSparseFloatVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ) + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.SPARSE_VECTOR_FLOAT32, sparseFloatVectorJdbcType ); + + javaTypeRegistry.addDescriptor( SparseFloatVectorJavaType.INSTANCE ); + + for ( Type vectorJavaType : VECTOR_JAVA_TYPES ) { + basicTypeRegistry.register( + new BasicArrayType<>( + floatBasicType, + genericVectorJdbcType, + javaTypeRegistry.getDescriptor( vectorJavaType ) + ), + StandardBasicTypes.VECTOR.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + floatVectorJdbcType, + javaTypeRegistry.getDescriptor( vectorJavaType ) + ), + StandardBasicTypes.VECTOR_FLOAT32.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + float16VectorJdbcType, + javaTypeRegistry.getDescriptor( vectorJavaType ) + ), + StandardBasicTypes.VECTOR_FLOAT16.getName() + ); + } + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.BYTE ), + bitVectorJdbcType, + javaTypeRegistry.getDescriptor( byte[].class ) + ), + StandardBasicTypes.VECTOR_BINARY.getName() + ); + basicTypeRegistry.register( + new BasicCollectionType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + sparseFloatVectorJdbcType, + SparseFloatVectorJavaType.INSTANCE, + "sparse_float_vector" + ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_FLOAT32, "vector($l)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_BINARY, "bit($l)", "bit", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_FLOAT16, "halfvec($l)", "halfvec", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.SPARSE_VECTOR_FLOAT32, "sparsevec($l)", "sparsevec", dialect ) + ); + } + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseByteVectorJavaType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseByteVectorJavaType.java new file mode 100644 index 000000000000..47c353b80e4b --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseByteVectorJavaType.java @@ -0,0 +1,116 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.tool.schema.extract.spi.ColumnTypeInformation; +import org.hibernate.type.BasicCollectionType; +import org.hibernate.type.BasicType; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.AbstractClassJavaType; +import org.hibernate.type.descriptor.java.BasicPluralJavaType; +import org.hibernate.type.descriptor.java.ByteJavaType; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.java.MutableMutabilityPlan; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.descriptor.jdbc.JdbcTypeIndicators; +import org.hibernate.type.spi.TypeConfiguration; +import org.hibernate.vector.SparseByteVector; + +import java.util.Arrays; +import java.util.List; + + +public class SparseByteVectorJavaType extends AbstractClassJavaType implements BasicPluralJavaType { + + public static final SparseByteVectorJavaType INSTANCE = new SparseByteVectorJavaType(); + + public SparseByteVectorJavaType() { + super( SparseByteVector.class, new SparseVectorMutabilityPlan() ); + } + + @Override + public JavaType getElementJavaType() { + return ByteJavaType.INSTANCE; + } + + @Override + public BasicType resolveType(TypeConfiguration typeConfiguration, Dialect dialect, BasicType elementType, ColumnTypeInformation columnTypeInformation, JdbcTypeIndicators stdIndicators) { + final int arrayTypeCode = stdIndicators.getPreferredSqlTypeCodeForArray( elementType.getJdbcType().getDefaultSqlTypeCode() ); + final JdbcType arrayJdbcType = typeConfiguration.getJdbcTypeRegistry() + .resolveTypeConstructorDescriptor( arrayTypeCode, elementType, columnTypeInformation ); + if ( elementType.getValueConverter() != null ) { + throw new IllegalArgumentException( "Can't convert element type of sparse vector" ); + } + return typeConfiguration.getBasicTypeRegistry() + .resolve( this, arrayJdbcType, + () -> new BasicCollectionType<>( elementType, arrayJdbcType, this, "sparse_byte_vector" ) ); + } + + @Override + public JdbcType getRecommendedJdbcType(JdbcTypeIndicators indicators) { + return indicators.getJdbcType( SqlTypes.SPARSE_VECTOR_INT8 ); + } + + @Override + public X unwrap(SparseByteVector value, Class type, WrapperOptions options) { + if ( value == null ) { + return null; + } + else if ( type.isInstance( value ) ) { + //noinspection unchecked + return (X) value; + } + else if ( byte[].class.isAssignableFrom( type ) ) { + return (X) value.toDenseVector(); + } + else if ( Object[].class.isAssignableFrom( type ) ) { + //noinspection unchecked + return (X) value.toArray(); + } + else if ( String.class.isAssignableFrom( type ) ) { + //noinspection unchecked + return (X) value.toString(); + } + else { + throw unknownUnwrap( type ); + } + } + + @Override + public SparseByteVector wrap(X value, WrapperOptions options) { + if ( value == null ) { + return null; + } + else if (value instanceof SparseByteVector vector) { + return vector; + } + else if (value instanceof List list) { + //noinspection unchecked + return new SparseByteVector( (List) list ); + } + else if (value instanceof Object[] array) { + //noinspection unchecked + return new SparseByteVector( (List) (List) Arrays.asList( array ) ); + } + else if (value instanceof byte[] vector) { + return new SparseByteVector( vector ); + } + else if (value instanceof String vector) { + return new SparseByteVector( vector ); + } + else { + throw unknownWrap( value.getClass() ); + } + } + + private static class SparseVectorMutabilityPlan extends MutableMutabilityPlan { + @Override + protected SparseByteVector deepCopyNotNull(SparseByteVector value) { + return value.clone(); + } + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseDoubleVectorJavaType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseDoubleVectorJavaType.java new file mode 100644 index 000000000000..d43ae847671a --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseDoubleVectorJavaType.java @@ -0,0 +1,116 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.tool.schema.extract.spi.ColumnTypeInformation; +import org.hibernate.type.BasicCollectionType; +import org.hibernate.type.BasicType; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.AbstractClassJavaType; +import org.hibernate.type.descriptor.java.BasicPluralJavaType; +import org.hibernate.type.descriptor.java.DoubleJavaType; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.java.MutableMutabilityPlan; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.descriptor.jdbc.JdbcTypeIndicators; +import org.hibernate.type.spi.TypeConfiguration; +import org.hibernate.vector.SparseDoubleVector; + +import java.util.Arrays; +import java.util.List; + + +public class SparseDoubleVectorJavaType extends AbstractClassJavaType implements BasicPluralJavaType { + + public static final SparseDoubleVectorJavaType INSTANCE = new SparseDoubleVectorJavaType(); + + public SparseDoubleVectorJavaType() { + super( SparseDoubleVector.class, new SparseVectorMutabilityPlan() ); + } + + @Override + public JavaType getElementJavaType() { + return DoubleJavaType.INSTANCE; + } + + @Override + public BasicType resolveType(TypeConfiguration typeConfiguration, Dialect dialect, BasicType elementType, ColumnTypeInformation columnTypeInformation, JdbcTypeIndicators stdIndicators) { + final int arrayTypeCode = stdIndicators.getPreferredSqlTypeCodeForArray( elementType.getJdbcType().getDefaultSqlTypeCode() ); + final JdbcType arrayJdbcType = typeConfiguration.getJdbcTypeRegistry() + .resolveTypeConstructorDescriptor( arrayTypeCode, elementType, columnTypeInformation ); + if ( elementType.getValueConverter() != null ) { + throw new IllegalArgumentException( "Can't convert element type of sparse vector" ); + } + return typeConfiguration.getBasicTypeRegistry() + .resolve( this, arrayJdbcType, + () -> new BasicCollectionType<>( elementType, arrayJdbcType, this, "sparse_double_vector" ) ); + } + + @Override + public JdbcType getRecommendedJdbcType(JdbcTypeIndicators indicators) { + return indicators.getJdbcType( SqlTypes.SPARSE_VECTOR_INT8 ); + } + + @Override + public X unwrap(SparseDoubleVector value, Class type, WrapperOptions options) { + if ( value == null ) { + return null; + } + else if ( type.isInstance( value ) ) { + //noinspection unchecked + return (X) value; + } + else if ( double[].class.isAssignableFrom( type ) ) { + return (X) value.toDenseVector(); + } + else if ( Object[].class.isAssignableFrom( type ) ) { + //noinspection unchecked + return (X) value.toArray(); + } + else if ( String.class.isAssignableFrom( type ) ) { + //noinspection unchecked + return (X) value.toString(); + } + else { + throw unknownUnwrap( type ); + } + } + + @Override + public SparseDoubleVector wrap(X value, WrapperOptions options) { + if ( value == null ) { + return null; + } + else if (value instanceof SparseDoubleVector vector) { + return vector; + } + else if (value instanceof List list) { + //noinspection unchecked + return new SparseDoubleVector( (List) list ); + } + else if (value instanceof Object[] array) { + //noinspection unchecked + return new SparseDoubleVector( (List) (List) Arrays.asList( array ) ); + } + else if (value instanceof double[] vector) { + return new SparseDoubleVector( vector ); + } + else if (value instanceof String vector) { + return new SparseDoubleVector( vector ); + } + else { + throw unknownWrap( value.getClass() ); + } + } + + private static class SparseVectorMutabilityPlan extends MutableMutabilityPlan { + @Override + protected SparseDoubleVector deepCopyNotNull(SparseDoubleVector value) { + return value.clone(); + } + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseFloatVectorJavaType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseFloatVectorJavaType.java new file mode 100644 index 000000000000..df1dad7444c8 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseFloatVectorJavaType.java @@ -0,0 +1,116 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.tool.schema.extract.spi.ColumnTypeInformation; +import org.hibernate.type.BasicCollectionType; +import org.hibernate.type.BasicType; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.AbstractClassJavaType; +import org.hibernate.type.descriptor.java.BasicPluralJavaType; +import org.hibernate.type.descriptor.java.FloatJavaType; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.java.MutableMutabilityPlan; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.descriptor.jdbc.JdbcTypeIndicators; +import org.hibernate.type.spi.TypeConfiguration; +import org.hibernate.vector.SparseFloatVector; + +import java.util.Arrays; +import java.util.List; + + +public class SparseFloatVectorJavaType extends AbstractClassJavaType implements BasicPluralJavaType { + + public static final SparseFloatVectorJavaType INSTANCE = new SparseFloatVectorJavaType(); + + public SparseFloatVectorJavaType() { + super( SparseFloatVector.class, new SparseVectorMutabilityPlan() ); + } + + @Override + public JavaType getElementJavaType() { + return FloatJavaType.INSTANCE; + } + + @Override + public BasicType resolveType(TypeConfiguration typeConfiguration, Dialect dialect, BasicType elementType, ColumnTypeInformation columnTypeInformation, JdbcTypeIndicators stdIndicators) { + final int arrayTypeCode = stdIndicators.getPreferredSqlTypeCodeForArray( elementType.getJdbcType().getDefaultSqlTypeCode() ); + final JdbcType arrayJdbcType = typeConfiguration.getJdbcTypeRegistry() + .resolveTypeConstructorDescriptor( arrayTypeCode, elementType, columnTypeInformation ); + if ( elementType.getValueConverter() != null ) { + throw new IllegalArgumentException( "Can't convert element type of sparse vector" ); + } + return typeConfiguration.getBasicTypeRegistry() + .resolve( this, arrayJdbcType, + () -> new BasicCollectionType<>( elementType, arrayJdbcType, this, "sparse_float_vector" ) ); + } + + @Override + public JdbcType getRecommendedJdbcType(JdbcTypeIndicators indicators) { + return indicators.getJdbcType( SqlTypes.SPARSE_VECTOR_INT8 ); + } + + @Override + public X unwrap(SparseFloatVector value, Class type, WrapperOptions options) { + if ( value == null ) { + return null; + } + else if ( type.isInstance( value ) ) { + //noinspection unchecked + return (X) value; + } + else if ( float[].class.isAssignableFrom( type ) ) { + return (X) value.toDenseVector(); + } + else if ( Object[].class.isAssignableFrom( type ) ) { + //noinspection unchecked + return (X) value.toArray(); + } + else if ( String.class.isAssignableFrom( type ) ) { + //noinspection unchecked + return (X) value.toString(); + } + else { + throw unknownUnwrap( type ); + } + } + + @Override + public SparseFloatVector wrap(X value, WrapperOptions options) { + if ( value == null ) { + return null; + } + else if (value instanceof SparseFloatVector vector) { + return vector; + } + else if (value instanceof List list) { + //noinspection unchecked + return new SparseFloatVector( (List) list ); + } + else if (value instanceof Object[] array) { + //noinspection unchecked + return new SparseFloatVector( (List) (List) Arrays.asList( array ) ); + } + else if (value instanceof float[] vector) { + return new SparseFloatVector( vector ); + } + else if (value instanceof String vector) { + return new SparseFloatVector( vector ); + } + else { + throw unknownWrap( value.getClass() ); + } + } + + private static class SparseVectorMutabilityPlan extends MutableMutabilityPlan { + @Override + protected SparseFloatVector deepCopyNotNull(SparseFloatVector value) { + return value.clone(); + } + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentTypeResolver.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorArgumentTypeResolver.java similarity index 81% rename from hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentTypeResolver.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorArgumentTypeResolver.java index 4af45fd9f44e..d867679a5acf 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentTypeResolver.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorArgumentTypeResolver.java @@ -2,7 +2,7 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import java.util.List; @@ -21,11 +21,18 @@ */ public class VectorArgumentTypeResolver implements AbstractFunctionArgumentTypeResolver { - public static final FunctionArgumentTypeResolver INSTANCE = new VectorArgumentTypeResolver(); + public static final FunctionArgumentTypeResolver INSTANCE = new VectorArgumentTypeResolver( 0 ); + public static final FunctionArgumentTypeResolver DISTANCE_INSTANCE = new VectorArgumentTypeResolver( 0, 1 ); + + private final int[] vectorIndices; + + public VectorArgumentTypeResolver(int... vectorIndices) { + this.vectorIndices = vectorIndices; + } @Override public @Nullable MappingModelExpressible resolveFunctionArgumentType(List> arguments, int argumentIndex, SqmToSqlAstConverter converter) { - for ( int i = 0; i < arguments.size(); i++ ) { + for ( int i : vectorIndices ) { if ( i != argumentIndex ) { final SqmTypedNode node = arguments.get( i ); if ( node instanceof SqmExpression ) { diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentValidator.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorArgumentValidator.java similarity index 69% rename from hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentValidator.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorArgumentValidator.java index 4bd1632c50be..b15f703d2b97 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentValidator.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorArgumentValidator.java @@ -2,17 +2,17 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import java.util.List; +import org.hibernate.type.BasicType; import org.hibernate.type.BindingContext; import org.hibernate.query.sqm.SqmExpressible; import org.hibernate.query.sqm.produce.function.ArgumentsValidator; import org.hibernate.query.sqm.produce.function.FunctionArgumentException; import org.hibernate.query.sqm.tree.SqmTypedNode; import org.hibernate.query.sqm.tree.domain.SqmDomainType; -import org.hibernate.type.BasicPluralType; import org.hibernate.type.SqlTypes; /** @@ -20,14 +20,21 @@ */ public class VectorArgumentValidator implements ArgumentsValidator { - public static final ArgumentsValidator INSTANCE = new VectorArgumentValidator(); + public static final ArgumentsValidator INSTANCE = new VectorArgumentValidator( 0 ); + public static final ArgumentsValidator DISTANCE_INSTANCE = new VectorArgumentValidator( 0, 1 ); + + private final int[] vectorIndices; + + public VectorArgumentValidator(int... vectorIndices) { + this.vectorIndices = vectorIndices; + } @Override public void validate( List> arguments, String functionName, BindingContext bindingContext) { - for ( int i = 0; i < arguments.size(); i++ ) { + for ( int i : vectorIndices ) { final SqmExpressible expressible = arguments.get( i ).getExpressible(); if ( expressible != null ) { final SqmDomainType type = expressible.getSqmType(); @@ -46,9 +53,10 @@ public void validate( } private static boolean isVectorType(SqmExpressible vectorType) { - return vectorType instanceof BasicPluralType basicPluralType - && switch ( basicPluralType.getJdbcType().getDefaultSqlTypeCode() ) { - case SqlTypes.VECTOR, SqlTypes.VECTOR_INT8, SqlTypes.VECTOR_FLOAT32, SqlTypes.VECTOR_FLOAT64 -> true; + return vectorType instanceof BasicType basicType + && switch ( basicType.getJdbcType().getDefaultSqlTypeCode() ) { + case SqlTypes.VECTOR, SqlTypes.VECTOR_INT8, SqlTypes.VECTOR_FLOAT16, SqlTypes.VECTOR_FLOAT32, SqlTypes.VECTOR_FLOAT64, + SqlTypes.VECTOR_BINARY, SqlTypes.SPARSE_VECTOR_INT8, SqlTypes.SPARSE_VECTOR_FLOAT32, SqlTypes.SPARSE_VECTOR_FLOAT64-> true; default -> false; }; } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorDdlType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorDdlType.java new file mode 100644 index 000000000000..d14d360351f6 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorDdlType.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl; + +/** + * DDL type for vector types. + * + * @since 7.1 + */ +public class VectorDdlType extends DdlTypeImpl { + + public VectorDdlType(int sqlTypeCode, boolean isLob, String typeNamePattern, String castTypeNamePattern, String castTypeName, Dialect dialect) { + super( sqlTypeCode, isLob, typeNamePattern, castTypeNamePattern, castTypeName, dialect ); + } + + public VectorDdlType(int sqlTypeCode, String typeNamePattern, String castTypeNamePattern, String castTypeName, Dialect dialect) { + super( sqlTypeCode, typeNamePattern, castTypeNamePattern, castTypeName, dialect ); + } + + public VectorDdlType(int sqlTypeCode, boolean isLob, String typeNamePattern, String castTypeName, Dialect dialect) { + super( sqlTypeCode, isLob, typeNamePattern, castTypeName, dialect ); + } + + public VectorDdlType(int sqlTypeCode, String typeNamePattern, String castTypeName, Dialect dialect) { + super( sqlTypeCode, typeNamePattern, castTypeName, dialect ); + } + + public VectorDdlType(int sqlTypeCode, String typeNamePattern, Dialect dialect) { + super( sqlTypeCode, typeNamePattern, dialect ); + } + + @Override + public String getTypeName(Size size) { + return getTypeName( + size.getArrayLength() == null ? null : size.getArrayLength().longValue(), + null, + null + ); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/VectorFunctionFactory.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorFunctionFactory.java similarity index 92% rename from hibernate-vector/src/main/java/org/hibernate/vector/VectorFunctionFactory.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorFunctionFactory.java index 71679aeac05a..691e9954b650 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/VectorFunctionFactory.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorFunctionFactory.java @@ -2,7 +2,7 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import org.hibernate.boot.model.FunctionContributions; import org.hibernate.query.sqm.function.SqmFunctionRegistry; @@ -58,6 +58,10 @@ public void hammingDistance(String pattern) { registerVectorDistanceFunction( "hamming_distance", pattern ); } + public void jaccardDistance(String pattern) { + registerVectorDistanceFunction( "jaccard_distance", pattern ); + } + public void vectorDimensions() { registerNamedVectorFunction( "vector_dims", integerType, 1 ); } @@ -70,9 +74,9 @@ public void registerVectorDistanceFunction(String functionName, String pattern) functionRegistry.patternDescriptorBuilder( functionName, pattern ) .setArgumentsValidator( StandardArgumentsValidators.composite( StandardArgumentsValidators.exactly( 2 ), - VectorArgumentValidator.INSTANCE + VectorArgumentValidator.DISTANCE_INSTANCE ) ) - .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.DISTANCE_INSTANCE ) .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) ) .register(); } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorHelper.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorHelper.java new file mode 100644 index 000000000000..6ca866d625c2 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorHelper.java @@ -0,0 +1,174 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.util.BitSet; + +/** + * Helper for vector related functionality. + * + * @since 7.1 + */ +public class VectorHelper { + + private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; + private static final float[] EMPTY_FLOAT_ARRAY = new float[0]; + private static final double[] EMPTY_DOUBLE_ARRAY = new double[0]; + + public static @Nullable byte[] parseByteVector(@Nullable String string) { + if ( string == null ) { + return null; + } + if ( string.length() == 2 ) { + return EMPTY_BYTE_ARRAY; + } + final BitSet commaPositions = new BitSet(); + int size = 1; + for ( int i = 1; i < string.length(); i++ ) { + final char c = string.charAt( i ); + if ( c == ',' ) { + commaPositions.set( i ); + size++; + } + } + final byte[] result = new byte[size]; + int doubleStartIndex = 1; + int commaIndex; + int index = 0; + while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) { + result[index++] = Byte.parseByte( string.substring( doubleStartIndex, commaIndex ) ); + doubleStartIndex = commaIndex + 1; + } + result[index] = Byte.parseByte( string.substring( doubleStartIndex, string.length() - 1 ) ); + return result; + } + + public static @Nullable float[] parseFloatVector(@Nullable String string) { + if ( string == null ) { + return null; + } + if ( string.length() == 2 ) { + return EMPTY_FLOAT_ARRAY; + } + final BitSet commaPositions = new BitSet(); + int size = 1; + for ( int i = 1; i < string.length(); i++ ) { + final char c = string.charAt( i ); + if ( c == ',' ) { + commaPositions.set( i ); + size++; + } + } + final float[] result = new float[size]; + int doubleStartIndex = 1; + int commaIndex; + int index = 0; + while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) { + result[index++] = Float.parseFloat( string.substring( doubleStartIndex, commaIndex ) ); + doubleStartIndex = commaIndex + 1; + } + result[index] = Float.parseFloat( string.substring( doubleStartIndex, string.length() - 1 ) ); + return result; + } + + public static @Nullable double[] parseDoubleVector(@Nullable String string) { + if ( string == null ) { + return null; + } + if ( string.length() == 2 ) { + return EMPTY_DOUBLE_ARRAY; + } + final BitSet commaPositions = new BitSet(); + int size = 1; + for ( int i = 1; i < string.length(); i++ ) { + final char c = string.charAt( i ); + if ( c == ',' ) { + commaPositions.set( i ); + size++; + } + } + final double[] result = new double[size]; + int doubleStartIndex = 1; + int commaIndex; + int index = 0; + while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) { + result[index++] = Double.parseDouble( string.substring( doubleStartIndex, commaIndex ) ); + doubleStartIndex = commaIndex + 1; + } + result[index] = Double.parseDouble( string.substring( doubleStartIndex, string.length() - 1 ) ); + return result; + } + + public static @Nullable float[] parseFloatVector(@Nullable byte[] bytes) { + if ( bytes == null ) { + return null; + } + if ( bytes.length == 0 ) { + return EMPTY_FLOAT_ARRAY; + } + if ( (bytes.length & 3) != 0 ) { + throw new IllegalArgumentException( + "Invalid byte array length. Expected a multiple of 4 but got: " + bytes.length ); + } + final float[] result = new float[bytes.length >> 2]; + for ( int i = 0, resultLength = result.length; i < resultLength; i++ ) { + final int offset = i << 2; + final int asInt = (bytes[offset] & 0xFF) + | ((bytes[offset + 1] & 0xFF) << 8) + | ((bytes[offset + 2] & 0xFF) << 16) + | ((bytes[offset + 3] & 0xFF) << 24); + result[i] = Float.intBitsToFloat( asInt ); + } + return result; + } + + public static byte[] parseBitString(String bitString) { + assert new BigInteger( "1" + bitString, 2 ).bitLength() == bitString.length() + 1; + final int fullBytesCount = bitString.length() >> 3; + final int fullBytesStartPosition = ((bitString.length() & 7) == 0 ? 0 : 1); + final int byteCount = fullBytesCount + fullBytesStartPosition; + final byte[] bytes = new byte[byteCount]; + final int fullBytesBitCount = fullBytesCount << 3; + final int leadingBits = bitString.length() - fullBytesBitCount; + if ( leadingBits > 0 ) { + for (int i = 0; i < leadingBits; i++ ) { + bytes[0] |= (byte) (((bitString.charAt( i ) - 48)) << (7 - i)); + } + } + for ( int i = fullBytesStartPosition; i < fullBytesCount; i ++ ) { + bytes[i] = (byte) ( + ((bitString.charAt( i * 8 + 0 ) - 48) << 7) + | ((bitString.charAt( i * 8 + 1 ) - 48) << 6) + | ((bitString.charAt( i * 8 + 2 ) - 48) << 5) + | ((bitString.charAt( i * 8 + 3 ) - 48) << 4) + | ((bitString.charAt( i * 8 + 4 ) - 48) << 3) + | ((bitString.charAt( i * 8 + 5 ) - 48) << 2) + | ((bitString.charAt( i * 8 + 6 ) - 48) << 1) + | ((bitString.charAt( i * 8 + 7 ) - 48) << 0) + ); + } + return bytes; + } + + public static String toBitString(byte[] bytes) { + final byte[] bitBytes = new byte[bytes.length * 8]; + for ( int i = 0; i < bytes.length; i++ ) { + final byte b = bytes[i]; + bitBytes[i * 8 + 0] = (byte) (((b >>> 7) & 1) + 48); + bitBytes[i * 8 + 1] = (byte) (((b >>> 6) & 1) + 48); + bitBytes[i * 8 + 2] = (byte) (((b >>> 5) & 1) + 48); + bitBytes[i * 8 + 3] = (byte) (((b >>> 4) & 1) + 48); + bitBytes[i * 8 + 4] = (byte) (((b >>> 3) & 1) + 48); + bitBytes[i * 8 + 5] = (byte) (((b >>> 2) & 1) + 48); + bitBytes[i * 8 + 6] = (byte) (((b >>> 1) & 1) + 48); + bitBytes[i * 8 + 7] = (byte) (((b >>> 0) & 1) + 48); + } + return new String( bitBytes, StandardCharsets.UTF_8 ); + } +} diff --git a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor index 6103956ccbd7..c4cdbd72b910 100644 --- a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor +++ b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor @@ -1,3 +1,4 @@ -org.hibernate.vector.PGVectorFunctionContributor -org.hibernate.vector.OracleVectorFunctionContributor -org.hibernate.vector.MariaDBFunctionContributor +org.hibernate.vector.internal.PGVectorFunctionContributor +org.hibernate.vector.internal.OracleVectorFunctionContributor +org.hibernate.vector.internal.MariaDBFunctionContributor +org.hibernate.vector.internal.MySQLFunctionContributor diff --git a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor index 11605464c824..860a988fa5ce 100644 --- a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor +++ b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor @@ -1,3 +1,4 @@ -org.hibernate.vector.PGVectorTypeContributor -org.hibernate.vector.OracleVectorTypeContributor -org.hibernate.vector.MariaDBTypeContributor +org.hibernate.vector.internal.PGVectorTypeContributor +org.hibernate.vector.internal.OracleVectorTypeContributor +org.hibernate.vector.internal.MariaDBTypeContributor +org.hibernate.vector.internal.MySQLTypeContributor diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/BinaryVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/BinaryVectorTest.java new file mode 100644 index 000000000000..66ac4d3eccc7 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/BinaryVectorTest.java @@ -0,0 +1,242 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.dialect.OracleDialect; +import org.hibernate.dialect.PostgreSQLDialect; +import org.hibernate.dialect.PostgresPlusDialect; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SkipForDialect; +import org.hibernate.type.SqlTypes; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.hibernate.vector.VectorTestHelper.cosineDistanceBinary; +import static org.hibernate.vector.VectorTestHelper.euclideanDistanceBinary; +import static org.hibernate.vector.VectorTestHelper.euclideanNormBinary; +import static org.hibernate.vector.VectorTestHelper.hammingDistanceBinary; +import static org.hibernate.vector.VectorTestHelper.innerProductBinary; +import static org.hibernate.vector.VectorTestHelper.jaccardDistanceBinary; +import static org.hibernate.vector.VectorTestHelper.taxicabDistanceBinary; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +@DomainModel(annotatedClasses = BinaryVectorTest.VectorEntity.class) +@SessionFactory +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsBinaryVectorType.class) +@SkipForDialect(dialectClass = PostgresPlusDialect.class, reason = "Test database does not have the extension enabled") +public class BinaryVectorTest { + + private static final byte[] V1 = new byte[]{ 1, 2, 3 }; + private static final byte[] V2 = new byte[]{ 4, 5, 6 }; + + @BeforeEach + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, V1 ) ); + em.persist( new VectorEntity( 2L, V2 ) ); + } ); + } + + @AfterEach + public void cleanup(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.createMutationQuery( "delete from VectorEntity" ).executeUpdate(); + } ); + } + + @Test + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new byte[]{ 1, 2, 3 }, tableRecord.getTheVector() ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new byte[]{ 4, 5, 6 }, tableRecord.getTheVector() ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsCosineDistance.class) + @SkipForDialect(dialectClass = PostgreSQLDialect.class, matchSubTypes = true, reason = "Not supported with bit vectors") + public void testCosineDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( cosineDistanceBinary( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( cosineDistanceBinary( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanDistance.class) + @SkipForDialect(dialectClass = PostgreSQLDialect.class, matchSubTypes = true, reason = "Not supported with bit vectors") + public void testEuclideanDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanDistanceBinary( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanDistanceBinary( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsTaxicabDistance.class) + @SkipForDialect(dialectClass = PostgreSQLDialect.class, matchSubTypes = true, reason = "Not supported with bit vectors") + public void testTaxicabDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( taxicabDistanceBinary( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( taxicabDistanceBinary( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsInnerProduct.class) + @SkipForDialect(dialectClass = PostgreSQLDialect.class, matchSubTypes = true, reason = "Not supported with bit vectors") + public void testInnerProduct(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( innerProductBinary( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( innerProductBinary( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( innerProductBinary( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + assertEquals( innerProductBinary( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsHammingDistance.class) + public void testHammingDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( hammingDistanceBinary( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( hammingDistanceBinary( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsJaccardDistance.class) + public void testJaccardDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, jaccard_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( jaccardDistanceBinary( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( jaccardDistanceBinary( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorDims.class) + public void testVectorDims(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( V1.length * 8, results.get( 0 ).get( 1 ) ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( V2.length * 8, results.get( 1 ).get( 1 ) ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorNorm.class) + @SkipForDialect(dialectClass = PostgreSQLDialect.class, matchSubTypes = true, reason = "Not supported with bit vectors") + @SkipForDialect(dialectClass = OracleDialect.class, reason = "Oracle 23.9 bug") + public void testVectorNorm(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanNormBinary( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanNormBinary( V2 ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Entity( name = "VectorEntity" ) + public static class VectorEntity { + + @Id + private Long id; + + @Column( name = "the_vector" ) + @JdbcTypeCode(SqlTypes.VECTOR_BINARY) + @Array(length = 24) + private byte[] theVector; + + public VectorEntity() { + } + + public VectorEntity(Long id, byte[] theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public byte[] getTheVector() { + return theVector; + } + + public void setTheVector(byte[] theVector) { + this.theVector = theVector; + } + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/OracleByteVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/ByteVectorTest.java similarity index 77% rename from hibernate-vector/src/test/java/org/hibernate/vector/OracleByteVectorTest.java rename to hibernate-vector/src/test/java/org/hibernate/vector/ByteVectorTest.java index bfbf2cc28d34..ad811e5ea594 100644 --- a/hibernate-vector/src/test/java/org/hibernate/vector/OracleByteVectorTest.java +++ b/hibernate-vector/src/test/java/org/hibernate/vector/ByteVectorTest.java @@ -9,11 +9,12 @@ import org.hibernate.annotations.Array; import org.hibernate.annotations.JdbcTypeCode; import org.hibernate.dialect.OracleDialect; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; import org.hibernate.testing.orm.junit.SkipForDialect; import org.hibernate.type.SqlTypes; import org.hibernate.testing.orm.junit.DomainModel; -import org.hibernate.testing.orm.junit.RequiresDialect; import org.hibernate.testing.orm.junit.SessionFactory; import org.hibernate.testing.orm.junit.SessionFactoryScope; import org.junit.jupiter.api.AfterEach; @@ -25,16 +26,22 @@ import jakarta.persistence.Id; import jakarta.persistence.Tuple; +import static org.hibernate.vector.VectorTestHelper.cosineDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanNorm; +import static org.hibernate.vector.VectorTestHelper.hammingDistance; +import static org.hibernate.vector.VectorTestHelper.innerProduct; +import static org.hibernate.vector.VectorTestHelper.taxicabDistance; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Hassan AL Meftah */ -@DomainModel(annotatedClasses = OracleByteVectorTest.VectorEntity.class) +@DomainModel(annotatedClasses = ByteVectorTest.VectorEntity.class) @SessionFactory -@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4) -public class OracleByteVectorTest { +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsByteVectorType.class) +public class ByteVectorTest { private static final byte[] V1 = new byte[]{ 1, 2, 3 }; private static final byte[] V2 = new byte[]{ 4, 5, 6 }; @@ -67,14 +74,13 @@ public void testRead(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsCosineDistance.class) public void testCosineDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::cosine-distance-example[] final byte[] vector = new byte[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) - .setParameter( "vec", vector, byte[].class ) + .setParameter( "vec", vector ) .getResultList(); - //end::cosine-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); @@ -84,14 +90,13 @@ public void testCosineDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanDistance.class) public void testEuclideanDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::euclidean-distance-example[] final byte[] vector = new byte[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::euclidean-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); @@ -101,14 +106,13 @@ public void testEuclideanDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsTaxicabDistance.class) public void testTaxicabDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::taxicab-distance-example[] final byte[] vector = new byte[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::taxicab-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -118,14 +122,13 @@ public void testTaxicabDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsInnerProduct.class) public void testInnerProduct(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::inner-product-example[] final byte[] vector = new byte[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::inner-product-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -137,14 +140,13 @@ public void testInnerProduct(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsHammingDistance.class) public void testHammingDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::inner-product-example[] final byte[] vector = new byte[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::inner-product-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -154,12 +156,11 @@ public void testHammingDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorDims.class) public void testVectorDims(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::vector-dims-example[] final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) .getResultList(); - //end::vector-dims-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( V1.length, results.get( 0 ).get( 1 ) ); @@ -169,13 +170,12 @@ public void testVectorDims(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorNorm.class) @SkipForDialect(dialectClass = OracleDialect.class, reason = "Oracle 23.9 bug") public void testVectorNorm(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::vector-norm-example[] final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) .getResultList(); - //end::vector-norm-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -184,74 +184,16 @@ public void testVectorNorm(SessionFactoryScope scope) { } ); } - private static double cosineDistance(byte[] f1, byte[] f2) { - return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); - } - - private static double euclideanDistance(byte[] f1, byte[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += Math.pow( (double) f1[i] - f2[i], 2 ); - } - return Math.sqrt( result ); - } - - private static double taxicabDistance(byte[] f1, byte[] f2) { - return norm( f1 ) - norm( f2 ); - } - - private static double innerProduct(byte[] f1, byte[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += ( (double) f1[i] ) * ( (double) f2[i] ); - } - return result; - } - - public static double hammingDistance(byte[] f1, byte[] f2) { - assert f1.length == f2.length; - int distance = 0; - for (int i = 0; i < f1.length; i++) { - if (!(f1[i] == f2[i])) { - distance++; - } - } - return distance; - } - - - private static double euclideanNorm(byte[] f) { - double result = 0; - for ( double v : f ) { - result += Math.pow( v, 2 ); - } - return Math.sqrt( result ); - } - - private static double norm(byte[] f) { - double result = 0; - for ( double v : f ) { - result += Math.abs( v ); - } - return result; - } - @Entity( name = "VectorEntity" ) public static class VectorEntity { @Id private Long id; - //tag::usage-example[] @Column( name = "the_vector" ) @JdbcTypeCode(SqlTypes.VECTOR_INT8) @Array(length = 3) private byte[] theVector; - //end::usage-example[] - - public VectorEntity() { } diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/OracleDoubleVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/DoubleVectorTest.java similarity index 77% rename from hibernate-vector/src/test/java/org/hibernate/vector/OracleDoubleVectorTest.java rename to hibernate-vector/src/test/java/org/hibernate/vector/DoubleVectorTest.java index 49add683b4d4..04a830ccb176 100644 --- a/hibernate-vector/src/test/java/org/hibernate/vector/OracleDoubleVectorTest.java +++ b/hibernate-vector/src/test/java/org/hibernate/vector/DoubleVectorTest.java @@ -9,11 +9,12 @@ import org.hibernate.annotations.Array; import org.hibernate.annotations.JdbcTypeCode; import org.hibernate.dialect.OracleDialect; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; import org.hibernate.testing.orm.junit.SkipForDialect; import org.hibernate.type.SqlTypes; import org.hibernate.testing.orm.junit.DomainModel; -import org.hibernate.testing.orm.junit.RequiresDialect; import org.hibernate.testing.orm.junit.SessionFactory; import org.hibernate.testing.orm.junit.SessionFactoryScope; import org.junit.jupiter.api.AfterEach; @@ -25,16 +26,22 @@ import jakarta.persistence.Id; import jakarta.persistence.Tuple; +import static org.hibernate.vector.VectorTestHelper.cosineDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanNorm; +import static org.hibernate.vector.VectorTestHelper.hammingDistance; +import static org.hibernate.vector.VectorTestHelper.innerProduct; +import static org.hibernate.vector.VectorTestHelper.taxicabDistance; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Hassan AL Meftah */ -@DomainModel(annotatedClasses = OracleDoubleVectorTest.VectorEntity.class) +@DomainModel(annotatedClasses = DoubleVectorTest.VectorEntity.class) @SessionFactory -@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4) -public class OracleDoubleVectorTest { +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsDoubleVectorType.class) +public class DoubleVectorTest { private static final double[] V1 = new double[]{ 1, 2, 3 }; private static final double[] V2 = new double[]{ 4, 5, 6 }; @@ -67,14 +74,13 @@ public void testRead(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsCosineDistance.class) public void testCosineDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::cosine-distance-example[] final double[] vector = new double[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::cosine-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); @@ -84,14 +90,13 @@ public void testCosineDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanDistance.class) public void testEuclideanDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::euclidean-distance-example[] final double[] vector = new double[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::euclidean-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.00002D ); @@ -101,14 +106,13 @@ public void testEuclideanDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsTaxicabDistance.class) public void testTaxicabDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::taxicab-distance-example[] final double[] vector = new double[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::taxicab-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -118,14 +122,13 @@ public void testTaxicabDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsInnerProduct.class) public void testInnerProduct(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::inner-product-example[] final double[] vector = new double[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::inner-product-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -137,14 +140,13 @@ public void testInnerProduct(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsHammingDistance.class) public void testHammingDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::inner-product-example[] final double[] vector = new double[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::inner-product-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -153,12 +155,11 @@ public void testHammingDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorDims.class) public void testVectorDims(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::vector-dims-example[] final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) .getResultList(); - //end::vector-dims-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( V1.length, results.get( 0 ).get( 1 ) ); @@ -168,13 +169,12 @@ public void testVectorDims(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorNorm.class) @SkipForDialect(dialectClass = OracleDialect.class, reason = "Oracle 23.9 bug") public void testVectorNorm(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::vector-norm-example[] final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) .getResultList(); - //end::vector-norm-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -183,72 +183,16 @@ public void testVectorNorm(SessionFactoryScope scope) { } ); } - - private static double cosineDistance(double[] f1, double[] f2) { - return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); - } - - private static double euclideanDistance(double[] f1, double[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += Math.pow( (double) f1[i] - f2[i], 2 ); - } - return Math.sqrt( result ); - } - - private static double taxicabDistance(double[] f1, double[] f2) { - return norm( f1 ) - norm( f2 ); - } - - public static double hammingDistance(double[] f1, double[] f2) { - assert f1.length == f2.length; - int distance = 0; - for (int i = 0; i < f1.length; i++) { - if (!(f1[i] == f2[i])) { - distance++; - } - } - return distance; - } - - private static double innerProduct(double[] f1, double[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += ( (double) f1[i] ) * ( (double) f2[i] ); - } - return result; - } - - private static double euclideanNorm(double[] f) { - double result = 0; - for ( double v : f ) { - result += Math.pow( v, 2 ); - } - return Math.sqrt( result ); - } - - private static double norm(double[] f) { - double result = 0; - for ( double v : f ) { - result += Math.abs( v ); - } - return result; - } - @Entity( name = "VectorEntity" ) public static class VectorEntity { @Id private Long id; - //tag::usage-example[] @Column( name = "the_vector" ) @JdbcTypeCode(SqlTypes.VECTOR_FLOAT64) @Array(length = 3) private double[] theVector; - //end::usage-example[] public VectorEntity() { } diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/Float16VectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/Float16VectorTest.java new file mode 100644 index 000000000000..599101bc6bd2 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/Float16VectorTest.java @@ -0,0 +1,111 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.type.SqlTypes; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.hibernate.vector.VectorTestHelper.euclideanNormalize; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +@DomainModel(annotatedClasses = Float16VectorTest.VectorEntity.class) +@SessionFactory +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsFloat16VectorType.class) +public class Float16VectorTest extends FloatVectorTest { + + @BeforeEach + @Override + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, V1 ) ); + em.persist( new VectorEntity( 2L, V2 ) ); + } ); + } + + @Test + @Override + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new float[] { 1, 2, 3 }, tableRecord.getTheVector(), 0 ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new float[] { 4, 5, 6 }, tableRecord.getTheVector(), 0 ); + } ); + } + + // Due to lower precision (float16/half-precision floating-point) type usage, + // we have to give a higher allowed delta since we can't easily calculate with the same precision in Java yet + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsL2Normalize.class) + @Override + public void testL2Normalize(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( + "select e.id, l2_normalize(e.theVector) from VectorEntity e order by e.id", + Tuple.class + ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertArrayEquals( euclideanNormalize( V1 ), results.get( 0 ).get( 1, float[].class ), 0.0002f ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertArrayEquals( euclideanNormalize( V2 ), results.get( 1 ).get( 1, float[].class ), 0.0002f ); + } ); + } + + @Entity(name = "VectorEntity") + public static class VectorEntity { + + @Id + private Long id; + + @Column(name = "the_vector") + @JdbcTypeCode(SqlTypes.VECTOR_FLOAT16) + @Array(length = 3) + private float[] theVector; + + + public VectorEntity() { + } + + public VectorEntity(Long id, float[] theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public float[] getTheVector() { + return theVector; + } + + public void setTheVector(float[] theVector) { + this.theVector = theVector; + } + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/Float32VectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/Float32VectorTest.java new file mode 100644 index 000000000000..90cc3612f45f --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/Float32VectorTest.java @@ -0,0 +1,86 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.type.SqlTypes; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +@DomainModel(annotatedClasses = Float32VectorTest.VectorEntity.class) +@SessionFactory +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsFloatVectorType.class) +public class Float32VectorTest extends FloatVectorTest { + + @BeforeEach + @Override + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, V1 ) ); + em.persist( new VectorEntity( 2L, V2 ) ); + } ); + } + + @Test + @Override + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new float[] { 1, 2, 3 }, tableRecord.getTheVector(), 0 ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new float[] { 4, 5, 6 }, tableRecord.getTheVector(), 0 ); + } ); + } + + @Entity(name = "VectorEntity") + public static class VectorEntity { + + @Id + private Long id; + + @Column(name = "the_vector") + @JdbcTypeCode(SqlTypes.VECTOR_FLOAT32) + @Array(length = 3) + private float[] theVector; + + + public VectorEntity() { + } + + public VectorEntity(Long id, float[] theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public float[] getTheVector() { + return theVector; + } + + public void setTheVector(float[] theVector) { + this.theVector = theVector; + } + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/OracleGenericVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/FloatVectorTest.java similarity index 56% rename from hibernate-vector/src/test/java/org/hibernate/vector/OracleGenericVectorTest.java rename to hibernate-vector/src/test/java/org/hibernate/vector/FloatVectorTest.java index d199b91f594b..baadf1cfaf10 100644 --- a/hibernate-vector/src/test/java/org/hibernate/vector/OracleGenericVectorTest.java +++ b/hibernate-vector/src/test/java/org/hibernate/vector/FloatVectorTest.java @@ -4,41 +4,50 @@ */ package org.hibernate.vector; -import java.util.List; - +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; import org.hibernate.annotations.Array; import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.dialect.MySQLDialect; import org.hibernate.dialect.OracleDialect; +import org.hibernate.dialect.PostgreSQLDialect; import org.hibernate.testing.orm.junit.SkipForDialect; -import org.hibernate.type.SqlTypes; - +import org.hibernate.dialect.PostgresPlusDialect; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; import org.hibernate.testing.orm.junit.DomainModel; -import org.hibernate.testing.orm.junit.RequiresDialect; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; import org.hibernate.testing.orm.junit.SessionFactory; import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SkipForDialect; +import org.hibernate.type.SqlTypes; +import org.hibernate.vector.internal.VectorHelper; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import jakarta.persistence.Column; -import jakarta.persistence.Entity; -import jakarta.persistence.Id; -import jakarta.persistence.Tuple; +import java.util.List; +import static org.hibernate.vector.VectorTestHelper.cosineDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanNorm; +import static org.hibernate.vector.VectorTestHelper.euclideanNormalize; +import static org.hibernate.vector.VectorTestHelper.hammingDistance; +import static org.hibernate.vector.VectorTestHelper.innerProduct; +import static org.hibernate.vector.VectorTestHelper.taxicabDistance; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -/** - * @author Hassan AL Meftah - */ -@DomainModel(annotatedClasses = OracleGenericVectorTest.VectorEntity.class) +@DomainModel(annotatedClasses = FloatVectorTest.VectorEntity.class) @SessionFactory -@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4) -public class OracleGenericVectorTest { +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorType.class) +@SkipForDialect(dialectClass = PostgresPlusDialect.class, reason = "Test database does not have the extension enabled") +public class FloatVectorTest { - private static final float[] V1 = new float[] { 1, 2, 3 }; - private static final float[] V2 = new float[] { 4, 5, 6 }; + protected static final float[] V1 = new float[] { 1, 2, 3 }; + protected static final float[] V2 = new float[] { 4, 5, 6 }; @BeforeEach public void prepareData(SessionFactoryScope scope) { @@ -68,6 +77,20 @@ public void testRead(SessionFactoryScope scope) { } @Test + public void testCast(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::vector-cast-example[] + final Tuple vector = em.createSelectionQuery( "select cast(e.theVector as string), cast('[1, 1, 1]' as vector) from VectorEntity e where e.id = 1", Tuple.class ) + .getSingleResult(); + //end::vector-cast-example[] + assertArrayEquals( new float[]{ 1, 2, 3 }, VectorHelper.parseFloatVector( vector.get( 0, String.class ) ) ); + assertArrayEquals( new float[]{ 1, 1, 1 }, vector.get( 1, float[].class ) ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsCosineDistance.class) + @SkipForDialect(dialectClass = MySQLDialect.class, reason = "Only MySQL HeatWave supports this function") public void testCosineDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { //tag::cosine-distance-example[] @@ -88,6 +111,8 @@ public void testCosineDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanDistance.class) + @SkipForDialect(dialectClass = MySQLDialect.class, reason = "Only MySQL HeatWave supports this function") public void testEuclideanDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { //tag::euclidean-distance-example[] @@ -108,6 +133,7 @@ public void testEuclideanDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsTaxicabDistance.class) public void testTaxicabDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { //tag::taxicab-distance-example[] @@ -128,6 +154,8 @@ public void testTaxicabDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsInnerProduct.class) + @SkipForDialect(dialectClass = MySQLDialect.class, reason = "Only MySQL HeatWave supports this function") public void testInnerProduct(SessionFactoryScope scope) { scope.inTransaction( em -> { //tag::inner-product-example[] @@ -150,9 +178,11 @@ public void testInnerProduct(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsHammingDistance.class) + @SkipForDialect(dialectClass = PostgreSQLDialect.class, matchSubTypes = true, reason = "Only supported with bit vectors") public void testHammingDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::inner-product-example[] + //tag::hamming-distance-example[] final float[] vector = new float[] { 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", @@ -160,7 +190,7 @@ public void testHammingDistance(SessionFactoryScope scope) { ) .setParameter( "vec", vector ) .getResultList(); - //end::inner-product-example[] + //end::hamming-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -170,6 +200,7 @@ public void testHammingDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorDims.class) public void testVectorDims(SessionFactoryScope scope) { scope.inTransaction( em -> { //tag::vector-dims-example[] @@ -188,6 +219,7 @@ public void testVectorDims(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorNorm.class) @SkipForDialect(dialectClass = OracleDialect.class, reason = "Oracle 23.9 bug") public void testVectorNorm(SessionFactoryScope scope) { scope.inTransaction( em -> { @@ -206,59 +238,83 @@ public void testVectorNorm(SessionFactoryScope scope) { } ); } - - private static double cosineDistance(float[] f1, float[] f2) { - return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); - } - - private static double euclideanDistance(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += Math.pow( (double) f1[i] - f2[i], 2 ); - } - return Math.sqrt( result ); - } - - private static double taxicabDistance(float[] f1, float[] f2) { - return norm( f1 ) - norm( f2 ); - } - - private static double innerProduct(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += ( (double) f1[i] ) * ( (double) f2[i] ); - } - return result; + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsL2Norm.class) + @SkipForDialect(dialectClass = OracleDialect.class, reason = "Oracle 23.9 bug") + public void testL2Norm(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::l2-norm-example[] + final List results = em.createSelectionQuery( + "select e.id, l2_norm(e.theVector) from VectorEntity e order by e.id", + Tuple.class + ) + .getResultList(); + //end::l2-norm-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); } - public static double hammingDistance(float[] f1, float[] f2) { - assert f1.length == f2.length; - int distance = 0; - for ( int i = 0; i < f1.length; i++ ) { - if ( !( f1[i] == f2[i] ) ) { - distance++; - } - } - return distance; + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsL2Normalize.class) + public void testL2Normalize(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::l2-normalize-example[] + final List results = em.createSelectionQuery( + "select e.id, l2_normalize(e.theVector) from VectorEntity e order by e.id", + Tuple.class + ) + .getResultList(); + //end::l2-normalize-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertArrayEquals( euclideanNormalize( V1 ), results.get( 0 ).get( 1, float[].class ), 0f ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertArrayEquals( euclideanNormalize( V2 ), results.get( 1 ).get( 1, float[].class ), 0f ); + } ); } - - private static double euclideanNorm(float[] f) { - double result = 0; - for ( double v : f ) { - result += Math.pow( v, 2 ); - } - return Math.sqrt( result ); + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsSubvector.class) + public void testSubvector(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::subvector-example[] + final List results = em.createSelectionQuery( + "select e.id, subvector(e.theVector, 1, 1) from VectorEntity e order by e.id", + Tuple.class + ) + .getResultList(); + //end::subvector-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( 1, results.get( 0 ).get( 1, float[].class ).length ); + assertEquals( V1[0], results.get( 0 ).get( 1, float[].class )[0], 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( 1, results.get( 1 ).get( 1, float[].class ).length ); + assertEquals( V2[0], results.get( 1 ).get( 1, float[].class )[0], 0D ); + } ); } - private static double norm(float[] f) { - double result = 0; - for ( double v : f ) { - result += Math.abs( v ); - } - return result; + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsBinaryQuantize.class) + public void testBinaryQuantize(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::binary-quantize-example[] + final List results = em.createSelectionQuery( + "select e.id, binary_quantize(e.theVector) from VectorEntity e order by e.id", + Tuple.class + ) + .getResultList(); + //end::binary-quantize-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertArrayEquals( new byte[]{(byte) 0b11100000}, results.get( 0 ).get( 1, byte[].class ) ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertArrayEquals( new byte[]{(byte) 0b11100000}, results.get( 1 ).get( 1, byte[].class ) ); + } ); } @Entity(name = "VectorEntity") diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/MariaDBTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/MariaDBTest.java deleted file mode 100644 index 9fe030842601..000000000000 --- a/hibernate-vector/src/test/java/org/hibernate/vector/MariaDBTest.java +++ /dev/null @@ -1,167 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright Red Hat Inc. and Hibernate Authors - */ -package org.hibernate.vector; - -import jakarta.persistence.Column; -import jakarta.persistence.Entity; -import jakarta.persistence.Id; -import jakarta.persistence.Tuple; -import org.hibernate.annotations.Array; -import org.hibernate.annotations.JdbcTypeCode; -import org.hibernate.dialect.MariaDBDialect; -import org.hibernate.testing.orm.junit.DomainModel; -import org.hibernate.testing.orm.junit.RequiresDialect; -import org.hibernate.testing.orm.junit.SessionFactory; -import org.hibernate.testing.orm.junit.SessionFactoryScope; -import org.hibernate.type.SqlTypes; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; - -/** - * @author Diego Dupin - */ -@DomainModel(annotatedClasses = MariaDBTest.VectorEntity.class) -@SessionFactory -@RequiresDialect(value = MariaDBDialect.class, matchSubTypes = false, majorVersion = 11, minorVersion = 7) -public class MariaDBTest { - - private static final float[] V1 = new float[]{ 1, 2, 3 }; - private static final float[] V2 = new float[]{ 4, 5, 6 }; - - @BeforeEach - public void prepareData(SessionFactoryScope scope) { - scope.inTransaction( em -> { - em.persist( new VectorEntity( 1L, V1 ) ); - em.persist( new VectorEntity( 2L, V2 ) ); - } ); - } - - @AfterEach - public void cleanup(SessionFactoryScope scope) { - scope.inTransaction( em -> { - em.createMutationQuery( "delete from VectorEntity" ).executeUpdate(); - } ); - } - - @Test - public void testRead(SessionFactoryScope scope) { - scope.inTransaction( em -> { - VectorEntity tableRecord; - tableRecord = em.find( VectorEntity.class, 1L ); - assertArrayEquals( new float[]{ 1, 2, 3 }, tableRecord.getTheVector(), 0 ); - - tableRecord = em.find( VectorEntity.class, 2L ); - assertArrayEquals( new float[]{ 4, 5, 6 }, tableRecord.getTheVector(), 0 ); - } ); - } - - @Test - public void testCosineDistance(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::cosine-distance-example[] - final float[] vector = new float[]{ 1, 1, 1 }; - final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) - .setParameter( "vec", vector ) - .getResultList(); - //end::cosine-distance-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0.0000000000000002D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0.0000000000000002D ); - } ); - } - - @Test - public void testEuclideanDistance(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::euclidean-distance-example[] - final float[] vector = new float[]{ 1, 1, 1 }; - final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) - .setParameter( "vec", vector ) - .getResultList(); - //end::euclidean-distance-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D ); - } ); - } - - private static double cosineDistance(float[] f1, float[] f2) { - return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); - } - - private static double euclideanDistance(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += Math.pow( (double) f1[i] - f2[i], 2 ); - } - return Math.sqrt( result ); - } - - private static double innerProduct(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += ( (double) f1[i] ) * ( (double) f2[i] ); - } - return result; - } - - private static double euclideanNorm(float[] f) { - double result = 0; - for ( float v : f ) { - result += Math.pow( v, 2 ); - } - return Math.sqrt( result ); - } - - @Entity( name = "VectorEntity" ) - public static class VectorEntity { - - @Id - private Long id; - - //tag::usage-example[] - @Column( name = "the_vector" ) - @JdbcTypeCode(SqlTypes.VECTOR) - @Array(length = 3) - private float[] theVector; - //end::usage-example[] - - public VectorEntity() { - } - - public VectorEntity(Long id, float[] theVector) { - this.id = id; - this.theVector = theVector; - } - - public Long getId() { - return id; - } - - public void setId(Long id) { - this.id = id; - } - - public float[] getTheVector() { - return theVector; - } - - public void setTheVector(float[] theVector) { - this.theVector = theVector; - } - } -} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/PGVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/PGVectorTest.java index d3c44c4eebc2..9c3c2f734972 100644 --- a/hibernate-vector/src/test/java/org/hibernate/vector/PGVectorTest.java +++ b/hibernate-vector/src/test/java/org/hibernate/vector/PGVectorTest.java @@ -4,28 +4,25 @@ */ package org.hibernate.vector; -import java.util.List; - +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; import org.hibernate.annotations.Array; import org.hibernate.annotations.JdbcTypeCode; import org.hibernate.dialect.CockroachDialect; import org.hibernate.dialect.PostgreSQLDialect; -import org.hibernate.testing.orm.junit.RequiresDialects; -import org.hibernate.testing.orm.junit.SkipForDialect; -import org.hibernate.type.SqlTypes; - import org.hibernate.testing.orm.junit.DomainModel; import org.hibernate.testing.orm.junit.RequiresDialect; import org.hibernate.testing.orm.junit.SessionFactory; import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SkipForDialect; +import org.hibernate.type.SqlTypes; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import jakarta.persistence.Column; -import jakarta.persistence.Entity; -import jakarta.persistence.Id; -import jakarta.persistence.Tuple; +import java.util.List; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -35,10 +32,8 @@ */ @DomainModel(annotatedClasses = PGVectorTest.VectorEntity.class) @SessionFactory -@RequiresDialects({ - @RequiresDialect(value = PostgreSQLDialect.class, matchSubTypes = false), - @RequiresDialect(value = CockroachDialect.class, majorVersion = 24, minorVersion = 2) -}) +@RequiresDialect(value = PostgreSQLDialect.class, matchSubTypes = false) +@RequiresDialect(value = CockroachDialect.class, majorVersion = 24, minorVersion = 2) public class PGVectorTest { private static final float[] V1 = new float[]{ 1, 2, 3 }; @@ -59,118 +54,6 @@ public void cleanup(SessionFactoryScope scope) { } ); } - @Test - public void testRead(SessionFactoryScope scope) { - scope.inTransaction( em -> { - VectorEntity tableRecord; - tableRecord = em.find( VectorEntity.class, 1L ); - assertArrayEquals( new float[]{ 1, 2, 3 }, tableRecord.getTheVector(), 0 ); - - tableRecord = em.find( VectorEntity.class, 2L ); - assertArrayEquals( new float[]{ 4, 5, 6 }, tableRecord.getTheVector(), 0 ); - } ); - } - - @Test - public void testCosineDistance(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::cosine-distance-example[] - final float[] vector = new float[]{ 1, 1, 1 }; - final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) - .setParameter( "vec", vector ) - .getResultList(); - //end::cosine-distance-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0.0000000000000002D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0.0000000000000002D ); - } ); - } - - @Test - public void testEuclideanDistance(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::euclidean-distance-example[] - final float[] vector = new float[]{ 1, 1, 1 }; - final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) - .setParameter( "vec", vector ) - .getResultList(); - //end::euclidean-distance-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D ); - } ); - } - - @Test - public void testTaxicabDistance(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::taxicab-distance-example[] - final float[] vector = new float[]{ 1, 1, 1 }; - final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) - .setParameter( "vec", vector ) - .getResultList(); - //end::taxicab-distance-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D ); - } ); - } - - @Test - public void testInnerProduct(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::inner-product-example[] - final float[] vector = new float[]{ 1, 1, 1 }; - final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) - .setParameter( "vec", vector ) - .getResultList(); - //end::inner-product-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D ); - assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, Double.class ), 0D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D ); - assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, Double.class ), 0D ); - } ); - } - - @Test - public void testVectorDims(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::vector-dims-example[] - final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) - .getResultList(); - //end::vector-dims-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( V1.length, results.get( 0 ).get( 1 ) ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( V2.length, results.get( 1 ).get( 1 ) ); - } ); - } - - @Test - public void testVectorNorm(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::vector-norm-example[] - final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) - .getResultList(); - //end::vector-norm-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, Double.class ), 0D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, Double.class ), 0D ); - } ); - } - @Test @SkipForDialect(dialectClass = CockroachDialect.class, reason = "CockroachDB does not currently support the sum() function on vector type" ) public void testVectorSum(SessionFactoryScope scope) { @@ -227,60 +110,16 @@ public void testMultiplication(SessionFactoryScope scope) { } ); } - private static double cosineDistance(float[] f1, float[] f2) { - return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); - } - - private static double euclideanDistance(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += Math.pow( (double) f1[i] - f2[i], 2 ); - } - return Math.sqrt( result ); - } - - private static double taxicabDistance(float[] f1, float[] f2) { - return norm( f1 ) - norm( f2 ); - } - - private static double innerProduct(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += ( (double) f1[i] ) * ( (double) f2[i] ); - } - return result; - } - - private static double euclideanNorm(float[] f) { - double result = 0; - for ( float v : f ) { - result += Math.pow( v, 2 ); - } - return Math.sqrt( result ); - } - - private static double norm(float[] f) { - double result = 0; - for ( float v : f ) { - result += Math.abs( v ); - } - return result; - } - @Entity( name = "VectorEntity" ) public static class VectorEntity { @Id private Long id; - //tag::usage-example[] @Column( name = "the_vector" ) @JdbcTypeCode(SqlTypes.VECTOR) @Array(length = 3) private float[] theVector; - //end::usage-example[] public VectorEntity() { } diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/SparseByteVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/SparseByteVectorTest.java new file mode 100644 index 000000000000..a6b2704288a7 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/SparseByteVectorTest.java @@ -0,0 +1,214 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.type.SqlTypes; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.hibernate.vector.VectorTestHelper.cosineDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanNorm; +import static org.hibernate.vector.VectorTestHelper.hammingDistance; +import static org.hibernate.vector.VectorTestHelper.innerProduct; +import static org.hibernate.vector.VectorTestHelper.taxicabDistance; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +@DomainModel(annotatedClasses = SparseByteVectorTest.VectorEntity.class) +@SessionFactory +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsSparseByteVectorType.class) +public class SparseByteVectorTest { + + private static final byte[] V1 = new byte[]{ 0, 2, 3 }; + private static final byte[] V2 = new byte[]{ 0, 5, 6 }; + + @BeforeEach + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, new SparseByteVector( V1 ) ) ); + em.persist( new VectorEntity( 2L, new SparseByteVector( V2 ) ) ); + } ); + } + + @AfterEach + public void cleanup(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.createMutationQuery( "delete from VectorEntity" ).executeUpdate(); + } ); + } + + @Test + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new byte[]{ 0, 2, 3 }, tableRecord.getTheVector().toDenseVector() ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new byte[]{ 0, 5, 6 }, tableRecord.getTheVector().toDenseVector() ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsCosineDistance.class) + public void testCosineDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseByteVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanDistance.class) + public void testEuclideanDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseByteVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsTaxicabDistance.class) + public void testTaxicabDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseByteVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsInnerProduct.class) + public void testInnerProduct(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseByteVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsHammingDistance.class) + public void testHammingDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseByteVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorDims.class) + public void testVectorDims(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( V1.length, results.get( 0 ).get( 1 ) ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( V2.length, results.get( 1 ).get( 1 ) ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorNorm.class) + public void testVectorNorm(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Entity( name = "VectorEntity" ) + public static class VectorEntity { + + @Id + private Long id; + + @Column( name = "the_vector" ) + @JdbcTypeCode(SqlTypes.SPARSE_VECTOR_INT8) + @Array(length = 3) + private SparseByteVector theVector; + + public VectorEntity() { + } + + public VectorEntity(Long id, SparseByteVector theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public SparseByteVector getTheVector() { + return theVector; + } + + public void setTheVector(SparseByteVector theVector) { + this.theVector = theVector; + } + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/SparseByteVectorUnitTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/SparseByteVectorUnitTest.java new file mode 100644 index 000000000000..9fc656be8801 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/SparseByteVectorUnitTest.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +public class SparseByteVectorUnitTest { + + @Test + public void testEmpty() { + final SparseByteVector bytes = new SparseByteVector( 3 ); + bytes.set( 1, (byte) 3 ); + assertArrayEquals( new Object[] {(byte) 0, (byte) 3, (byte) 0}, bytes.toArray() ); + } + + @Test + public void testInsertBefore() { + final SparseByteVector bytes = new SparseByteVector( 3, new int[] {1}, new byte[] {3} ); + bytes.set( 0, (byte) 2 ); + assertArrayEquals( new Object[] {(byte) 2, (byte) 3, (byte) 0}, bytes.toArray() ); + } + + @Test + public void testInsertAfter() { + final SparseByteVector bytes = new SparseByteVector( 3, new int[] {1}, new byte[] {3} ); + bytes.set( 2, (byte) 2 ); + assertArrayEquals( new Object[] {(byte) 0, (byte) 3, (byte) 2}, bytes.toArray() ); + } + + @Test + public void testReplace() { + final SparseByteVector bytes = new SparseByteVector( 3, new int[] {0, 1, 2}, new byte[] {3, 3, 3} ); + bytes.set( 2, (byte) 2 ); + assertArrayEquals( new Object[] {(byte) 3, (byte) 3, (byte) 2}, bytes.toArray() ); + } + + @Test + public void testFromDenseVector() { + final SparseByteVector bytes = new SparseByteVector( new byte[] {0, 3, 0} ); + assertArrayEquals( new Object[] {(byte) 0, (byte) 3, (byte) 0}, bytes.toArray() ); + } + + @Test + public void testFromDenseVectorList() { + final SparseByteVector bytes = new SparseByteVector( List.of( (byte) 0, (byte) 3, (byte) 0 ) ); + assertArrayEquals( new Object[] {(byte) 0, (byte) 3, (byte) 0}, bytes.toArray() ); + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/SparseDoubleVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/SparseDoubleVectorTest.java new file mode 100644 index 000000000000..1fcec4c40e0a --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/SparseDoubleVectorTest.java @@ -0,0 +1,214 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.type.SqlTypes; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.hibernate.vector.VectorTestHelper.cosineDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanNorm; +import static org.hibernate.vector.VectorTestHelper.hammingDistance; +import static org.hibernate.vector.VectorTestHelper.innerProduct; +import static org.hibernate.vector.VectorTestHelper.taxicabDistance; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +@DomainModel(annotatedClasses = SparseDoubleVectorTest.VectorEntity.class) +@SessionFactory +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsSparseDoubleVectorType.class) +public class SparseDoubleVectorTest { + + private static final double[] V1 = new double[]{ 0, 2, 3 }; + private static final double[] V2 = new double[]{ 0, 5, 6 }; + + @BeforeEach + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, new SparseDoubleVector( V1 ) ) ); + em.persist( new VectorEntity( 2L, new SparseDoubleVector( V2 ) ) ); + } ); + } + + @AfterEach + public void cleanup(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.createMutationQuery( "delete from VectorEntity" ).executeUpdate(); + } ); + } + + @Test + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new double[]{ 0, 2, 3 }, tableRecord.getTheVector().toDenseVector() ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new double[]{ 0, 5, 6 }, tableRecord.getTheVector().toDenseVector() ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsCosineDistance.class) + public void testCosineDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseDoubleVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanDistance.class) + public void testEuclideanDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseDoubleVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsTaxicabDistance.class) + public void testTaxicabDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseDoubleVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsInnerProduct.class) + public void testInnerProduct(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseDoubleVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsHammingDistance.class) + public void testHammingDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseDoubleVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorDims.class) + public void testVectorDims(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( V1.length, results.get( 0 ).get( 1 ) ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( V2.length, results.get( 1 ).get( 1 ) ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorNorm.class) + public void testVectorNorm(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Entity( name = "VectorEntity" ) + public static class VectorEntity { + + @Id + private Long id; + + @Column( name = "the_vector" ) + @JdbcTypeCode(SqlTypes.SPARSE_VECTOR_FLOAT64) + @Array(length = 3) + private SparseDoubleVector theVector; + + public VectorEntity() { + } + + public VectorEntity(Long id, SparseDoubleVector theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public SparseDoubleVector getTheVector() { + return theVector; + } + + public void setTheVector(SparseDoubleVector theVector) { + this.theVector = theVector; + } + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/SparseDoubleVectorUnitTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/SparseDoubleVectorUnitTest.java new file mode 100644 index 000000000000..adf14bb2cef2 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/SparseDoubleVectorUnitTest.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +public class SparseDoubleVectorUnitTest { + + @Test + public void testEmpty() { + final SparseDoubleVector doubles = new SparseDoubleVector( 3 ); + doubles.set( 1, (double) 3 ); + assertArrayEquals( new Object[] {(double) 0, (double) 3, (double) 0}, doubles.toArray() ); + } + + @Test + public void testInsertBefore() { + final SparseDoubleVector doubles = new SparseDoubleVector( 3, new int[] {1}, new double[] {3} ); + doubles.set( 0, (double) 2 ); + assertArrayEquals( new Object[] {(double) 2, (double) 3, (double) 0}, doubles.toArray() ); + } + + @Test + public void testInsertAfter() { + final SparseDoubleVector doubles = new SparseDoubleVector( 3, new int[] {1}, new double[] {3} ); + doubles.set( 2, (double) 2 ); + assertArrayEquals( new Object[] {(double) 0, (double) 3, (double) 2}, doubles.toArray() ); + } + + @Test + public void testReplace() { + final SparseDoubleVector doubles = new SparseDoubleVector( 3, new int[] {0, 1, 2}, new double[] {3, 3, 3} ); + doubles.set( 2, (double) 2 ); + assertArrayEquals( new Object[] {(double) 3, (double) 3, (double) 2}, doubles.toArray() ); + } + + @Test + public void testFromDenseVector() { + final SparseDoubleVector doubles = new SparseDoubleVector( new double[] {0, 3, 0} ); + assertArrayEquals( new Object[] {(double) 0, (double) 3, (double) 0}, doubles.toArray() ); + } + + @Test + public void testFromDenseVectorList() { + final SparseDoubleVector doubles = new SparseDoubleVector( List.of( (double) 0, (double) 3, (double) 0 ) ); + assertArrayEquals( new Object[] {(double) 0, (double) 3, (double) 0}, doubles.toArray() ); + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/OracleFloatVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/SparseFloatVectorTest.java similarity index 52% rename from hibernate-vector/src/test/java/org/hibernate/vector/OracleFloatVectorTest.java rename to hibernate-vector/src/test/java/org/hibernate/vector/SparseFloatVectorTest.java index da83e8bf7ef6..92df8e0c60bf 100644 --- a/hibernate-vector/src/test/java/org/hibernate/vector/OracleFloatVectorTest.java +++ b/hibernate-vector/src/test/java/org/hibernate/vector/SparseFloatVectorTest.java @@ -4,46 +4,50 @@ */ package org.hibernate.vector; -import java.util.List; - +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; import org.hibernate.annotations.Array; import org.hibernate.annotations.JdbcTypeCode; -import org.hibernate.dialect.OracleDialect; -import org.hibernate.testing.orm.junit.SkipForDialect; -import org.hibernate.type.SqlTypes; - +import org.hibernate.dialect.PostgreSQLDialect; +import org.hibernate.dialect.PostgresPlusDialect; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; import org.hibernate.testing.orm.junit.DomainModel; -import org.hibernate.testing.orm.junit.RequiresDialect; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; import org.hibernate.testing.orm.junit.SessionFactory; import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SkipForDialect; +import org.hibernate.type.SqlTypes; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import jakarta.persistence.Column; -import jakarta.persistence.Entity; -import jakarta.persistence.Id; -import jakarta.persistence.Tuple; +import java.util.List; +import static org.hibernate.vector.VectorTestHelper.cosineDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanNorm; +import static org.hibernate.vector.VectorTestHelper.hammingDistance; +import static org.hibernate.vector.VectorTestHelper.innerProduct; +import static org.hibernate.vector.VectorTestHelper.taxicabDistance; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -/** - * @author Hassan AL Meftah - */ -@DomainModel(annotatedClasses = OracleFloatVectorTest.VectorEntity.class) +@DomainModel(annotatedClasses = SparseFloatVectorTest.VectorEntity.class) @SessionFactory -@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4) -public class OracleFloatVectorTest { +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsSparseFloatVectorType.class) +@SkipForDialect(dialectClass = PostgresPlusDialect.class, reason = "Test database does not have the extension enabled") +public class SparseFloatVectorTest { - private static final float[] V1 = new float[] { 1, 2, 3 }; - private static final float[] V2 = new float[] { 4, 5, 6 }; + private static final float[] V1 = new float[]{ 0, 2, 3 }; + private static final float[] V2 = new float[]{ 0, 5, 6 }; @BeforeEach public void prepareData(SessionFactoryScope scope) { scope.inTransaction( em -> { - em.persist( new VectorEntity( 1L, V1 ) ); - em.persist( new VectorEntity( 2L, V2 ) ); + em.persist( new VectorEntity( 1L, new SparseFloatVector( V1 ) ) ); + em.persist( new VectorEntity( 2L, new SparseFloatVector( V2 ) ) ); } ); } @@ -59,25 +63,21 @@ public void testRead(SessionFactoryScope scope) { scope.inTransaction( em -> { VectorEntity tableRecord; tableRecord = em.find( VectorEntity.class, 1L ); - assertArrayEquals( new float[] { 1, 2, 3 }, tableRecord.getTheVector(), 0 ); + assertArrayEquals( new float[]{ 0, 2, 3 }, tableRecord.getTheVector().toDenseVector() ); tableRecord = em.find( VectorEntity.class, 2L ); - assertArrayEquals( new float[] { 4, 5, 6 }, tableRecord.getTheVector(), 0 ); + assertArrayEquals( new float[]{ 0, 5, 6 }, tableRecord.getTheVector().toDenseVector() ); } ); } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsCosineDistance.class) public void testCosineDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::cosine-distance-example[] - final float[] vector = new float[] { 1, 1, 1 }; - final List results = em.createSelectionQuery( - "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", - Tuple.class - ) - .setParameter( "vec", vector ) + final float[] vector = new float[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseFloatVector( vector ) ) .getResultList(); - //end::cosine-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); @@ -87,17 +87,13 @@ public void testCosineDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanDistance.class) public void testEuclideanDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::euclidean-distance-example[] - final float[] vector = new float[] { 1, 1, 1 }; - final List results = em.createSelectionQuery( - "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", - Tuple.class - ) - .setParameter( "vec", vector ) + final float[] vector = new float[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseFloatVector( vector ) ) .getResultList(); - //end::euclidean-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); @@ -107,17 +103,13 @@ public void testEuclideanDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsTaxicabDistance.class) public void testTaxicabDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::taxicab-distance-example[] - final float[] vector = new float[] { 1, 1, 1 }; - final List results = em.createSelectionQuery( - "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", - Tuple.class - ) - .setParameter( "vec", vector ) + final float[] vector = new float[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseFloatVector( vector ) ) .getResultList(); - //end::taxicab-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -127,17 +119,13 @@ public void testTaxicabDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsInnerProduct.class) public void testInnerProduct(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::inner-product-example[] - final float[] vector = new float[] { 1, 1, 1 }; - final List results = em.createSelectionQuery( - "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", - Tuple.class - ) - .setParameter( "vec", vector ) + final float[] vector = new float[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseFloatVector( vector ) ) .getResultList(); - //end::inner-product-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -149,17 +137,14 @@ public void testInnerProduct(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsHammingDistance.class) + @SkipForDialect(dialectClass = PostgreSQLDialect.class, matchSubTypes = true, reason = "Not supported with sparse vectors") public void testHammingDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::inner-product-example[] - final float[] vector = new float[] { 1, 1, 1 }; - final List results = em.createSelectionQuery( - "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", - Tuple.class - ) - .setParameter( "vec", vector ) + final float[] vector = new float[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseFloatVector( vector ) ) .getResultList(); - //end::inner-product-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -169,15 +154,11 @@ public void testHammingDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorDims.class) public void testVectorDims(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::vector-dims-example[] - final List results = em.createSelectionQuery( - "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", - Tuple.class - ) + final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) .getResultList(); - //end::vector-dims-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( V1.length, results.get( 0 ).get( 1 ) ); @@ -187,16 +168,11 @@ public void testVectorDims(SessionFactoryScope scope) { } @Test - @SkipForDialect(dialectClass = OracleDialect.class, reason = "Oracle 23.9 bug") + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorNorm.class) public void testVectorNorm(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::vector-norm-example[] - final List results = em.createSelectionQuery( - "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", - Tuple.class - ) + final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) .getResultList(); - //end::vector-norm-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -205,79 +181,21 @@ public void testVectorNorm(SessionFactoryScope scope) { } ); } - - private static double cosineDistance(float[] f1, float[] f2) { - return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); - } - - private static double euclideanDistance(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += Math.pow( (double) f1[i] - f2[i], 2 ); - } - return Math.sqrt( result ); - } - - private static double taxicabDistance(float[] f1, float[] f2) { - return norm( f1 ) - norm( f2 ); - } - - private static double innerProduct(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += ( (double) f1[i] ) * ( (double) f2[i] ); - } - return result; - } - - public static double hammingDistance(float[] f1, float[] f2) { - assert f1.length == f2.length; - int distance = 0; - for ( int i = 0; i < f1.length; i++ ) { - if ( !( f1[i] == f2[i] ) ) { - distance++; - } - } - return distance; - } - - - private static double euclideanNorm(float[] f) { - double result = 0; - for ( double v : f ) { - result += Math.pow( v, 2 ); - } - return Math.sqrt( result ); - } - - private static double norm(float[] f) { - double result = 0; - for ( double v : f ) { - result += Math.abs( v ); - } - return result; - } - - @Entity(name = "VectorEntity") + @Entity( name = "VectorEntity" ) public static class VectorEntity { @Id private Long id; - //tag::usage-example[] - @Column(name = "the_vector") - @JdbcTypeCode(SqlTypes.VECTOR_FLOAT32) + @Column( name = "the_vector" ) + @JdbcTypeCode(SqlTypes.SPARSE_VECTOR_FLOAT32) @Array(length = 3) - private float[] theVector; - //end::usage-example[] - + private SparseFloatVector theVector; public VectorEntity() { } - public VectorEntity(Long id, float[] theVector) { + public VectorEntity(Long id, SparseFloatVector theVector) { this.id = id; this.theVector = theVector; } @@ -290,11 +208,11 @@ public void setId(Long id) { this.id = id; } - public float[] getTheVector() { + public SparseFloatVector getTheVector() { return theVector; } - public void setTheVector(float[] theVector) { + public void setTheVector(SparseFloatVector theVector) { this.theVector = theVector; } } diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/SparseFloatVectorUnitTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/SparseFloatVectorUnitTest.java new file mode 100644 index 000000000000..20eb1dd3c3f8 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/SparseFloatVectorUnitTest.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +public class SparseFloatVectorUnitTest { + + @Test + public void testEmpty() { + final SparseFloatVector floats = new SparseFloatVector( 3 ); + floats.set( 1, (float) 3 ); + assertArrayEquals( new Object[] {(float) 0, (float) 3, (float) 0}, floats.toArray() ); + } + + @Test + public void testInsertBefore() { + final SparseFloatVector floats = new SparseFloatVector( 3, new int[] {1}, new float[] {3} ); + floats.set( 0, (float) 2 ); + assertArrayEquals( new Object[] {(float) 2, (float) 3, (float) 0}, floats.toArray() ); + } + + @Test + public void testInsertAfter() { + final SparseFloatVector floats = new SparseFloatVector( 3, new int[] {1}, new float[] {3} ); + floats.set( 2, (float) 2 ); + assertArrayEquals( new Object[] {(float) 0, (float) 3, (float) 2}, floats.toArray() ); + } + + @Test + public void testReplace() { + final SparseFloatVector floats = new SparseFloatVector( 3, new int[] {0, 1, 2}, new float[] {3, 3, 3} ); + floats.set( 2, (float) 2 ); + assertArrayEquals( new Object[] {(float) 3, (float) 3, (float) 2}, floats.toArray() ); + } + + @Test + public void testFromDenseVector() { + final SparseFloatVector floats = new SparseFloatVector( new float[] {0, 3, 0} ); + assertArrayEquals( new Object[] {(float) 0, (float) 3, (float) 0}, floats.toArray() ); + } + + @Test + public void testFromDenseVectorList() { + final SparseFloatVector floats = new SparseFloatVector( List.of( (float) 0, (float) 3, (float) 0 ) ); + assertArrayEquals( new Object[] {(float) 0, (float) 3, (float) 0}, floats.toArray() ); + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/VectorTestHelper.java b/hibernate-vector/src/test/java/org/hibernate/vector/VectorTestHelper.java new file mode 100644 index 000000000000..083b847578cf --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/VectorTestHelper.java @@ -0,0 +1,236 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +public class VectorTestHelper { + + public static double cosineDistance(float[] f1, float[] f2) { + return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); + } + + public static double cosineDistance(double[] f1, double[] f2) { + return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); + } + + public static double cosineDistance(byte[] f1, byte[] f2) { + return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); + } + + public static double cosineDistanceBinary(byte[] f1, byte[] f2) { + return 1D - innerProductBinary( f1, f2 ) / ( euclideanNormBinary( f1 ) * euclideanNormBinary( f2 ) ); + } + + public static double euclideanDistance(float[] f1, float[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.pow( (double) f1[i] - f2[i], 2 ); + } + return Math.sqrt( result ); + } + + public static double euclideanDistance(double[] f1, double[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.pow( (double) f1[i] - f2[i], 2 ); + } + return Math.sqrt( result ); + } + + public static double euclideanDistance(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.pow( (double) f1[i] - f2[i], 2 ); + } + return Math.sqrt( result ); + } + + public static double euclideanDistanceBinary(byte[] f1, byte[] f2) { + // On bit level, the two distance functions are equivalent + return Math.sqrt( hammingDistanceBinary( f1, f2 ) ); + } + + public static double taxicabDistance(float[] f1, float[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.abs( f1[i] - f2[i] ); + } + return result; + } + + public static double taxicabDistance(double[] f1, double[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.abs( f1[i] - f2[i] ); + } + return result; + } + + public static double taxicabDistance(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.abs( f1[i] - f2[i] ); + } + return result; + } + + public static double taxicabDistanceBinary(byte[] f1, byte[] f2) { + // On bit level, the two distance functions are equivalent + return hammingDistanceBinary( f1, f2 ); + } + + public static double innerProduct(float[] f1, float[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += ( (double) f1[i] ) * ( (double) f2[i] ); + } + return result; + } + + public static double innerProduct(double[] f1, double[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += ( (double) f1[i] ) * ( (double) f2[i] ); + } + return result; + } + + public static double innerProduct(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += ( (double) f1[i] ) * ( (double) f2[i] ); + } + return result; + } + + public static double innerProductBinary(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Integer.bitCount( f1[i] & f2[i] ); + } + return result; + } + + public static double hammingDistance(float[] f1, float[] f2) { + assert f1.length == f2.length; + int distance = 0; + for ( int i = 0; i < f1.length; i++ ) { + if ( !( f1[i] == f2[i] ) ) { + distance++; + } + } + return distance; + } + + public static double hammingDistance(double[] f1, double[] f2) { + assert f1.length == f2.length; + int distance = 0; + for (int i = 0; i < f1.length; i++) { + if (!(f1[i] == f2[i])) { + distance++; + } + } + return distance; + } + + public static double hammingDistance(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + int distance = 0; + for (int i = 0; i < f1.length; i++) { + if (!(f1[i] == f2[i])) { + distance++; + } + } + return distance; + } + + public static double hammingDistanceBinary(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + int distance = 0; + for (int i = 0; i < f1.length; i++) { + distance += Integer.bitCount( f1[i] ^ f2[i] ); + } + return distance; + } + + public static double euclideanNorm(float[] f) { + double result = 0; + for ( float v : f ) { + result += Math.pow( v, 2 ); + } + return Math.sqrt( result ); + } + + public static double euclideanNorm(double[] f) { + double result = 0; + for ( double v : f ) { + result += Math.pow( v, 2 ); + } + return Math.sqrt( result ); + } + + public static double euclideanNorm(byte[] f) { + double result = 0; + for ( byte v : f ) { + result += Math.pow( v, 2 ); + } + return Math.sqrt( result ); + } + + public static float[] euclideanNormalize(float[] f) { + final double norm = euclideanNorm( f ); + final float[] result = new float[f.length]; + for ( int i = 0; i < f.length; i++ ) { + result[i] = (float) (f[i] / norm); + } + return result; + } + + public static float[] euclideanNormalize(double[] f) { + final double norm = euclideanNorm( f ); + final float[] result = new float[f.length]; + for ( int i = 0; i < f.length; i++ ) { + result[i] = (float) (f[i] / norm); + } + return result; + } + + public static float[] euclideanNormalize(byte[] f) { + final double norm = euclideanNorm( f ); + final float[] result = new float[f.length]; + for ( int i = 0; i < f.length; i++ ) { + result[i] = (float) (f[i] / norm); + } + return result; + } + + public static double euclideanNormBinary(byte[] f) { + double result = 0; + for ( byte v : f ) { + result += Integer.bitCount( v ); + } + return Math.sqrt( result ); + } + + public static double jaccardDistanceBinary(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + int intersectionSum = 0; + int unionSum = 0; + for (int i = 0; i < f1.length; i++) { + intersectionSum += Integer.bitCount( f1[i] & f2[i] ); + unionSum += Integer.bitCount( f1[i] | f2[i] ); + } + return 1d - (double) intersectionSum / unionSum; + } +}