Skip to content

Commit aa6d64c

Browse files
authored
[stdlib] [Sema] Add broadcasting 'adding(_:)' and 'subtracting(_:)' to 'VectorProtocol'. (#25525)
- Add broadcasting 'adding(_:)' and 'subtracting(_:)' to 'VectorProtocol'. These are important for aggregate algorithms such as machine learning optimizers (tensorflow/swift-apis#218). - Make their implementations compiler-derivable. - Add operators for `adding(_:)` and `subtracting(_:)` in a protocol extension. - Comment out all operators in protocol extensions for `VectorProtocol` because a source breakage has been found.
1 parent 8d39256 commit aa6d64c

File tree

9 files changed

+151
-30
lines changed

9 files changed

+151
-30
lines changed

include/swift/AST/KnownIdentifiers.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,12 @@ IDENTIFIER_(typeList)
135135
// AdditiveArithmetic, VectorProtocol
136136
IDENTIFIER(zero)
137137
IDENTIFIER(VectorSpaceScalar)
138+
IDENTIFIER(adding)
139+
IDENTIFIER(subtracting)
138140
IDENTIFIER(scaled)
141+
IDENTIFIER(by)
142+
IDENTIFIER(scale)
143+
IDENTIFIER(x)
139144
// Differentiable
140145
IDENTIFIER(AllDifferentiableVariables)
141146
IDENTIFIER(TangentVector)

lib/Sema/DerivedConformanceVectorProtocol.cpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -202,30 +202,23 @@ static void deriveBodyVectorProtocol_method(AbstractFunctionDecl *funcDecl,
202202
BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true));
203203
}
204204

205-
// Synthesize body for `scaled(by:)`.
206-
static void deriveBodyVectorProtocol_scalarMultiply(
207-
AbstractFunctionDecl *funcDecl, void *) {
208-
auto &C = funcDecl->getASTContext();
209-
deriveBodyVectorProtocol_method(funcDecl, C.Id_scaled, C.getIdentifier("by"));
210-
}
211-
212205
// Synthesize function declaration for a `VectorProtocol` method requirement.
213206
static ValueDecl *deriveVectorProtocol_method(
214-
DerivedConformance &derived, Identifier methodName, Identifier argumentName,
215-
Identifier parameterName, Type parameterType, Type returnType,
216-
AbstractFunctionDecl::BodySynthesizer bodySynthesizer) {
207+
DerivedConformance &derived, Identifier methodBaseName,
208+
Identifier argumentLabel, Identifier parameterName, Type parameterType,
209+
Type returnType, AbstractFunctionDecl::BodySynthesizer bodySynthesizer) {
217210
auto nominal = derived.Nominal;
218211
auto &TC = derived.TC;
219212
auto &C = derived.TC.Context;
220213
auto parentDC = derived.getConformanceContext();
221214

222215
auto *param =
223216
new (C) ParamDecl(VarDecl::Specifier::Default, SourceLoc(), SourceLoc(),
224-
argumentName, SourceLoc(), parameterName, parentDC);
217+
argumentLabel, SourceLoc(), parameterName, parentDC);
225218
param->setInterfaceType(parameterType);
226219
ParameterList *params = ParameterList::create(C, {param});
227220

