Skip to content

Implement arithmetic expressions #1037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.bson.BsonArray;
import org.bson.BsonBoolean;
import org.bson.BsonDateTime;
import org.bson.BsonDecimal128;
import org.bson.BsonDocument;
import org.bson.BsonDouble;
import org.bson.BsonInt32;
Expand All @@ -28,13 +29,15 @@
import org.bson.BsonString;
import org.bson.BsonValue;
import org.bson.conversions.Bson;
import org.bson.types.Decimal128;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import static com.mongodb.client.model.expressions.MqlExpression.AstPlaceholder;

/**
* Convenience methods related to {@link Expression}.
Expand All @@ -52,7 +55,7 @@ private Expressions() {}
*/
public static BooleanExpression of(final boolean of) {
// we intentionally disallow ofBoolean(null)
return new MqlExpression<>((codecRegistry) -> new BsonBoolean(of));
return new MqlExpression<>((codecRegistry) -> new AstPlaceholder(new BsonBoolean(of)));
}

/**
Expand All @@ -63,16 +66,21 @@ public static BooleanExpression of(final boolean of) {
* @return the integer expression
*/
public static IntegerExpression of(final int of) {
return new MqlExpression<>((codecRegistry) -> new BsonInt32(of));
return new MqlExpression<>((codecRegistry) -> new AstPlaceholder(new BsonInt32(of)));
}
public static IntegerExpression of(final long of) {
return new MqlExpression<>((codecRegistry) -> new BsonInt64(of));
return new MqlExpression<>((codecRegistry) -> new AstPlaceholder(new BsonInt64(of)));
}
public static NumberExpression of(final double of) {
return new MqlExpression<>((codecRegistry) -> new BsonDouble(of));
return new MqlExpression<>((codecRegistry) -> new AstPlaceholder(new BsonDouble(of)));
}
public static NumberExpression of(final Decimal128 of) {
Assertions.notNull("Decimal128", of);
return new MqlExpression<>((codecRegistry) -> new AstPlaceholder(new BsonDecimal128(of)));
}
public static DateExpression of(final Instant of) {
return new MqlExpression<>((codecRegistry) -> new BsonDateTime(of.toEpochMilli()));
Assertions.notNull("Instant", of);
return new MqlExpression<>((codecRegistry) -> new AstPlaceholder(new BsonDateTime(of.toEpochMilli())));
}

/**
Expand All @@ -84,7 +92,7 @@ public static DateExpression of(final Instant of) {
*/
public static StringExpression of(final String of) {
Assertions.notNull("String", of);
return new MqlExpression<>((codecRegistry) -> new BsonString(of));
return new MqlExpression<>((codecRegistry) -> new AstPlaceholder(new BsonString(of)));
}

