Skip to content

Commit 03c4628

Browse files
committed
GH-3003 - Add vector type support.
This commit combines Spring Data Commons' Vector type with the Neo4j vector functionality. Fields defined as Spring Data Commons `Vector` will get persisted through the `setNodeVectorProperty` procedure. Closes #3003 Signed-off-by: Gerrit Meier <[email protected]>
1 parent 9dbaa48 commit 03c4628

File tree

12 files changed

+211
-43
lines changed

12 files changed

+211
-43
lines changed

src/main/antora/modules/ROOT/pages/appendix/conversions.adoc

+15
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ If you require the time zone, use a type that supports it (i.e. `ZoneDateTime`)
188188
|Point with CRS 4326 and x/y corresponding to lat/long
189189
|
190190

191+
|`org.springframework.data.domain.Vector`
192+
|persisted through `setNodeVectorProperty`
193+
|
194+
191195
|Instances of `Enum`
192196
|String (The name value of the enum)
193197
|
@@ -210,6 +214,17 @@ If you require the time zone, use a type that supports it (i.e. `ZoneDateTime`)
210214

211215
|===
212216

217+
[[build-in.conversions.vector]]
218+
=== Vector type
219+
Spring Data has its own type for vector representation `org.springframework.data.domain.Vector`.
220+
While this can be used as a wrapper around a `float` or `double` array, Spring Data Neo4j supports only the `double` variant right now.
221+
From a user perspective, it is possible to only define the `Vector` interface on the property definition and use either `double` or `float`.
222+
Neo4j will store both `double` and `float` variants as a 64-bit Cypher `FLOAT` value, which is consistent with values persisted through Cypher and the dedicated `setNodeVectorProperty` function that Spring Data Neo4j uses to persist the property.
223+
224+
NOTE: Spring Data Neo4j only allows one `Vector` property to be present in an entity definition.
225+
226+
NOTE: Please be aware that a persisted `float` value differs from a read back value due to the nature of floating numbers.
227+
213228
[[custom.conversions]]
214229
== Custom conversions
215230

src/main/java/org/springframework/data/neo4j/core/convert/AdditionalTypes.java

+11-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import org.springframework.data.convert.ConverterBuilder;
5151
import org.springframework.data.convert.ReadingConverter;
5252
import org.springframework.data.convert.WritingConverter;
53+
import org.springframework.data.domain.Vector;
5354
import org.springframework.data.mapping.MappingException;
5455
import org.springframework.util.Assert;
5556
import org.springframework.util.StringUtils;
@@ -104,14 +105,18 @@ final class AdditionalTypes {
104105
hlp.add(ConverterBuilder.reading(Value.class, Node.class, Value::asNode));
105106
hlp.add(ConverterBuilder.reading(Value.class, Relationship.class, Value::asRelationship));
106107
hlp.add(ConverterBuilder.reading(Value.class, Map.class, Value::asMap).andWriting(AdditionalTypes::value));
107-
108+
hlp.add(ConverterBuilder.reading(Value.class, Vector.class, AdditionalTypes::asVector).andWriting(AdditionalTypes::value));
108109
CONVERTERS = Collections.unmodifiableList(hlp);
109110
}
110111

111112
static Value value(Map<?, ?> map) {
112113
return Values.value(map);
113114
}
114115

116+
static Value value(Vector vector) {
117+
return Values.value(vector.toDoubleArray());
118+
}
119+
115120
static TimeZone asTimeZone(Value value) {
116121
return TimeZone.getTimeZone(value.asString());
117122
}
@@ -462,6 +467,11 @@ static short[] asShortArray(Value value) {
462467
return array;
463468
}
464469

