Skip to content

Add Vector type support #3009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/main/antora/modules/ROOT/pages/appendix/conversions.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ If you require the time zone, use a type that supports it (i.e. `ZoneDateTime`)
|Point with CRS 4326 and x/y corresponding to lat/long
|

|`org.springframework.data.domain.Vector`
|persisted through `setNodeVectorProperty`
|

|Instances of `Enum`
|String (The name value of the enum)
|
Expand All @@ -210,6 +214,17 @@ If you require the time zone, use a type that supports it (i.e. `ZoneDateTime`)

|===

[[build-in.conversions.vector]]
=== Vector type
Spring Data has its own type for vector representation `org.springframework.data.domain.Vector`.
While this can be used as a wrapper around a `float` or `double` array, Spring Data Neo4j supports only the `double` variant right now.
From a user perspective, it is possible to only define the `Vector` interface on the property definition and use either `double` or `float`.
Neo4j will store both `double` and `float` variants as a 64-bit Cypher `FLOAT` value, which is consistent with values persisted through Cypher and the dedicated `setNodeVectorProperty` function that Spring Data Neo4j uses to persist the property.

NOTE: Spring Data Neo4j only allows one `Vector` property to be present in an entity definition.

NOTE: Please be aware that a persisted `float` value differs from a read back value due to the nature of floating numbers.

[[custom.conversions]]
== Custom conversions

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.springframework.data.convert.ConverterBuilder;
import org.springframework.data.convert.ReadingConverter;
import org.springframework.data.convert.WritingConverter;
import org.springframework.data.domain.Vector;
import org.springframework.data.mapping.MappingException;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
Expand Down Expand Up @@ -104,14 +105,18 @@ final class AdditionalTypes {
hlp.add(ConverterBuilder.reading(Value.class, Node.class, Value::asNode));
hlp.add(ConverterBuilder.reading(Value.class, Relationship.class, Value::asRelationship));
hlp.add(ConverterBuilder.reading(Value.class, Map.class, Value::asMap).andWriting(AdditionalTypes::value));

hlp.add(ConverterBuilder.reading(Value.class, Vector.class, AdditionalTypes::asVector).andWriting(AdditionalTypes::value));
CONVERTERS = Collections.unmodifiableList(hlp);
}

static Value value(Map<?, ?> map) {
return Values.value(map);
}

static Value value(Vector vector) {
return Values.value(vector.toDoubleArray());
}

static TimeZone asTimeZone(Value value) {
return TimeZone.getTimeZone(value.asString());
}
Expand Down Expand Up @@ -462,6 +467,11 @@ static short[] asShortArray(Value value) {
return array;
}

static Vector asVector(Value value) {
double[] array = asDoubleArray(value);
return Vector.of(array);
}

