diff --git a/doc/snippets/Microsoft.Data.SqlClient/SqlCommand.xml b/doc/snippets/Microsoft.Data.SqlClient/SqlCommand.xml index 5e28842000..78b6fefe5e 100644 --- a/doc/snippets/Microsoft.Data.SqlClient/SqlCommand.xml +++ b/doc/snippets/Microsoft.Data.SqlClient/SqlCommand.xml @@ -3604,7 +3604,9 @@ Before you call , specify t If you call an `Execute` method after calling , any parameter value that is larger than the value specified by the property is automatically truncated to the original specified size of the parameter, and no truncation errors are returned. -Output parameters (whether prepared or not) must have a user-specified data type. If you specify a variable length data type, you must also specify the maximum . +Output parameters (whether prepared or not) must have a user-specified data type. If you specify a variable length data type except vector, you must also specify the maximum . + +For vector data types, the property is ignored. The size of the vector is inferred from the of type . Prior to Visual Studio 2010, threw an exception. Beginning in Visual Studio 2010, this method does not throw an exception. diff --git a/doc/snippets/Microsoft.Data.SqlTypes/SqlVector.xml b/doc/snippets/Microsoft.Data.SqlTypes/SqlVector.xml index 0157a05787..7dea042145 100644 --- a/doc/snippets/Microsoft.Data.SqlTypes/SqlVector.xml +++ b/doc/snippets/Microsoft.Data.SqlTypes/SqlVector.xml @@ -5,20 +5,11 @@ Represents a vector value in SQL Server. - - - Constructs a null vector of the given length. SQL Server requires vector arguments to specify their length even when null. - - - Vector length must be non-negative. - - - Constructs a vector with the given values. - + @@ -37,13 +28,17 @@ Returns the number of elements in the vector. - - - Returns the number of bytes required to represent this vector when communicating with SQL Server. - - Returns the vector values as a memory region. No copies are made. + + + + Constructs a null vector of the given length. SQL Server requires vector arguments to specify their length even when null. + + + Vector length must be non-negative. + + diff --git a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs index 55b68fd9b3..3e26aa251a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs @@ -123,12 +123,10 @@ public SqlJson(System.Text.Json.JsonDocument jsonDoc) { } } /// - public sealed class SqlVector : System.Data.SqlTypes.INullable + public readonly struct SqlVector : System.Data.SqlTypes.INullable where T : unmanaged { /// - public SqlVector(int length) { } - /// public SqlVector(System.ReadOnlyMemory memory) { } /// public bool IsNull => throw null; @@ -136,10 +134,10 @@ public SqlVector(System.ReadOnlyMemory memory) { } public static SqlVector Null => throw null; /// public int Length { get { throw null; } } - /// - public int Size { get { throw null; } } /// public System.ReadOnlyMemory Memory { get { throw null; } } + /// + public static SqlVector CreateNull(int length) { throw null; } } } namespace Microsoft.Data.SqlClient diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs index d7f280ca33..a3bb3b46c2 100644 --- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs @@ -2417,12 +2417,10 @@ public SqlJson(System.Text.Json.JsonDocument jsonDoc) { } } /// - public sealed class SqlVector : System.Data.SqlTypes.INullable + public readonly struct SqlVector : System.Data.SqlTypes.INullable where T : unmanaged { /// - public SqlVector(int length) { } - /// public SqlVector(System.ReadOnlyMemory memory) { } /// public bool IsNull => throw null; @@ -2430,9 +2428,9 @@ public SqlVector(System.ReadOnlyMemory memory) { } public static SqlVector Null => throw null; /// public int Length { get { throw null; } } - /// - public int Size { get { throw null; } } /// public System.ReadOnlyMemory Memory { get { throw null; } } + /// + public static SqlVector CreateNull(int length) { throw null; } } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBuffer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBuffer.cs index 39d2758d62..bf27581b51 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBuffer.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBuffer.cs @@ -993,7 +993,7 @@ internal SqlVector GetSqlVector() where T : unmanaged { if (IsNull) { - return new SqlVector(_value._vectorInfo._elementCount); + return SqlVector.CreateNull(_value._vectorInfo._elementCount); } return new SqlVector(SqlBinary.Value); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDataReader.cs index d1f3f5c1a5..926b714462 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -3365,10 +3365,7 @@ private T GetFieldValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData met { if (typeof(T) == typeof(string) && metaData.metaType.SqlDbType == SqlDbTypeExtensions.Vector) { - if (data.IsNull) - return (T)(object)data.String; - else - return (T)(object)data.GetSqlVector().GetString(); + return (T)(object)data.String; } // the requested type is likely to be one that isn't supported so try the cast and // unless there is a null value conversion then feedback the cast exception with diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlParameter.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlParameter.cs index f9b940fec3..fb9cdbb8e0 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlParameter.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlParameter.cs @@ -773,7 +773,7 @@ private object GetVectorReturnValue() switch (elementType) { case MetaType.SqlVectorElementType.Float32: - return new SqlVector(elementCount); + return SqlVector.CreateNull(elementCount); default: throw SQL.VectorTypeNotSupported(elementType.ToString()); } @@ -857,8 +857,13 @@ public override int Size { throw ADP.InvalidSizeValue(value); } - PropertyChanging(); - _size = value; + + // We ignore the Size property for Vector types, as it is not applicable. + if (_metaType == null || _metaType.SqlDbType != SqlDbTypeExtensions.Vector) + { + PropertyChanging(); + _size = value; + } } } } @@ -1970,7 +1975,8 @@ internal void Prepare(SqlCommand cmd) { throw ADP.PrepareParameterType(cmd); } - else if (!ShouldSerializeSize() && !_metaType.IsFixed) + // For vector datatype we do not require size to be specified. It is inferred from the SqlParameter.Value. + else if (!ShouldSerializeSize() && !_metaType.IsFixed && _metaType.SqlDbType != SqlDbTypeExtensions.Vector) { throw ADP.PrepareParameterSize(cmd); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlVector.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlVector.cs index e63a5b7462..cf04ff8636 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlVector.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlVector.cs @@ -16,7 +16,7 @@ namespace Microsoft.Data.SqlTypes; /// -public sealed class SqlVector : INullable, ISqlVector +public readonly struct SqlVector : INullable, ISqlVector where T : unmanaged { #region Constants @@ -31,13 +31,13 @@ public sealed class SqlVector : INullable, ISqlVector private readonly byte _elementType; private readonly byte _elementSize; private readonly byte[] _tdsBytes; - + private readonly int _size; + #endregion #region Constructors - /// - public SqlVector(int length) + private SqlVector(int length) { if (length < 0) { @@ -49,13 +49,16 @@ public SqlVector(int length) IsNull = true; Length = length; - Size = TdsEnums.VECTOR_HEADER_SIZE + (_elementSize * Length); + _size = TdsEnums.VECTOR_HEADER_SIZE + (_elementSize * Length); _tdsBytes = Array.Empty(); Memory = new(); } - /// + /// + public static SqlVector CreateNull(int length) => new(length); + + /// public SqlVector(ReadOnlyMemory memory) { (_elementType, _elementSize) = GetTypeFieldsOrThrow(); @@ -63,7 +66,7 @@ public SqlVector(ReadOnlyMemory memory) IsNull = false; Length = memory.Length; - Size = TdsEnums.VECTOR_HEADER_SIZE + (_elementSize * Length); + _size = TdsEnums.VECTOR_HEADER_SIZE + (_elementSize * Length); _tdsBytes = MakeTdsBytes(memory); Memory = memory; @@ -73,7 +76,7 @@ internal SqlVector(byte[] tdsBytes) { (_elementType, _elementSize) = GetTypeFieldsOrThrow(); - (Length, Size) = GetCountsOrThrow(tdsBytes); + (Length, _size) = GetCountsOrThrow(tdsBytes); IsNull = false; @@ -99,18 +102,16 @@ internal string GetString() #region Properties /// - public bool IsNull { get; init; } + public bool IsNull { get; } /// public static SqlVector? Null => null; /// - public int Length { get; init; } - /// - public int Size { get; init; } - + public int Length { get; } + /// - public ReadOnlyMemory Memory { get; init; } + public ReadOnlyMemory Memory { get; } #endregion @@ -118,6 +119,8 @@ internal string GetString() byte ISqlVector.ElementType => _elementType; byte ISqlVector.ElementSize => _elementSize; byte[] ISqlVector.VectorPayload => _tdsBytes; + int ISqlVector.Size => _size; + #endregion #region Helpers @@ -154,7 +157,7 @@ private byte[] MakeTdsBytes(ReadOnlyMemory values) // | Stream of Values | NN * sizeof(T) | [element bytes...] | Raw bytes for vector elements | // +------------------------+-----------------+----------------------+--------------------------------------------------------------+ - byte[] result = new byte[Size]; + byte[] result = new byte[_size]; // Header Bytes result[0] = VecHeaderMagicNo; diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs index 8d205cfc9c..d2d8c52716 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs @@ -18,37 +18,38 @@ public static class VectorFloat32TestData { public const int VectorHeaderSize = 8; public static float[] testData = new float[] { 1.1f, 2.2f, 3.3f }; - public static int sizeInbytes = VectorHeaderSize + testData.Length * sizeof(float); public static int vectorColumnLength = testData.Length; + // Incorrect size for SqlParameter.Size + public static int IncorrectParamSize = 3234; public static IEnumerable GetVectorFloat32TestData() { // Pattern 1-4 with SqlVector(values: testData) - yield return new object[] { 1, new SqlVector(testData), testData, sizeInbytes, vectorColumnLength }; - yield return new object[] { 2, new SqlVector(testData), testData, sizeInbytes, vectorColumnLength }; - yield return new object[] { 3, new SqlVector(testData), testData, sizeInbytes, vectorColumnLength }; - yield return new object[] { 4, new SqlVector(testData), testData, sizeInbytes, vectorColumnLength }; + yield return new object[] { 1, new SqlVector(testData), testData, vectorColumnLength }; + yield return new object[] { 2, new SqlVector(testData), testData, vectorColumnLength }; + yield return new object[] { 3, new SqlVector(testData), testData, vectorColumnLength }; + yield return new object[] { 4, new SqlVector(testData), testData, vectorColumnLength }; // Pattern 1-4 with SqlVector(n) - yield return new object[] { 1, new SqlVector(vectorColumnLength), Array.Empty(), sizeInbytes, vectorColumnLength }; - yield return new object[] { 2, new SqlVector(vectorColumnLength), Array.Empty(), sizeInbytes, vectorColumnLength }; - yield return new object[] { 3, new SqlVector(vectorColumnLength), Array.Empty(), sizeInbytes, vectorColumnLength }; - yield return new object[] { 4, new SqlVector(vectorColumnLength), Array.Empty(), sizeInbytes, vectorColumnLength }; + yield return new object[] { 1, SqlVector.CreateNull(vectorColumnLength), Array.Empty(), vectorColumnLength }; + yield return new object[] { 2, SqlVector.CreateNull(vectorColumnLength), Array.Empty(), vectorColumnLength }; + yield return new object[] { 3, SqlVector.CreateNull(vectorColumnLength), Array.Empty(), vectorColumnLength }; + yield return new object[] { 4, SqlVector.CreateNull(vectorColumnLength), Array.Empty(), vectorColumnLength }; // Pattern 1-4 with DBNull - yield return new object[] { 1, DBNull.Value, Array.Empty(), sizeInbytes, vectorColumnLength }; - yield return new object[] { 2, DBNull.Value, Array.Empty(), sizeInbytes, vectorColumnLength }; - yield return new object[] { 3, DBNull.Value, Array.Empty(), sizeInbytes, vectorColumnLength }; - yield return new object[] { 4, DBNull.Value, Array.Empty(), sizeInbytes, vectorColumnLength }; + yield return new object[] { 1, DBNull.Value, Array.Empty(), vectorColumnLength }; + yield return new object[] { 2, DBNull.Value, Array.Empty(), vectorColumnLength }; + yield return new object[] { 3, DBNull.Value, Array.Empty(), vectorColumnLength }; + yield return new object[] { 4, DBNull.Value, Array.Empty(), vectorColumnLength }; // Pattern 1-4 with SqlVector.Null - yield return new object[] { 1, SqlVector.Null, Array.Empty(), sizeInbytes, vectorColumnLength }; + yield return new object[] { 1, SqlVector.Null, Array.Empty(), vectorColumnLength }; // Following scenario is not supported in SqlClient. // This can only be fixed with a behavior change that SqlParameter.Value is internally set to DBNull.Value if it is set to null. - //yield return new object[] { 2, SqlVector.Null, Array.Empty(), sizeInbytes, vectorColumnLength }; + //yield return new object[] { 2, SqlVector.Null, Array.Empty(), vectorColumnLength }; - yield return new object[] { 3, SqlVector.Null, Array.Empty(), sizeInbytes, vectorColumnLength }; - yield return new object[] { 4, SqlVector.Null, Array.Empty(), sizeInbytes, vectorColumnLength }; + yield return new object[] { 3, SqlVector.Null, Array.Empty(), vectorColumnLength }; + yield return new object[] { 4, SqlVector.Null, Array.Empty(), vectorColumnLength }; } } @@ -101,10 +102,9 @@ public void Dispose() DataTestUtility.DropStoredProcedure(connection, s_storedProcName); } - private void ValidateSqlVectorFloat32Object(bool isNull, SqlVector sqlVectorFloat32, float[] expectedData, int expectedSize, int expectedLength) + private void ValidateSqlVectorFloat32Object(bool isNull, SqlVector sqlVectorFloat32, float[] expectedData, int expectedLength) { Assert.Equal(expectedData, sqlVectorFloat32.Memory.ToArray()); - Assert.Equal(expectedSize, sqlVectorFloat32.Size); Assert.Equal(expectedLength, sqlVectorFloat32.Length); if (!isNull) { @@ -116,22 +116,22 @@ private void ValidateSqlVectorFloat32Object(bool isNull, SqlVector sqlVec } } - private void ValidateInsertedData(SqlConnection connection, float[] expectedData, int expectedSize, int expectedLength) + private void ValidateInsertedData(SqlConnection connection, float[] expectedData, int expectedLength) { using var selectCmd = new SqlCommand(s_selectCmdString, connection); using var reader = selectCmd.ExecuteReader(); Assert.True(reader.Read(), "No data found in the table."); //For both null and non-null cases, validate the SqlVector object - ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader.GetSqlVector(0), expectedData, expectedSize, expectedLength); - ValidateSqlVectorFloat32Object(reader.IsDBNull(0), reader.GetFieldValue>(0), expectedData, expectedSize, expectedLength); - ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader.GetSqlValue(0), expectedData, expectedSize, expectedLength); + ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader.GetSqlVector(0), expectedData, expectedLength); + ValidateSqlVectorFloat32Object(reader.IsDBNull(0), reader.GetFieldValue>(0), expectedData, expectedLength); + ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader.GetSqlValue(0), expectedData, expectedLength); if (!reader.IsDBNull(0)) { - ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader.GetValue(0), expectedData, expectedSize, expectedLength); - ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader[0], expectedData, expectedSize, expectedLength); - ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader["VectorData"], expectedData, expectedSize, expectedLength); + ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader.GetValue(0), expectedData, expectedLength); + ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader[0], expectedData, expectedLength); + ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader["VectorData"], expectedData, expectedLength); Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetString(0))); Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetSqlString(0).Value)); Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetFieldValue(0))); @@ -153,7 +153,6 @@ public void TestSqlVectorFloat32ParameterInsertionAndReads( int pattern, object value, float[] expectedValues, - int expectedSize, int expectedLength) { using var conn = new SqlConnection(s_connectionString); @@ -171,7 +170,8 @@ public void TestSqlVectorFloat32ParameterInsertionAndReads( }, 2 => new SqlParameter(s_vectorParamName, value), 3 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector) { Value = value }, - 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, new SqlVector(3).Size) { Value = value }, + // Even if size is specified, the actual size is determined by the value passed and specified size is ignored. + 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, VectorFloat32TestData.IncorrectParamSize) { Value = value }, _ => throw new ArgumentOutOfRangeException(nameof(pattern), $"Unsupported pattern: {pattern}") }; @@ -179,25 +179,25 @@ public void TestSqlVectorFloat32ParameterInsertionAndReads( Assert.Equal(1, insertCmd.ExecuteNonQuery()); insertCmd.Parameters.Clear(); - ValidateInsertedData(conn, expectedValues, expectedSize, expectedLength); + ValidateInsertedData(conn, expectedValues, expectedLength); } - private async Task ValidateInsertedDataAsync(SqlConnection connection, float[] expectedData, int expectedSize, int expectedLength) + private async Task ValidateInsertedDataAsync(SqlConnection connection, float[] expectedData, int expectedLength) { using var selectCmd = new SqlCommand(s_selectCmdString, connection); using var reader = await selectCmd.ExecuteReaderAsync(); Assert.True(await reader.ReadAsync(), "No data found in the table."); //For both null and non-null cases, validate the SqlVector object - ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader.GetSqlVector(0), expectedData, expectedSize, expectedLength); - ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), await reader.GetFieldValueAsync>(0), expectedData, expectedSize, expectedLength); - ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader.GetSqlValue(0), expectedData, expectedSize, expectedLength); + ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader.GetSqlVector(0), expectedData, expectedLength); + ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), await reader.GetFieldValueAsync>(0), expectedData, expectedLength); + ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader.GetSqlValue(0), expectedData, expectedLength); if (!await reader.IsDBNullAsync(0)) { - ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader.GetValue(0), expectedData, expectedSize, expectedLength); - ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader[0], expectedData, expectedSize, expectedLength); - ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader["VectorData"], expectedData, expectedSize, expectedLength); + ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader.GetValue(0), expectedData, expectedLength); + ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader[0], expectedData, expectedLength); + ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader["VectorData"], expectedData, expectedLength); Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetString(0))); Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetSqlString(0).Value)); Assert.Equal(expectedData, JsonSerializer.Deserialize(await reader.GetFieldValueAsync(0))); @@ -219,7 +219,6 @@ public async Task TestSqlVectorFloat32ParameterInsertionAndReadsAsync( int pattern, object value, float[] expectedValues, - int expectedSize, int expectedLength) { using var conn = new SqlConnection(s_connectionString); @@ -237,7 +236,7 @@ public async Task TestSqlVectorFloat32ParameterInsertionAndReadsAsync( }, 2 => new SqlParameter(s_vectorParamName, value), 3 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector) { Value = value }, - 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, new SqlVector(3).Size) { Value = value }, + 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, VectorFloat32TestData.IncorrectParamSize) { Value = value }, _ => throw new ArgumentOutOfRangeException(nameof(pattern), $"Unsupported pattern: {pattern}") }; @@ -245,7 +244,7 @@ public async Task TestSqlVectorFloat32ParameterInsertionAndReadsAsync( Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); insertCmd.Parameters.Clear(); - await ValidateInsertedDataAsync(conn, expectedValues, expectedSize, expectedLength); + await ValidateInsertedDataAsync(conn, expectedValues, expectedLength); } [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] @@ -254,7 +253,6 @@ public void TestStoredProcParamsForVectorFloat32( int pattern, object value, float[] expectedValues, - int expectedSize, int expectedLength) { //Create SP for test @@ -277,7 +275,7 @@ public void TestStoredProcParamsForVectorFloat32( }, 2 => new SqlParameter(s_vectorParamName, value), 3 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector) { Value = value }, - 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, new SqlVector(3).Size) { Value = value }, + 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, VectorFloat32TestData.IncorrectParamSize) { Value = value }, _ => throw new ArgumentOutOfRangeException(nameof(pattern), $"Unsupported pattern: {pattern}") }; command.Parameters.Add(inputParam); @@ -287,7 +285,7 @@ public void TestStoredProcParamsForVectorFloat32( ParameterName = s_outputVectorParamName, SqlDbType = SqlDbTypeExtensions.Vector, Direction = ParameterDirection.Output, - Value = new SqlVector(3) + Value = SqlVector.CreateNull(VectorFloat32TestData.vectorColumnLength) }; command.Parameters.Add(outputParam); @@ -295,13 +293,13 @@ public void TestStoredProcParamsForVectorFloat32( command.ExecuteNonQuery(); // Validate the output parameter - var vector = outputParam.Value as SqlVector; - ValidateSqlVectorFloat32Object(vector.IsNull, vector, expectedValues, expectedSize, expectedLength); + var vector = (SqlVector)outputParam.Value; + ValidateSqlVectorFloat32Object(vector.IsNull, vector, expectedValues, expectedLength); // Validate error for conventional way of setting output parameters command.Parameters.Clear(); command.Parameters.Add(inputParam); - var outputParamWithoutVal = new SqlParameter(s_outputVectorParamName, SqlDbTypeExtensions.Vector, new SqlVector(3).Size) { Direction = ParameterDirection.Output }; + var outputParamWithoutVal = new SqlParameter(s_outputVectorParamName, SqlDbTypeExtensions.Vector, VectorFloat32TestData.IncorrectParamSize) { Direction = ParameterDirection.Output }; command.Parameters.Add(outputParamWithoutVal); Assert.Throws(() => command.ExecuteNonQuery()); } @@ -312,7 +310,6 @@ public async Task TestStoredProcParamsForVectorFloat32Async( int pattern, object value, float[] expectedValues, - int expectedSize, int expectedLength) { //Create SP for test @@ -335,7 +332,7 @@ public async Task TestStoredProcParamsForVectorFloat32Async( }, 2 => new SqlParameter(s_vectorParamName, value), 3 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector) { Value = value }, - 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, new SqlVector(3).Size) { Value = value }, + 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, VectorFloat32TestData.IncorrectParamSize) { Value = value }, _ => throw new ArgumentOutOfRangeException(nameof(pattern), $"Unsupported pattern: {pattern}") }; command.Parameters.Add(inputParam); @@ -345,7 +342,7 @@ public async Task TestStoredProcParamsForVectorFloat32Async( ParameterName = s_outputVectorParamName, SqlDbType = SqlDbTypeExtensions.Vector, Direction = ParameterDirection.Output, - Value = new SqlVector(3) + Value = SqlVector.CreateNull(VectorFloat32TestData.vectorColumnLength) }; command.Parameters.Add(outputParam); @@ -353,13 +350,13 @@ public async Task TestStoredProcParamsForVectorFloat32Async( await command.ExecuteNonQueryAsync(); // Validate the output parameter - var vector = outputParam.Value as SqlVector; - ValidateSqlVectorFloat32Object(vector.IsNull, vector, expectedValues, expectedSize, expectedLength); + var vector = (SqlVector)outputParam.Value; + ValidateSqlVectorFloat32Object(vector.IsNull, vector, expectedValues, expectedLength); // Validate error for conventional way of setting output parameters command.Parameters.Clear(); command.Parameters.Add(inputParam); - var outputParamWithoutVal = new SqlParameter(s_outputVectorParamName, SqlDbTypeExtensions.Vector, new SqlVector(3).Size) { Direction = ParameterDirection.Output }; + var outputParamWithoutVal = new SqlParameter(s_outputVectorParamName, SqlDbTypeExtensions.Vector, VectorFloat32TestData.IncorrectParamSize) { Direction = ParameterDirection.Output }; command.Parameters.Add(outputParamWithoutVal); await Assert.ThrowsAsync(async () => await command.ExecuteNonQueryAsync()); } @@ -453,7 +450,6 @@ public void TestBulkCopyFromSqlTable(int bulkCopySourceMode) Assert.True(!verifyReader.IsDBNull(0), "First row in the table is null."); Assert.Equal(VectorFloat32TestData.testData, ((SqlVector)verifyReader.GetSqlVector(0)).Memory.ToArray()); Assert.Equal(VectorFloat32TestData.testData.Length, ((SqlVector)verifyReader.GetSqlVector(0)).Length); - Assert.Equal(VectorFloat32TestData.sizeInbytes, ((SqlVector)verifyReader.GetSqlVector(0)).Size); // Verify that we have another row Assert.True(verifyReader.Read(), "Second row not found in the table"); @@ -462,7 +458,6 @@ public void TestBulkCopyFromSqlTable(int bulkCopySourceMode) Assert.True(verifyReader.IsDBNull(0)); Assert.Equal(Array.Empty(), ((SqlVector)verifyReader.GetSqlVector(0)).Memory.ToArray()); Assert.Equal(VectorFloat32TestData.testData.Length, ((SqlVector)verifyReader.GetSqlVector(0)).Length); - Assert.Equal(VectorFloat32TestData.sizeInbytes, ((SqlVector)verifyReader.GetSqlVector(0)).Size); } [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] @@ -554,7 +549,6 @@ public async Task TestBulkCopyFromSqlTableAsync(int bulkCopySourceMode) var vector = await verifyReader.GetFieldValueAsync>(0); Assert.Equal(VectorFloat32TestData.testData, vector.Memory.ToArray()); Assert.Equal(VectorFloat32TestData.testData.Length, vector.Length); - Assert.Equal(VectorFloat32TestData.sizeInbytes, vector.Size); // Verify that we have another row Assert.True(await verifyReader.ReadAsync(), "Second row not found in the table"); @@ -564,7 +558,6 @@ public async Task TestBulkCopyFromSqlTableAsync(int bulkCopySourceMode) vector = await verifyReader.GetFieldValueAsync>(0); Assert.Equal(Array.Empty(), vector.Memory.ToArray()); Assert.Equal(VectorFloat32TestData.testData.Length, vector.Length); - Assert.Equal(VectorFloat32TestData.sizeInbytes, vector.Size); } [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] @@ -573,7 +566,7 @@ public void TestInsertVectorsFloat32WithPrepare() SqlConnection conn = new SqlConnection(s_connectionString); conn.Open(); SqlCommand command = new SqlCommand(s_insertCmdString, conn); - SqlParameter vectorParam = new SqlParameter("@VectorData", SqlDbTypeExtensions.Vector, new SqlVector(3).Size); + SqlParameter vectorParam = new SqlParameter("@VectorData", SqlDbTypeExtensions.Vector); command.Parameters.Add(vectorParam); command.Prepare(); for (int i = 0; i < 10; i++) diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SqlVectorTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SqlVectorTest.cs index 3390d95c02..c3d9869201 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/SqlVectorTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SqlVectorTest.cs @@ -18,24 +18,23 @@ public class SqlVectorTest [Fact] public void UnsupportedType() { - Assert.Throws(() => new SqlVector(5)); - Assert.Throws(() => new SqlVector(5)); - Assert.Throws(() => new SqlVector(5)); + Assert.Throws(() => SqlVector.CreateNull(5)); + Assert.Throws(() => SqlVector.CreateNull(5)); + Assert.Throws(() => SqlVector.CreateNull(5)); } [Fact] public void Construct_Length_Negative() { - Assert.Throws(() => new SqlVector(-1)); + Assert.Throws(() => SqlVector.CreateNull(-1)); } [Fact] public void Construct_Length() { - var vec = new SqlVector(5); + var vec = SqlVector.CreateNull(5); Assert.True(vec.IsNull); Assert.Equal(5, vec.Length); - Assert.Equal(28, vec.Size); // Note that ReadOnlyMemory<> equality checks that both instances point // to the same memory. We want to check memory content equality, so we // compare their arrays instead. @@ -45,16 +44,17 @@ public void Construct_Length() var ivec = vec as ISqlVector; Assert.Equal(0x00, ivec.ElementType); Assert.Equal(0x04, ivec.ElementSize); + Assert.Equal(28, ivec.Size); Assert.Empty(ivec.VectorPayload); } [Fact] public void Construct_WithLengthZero() { - var vec = new SqlVector(0); + var vec = SqlVector.CreateNull(0); Assert.True(vec.IsNull); Assert.Equal(0, vec.Length); - Assert.Equal(8, vec.Size); + // Note that ReadOnlyMemory<> equality checks that both instances point // to the same memory. We want to check memory content equality, so we // compare their arrays instead. @@ -64,6 +64,7 @@ public void Construct_WithLengthZero() var ivec = vec as ISqlVector; Assert.Equal(0x00, ivec.ElementType); Assert.Equal(0x04, ivec.ElementSize); + Assert.Equal(8, ivec.Size); Assert.Empty(ivec.VectorPayload); } @@ -73,13 +74,13 @@ public void Construct_Memory_Empty() SqlVector vec = new(new ReadOnlyMemory()); Assert.False(vec.IsNull); Assert.Equal(0, vec.Length); - Assert.Equal(8, vec.Size); Assert.Equal(new ReadOnlyMemory().ToArray(), vec.Memory.ToArray()); Assert.Equal("[]", vec.GetString()); var ivec = vec as ISqlVector; Assert.Equal(0x00, ivec.ElementType); Assert.Equal(0x04, ivec.ElementSize); + Assert.Equal(8, ivec.Size); Assert.Equal( new byte[] { 0xA9, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, ivec.VectorPayload); @@ -93,7 +94,6 @@ public void Construct_Memory() SqlVector vec = new(memory); Assert.False(vec.IsNull); Assert.Equal(2, vec.Length); - Assert.Equal(16, vec.Size); Assert.Equal(memory.ToArray(), vec.Memory.ToArray()); Assert.Equal(data, vec.Memory.ToArray()); #if NETFRAMEWORK @@ -104,6 +104,7 @@ public void Construct_Memory() var ivec = vec as ISqlVector; Assert.Equal(0x00, ivec.ElementType); Assert.Equal(0x04, ivec.ElementSize); + Assert.Equal(16, ivec.Size); Assert.Equal( MakeTdsPayload( new byte[] { 0xA9, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00 }, @@ -118,7 +119,6 @@ public void Construct_Memory_ImplicitConversionFromFloatArray() var vec = new SqlVector(data); Assert.False(vec.IsNull); Assert.Equal(3, vec.Length); - Assert.Equal(20, vec.Size); Assert.Equal(new ReadOnlyMemory(data).ToArray(), vec.Memory.ToArray()); Assert.Equal(data, vec.Memory.ToArray()); #if NETFRAMEWORK @@ -130,6 +130,7 @@ public void Construct_Memory_ImplicitConversionFromFloatArray() var ivec = vec as ISqlVector; Assert.Equal(0x00, ivec.ElementType); Assert.Equal(0x04, ivec.ElementSize); + Assert.Equal(20, ivec.Size); Assert.Equal( MakeTdsPayload( new byte[] { 0xA9, 0x01, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00 }, @@ -149,7 +150,6 @@ public void Construct_Bytes() var vec = new SqlVector(bytes); Assert.False(vec.IsNull); Assert.Equal(2, vec.Length); - Assert.Equal(16, vec.Size); Assert.Equal(new ReadOnlyMemory(data).ToArray(), vec.Memory.ToArray()); Assert.Equal(data, vec.Memory.ToArray()); #if NETFRAMEWORK @@ -161,6 +161,7 @@ public void Construct_Bytes() var ivec = vec as ISqlVector; Assert.Equal(0x00, ivec.ElementType); Assert.Equal(0x04, ivec.ElementSize); + Assert.Equal(16, ivec.Size); Assert.Equal(bytes, ivec.VectorPayload); }