228-
DeclName declName(C, methodName, params);
221+
DeclName declName(C, methodBaseName, params);
229222
auto funcDecl = FuncDecl::create(C, SourceLoc(), StaticSpellingKind::None,
230223
SourceLoc(), declName, SourceLoc(),
231224
/*Throws*/ false, SourceLoc(),
@@ -258,8 +251,13 @@ static ValueDecl *deriveVectorProtocol_method(
258251
return funcDecl;
259252
}
260253

261-
// Synthesize the `scaled(by:)` function declaration.
262-
static ValueDecl *deriveVectorProtocol_scaled(DerivedConformance &derived) {
254+
/// Synthesize a method declaration that has the following signture:
255+
/// func {methodBaseName}(
256+
/// {argumentLabel} {parameterName}: VectorSpaceScalar
257+
/// ) -> Self
258+
static ValueDecl *deriveVectorProtocol_unaryMethodOnScalar(
259+
DerivedConformance &derived, Identifier methodBaseName,
260+
Identifier argumentLabel, Identifier parameterName) {
263261
auto &C = derived.TC.Context;
264262
auto *nominal = derived.Nominal;
265263
auto *parentDC = derived.getConformanceContext();
@@ -268,18 +266,32 @@ static ValueDecl *deriveVectorProtocol_scaled(DerivedConformance &derived) {
268266
auto scalarType = deriveVectorProtocol_VectorSpaceScalar(nominal, parentDC)
269267
->mapTypeOutOfContext();
270268

269+
auto bodySynthesizer = [](AbstractFunctionDecl *funcDecl, void *ctx) {
270+
auto methodNameAndLabel = reinterpret_cast<Identifier *>(ctx);
271+
deriveBodyVectorProtocol_method(
272+
funcDecl, methodNameAndLabel[0], methodNameAndLabel[1]);
273+
};
274+
Identifier baseNameAndLabel[2] = {methodBaseName, argumentLabel};
271275
return deriveVectorProtocol_method(
272-
derived, C.Id_scaled, C.getIdentifier("by"), C.getIdentifier("scalar"),
273-
scalarType, selfInterfaceType,
274-
{deriveBodyVectorProtocol_scalarMultiply, nullptr});
276+
derived, methodBaseName, argumentLabel, parameterName, scalarType,
277+
selfInterfaceType,
278+
{bodySynthesizer, C.AllocateCopy(baseNameAndLabel).data()});
275279
}
276280

277281
ValueDecl *DerivedConformance::deriveVectorProtocol(ValueDecl *requirement) {
278282
// Diagnose conformances in disallowed contexts.
279283
if (checkAndDiagnoseDisallowedContext(requirement))
280284
return nullptr;
285+
auto &C = requirement->getASTContext();
281286
if (requirement->getBaseName() == TC.Context.Id_scaled)
282-
return deriveVectorProtocol_scaled(*this);
287+
return deriveVectorProtocol_unaryMethodOnScalar(
288+
*this, C.Id_scaled, C.Id_by, C.Id_scale);
289+
if (requirement->getBaseName() == TC.Context.Id_adding)
290+
return deriveVectorProtocol_unaryMethodOnScalar(
291+
*this, C.Id_adding, Identifier(), C.Id_x);
292+
if (requirement->getBaseName() == TC.Context.Id_subtracting)
293+
return deriveVectorProtocol_unaryMethodOnScalar(
294+
*this, C.Id_subtracting, Identifier(), C.Id_x);
283295
TC.diagnose(requirement->getLoc(), diag::broken_vector_protocol_requirement);
284296
return nullptr;
285297
}

lib/Sema/DerivedConformances.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,17 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
346346
return getRequirement(KnownProtocolKind::VectorProtocol);
347347
}
348348

349+
// SWIFT_ENABLE_TENSORFLOW
350+
// VectorProtocol.adding(_:)
351+
// VectorProtocol.subtracting(_:)
352+
if (name.isCompoundName() &&
353+
(name.getBaseName() == ctx.Id_adding ||
354+
name.getBaseName() == ctx.Id_subtracting)) {
355+
auto argumentNames = name.getArgumentNames();
356+
if (argumentNames.size() == 1 && argumentNames[0].empty())
357+
return getRequirement(KnownProtocolKind::VectorProtocol);
358+
}
359+
349360
// SWIFT_ENABLE_TENSORFLOW
350361
// TensorArrayProtocol._unpackTensorHandles(into:)
351362
if (name.isCompoundName() &&

stdlib/public/core/AutoDiff.swift

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ public protocol VectorProtocol : AdditiveArithmetic {
2626
/// The type of scalars in the vector space.
2727
associatedtype VectorSpaceScalar : AdditiveArithmetic
2828

29+
func adding(_ x: VectorSpaceScalar) -> Self
30+
31+
mutating func add(_ x: VectorSpaceScalar)
32+
33+
func subtracting(_ x: VectorSpaceScalar) -> Self
34+
35+
mutating func subtract(_ x: VectorSpaceScalar)
36+
2937
/// Returns `self` multiplied by the given scalar.
3038
func scaled(by scalar: VectorSpaceScalar) -> Self
3139

@@ -34,28 +42,66 @@ public protocol VectorProtocol : AdditiveArithmetic {
3442
}
3543

3644
public extension VectorProtocol {
45+
mutating func add(_ x: VectorSpaceScalar) {
46+
self = adding(x)
47+
}
48+
49+
mutating func subtract(_ x: VectorSpaceScalar) {
50+
self = subtracting(x)
51+
}
52+
3753
mutating func scale(by scalar: VectorSpaceScalar) {
3854
self = scaled(by: scalar)
3955
}
56+
}
57+
58+
/* Note: These default-implemented opreators will slow down type-checking
59+
performance and break existing code.
60+
61+
public extension VectorProtocol {
62+
static func + (lhs: Self, rhs: VectorSpaceScalar) -> Self {
63+
lhs.adding(rhs)
64+
}
65+
66+
static func + (lhs: VectorSpaceScalar, rhs: Self) -> Self {
67+
rhs.adding(lhs)
68+
}
69+
70+
static func += (lhs: inout Self, rhs: VectorSpaceScalar) {
71+
lhs.add(rhs)
72+
}
73+
74+
static func - (lhs: Self, rhs: VectorSpaceScalar) -> Self {
75+
lhs.subtracting(rhs)
76+
}
77+
78+
static func -= (lhs: inout Self, rhs: VectorSpaceScalar) {
79+
lhs.subtract(rhs)
80+
}
4081

4182
static func * (lhs: Self, rhs: VectorSpaceScalar) -> Self {
42-
return lhs.scaled(by: rhs)
83+
lhs.scaled(by: rhs)
4384
}
4485

4586
static func * (lhs: VectorSpaceScalar, rhs: Self) -> Self {
46-
return rhs.scaled(by: lhs)
87+
rhs.scaled(by: lhs)
4788
}
4889

4990
static func *= (lhs: inout Self, rhs: VectorSpaceScalar) {
5091
lhs.scale(by: rhs)
5192
}
5293
}
5394

54-
public extension VectorProtocol where VectorSpaceScalar: SignedNumeric {
95+
public extension VectorProtocol where VectorSpaceScalar : SignedNumeric {
96+
static func - (lhs: VectorSpaceScalar, rhs: Self) -> Self {
97+
-rhs.adding(lhs)
98+
}
99+
55100
static prefix func - (x: Self) -> Self {
56101
.zero - x
57102
}
58103
}
104+
*/
59105

60106
/// A type that mathematically represents a differentiable manifold whose
61107
/// tangent spaces are finite-dimensional.

stdlib/public/core/FloatingPointTypes.swift.gyb

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1879,8 +1879,28 @@ extension ${Self} : Strideable {
18791879
extension ${Self} : VectorProtocol {
18801880
public typealias VectorSpaceScalar = ${Self}
18811881

1882+
public func adding(_ x: ${Self}) -> ${Self} {
1883+
self + x
1884+
}
1885+
1886+
public mutating func add(_ x: ${Self}) {
1887+
self += x
1888+
}
1889+
1890+
public func subtracting(_ x: ${Self}) -> ${Self} {
1891+
self - x
1892+
}
1893+
1894+
public mutating func subtract(_ x: ${Self}) {
1895+
self -= x
1896+
}
1897+
18821898
public func scaled(by scalar: ${Self}) -> ${Self} {
1883-
return self * scalar
1899+
self * scalar
1900+
}
1901+
1902+
public mutating func scale(by scalar: ${Self}) {
1903+
self *= scalar
18841904
}
18851905
}
18861906

test/AutoDiff/refcounting.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ public struct Vector : AdditiveArithmetic, VectorProtocol, Differentiable, Equat
2323
@differentiable(vjp: fakeVJP)
2424
public static func - (lhs: Vector, rhs: Vector) -> Vector { abort() }
2525

26+
public func adding(_ scalar: Float) -> Vector { abort() }
27+
public func subtracting(_ scalar: Float) -> Vector { abort() }
2628
public func scaled(by scalar: Float) -> Vector { abort() }
2729

2830
public static func fakeVJP(lhs: Vector, rhs: Vector) -> (Vector, (Vector) -> (Vector, Vector)) { abort() }

test/Sema/struct_key_path_iterable.swift

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ extension Tensor : Equatable where Scalar : Equatable {}
1111
extension Tensor : AdditiveArithmetic where Scalar : AdditiveArithmetic {}
1212
extension Tensor : VectorProtocol where Scalar : AdditiveArithmetic {
1313
typealias VectorSpaceScalar = Scalar
14+
func adding(_: Scalar) -> Self { self }
15+
func subtracting(_: Scalar) -> Self { self }
1416
func scaled(by scalar: Scalar) -> Self { self }
1517
}
1618

@@ -55,6 +57,12 @@ extension TensorParameters : VectorProtocol {
5557
return TensorParameters(w: lhs.w + rhs.w, b: lhs.b + rhs.b)
5658
}
5759
typealias VectorSpaceScalar = Float
60+
func adding(_ x: VectorSpaceScalar) -> TensorParameters {
61+
return TensorParameters(w: w.adding(x), b: b.adding(x))
62+
}
63+
func subtracting(_ x: VectorSpaceScalar) -> TensorParameters {
64+
return TensorParameters(w: w.subtracting(x), b: b.subtracting(x))
65+
}
5866
func scaled(by scalar: VectorSpaceScalar) -> TensorParameters {
5967
return TensorParameters(w: w.scaled(by: scalar), b: b.scaled(by: scalar))
6068
}
@@ -100,8 +108,8 @@ struct DummyOptimizer<P : KeyPathIterable, Scalar : BinaryFloatingPoint>
100108
parameters: inout P, withGradients gradients: P
101109
) {
102110
for kp in parameters.recursivelyAllWritableKeyPaths(to: Tensor<Scalar>.self) {
103-
firstMoments[keyPath: kp] *= learningRate
104-
parameters[keyPath: kp] -= learningRate * parameters[keyPath: kp]
111+
firstMoments[keyPath: kp].scale(by: learningRate)
112+
parameters[keyPath: kp] -= parameters[keyPath: kp].scaled(by: learningRate)
105113
}
106114
}
107115
}

test/Sema/struct_vector_protocol.swift

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,28 @@ _ x: inout T, scalar: T.VectorSpaceScalar
77
// Test `AdditiveArithmetic` requirements: `zero`, `+`, `-`.
88
let zero = T.zero
99
x += x + zero
10-
x -= x - zero
11-
// Test `VectorProtocol` requirements: `VectorSpaceScalar`, `*`.
12-
x *= scalar
13-
_ = scalar * x
14-
_ = x * scalar
10+
x += x - zero
11+
// Test `VectorProtocol` requirements: `VectorSpaceScalar`, `adding(_:)`, `add(_:)`
12+
// `subtracting(_:)`, `subtract(_:)`, `scaled(by:)`, and `scale(by:)`.
13+
x.add(scalar)
14+
x.add(scalar)
15+
x.scale(by: scalar)
16+
_ = x.adding(scalar)
17+
_ = x.subtracting(scalar)
18+
_ = x.scaled(by: scalar)
19+
20+
// NOTE: Operators have been disabled for type checker performance reasons.
21+
// x += x + zero
22+
// x -= x - zero
23+
// Test `VectorProtocol` requirements: `VectorSpaceScalar`, `+`, `-`, `*`.
24+
// x += scalar
25+
// x -= scalar
26+
// x *= scalar
27+
// _ = x + scalar
28+
// _ = scalar + x
29+
// _ = x - scalar
30+
// _ = scalar * x
31+
// _ = x * scalar
1532
}
1633

1734
struct Float2: VectorProtocol {

utils/update_checkout/update-checkout-config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@
375375
"clang-tools-extra": "swift-DEVELOPMENT-SNAPSHOT-2019-06-06-a",
376376
"libcxx": "swift-DEVELOPMENT-SNAPSHOT-2019-06-06-a",
377377
"tensorflow": "ebc41609e27dcf0998d8970e77a2e1f53e13ac86",
378-
"tensorflow-swift-apis": "5d3ef57b501781f1bab3b4ca85e8b8fc91671c14",
378+
"tensorflow-swift-apis": "7a3ed481bba53a7cd82f8a46c0df9f09a6e9747f",
379379
"indexstore-db": "swift-DEVELOPMENT-SNAPSHOT-2019-06-06-a",
380380
"sourcekit-lsp": "swift-DEVELOPMENT-SNAPSHOT-2019-06-06-a"
381381
}

0 commit comments

Comments
 (0)