/**
Expand All @@ -99,27 +107,42 @@ public static ArrayExpression<BooleanExpression> ofBooleanArray(final boolean...
for (boolean b : array) {
result.add(new BsonBoolean(b));
}
return new MqlExpression<>((cr) -> new BsonArray(result));
return new MqlExpression<>((cr) -> new AstPlaceholder(new BsonArray(result)));
}


public static ArrayExpression<IntegerExpression> ofIntegerArray(final int... ofIntegerArray) {
List<BsonValue> array = Arrays.stream(ofIntegerArray)
.mapToObj(BsonInt32::new)
.collect(Collectors.toList());
return new MqlExpression<>((cr) -> new BsonArray(array));
return new MqlExpression<>((cr) -> new AstPlaceholder(new BsonArray(array)));
}

public static DocumentExpression ofDocument(final Bson document) {
Assertions.notNull("document", document);
// All documents are wrapped in a $literal. If we don't wrap, we need to
// check for empty documents and documents that are actually expressions
// (and need to be wrapped in $literal anyway). This would be brittle.
return new MqlExpression<>((cr) -> new BsonDocument("$literal",
document.toBsonDocument(BsonDocument.class, cr)));
return new MqlExpression<>((cr) -> new AstPlaceholder(new BsonDocument("$literal",
document.toBsonDocument(BsonDocument.class, cr))));
}

public static <R extends Expression> R ofNull() {
return new MqlExpression<>((cr) -> new BsonNull()).assertImplementsAllExpressions();
return new MqlExpression<>((cr) -> new AstPlaceholder(new BsonNull()))
.assertImplementsAllExpressions();
}

static NumberExpression numberToExpression(final Number number) {
if (number instanceof Integer) {
return of((int) number);
} else if (number instanceof Long) {
return of((long) number);
} else if (number instanceof Double) {
return of((double) number);
} else if (number instanceof Decimal128) {
return of((Decimal128) number);
} else {
throw new IllegalArgumentException("Number must be one of: Integer, Long, Double, Decimal128");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,26 @@
* Expresses an integer value.
*/
public interface IntegerExpression extends NumberExpression {
IntegerExpression multiply(IntegerExpression i);

default IntegerExpression multiply(final int multiply) {
return this.multiply(Expressions.of(multiply));
}

IntegerExpression add(IntegerExpression i);

default IntegerExpression add(final int add) {
return this.add(Expressions.of(add));
}

IntegerExpression subtract(IntegerExpression i);

default IntegerExpression subtract(final int subtract) {
return this.subtract(Expressions.of(subtract));
}

IntegerExpression max(IntegerExpression i);
IntegerExpression min(IntegerExpression i);

IntegerExpression abs();
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ final class MqlExpression<T extends Expression>
implements Expression, BooleanExpression, IntegerExpression, NumberExpression,
StringExpression, DateExpression, DocumentExpression, ArrayExpression<T> {

private final Function<CodecRegistry, BsonValue> fn;
private final Function<CodecRegistry, AstPlaceholder> fn;

MqlExpression(final Function<CodecRegistry, BsonValue> fn) {
MqlExpression(final Function<CodecRegistry, AstPlaceholder> fn) {
this.fn = fn;
}

Expand All @@ -41,33 +41,41 @@ final class MqlExpression<T extends Expression>
* {@link MqlExpressionCodec}.
*/
BsonValue toBsonValue(final CodecRegistry codecRegistry) {
return fn.apply(codecRegistry);
return fn.apply(codecRegistry).bsonValue;
}

private Function<CodecRegistry, BsonValue> astDoc(final String name, final BsonDocument value) {
return (cr) -> new BsonDocument(name, value);
private AstPlaceholder astDoc(final String name, final BsonDocument value) {
return new AstPlaceholder(new BsonDocument(name, value));
}

private Function<CodecRegistry, BsonValue> ast(final String name) {
return (cr) -> new BsonDocument(name, this.toBsonValue(cr));
static final class AstPlaceholder {
private final BsonValue bsonValue;

AstPlaceholder(final BsonValue bsonValue) {
this.bsonValue = bsonValue;
}
}

private Function<CodecRegistry, AstPlaceholder> ast(final String name) {
return (cr) -> new AstPlaceholder(new BsonDocument(name, this.toBsonValue(cr)));
}

private Function<CodecRegistry, BsonValue> ast(final String name, final Expression param1) {
private Function<CodecRegistry, AstPlaceholder> ast(final String name, final Expression param1) {
return (cr) -> {
BsonArray value = new BsonArray();
value.add(this.toBsonValue(cr));
value.add(extractBsonValue(cr, param1));
return new BsonDocument(name, value);
return new AstPlaceholder(new BsonDocument(name, value));
};
}

private Function<CodecRegistry, BsonValue> ast(final String name, final Expression param1, final Expression param2) {
private Function<CodecRegistry, AstPlaceholder> ast(final String name, final Expression param1, final Expression param2) {
return (cr) -> {
BsonArray value = new BsonArray();
value.add(this.toBsonValue(cr));
value.add(extractBsonValue(cr, param1));
value.add(extractBsonValue(cr, param2));
return new BsonDocument(name, value);
return new AstPlaceholder(new BsonDocument(name, value));
};
}

Expand All @@ -89,12 +97,12 @@ <R extends Expression> R assertImplementsAllExpressions() {
return (R) this;
}

private static <R extends Expression> R newMqlExpression(final Function<CodecRegistry, BsonValue> ast) {
private static <R extends Expression> R newMqlExpression(final Function<CodecRegistry, AstPlaceholder> ast) {
return new MqlExpression<>(ast).assertImplementsAllExpressions();
}

private <R extends Expression> R variable(final String variable) {
return newMqlExpression((cr) -> new BsonString(variable));
return newMqlExpression((cr) -> new AstPlaceholder(new BsonString(variable)));
}

/** @see BooleanExpression */
Expand Down Expand Up @@ -159,15 +167,15 @@ public <R extends Expression> ArrayExpression<R> map(final Function<? super T, ?
T varThis = variable("$$this");
return new MqlExpression<>((cr) -> astDoc("$map", new BsonDocument()
.append("input", this.toBsonValue(cr))
.append("in", extractBsonValue(cr, in.apply(varThis)))).apply(cr));
.append("in", extractBsonValue(cr, in.apply(varThis)))));
}

@Override
public ArrayExpression<T> filter(final Function<? super T, ? extends BooleanExpression> cond) {
T varThis = variable("$$this");
return new MqlExpression<T>((cr) -> astDoc("$filter", new BsonDocument()
.append("input", this.toBsonValue(cr))
.append("cond", extractBsonValue(cr, cond.apply(varThis)))).apply(cr));
.append("cond", extractBsonValue(cr, cond.apply(varThis)))));
}

@Override
Expand All @@ -177,7 +185,81 @@ public T reduce(final T initialValue, final BinaryOperator<T> in) {
return newMqlExpression((cr) -> astDoc("$reduce", new BsonDocument()
.append("input", this.toBsonValue(cr))
.append("initialValue", extractBsonValue(cr, initialValue))
.append("in", extractBsonValue(cr, in.apply(varThis, varValue)))).apply(cr));
.append("in", extractBsonValue(cr, in.apply(varThis, varValue)))));
}


/** @see IntegerExpression
* @see NumberExpression */

@Override
public IntegerExpression multiply(final NumberExpression n) {
return newMqlExpression(ast("$multiply", n));
}

@Override
public NumberExpression add(final NumberExpression n) {
return new MqlExpression<>(ast("$add", n));
}

@Override
public NumberExpression divide(final NumberExpression n) {
return new MqlExpression<>(ast("$divide", n));
}

@Override
public NumberExpression max(final NumberExpression n) {
return new MqlExpression<>(ast("$max", n));
}

@Override
public NumberExpression min(final NumberExpression n) {
return new MqlExpression<>(ast("$min", n));
}

@Override
public IntegerExpression round() {
return new MqlExpression<>(ast("$round"));
}

@Override
public NumberExpression round(final IntegerExpression place) {
return new MqlExpression<>(ast("$round", place));
}

@Override
public IntegerExpression multiply(final IntegerExpression i) {
return new MqlExpression<>(ast("$multiply", i));
}

@Override
public IntegerExpression abs() {
return newMqlExpression(ast("$abs"));
}

@Override
public NumberExpression subtract(final NumberExpression n) {
return new MqlExpression<>(ast("$subtract", n));
}

@Override
public IntegerExpression add(final IntegerExpression i) {
return new MqlExpression<>(ast("$add", i));
}

@Override
public IntegerExpression subtract(final IntegerExpression i) {
return new MqlExpression<>(ast("$subtract", i));
}

@Override
public IntegerExpression max(final IntegerExpression i) {
return new MqlExpression<>(ast("$max", i));
}

@Override
public IntegerExpression min(final IntegerExpression i) {
return new MqlExpression<>(ast("$min", i));
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,37 @@
*/
public interface NumberExpression extends Expression {

NumberExpression multiply(NumberExpression n);

default NumberExpression multiply(final Number multiply) {
return this.multiply(Expressions.numberToExpression(multiply));
}

NumberExpression divide(NumberExpression n);

default NumberExpression divide(final Number divide) {
return this.divide(Expressions.numberToExpression(divide));
}

NumberExpression add(NumberExpression n);

default NumberExpression add(final Number add) {
return this.add(Expressions.numberToExpression(add));
}

NumberExpression subtract(NumberExpression n);

default NumberExpression subtract(final Number subtract) {
return this.subtract(Expressions.numberToExpression(subtract));
}

NumberExpression max(NumberExpression n);

NumberExpression min(NumberExpression n);

IntegerExpression round();

NumberExpression round(IntegerExpression place);

NumberExpression abs();
}
Loading