470+
static Vector asVector(Value value) {
471+
double[] array = asDoubleArray(value);
472+
return Vector.of(array);
473+
}
474+
465475
static Value value(short[] aShortArray) {
466476
if (aShortArray == null) {
467477
return Values.NULL;

src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ public final class Constants {
5858
public static final String NAME_OF_ID = "__id__";
5959
public static final String NAME_OF_VERSION_PARAM = "__version__";
6060
public static final String NAME_OF_PROPERTIES_PARAM = "__properties__";
61+
public static final String NAME_OF_VECTOR_PROPERTY = "__vectorProperty__";
62+
public static final String NAME_OF_VECTOR_VALUE = "__vectorValue__";
6163
/**
6264
* Indicates the parameter that contains the static labels which are required to correctly compute the difference
6365
* in the list of dynamic labels when saving a node.

src/main/java/org/springframework/data/neo4j/core/mapping/CypherGenerator.java

+29-31
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,17 @@ public Statement prepareSaveOf(NodeDescription<?> nodeDescription,
308308
Assert.notNull(idDescription, "Cannot save individual nodes without an id attribute");
309309
Parameter<?> idParameter = parameter(Constants.NAME_OF_ID);
310310

311+
Function<StatementBuilder.OngoingMatchAndUpdate, Statement> vectorProcedureCall = (bs) -> {
312+
if (((Neo4jPersistentEntity<?>) nodeDescription).hasVectorProperty()) {
313+
return bs.with(rootNode)
314+
.call("db.create.setNodeVectorProperty")
315+
.withArgs(rootNode.getRequiredSymbolicName(), parameter(Constants.NAME_OF_VECTOR_PROPERTY), parameter(Constants.NAME_OF_VECTOR_VALUE))
316+
.withoutResults()
317+
.returning(rootNode).build();
318+
}
319+
return bs.returning(rootNode).build();
320+
};
321+
311322
if (!idDescription.isInternallyGeneratedId()) {
312323
GraphPropertyDescription idPropertyDescription = ((Neo4jPersistentEntity<?>) nodeDescription).getRequiredIdProperty();
313324

@@ -316,92 +327,79 @@ public Statement prepareSaveOf(NodeDescription<?> nodeDescription,
316327
String nameOfPossibleExistingNode = "hlp";
317328
Node possibleExistingNode = node(primaryLabel, additionalLabels).named(nameOfPossibleExistingNode);
318329

319-
Statement createIfNew = updateDecorator.apply(optionalMatch(possibleExistingNode)
330+
Statement createIfNew = vectorProcedureCall.apply(updateDecorator.apply(optionalMatch(possibleExistingNode)
320331
.where(createCompositePropertyCondition(idPropertyDescription, possibleExistingNode.getRequiredSymbolicName(), idParameter))
321332
.with(possibleExistingNode)
322333
.where(possibleExistingNode.isNull())
323334
.create(rootNode.withProperties(versionProperty, literalOf(0)))
324335
.with(rootNode)
325-
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))).returning(rootNode)
326-
.build();
336+
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));
327337

328-
Statement updateIfExists = updateDecorator.apply(match(rootNode)
338+
Statement updateIfExists = vectorProcedureCall.apply(updateDecorator.apply(match(rootNode)
329339
.where(createCompositePropertyCondition(idPropertyDescription, rootNode.getRequiredSymbolicName(), idParameter))
330340
.and(versionProperty.isEqualTo(parameter(Constants.NAME_OF_VERSION_PARAM))) // Initial check
331341
.set(versionProperty.to(versionProperty.add(literalOf(1)))) // Acquire lock
332342
.with(rootNode)
333343
.where(versionProperty.isEqualTo(coalesce(parameter(Constants.NAME_OF_VERSION_PARAM), literalOf(0)).add(
334344
literalOf(1))))
335-
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM)))
336-
.returning(rootNode)
337-
.build();
345+
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));
338346
return Cypher.union(createIfNew, updateIfExists);
339347

340348
} else {
341349
String nameOfPossibleExistingNode = "hlp";
342350
Node possibleExistingNode = node(primaryLabel, additionalLabels).named(nameOfPossibleExistingNode);
343351

344-
Statement createIfNew = updateDecorator.apply(optionalMatch(possibleExistingNode)
352+
Statement createIfNew = vectorProcedureCall.apply(updateDecorator.apply(optionalMatch(possibleExistingNode)
345353
.where(createCompositePropertyCondition(idPropertyDescription, possibleExistingNode.getRequiredSymbolicName(), idParameter))
346354
.with(possibleExistingNode)
347355
.where(possibleExistingNode.isNull())
348356
.create(rootNode)
349357
.with(rootNode)
350-
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))).returning(rootNode)
351-
.build();
358+
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));
352359

353-
Statement updateIfExists = updateDecorator.apply(match(rootNode)
360+
Statement updateIfExists = vectorProcedureCall.apply(updateDecorator.apply(match(rootNode)
354361
.where(createCompositePropertyCondition(idPropertyDescription, rootNode.getRequiredSymbolicName(), idParameter))
355362
.with(rootNode)
356-
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM)))
357-
.returning(rootNode)
358-
.build();
363+
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));
359364
return Cypher.union(createIfNew, updateIfExists);
360365
}
361366
} else {
362367
String nameOfPossibleExistingNode = "hlp";
363368
Node possibleExistingNode = node(primaryLabel, additionalLabels).named(nameOfPossibleExistingNode);
364369

365-
Statement createIfNew;
366-
Statement updateIfExists;
367-
368370
var neo4jPersistentEntity = (Neo4jPersistentEntity<?>) nodeDescription;
369371
var nodeIdFunction = getNodeIdFunction(neo4jPersistentEntity, canUseElementId);
370372

371373
if (neo4jPersistentEntity.hasVersionProperty()) {
372374
Property versionProperty = rootNode.property(neo4jPersistentEntity.getRequiredVersionProperty().getName());
373375

374-
createIfNew = updateDecorator.apply(optionalMatch(possibleExistingNode)
376+
var createIfNew = vectorProcedureCall.apply(updateDecorator.apply(optionalMatch(possibleExistingNode)
375377
.where(nodeIdFunction.apply(possibleExistingNode).isEqualTo(idParameter))
376378
.with(possibleExistingNode)
377379
.where(possibleExistingNode.isNull())
378380
.create(rootNode.withProperties(versionProperty, literalOf(0)))
379381
.with(rootNode)
380-
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM)))
381-
.returning(rootNode)
382-
.build();
382+
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));
383383

384-
updateIfExists = updateDecorator.apply(match(rootNode)
384+
var updateIfExists = vectorProcedureCall.apply(updateDecorator.apply(match(rootNode)
385385
.where(nodeIdFunction.apply(rootNode).isEqualTo(idParameter))
386386
.and(versionProperty.isEqualTo(parameter(Constants.NAME_OF_VERSION_PARAM))) // Initial check
387387
.set(versionProperty.to(versionProperty.add(literalOf(1)))) // Acquire lock
388388
.with(rootNode)
389389
.where(versionProperty.isEqualTo(coalesce(parameter(Constants.NAME_OF_VERSION_PARAM), literalOf(0)).add(
390390
literalOf(1))))
391-
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM)))
392-
.returning(rootNode).build();
391+
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));
392+
return Cypher.union(createIfNew, updateIfExists);
393393
} else {
394-
createIfNew = updateDecorator
395-
.apply(optionalMatch(possibleExistingNode).where(nodeIdFunction.apply(possibleExistingNode).isEqualTo(idParameter))
394+
var createStatement = vectorProcedureCall.apply(updateDecorator.apply(optionalMatch(possibleExistingNode).where(nodeIdFunction.apply(possibleExistingNode).isEqualTo(idParameter))
396395
.with(possibleExistingNode).where(possibleExistingNode.isNull()).create(rootNode)
397-
.set(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM)))
398-
.returning(rootNode).build();
396+
.set(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));
397+
var updateStatement = vectorProcedureCall.apply(updateDecorator.apply(match(rootNode).where(nodeIdFunction.apply(rootNode).isEqualTo(idParameter))
398+
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));
399399

400-
updateIfExists = updateDecorator.apply(match(rootNode).where(nodeIdFunction.apply(rootNode).isEqualTo(idParameter))
401-
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))).returning(rootNode).build();
400+
return Cypher.union(createStatement, updateStatement);
402401
}
403402

404-
return Cypher.union(createIfNew, updateIfExists);
405403
}
406404
}
407405

src/main/java/org/springframework/data/neo4j/core/mapping/DefaultNeo4jEntityConverter.java

+9-2
Original file line numberDiff line numberDiff line change
@@ -242,14 +242,13 @@ public void write(Object source, Map<String, Object> parameters) {
242242
PropertyHandlerSupport.of(nodeDescription).doWithProperties((Neo4jPersistentProperty p) -> {
243243

244244
// Skip the internal properties, we don't want them to end up stored as properties
245-
if (p.isInternalIdProperty() || p.isDynamicLabels() || p.isEntity() || p.isVersionProperty() || p.isReadOnly()) {
245+
if (p.isInternalIdProperty() || p.isDynamicLabels() || p.isEntity() || p.isVersionProperty() || p.isReadOnly() || p.isVectorProperty()) {
246246
return;
247247
}
248248

249249
final Value value = conversionService.writeValue(propertyAccessor.getProperty(p), p.getTypeInformation(), p.getOptionalConverter());
250250
if (p.isComposite()) {
251251
properties.put(p.getPropertyName(), new MapValueWrapper(value));
252-
//value.keys().forEach(k -> properties.put(k, value.get(k)));
253252
} else {
254253
properties.put(p.getPropertyName(), value);
255254
}
@@ -270,6 +269,14 @@ public void write(Object source, Map<String, Object> parameters) {
270269
// we incremented this upfront the persist operation so the matching version would be one "before"
271270
parameters.put(Constants.NAME_OF_VERSION_PARAM, versionProperty);
272271
}
272+
273+
// special handling for vector property to provide the needed procedure information
274+
if (nodeDescription.hasVectorProperty()) {
275+
Neo4jPersistentProperty vectorProperty = nodeDescription.getRequiredVectorProperty();
276+
parameters.put(Constants.NAME_OF_VECTOR_PROPERTY, vectorProperty.getPropertyName());
277+
parameters.put(Constants.NAME_OF_VECTOR_VALUE, conversionService.writeValue(propertyAccessor.getProperty(vectorProperty), vectorProperty.getTypeInformation(), vectorProperty.getOptionalConverter()));
278+
return;
279+
}
273280
}
274281

275282
/**

src/main/java/org/springframework/data/neo4j/core/mapping/DefaultNeo4jPersistentEntity.java

+35
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.springframework.core.annotation.AnnotatedElementUtils;
3636
import org.springframework.core.log.LogAccessor;
3737
import org.springframework.data.annotation.Persistent;
38+
import org.springframework.data.domain.Vector;
3839
import org.springframework.data.mapping.Association;
3940
import org.springframework.data.mapping.model.BasicPersistentEntity;
4041
import org.springframework.data.neo4j.core.schema.DynamicLabels;
@@ -88,6 +89,8 @@ final class DefaultNeo4jPersistentEntity<T> extends BasicPersistentEntity<T, Neo
8889

8990
private List<NodeDescription<?>> childNodeDescriptionsInHierarchy;
9091

92+
private final Lazy<Neo4jPersistentProperty> vectorProperty;
93+
9194
DefaultNeo4jPersistentEntity(TypeInformation<T> information) {
9295
super(information);
9396

@@ -99,6 +102,8 @@ final class DefaultNeo4jPersistentEntity<T> extends BasicPersistentEntity<T, Neo
99102
this.isRelationshipPropertiesEntity = Lazy.of(() -> isAnnotationPresent(RelationshipProperties.class));
100103
this.idDescription = Lazy.of(this::computeIdDescription);
101104
this.childNodeDescriptionsInHierarchy = computeChildNodeDescriptionInHierarchy();
105+
this.vectorProperty = Lazy.of(() -> getGraphProperties().stream().map(Neo4jPersistentProperty.class::cast)
106+
.filter(Neo4jPersistentProperty::isVectorProperty).findFirst().orElse(null));
102107
}
103108

104109
/*
@@ -212,6 +217,7 @@ public void verify() {
212217
verifyDynamicAssociations();
213218
verifyAssociationsWithProperties();
214219
verifyDynamicLabels();
220+
verifyAtMostOneVectorDefinition();
215221
}
216222

217223
private void verifyIdDescription() {
@@ -301,6 +307,18 @@ private void verifyDynamicLabels() {
301307
DynamicLabels.class.getSimpleName(), namesOfPropertiesWithDynamicLabels));
302308
}
303309

310+
private void verifyAtMostOneVectorDefinition() {
311+
List<Neo4jPersistentProperty> foundVectorDefinition = new ArrayList<>();
312+
PropertyHandlerSupport.of(this).doWithProperties(persistentProperty -> {
313+
if (persistentProperty.getType().isAssignableFrom(Vector.class)) {
314+
foundVectorDefinition.add(persistentProperty);
315+
}
316+
});
317+
318+
Assert.state(foundVectorDefinition.size() <= 1, () -> String.format("There are multiple fields of type %s in entity %s: %s",
319+
Vector.class.toString(), this.getName(), foundVectorDefinition.stream().map(p -> p.getPropertyName()).toList()));
320+
}
321+
304322
/**
305323
* The primary label will get computed and returned by following rules:<br>
306324
* 1. If there is no {@link Node} annotation, use the class name.<br>
@@ -410,6 +428,23 @@ public boolean describesInterface() {
410428
return this.getTypeInformation().getRawTypeInformation().getType().isInterface();
411429
}
412430

431+
@Override
432+
public boolean hasVectorProperty() {
433+
return Optional.ofNullable(getVectorProperty()).map(v -> true).orElse(false);
434+
}
435+
436+
public Neo4jPersistentProperty getVectorProperty() {
437+
return this.vectorProperty.getNullable();
438+
}
439+
440+
public Neo4jPersistentProperty getRequiredVectorProperty() {
441+
Neo4jPersistentProperty property = getVectorProperty();
442+
if (property != null) {
443+
return property;
444+
}
445+
throw new IllegalStateException(String.format("Required vector property not found for %s", this.getType()));
446+
}
447+
413448
private static boolean hasEmptyLabelInformation(Node nodeAnnotation) {
414449
return nodeAnnotation.labels().length < 1 && !StringUtils.hasText(nodeAnnotation.primaryLabel());
415450
}

src/main/java/org/springframework/data/neo4j/core/mapping/Neo4jPersistentEntity.java

+5
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,9 @@ default boolean isUsingDeprecatedInternalId() {
7575
}
7676
return isUsingInternalIds() && Neo4jPersistentEntity.DEPRECATED_GENERATED_ID_TYPES.contains(getRequiredIdProperty().getType());
7777
}
78+
79+
boolean hasVectorProperty();
80+
81+
Neo4jPersistentProperty getVectorProperty();
82+
Neo4jPersistentProperty getRequiredVectorProperty();
7883
}

src/main/java/org/springframework/data/neo4j/core/mapping/Neo4jPersistentProperty.java

+5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.Optional;
1919

2020
import org.apiguardian.api.API;
21+
import org.springframework.data.domain.Vector;
2122
import org.springframework.data.mapping.PersistentProperty;
2223
import org.springframework.data.neo4j.core.convert.Neo4jPersistentPropertyConverter;
2324
import org.springframework.data.neo4j.core.schema.CompositeProperty;
@@ -67,6 +68,10 @@ default boolean isDynamicLabels() {
6768
return this.isAnnotationPresent(DynamicLabels.class) && this.isCollectionLike();
6869
}
6970

71+
default boolean isVectorProperty() {
72+
return this.getType().isAssignableFrom(Vector.class);
73+
}
74+
7075
@Nullable
7176
Neo4jPersistentPropertyConverter<?> getOptionalConverter();
7277

0 commit comments

Comments
 (0)