diff --git a/.evergreen/run-atlas-search-tests.sh b/.evergreen/run-atlas-search-tests.sh index 7669c87ae5d..36cc981b3f4 100755 --- a/.evergreen/run-atlas-search-tests.sh +++ b/.evergreen/run-atlas-search-tests.sh @@ -16,4 +16,4 @@ echo "Running Atlas Search tests" ./gradlew --stacktrace --info \ -Dorg.mongodb.test.atlas.search=true \ -Dorg.mongodb.test.uri=${MONGODB_URI} \ - driver-core:test --tests AggregatesSearchIntegrationTest + driver-core:test --tests AggregatesSearchIntegrationTest --tests AggregatesVectorSearchIntegrationTest diff --git a/driver-core/src/main/com/mongodb/client/model/Aggregates.java b/driver-core/src/main/com/mongodb/client/model/Aggregates.java index 4bb3a03771c..7d6306cdd23 100644 --- a/driver-core/src/main/com/mongodb/client/model/Aggregates.java +++ b/driver-core/src/main/com/mongodb/client/model/Aggregates.java @@ -37,6 +37,7 @@ import org.bson.BsonType; import org.bson.BsonValue; import org.bson.Document; +import org.bson.Vector; import org.bson.codecs.configuration.CodecRegistry; import org.bson.conversions.Bson; @@ -963,28 +964,37 @@ public static Bson vectorSearch( notNull("queryVector", queryVector); notNull("index", index); notNull("options", options); - return new Bson() { - @Override - public <TDocument> BsonDocument toBsonDocument(final Class<TDocument> documentClass, final CodecRegistry codecRegistry) { - Document specificationDoc = new Document("path", path.toValue()) - .append("queryVector", queryVector) - .append("index", index) - .append("limit", limit); - specificationDoc.putAll(options.toBsonDocument(documentClass, codecRegistry)); - return new Document("$vectorSearch", specificationDoc).toBsonDocument(documentClass, codecRegistry); - } + return new VectorSearchBson(path, queryVector, index, limit, options); + } - @Override - public String toString() { - return "Stage{name=$vectorSearch" - + ", path=" + path - + ", queryVector=" + queryVector - + ", index=" + index - + ", limit=" + limit - + ", options=" + options - + '}'; - } - }; + /** + * Creates a {@code $vectorSearch} pipeline stage supported by MongoDB Atlas. + * You may use the {@code $meta: "vectorSearchScore"} expression, e.g., via {@link Projections#metaVectorSearchScore(String)}, + * to extract the relevance score assigned to each found document. + * + * @param queryVector The {@linkplain Vector query vector}. The number of dimensions must match that of the {@code index}. + * @param path The field to be searched. + * @param index The name of the index to use. + * @param limit The limit on the number of documents produced by the pipeline stage. + * @param options Optional {@code $vectorSearch} pipeline stage fields. + * @return The {@code $vectorSearch} pipeline stage. + * @mongodb.atlas.manual atlas-vector-search/vector-search-stage/ $vectorSearch + * @mongodb.atlas.manual atlas-search/scoring/ Scoring + * @mongodb.server.release 6.0 + * @see Vector + * @since 5.3 + */ + public static Bson vectorSearch( + final FieldSearchPath path, + final Vector queryVector, + final String index, + final long limit, + final VectorSearchOptions options) { + notNull("path", path); + notNull("queryVector", queryVector); + notNull("index", index); + notNull("options", options); + return new VectorSearchBson(path, queryVector, index, limit, options); } /** @@ -2145,6 +2155,45 @@ public String toString() { } } + private static class VectorSearchBson implements Bson { + private final FieldSearchPath path; + private final Object queryVector; + private final String index; + private final long limit; + private final VectorSearchOptions options; + + VectorSearchBson(final FieldSearchPath path, final Object queryVector, + final String index, final long limit, + final VectorSearchOptions options) { + this.path = path; + this.queryVector = queryVector; + this.index = index; + this.limit = limit; + this.options = options; + } + + @Override + public <TDocument> BsonDocument toBsonDocument(final Class<TDocument> documentClass, final CodecRegistry codecRegistry) { + Document specificationDoc = new Document("path", path.toValue()) + .append("queryVector", queryVector) + .append("index", index) + .append("limit", limit); + specificationDoc.putAll(options.toBsonDocument(documentClass, codecRegistry)); + return new Document("$vectorSearch", specificationDoc).toBsonDocument(documentClass, codecRegistry); + } + + @Override + public String toString() { + return "Stage{name=$vectorSearch" + + ", path=" + path + + ", queryVector=" + queryVector + + ", index=" + index + + ", limit=" + limit + + ", options=" + options + + '}'; + } + } + private Aggregates() { } } diff --git a/driver-core/src/main/com/mongodb/internal/operation/CreateSearchIndexesOperation.java b/driver-core/src/main/com/mongodb/internal/operation/CreateSearchIndexesOperation.java index 2e52e3fa0ae..a57087e9217 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/CreateSearchIndexesOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/CreateSearchIndexesOperation.java @@ -32,11 +32,11 @@ * * <p>This class is not part of the public API and may be removed or changed at any time</p> */ -final class CreateSearchIndexesOperation extends AbstractWriteSearchIndexOperation { +public final class CreateSearchIndexesOperation extends AbstractWriteSearchIndexOperation { private static final String COMMAND_NAME = "createSearchIndexes"; private final List<SearchIndexRequest> indexRequests; - CreateSearchIndexesOperation(final MongoNamespace namespace, final List<SearchIndexRequest> indexRequests) { + public CreateSearchIndexesOperation(final MongoNamespace namespace, final List<SearchIndexRequest> indexRequests) { super(namespace); this.indexRequests = assertNotNull(indexRequests); } diff --git a/driver-core/src/main/com/mongodb/internal/operation/ListSearchIndexesOperation.java b/driver-core/src/main/com/mongodb/internal/operation/ListSearchIndexesOperation.java index 0f9a81dbf19..3dfde30511d 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/ListSearchIndexesOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/ListSearchIndexesOperation.java @@ -42,7 +42,7 @@ * * <p>This class is not part of the public API and may be removed or changed at any time</p> */ -final class ListSearchIndexesOperation<T> +public final class ListSearchIndexesOperation<T> implements AsyncExplainableReadOperation<AsyncBatchCursor<T>>, ExplainableReadOperation<BatchCursor<T>> { private static final String STAGE_LIST_SEARCH_INDEXES = "$listSearchIndexes"; private final MongoNamespace namespace; @@ -59,9 +59,10 @@ final class ListSearchIndexesOperation<T> private final String indexName; private final boolean retryReads; - ListSearchIndexesOperation(final MongoNamespace namespace, final Decoder<T> decoder, @Nullable final String indexName, - @Nullable final Integer batchSize, @Nullable final Collation collation, @Nullable final BsonValue comment, - @Nullable final Boolean allowDiskUse, final boolean retryReads) { + public ListSearchIndexesOperation(final MongoNamespace namespace, final Decoder<T> decoder, @Nullable final String indexName, + @Nullable final Integer batchSize, @Nullable final Collation collation, + @Nullable final BsonValue comment, + @Nullable final Boolean allowDiskUse, final boolean retryReads) { this.namespace = namespace; this.decoder = decoder; this.allowDiskUse = allowDiskUse; diff --git a/driver-core/src/main/com/mongodb/internal/operation/SearchIndexRequest.java b/driver-core/src/main/com/mongodb/internal/operation/SearchIndexRequest.java index 0d37d2c2178..29b9b1ef34d 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/SearchIndexRequest.java +++ b/driver-core/src/main/com/mongodb/internal/operation/SearchIndexRequest.java @@ -31,14 +31,15 @@ * * <p>This class is not part of the public API and may be removed or changed at any time</p> */ -final class SearchIndexRequest { +public final class SearchIndexRequest { private final BsonDocument definition; @Nullable private final String indexName; @Nullable private final SearchIndexType searchIndexType; - SearchIndexRequest(final BsonDocument definition, @Nullable final String indexName, @Nullable final SearchIndexType searchIndexType) { + public SearchIndexRequest(final BsonDocument definition, @Nullable final String indexName, + @Nullable final SearchIndexType searchIndexType) { assertNotNull(definition); this.definition = definition; this.indexName = indexName; diff --git a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java new file mode 100644 index 00000000000..15def0f5d71 --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java @@ -0,0 +1,353 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.model.search; + +import com.mongodb.MongoInterruptedException; +import com.mongodb.MongoNamespace; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.SearchIndexType; +import com.mongodb.client.test.CollectionHelper; +import com.mongodb.internal.operation.SearchIndexRequest; +import org.bson.BsonDocument; +import org.bson.Document; +import org.bson.Vector; +import org.bson.codecs.DocumentCodec; +import org.bson.conversions.Bson; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.stream.Stream; + +import static com.mongodb.ClusterFixture.isAtlasSearchTest; +import static com.mongodb.ClusterFixture.serverVersionAtLeast; +import static com.mongodb.client.model.Filters.and; +import static com.mongodb.client.model.Filters.eq; +import static com.mongodb.client.model.Filters.gt; +import static com.mongodb.client.model.Filters.gte; +import static com.mongodb.client.model.Filters.in; +import static com.mongodb.client.model.Filters.lt; +import static com.mongodb.client.model.Filters.lte; +import static com.mongodb.client.model.Filters.ne; +import static com.mongodb.client.model.Filters.nin; +import static com.mongodb.client.model.Filters.or; +import static com.mongodb.client.model.Projections.fields; +import static com.mongodb.client.model.Projections.metaVectorSearchScore; +import static com.mongodb.client.model.search.SearchPath.fieldPath; +import static com.mongodb.client.model.search.VectorSearchOptions.approximateVectorSearchOptions; +import static com.mongodb.client.model.search.VectorSearchOptions.exactVectorSearchOptions; +import static java.lang.String.format; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +class AggregatesVectorSearchIntegrationTest { + private static final String EXCEED_WAIT_ATTEMPTS_ERROR_MESSAGE = + "Exceeded maximum attempts waiting for Search Index creation in Atlas cluster. Index document: %s"; + + private static final String VECTOR_INDEX = "vector_search_index"; + private static final String VECTOR_FIELD_INT_8 = "int8Vector"; + private static final String VECTOR_FIELD_FLOAT_32 = "float32Vector"; + private static final String VECTOR_FIELD_LEGACY_DOUBLE_LIST = "legacyDoubleVector"; + private static final int LIMIT = 5; + private static final String FIELD_YEAR = "year"; + private static CollectionHelper<Document> collectionHelper; + private static final BsonDocument VECTOR_SEARCH_INDEX_DEFINITION = BsonDocument.parse( + "{" + + " fields: [" + + " {" + + " path: '" + VECTOR_FIELD_INT_8 + "'," + + " numDimensions: 5," + + " similarity: 'cosine'," + + " type: 'vector'," + + " }," + + " {" + + " path: '" + VECTOR_FIELD_FLOAT_32 + "'," + + " numDimensions: 5," + + " similarity: 'cosine'," + + " type: 'vector'," + + " }," + + " {" + + " path: '" + VECTOR_FIELD_LEGACY_DOUBLE_LIST + "'," + + " numDimensions: 5," + + " similarity: 'cosine'," + + " type: 'vector'," + + " }," + + " {" + + " path: '" + FIELD_YEAR + "'," + + " type: 'filter'," + + " }," + + " ]" + + "}"); + + @BeforeAll + static void beforeAll() { + assumeTrue(isAtlasSearchTest()); + assumeTrue(serverVersionAtLeast(6, 0)); + + collectionHelper = + new CollectionHelper<>(new DocumentCodec(), new MongoNamespace("javaVectorSearchTest", AggregatesVectorSearchIntegrationTest.class.getSimpleName())); + collectionHelper.drop(); + collectionHelper.insertDocuments( + new Document() + .append("_id", 0) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{0, 1, 2, 3, 4})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{0.0001, 1.12345, 2.23456, 3.34567, 4.45678}) + .append(FIELD_YEAR, 2016), + new Document() + .append("_id", 1) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{1, 2, 3, 4, 5})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{1.0001f, 2.12345f, 3.23456f, 4.34567f, 5.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{1.0001, 2.12345, 3.23456, 4.34567, 5.45678}) + .append(FIELD_YEAR, 2017), + new Document() + .append("_id", 2) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{2, 3, 4, 5, 6})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{2.0002f, 3.12345f, 4.23456f, 5.34567f, 6.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{2.0002, 3.12345, 4.23456, 5.34567, 6.45678}) + .append(FIELD_YEAR, 2018), + new Document() + .append("_id", 3) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{3, 4, 5, 6, 7})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{3.0003f, 4.12345f, 5.23456f, 6.34567f, 7.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{3.0003, 4.12345, 5.23456, 6.34567, 7.45678}) + .append(FIELD_YEAR, 2019), + new Document() + .append("_id", 4) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{4, 5, 6, 7, 8})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{4.0004f, 5.12345f, 6.23456f, 7.34567f, 8.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{4.0004, 5.12345, 6.23456, 7.34567, 8.45678}) + .append(FIELD_YEAR, 2020), + new Document() + .append("_id", 5) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{5, 6, 7, 8, 9})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{5.0005f, 6.12345f, 7.23456f, 8.34567f, 9.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{5.0005, 6.12345, 7.23456, 8.34567, 9.45678}) + .append(FIELD_YEAR, 2021), + new Document() + .append("_id", 6) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{6, 7, 8, 9, 10})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{6.0006f, 7.12345f, 8.23456f, 9.34567f, 10.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{6.0006, 7.12345, 8.23456, 9.34567, 10.45678}) + .append(FIELD_YEAR, 2022), + new Document() + .append("_id", 7) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{7, 8, 9, 10, 11})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{7.0007f, 8.12345f, 9.23456f, 10.34567f, 11.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{7.0007, 8.12345, 9.23456, 10.34567, 11.45678}) + .append(FIELD_YEAR, 2023), + new Document() + .append("_id", 8) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{8, 9, 10, 11, 12})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{8.0008f, 9.12345f, 10.23456f, 11.34567f, 12.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{8.0008, 9.12345, 10.23456, 11.34567, 12.45678}) + .append(FIELD_YEAR, 2024), + new Document() + .append("_id", 9) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{9, 10, 11, 12, 13})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{9.0009f, 10.12345f, 11.23456f, 12.34567f, 13.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{9.0009, 10.12345, 11.23456, 12.34567, 13.45678}) + .append(FIELD_YEAR, 2025) + ); + + collectionHelper.createSearchIndex( + new SearchIndexRequest(VECTOR_SEARCH_INDEX_DEFINITION, VECTOR_INDEX, + SearchIndexType.vectorSearch())); + awaitIndexCreation(); + } + + @AfterAll + static void afterAll() { + if (collectionHelper != null) { + collectionHelper.drop(); + } + } + + private static Stream<Arguments> provideSupportedVectors() { + return Stream.of( + arguments(Vector.int8Vector(new byte[]{0, 1, 2, 3, 4}), + // `multi` is used here only to verify that it is tolerated + fieldPath(VECTOR_FIELD_INT_8).multi("ignored"), + approximateVectorSearchOptions(LIMIT * 2)), + arguments(Vector.int8Vector(new byte[]{0, 1, 2, 3, 4}), + fieldPath(VECTOR_FIELD_INT_8), + approximateVectorSearchOptions(LIMIT * 2)), + + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), + // `multi` is used here only to verify that it is tolerated + fieldPath(VECTOR_FIELD_FLOAT_32).multi("ignored"), + approximateVectorSearchOptions(LIMIT * 2)), + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), + fieldPath(VECTOR_FIELD_FLOAT_32), + approximateVectorSearchOptions(LIMIT * 2)), + + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), + // `multi` is used here only to verify that it is tolerated + fieldPath(VECTOR_FIELD_FLOAT_32).multi("ignored"), + exactVectorSearchOptions()), + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), + fieldPath(VECTOR_FIELD_FLOAT_32), + exactVectorSearchOptions()), + + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), + // `multi` is used here only to verify that it is tolerated + fieldPath(VECTOR_FIELD_LEGACY_DOUBLE_LIST).multi("ignored"), + exactVectorSearchOptions()), + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), + fieldPath(VECTOR_FIELD_LEGACY_DOUBLE_LIST), + exactVectorSearchOptions()), + + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), + // `multi` is used here only to verify that it is tolerated + fieldPath(VECTOR_FIELD_LEGACY_DOUBLE_LIST).multi("ignored"), + approximateVectorSearchOptions(LIMIT * 2)), + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), + fieldPath(VECTOR_FIELD_LEGACY_DOUBLE_LIST), + approximateVectorSearchOptions(LIMIT * 2)) + ); + } + + @ParameterizedTest + @MethodSource("provideSupportedVectors") + void shouldSearchByVectorWithSearchScore(final Vector vector, + final FieldSearchPath fieldSearchPath, + final VectorSearchOptions vectorSearchOptions) { + //given + List<Bson> pipeline = asList( + Aggregates.vectorSearch( + fieldSearchPath, + vector, + VECTOR_INDEX, LIMIT, + vectorSearchOptions), + Aggregates.project( + fields( + metaVectorSearchScore("vectorSearchScore") + )) + ); + + //when + List<Document> aggregate = collectionHelper.aggregate(pipeline); + + //then + Assertions.assertEquals(LIMIT, aggregate.size()); + assertScoreIsDecreasing(aggregate); + Document highestScoreDocument = aggregate.get(0); + assertEquals(1, highestScoreDocument.getDouble("vectorSearchScore")); + } + + @ParameterizedTest + @MethodSource("provideSupportedVectors") + void shouldSearchByVector(final Vector vector, + final FieldSearchPath fieldSearchPath, + final VectorSearchOptions vectorSearchOptions) { + //given + List<Bson> pipeline = asList( + Aggregates.vectorSearch( + fieldSearchPath, + vector, + VECTOR_INDEX, LIMIT, + vectorSearchOptions) + ); + + //when + List<Document> aggregate = collectionHelper.aggregate(pipeline); + + //then + Assertions.assertEquals(LIMIT, aggregate.size()); + assertFalse( + aggregate.stream() + .anyMatch(document -> document.containsKey("vectorSearchScore")) + ); + } + + @ParameterizedTest + @MethodSource("provideSupportedVectors") + void shouldSearchByVectorWithFilter(final Vector vector, + final FieldSearchPath fieldSearchPath, + final VectorSearchOptions vectorSearchOptions) { + Consumer<Bson> asserter = filter -> { + List<Bson> pipeline = singletonList( + Aggregates.vectorSearch( + fieldSearchPath, vector, VECTOR_INDEX, 1, + vectorSearchOptions.filter(filter)) + ); + + List<Document> aggregate = collectionHelper.aggregate(pipeline); + Assertions.assertFalse(aggregate.isEmpty()); + }; + + assertAll( + () -> asserter.accept(lt("year", 2020)), + () -> asserter.accept(lte("year", 2020)), + () -> asserter.accept(eq("year", 2020)), + () -> asserter.accept(gte("year", 2016)), + () -> asserter.accept(gt("year", 2015)), + () -> asserter.accept(ne("year", 2016)), + () -> asserter.accept(in("year", 2000, 2024)), + () -> asserter.accept(nin("year", 2000, 2024)), + () -> asserter.accept(and(gte("year", 2015), lte("year", 2017))), + () -> asserter.accept(or(eq("year", 2015), eq("year", 2017))) + ); + } + + private static void assertScoreIsDecreasing(final List<Document> aggregate) { + double previousScore = Integer.MAX_VALUE; + for (Document document : aggregate) { + Double vectorSearchScore = document.getDouble("vectorSearchScore"); + assertTrue(vectorSearchScore > 0, "Expected positive score"); + assertTrue(vectorSearchScore < previousScore, "Expected decreasing score"); + previousScore = vectorSearchScore; + } + } + + private static void awaitIndexCreation() { + int attempts = 10; + Optional<Document> searchIndex = Optional.empty(); + + while (attempts-- > 0) { + searchIndex = collectionHelper.listSearchIndex(VECTOR_INDEX); + if (searchIndex.filter(document -> document.getBoolean("queryable")) + .isPresent()) { + return; + } + + try { + TimeUnit.SECONDS.sleep(5); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new MongoInterruptedException(null, e); + } + } + + searchIndex.ifPresent(document -> + Assertions.fail(format(EXCEED_WAIT_ATTEMPTS_ERROR_MESSAGE, document.toJson()))); + Assertions.fail(format(EXCEED_WAIT_ATTEMPTS_ERROR_MESSAGE, "null")); + } +} diff --git a/driver-core/src/test/functional/com/mongodb/client/test/CollectionHelper.java b/driver-core/src/test/functional/com/mongodb/client/test/CollectionHelper.java index e297726d325..adce165ee51 100644 --- a/driver-core/src/test/functional/com/mongodb/client/test/CollectionHelper.java +++ b/driver-core/src/test/functional/com/mongodb/client/test/CollectionHelper.java @@ -43,11 +43,14 @@ import com.mongodb.internal.operation.CountDocumentsOperation; import com.mongodb.internal.operation.CreateCollectionOperation; import com.mongodb.internal.operation.CreateIndexesOperation; +import com.mongodb.internal.operation.CreateSearchIndexesOperation; import com.mongodb.internal.operation.DropCollectionOperation; import com.mongodb.internal.operation.DropDatabaseOperation; import com.mongodb.internal.operation.FindOperation; import com.mongodb.internal.operation.ListIndexesOperation; +import com.mongodb.internal.operation.ListSearchIndexesOperation; import com.mongodb.internal.operation.MixedBulkWriteOperation; +import com.mongodb.internal.operation.SearchIndexRequest; import org.bson.BsonArray; import org.bson.BsonDocument; import org.bson.BsonDocumentWrapper; @@ -56,6 +59,7 @@ import org.bson.BsonString; import org.bson.BsonValue; import org.bson.Document; +import org.bson.assertions.Assertions; import org.bson.codecs.BsonDocumentCodec; import org.bson.codecs.Codec; import org.bson.codecs.Decoder; @@ -65,6 +69,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import static com.mongodb.ClusterFixture.executeAsync; @@ -297,6 +302,25 @@ public List<T> find() { return find(codec); } + public Optional<T> listSearchIndex(final String indexName) { + ListSearchIndexesOperation<T> listSearchIndexesOperation = + new ListSearchIndexesOperation<>(namespace, codec, indexName, null, null, null, null, true); + BatchCursor<T> cursor = listSearchIndexesOperation.execute(getBinding()); + + List<T> results = new ArrayList<>(); + while (cursor.hasNext()) { + results.addAll(cursor.next()); + } + Assertions.assertTrue("Expected at most one result, but found " + results.size(), results.size() <= 1); + return results.isEmpty() ? Optional.empty() : Optional.of(results.get(0)); + } + + public void createSearchIndex(final SearchIndexRequest searchIndexModel) { + CreateSearchIndexesOperation searchIndexesOperation = + new CreateSearchIndexesOperation(namespace, singletonList(searchIndexModel)); + searchIndexesOperation.execute(getBinding()); + } + public <D> List<D> find(final Codec<D> codec) { BatchCursor<D> cursor = new FindOperation<>(namespace, codec) .sort(new BsonDocument("_id", new BsonInt32(1))) diff --git a/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy b/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy index 21df76e401e..3af81fc992c 100644 --- a/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy @@ -23,6 +23,7 @@ import com.mongodb.client.model.search.SearchOperator import org.bson.BsonDocument import org.bson.BsonInt32 import org.bson.Document +import org.bson.Vector import org.bson.conversions.Bson import spock.lang.IgnoreIf import spock.lang.Specification @@ -855,7 +856,7 @@ class AggregatesSpecification extends Specification { BsonDocument vectorSearchDoc = toBson( vectorSearch( fieldPath('fieldName').multi('ignored'), - [1.0d, 2.0d], + vector, 'indexName', 1, approximateVectorSearchOptions(2) @@ -868,13 +869,20 @@ class AggregatesSpecification extends Specification { vectorSearchDoc == parse('''{ "$vectorSearch": { "path": "fieldName", - "queryVector": [1.0, 2.0], + "queryVector": ''' + queryVector + ''', "index": "indexName", "numCandidates": {"$numberLong": "2"}, "limit": {"$numberLong": "1"}, "filter": {"fieldName": {"$ne": "fieldValue"}} } }''') + + where: + vector | queryVector + Vector.int8Vector(new byte[]{127, 7}) | '{"$binary": {"base64": "AwB/Bw==", "subType": "09"}}' + Vector.floatVector(new float[]{127.0f, 7.0f}) | '{"$binary": {"base64": "JwAAAP5CAADgQA==", "subType": "09"}}' + Vector.packedBitVector(new byte[]{127, 7}, (byte) 0) | '{"$binary": {"base64": "EAB/Bw==", "subType": "09"}}' + [1.0d, 2.0d] | "[1.0, 2.0]" } def 'should render exact $vectorSearch'() { @@ -882,7 +890,7 @@ class AggregatesSpecification extends Specification { BsonDocument vectorSearchDoc = toBson( vectorSearch( fieldPath('fieldName').multi('ignored'), - [1.0d, 2.0d], + vector, 'indexName', 1, exactVectorSearchOptions() @@ -895,13 +903,19 @@ class AggregatesSpecification extends Specification { vectorSearchDoc == parse('''{ "$vectorSearch": { "path": "fieldName", - "queryVector": [1.0, 2.0], + "queryVector": ''' + queryVector + ''', "index": "indexName", "exact": true, "limit": {"$numberLong": "1"}, "filter": {"fieldName": {"$ne": "fieldValue"}} } }''') + + where: + vector | queryVector + Vector.int8Vector(new byte[]{127, 7}) | '{"$binary": {"base64": "AwB/Bw==", "subType": "09"}}' + Vector.floatVector(new float[]{127.0f, 7.0f}) | '{"$binary": {"base64": "JwAAAP5CAADgQA==", "subType": "09"}}' + [1.0d, 2.0d] | "[1.0, 2.0]" } def 'should create string representation for simple stages'() {