static Value value(short[] aShortArray) {
if (aShortArray == null) {
return Values.NULL;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ public final class Constants {
public static final String NAME_OF_ID = "__id__";
public static final String NAME_OF_VERSION_PARAM = "__version__";
public static final String NAME_OF_PROPERTIES_PARAM = "__properties__";
public static final String NAME_OF_VECTOR_PROPERTY = "__vectorProperty__";
public static final String NAME_OF_VECTOR_VALUE = "__vectorValue__";
/**
* Indicates the parameter that contains the static labels which are required to correctly compute the difference
* in the list of dynamic labels when saving a node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,17 @@ public Statement prepareSaveOf(NodeDescription<?> nodeDescription,
Assert.notNull(idDescription, "Cannot save individual nodes without an id attribute");
Parameter<?> idParameter = parameter(Constants.NAME_OF_ID);

Function<StatementBuilder.OngoingMatchAndUpdate, Statement> vectorProcedureCall = (bs) -> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

;)

if (((Neo4jPersistentEntity<?>) nodeDescription).hasVectorProperty()) {
return bs.with(rootNode)
.call("db.create.setNodeVectorProperty")
.withArgs(rootNode.getRequiredSymbolicName(), parameter(Constants.NAME_OF_VECTOR_PROPERTY), parameter(Constants.NAME_OF_VECTOR_VALUE))
.withoutResults()
.returning(rootNode).build();
}
return bs.returning(rootNode).build();
};

if (!idDescription.isInternallyGeneratedId()) {
GraphPropertyDescription idPropertyDescription = ((Neo4jPersistentEntity<?>) nodeDescription).getRequiredIdProperty();

Expand All @@ -316,92 +327,79 @@ public Statement prepareSaveOf(NodeDescription<?> nodeDescription,
String nameOfPossibleExistingNode = "hlp";
Node possibleExistingNode = node(primaryLabel, additionalLabels).named(nameOfPossibleExistingNode);

Statement createIfNew = updateDecorator.apply(optionalMatch(possibleExistingNode)
Statement createIfNew = vectorProcedureCall.apply(updateDecorator.apply(optionalMatch(possibleExistingNode)
.where(createCompositePropertyCondition(idPropertyDescription, possibleExistingNode.getRequiredSymbolicName(), idParameter))
.with(possibleExistingNode)
.where(possibleExistingNode.isNull())
.create(rootNode.withProperties(versionProperty, literalOf(0)))
.with(rootNode)
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))).returning(rootNode)
.build();
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));

Statement updateIfExists = updateDecorator.apply(match(rootNode)
Statement updateIfExists = vectorProcedureCall.apply(updateDecorator.apply(match(rootNode)
.where(createCompositePropertyCondition(idPropertyDescription, rootNode.getRequiredSymbolicName(), idParameter))
.and(versionProperty.isEqualTo(parameter(Constants.NAME_OF_VERSION_PARAM))) // Initial check
.set(versionProperty.to(versionProperty.add(literalOf(1)))) // Acquire lock
.with(rootNode)
.where(versionProperty.isEqualTo(coalesce(parameter(Constants.NAME_OF_VERSION_PARAM), literalOf(0)).add(
literalOf(1))))
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM)))
.returning(rootNode)
.build();
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));
return Cypher.union(createIfNew, updateIfExists);

} else {
String nameOfPossibleExistingNode = "hlp";
Node possibleExistingNode = node(primaryLabel, additionalLabels).named(nameOfPossibleExistingNode);

Statement createIfNew = updateDecorator.apply(optionalMatch(possibleExistingNode)
Statement createIfNew = vectorProcedureCall.apply(updateDecorator.apply(optionalMatch(possibleExistingNode)
.where(createCompositePropertyCondition(idPropertyDescription, possibleExistingNode.getRequiredSymbolicName(), idParameter))
.with(possibleExistingNode)
.where(possibleExistingNode.isNull())
.create(rootNode)
.with(rootNode)
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))).returning(rootNode)
.build();
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));

Statement updateIfExists = updateDecorator.apply(match(rootNode)
Statement updateIfExists = vectorProcedureCall.apply(updateDecorator.apply(match(rootNode)
.where(createCompositePropertyCondition(idPropertyDescription, rootNode.getRequiredSymbolicName(), idParameter))
.with(rootNode)
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM)))
.returning(rootNode)
.build();
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));
return Cypher.union(createIfNew, updateIfExists);
}
} else {
String nameOfPossibleExistingNode = "hlp";
Node possibleExistingNode = node(primaryLabel, additionalLabels).named(nameOfPossibleExistingNode);

Statement createIfNew;
Statement updateIfExists;

var neo4jPersistentEntity = (Neo4jPersistentEntity<?>) nodeDescription;
var nodeIdFunction = getNodeIdFunction(neo4jPersistentEntity, canUseElementId);

if (neo4jPersistentEntity.hasVersionProperty()) {
Property versionProperty = rootNode.property(neo4jPersistentEntity.getRequiredVersionProperty().getName());

createIfNew = updateDecorator.apply(optionalMatch(possibleExistingNode)
var createIfNew = vectorProcedureCall.apply(updateDecorator.apply(optionalMatch(possibleExistingNode)
.where(nodeIdFunction.apply(possibleExistingNode).isEqualTo(idParameter))
.with(possibleExistingNode)
.where(possibleExistingNode.isNull())
.create(rootNode.withProperties(versionProperty, literalOf(0)))
.with(rootNode)
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM)))
.returning(rootNode)
.build();
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));

updateIfExists = updateDecorator.apply(match(rootNode)
var updateIfExists = vectorProcedureCall.apply(updateDecorator.apply(match(rootNode)
.where(nodeIdFunction.apply(rootNode).isEqualTo(idParameter))
.and(versionProperty.isEqualTo(parameter(Constants.NAME_OF_VERSION_PARAM))) // Initial check
.set(versionProperty.to(versionProperty.add(literalOf(1)))) // Acquire lock
.with(rootNode)
.where(versionProperty.isEqualTo(coalesce(parameter(Constants.NAME_OF_VERSION_PARAM), literalOf(0)).add(
literalOf(1))))
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM)))
.returning(rootNode).build();
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));
return Cypher.union(createIfNew, updateIfExists);
} else {
createIfNew = updateDecorator
.apply(optionalMatch(possibleExistingNode).where(nodeIdFunction.apply(possibleExistingNode).isEqualTo(idParameter))
var createStatement = vectorProcedureCall.apply(updateDecorator.apply(optionalMatch(possibleExistingNode).where(nodeIdFunction.apply(possibleExistingNode).isEqualTo(idParameter))
.with(possibleExistingNode).where(possibleExistingNode.isNull()).create(rootNode)
.set(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM)))
.returning(rootNode).build();
.set(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));
var updateStatement = vectorProcedureCall.apply(updateDecorator.apply(match(rootNode).where(nodeIdFunction.apply(rootNode).isEqualTo(idParameter))
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))));

updateIfExists = updateDecorator.apply(match(rootNode).where(nodeIdFunction.apply(rootNode).isEqualTo(idParameter))
.mutate(rootNode, parameter(Constants.NAME_OF_PROPERTIES_PARAM))).returning(rootNode).build();
return Cypher.union(createStatement, updateStatement);
}

return Cypher.union(createIfNew, updateIfExists);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,13 @@ public void write(Object source, Map<String, Object> parameters) {
PropertyHandlerSupport.of(nodeDescription).doWithProperties((Neo4jPersistentProperty p) -> {

// Skip the internal properties, we don't want them to end up stored as properties
if (p.isInternalIdProperty() || p.isDynamicLabels() || p.isEntity() || p.isVersionProperty() || p.isReadOnly()) {
if (p.isInternalIdProperty() || p.isDynamicLabels() || p.isEntity() || p.isVersionProperty() || p.isReadOnly() || p.isVectorProperty()) {
return;
}

final Value value = conversionService.writeValue(propertyAccessor.getProperty(p), p.getTypeInformation(), p.getOptionalConverter());
if (p.isComposite()) {
properties.put(p.getPropertyName(), new MapValueWrapper(value));
//value.keys().forEach(k -> properties.put(k, value.get(k)));
} else {
properties.put(p.getPropertyName(), value);
}
Expand All @@ -270,6 +269,14 @@ public void write(Object source, Map<String, Object> parameters) {
// we incremented this upfront the persist operation so the matching version would be one "before"
parameters.put(Constants.NAME_OF_VERSION_PARAM, versionProperty);
}

// special handling for vector property to provide the needed procedure information
if (nodeDescription.hasVectorProperty()) {
Neo4jPersistentProperty vectorProperty = nodeDescription.getRequiredVectorProperty();
parameters.put(Constants.NAME_OF_VECTOR_PROPERTY, vectorProperty.getPropertyName());
parameters.put(Constants.NAME_OF_VECTOR_VALUE, conversionService.writeValue(propertyAccessor.getProperty(vectorProperty), vectorProperty.getTypeInformation(), vectorProperty.getOptionalConverter()));
return;
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.log.LogAccessor;
import org.springframework.data.annotation.Persistent;
import org.springframework.data.domain.Vector;
import org.springframework.data.mapping.Association;
import org.springframework.data.mapping.model.BasicPersistentEntity;
import org.springframework.data.neo4j.core.schema.DynamicLabels;
Expand Down Expand Up @@ -88,6 +89,8 @@ final class DefaultNeo4jPersistentEntity<T> extends BasicPersistentEntity<T, Neo

private List<NodeDescription<?>> childNodeDescriptionsInHierarchy;

private final Lazy<Neo4jPersistentProperty> vectorProperty;

DefaultNeo4jPersistentEntity(TypeInformation<T> information) {
super(information);

Expand All @@ -99,6 +102,8 @@ final class DefaultNeo4jPersistentEntity<T> extends BasicPersistentEntity<T, Neo
this.isRelationshipPropertiesEntity = Lazy.of(() -> isAnnotationPresent(RelationshipProperties.class));
this.idDescription = Lazy.of(this::computeIdDescription);
this.childNodeDescriptionsInHierarchy = computeChildNodeDescriptionInHierarchy();
this.vectorProperty = Lazy.of(() -> getGraphProperties().stream().map(Neo4jPersistentProperty.class::cast)
.filter(Neo4jPersistentProperty::isVectorProperty).findFirst().orElse(null));
}

/*
Expand Down Expand Up @@ -212,6 +217,7 @@ public void verify() {
verifyDynamicAssociations();
verifyAssociationsWithProperties();
verifyDynamicLabels();
verifyAtMostOneVectorDefinition();
}

private void verifyIdDescription() {
Expand Down Expand Up @@ -301,6 +307,18 @@ private void verifyDynamicLabels() {
DynamicLabels.class.getSimpleName(), namesOfPropertiesWithDynamicLabels));
}

private void verifyAtMostOneVectorDefinition() {
List<Neo4jPersistentProperty> foundVectorDefinition = new ArrayList<>();
PropertyHandlerSupport.of(this).doWithProperties(persistentProperty -> {
if (persistentProperty.getType().isAssignableFrom(Vector.class)) {
foundVectorDefinition.add(persistentProperty);
}
});

Assert.state(foundVectorDefinition.size() <= 1, () -> String.format("There are multiple fields of type %s in entity %s: %s",
Vector.class.toString(), this.getName(), foundVectorDefinition.stream().map(p -> p.getPropertyName()).toList()));
}

/**
* The primary label will get computed and returned by following rules:<br>
* 1. If there is no {@link Node} annotation, use the class name.<br>
Expand Down Expand Up @@ -410,6 +428,23 @@ public boolean describesInterface() {
return this.getTypeInformation().getRawTypeInformation().getType().isInterface();
}

@Override
public boolean hasVectorProperty() {
return Optional.ofNullable(getVectorProperty()).map(v -> true).orElse(false);
}

public Neo4jPersistentProperty getVectorProperty() {
return this.vectorProperty.getNullable();
}

public Neo4jPersistentProperty getRequiredVectorProperty() {
Neo4jPersistentProperty property = getVectorProperty();
if (property != null) {
return property;
}
throw new IllegalStateException(String.format("Required vector property not found for %s", this.getType()));
}

private static boolean hasEmptyLabelInformation(Node nodeAnnotation) {
return nodeAnnotation.labels().length < 1 && !StringUtils.hasText(nodeAnnotation.primaryLabel());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,9 @@ default boolean isUsingDeprecatedInternalId() {
}
return isUsingInternalIds() && Neo4jPersistentEntity.DEPRECATED_GENERATED_ID_TYPES.contains(getRequiredIdProperty().getType());
}

boolean hasVectorProperty();

Neo4jPersistentProperty getVectorProperty();
Neo4jPersistentProperty getRequiredVectorProperty();
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.Optional;

import org.apiguardian.api.API;
import org.springframework.data.domain.Vector;
import org.springframework.data.mapping.PersistentProperty;
import org.springframework.data.neo4j.core.convert.Neo4jPersistentPropertyConverter;
import org.springframework.data.neo4j.core.schema.CompositeProperty;
Expand Down Expand Up @@ -67,6 +68,10 @@ default boolean isDynamicLabels() {
return this.isAnnotationPresent(DynamicLabels.class) && this.isCollectionLike();
}

default boolean isVectorProperty() {
return this.getType().isAssignableFrom(Vector.class);
}

@Nullable
Neo4jPersistentPropertyConverter<?> getOptionalConverter();

Expand Down
Loading