Skip to content

Commit 8dc5f8d

Browse files
authored
Merge branch 'main' into dw/fix_to_composite
2 parents 0b21548 + 394f539 commit 8dc5f8d

File tree

3 files changed

+9
-25
lines changed

3 files changed

+9
-25
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name: CI
22
on:
33
push:
4-
branches: [master]
4+
branches: [main]
55
tags: [v*]
66
pull_request:
77
schedule:

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "1.2.4"
3+
version = "1.3.1"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1212

1313
[compat]
14-
ChainRulesCore = "1"
14+
ChainRulesCore = "1.11.2"
1515
Compat = "3"
1616
FiniteDifferences = "0.12.12"
1717
julia = "1"

src/rand_tangent.jl

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,6 @@ Rather it is an arbitary value, that is generated using the `rng`.
77
"""
88
rand_tangent(x) = rand_tangent(Random.GLOBAL_RNG, x)
99

10-
rand_tangent(rng::AbstractRNG, x::Symbol) = NoTangent()
11-
rand_tangent(rng::AbstractRNG, x::AbstractChar) = NoTangent()
12-
rand_tangent(rng::AbstractRNG, x::AbstractString) = NoTangent()
13-
14-
rand_tangent(rng::AbstractRNG, x::Integer) = NoTangent()
15-
1610
# Try and make nice numbers with short decimal representations for good error messages
1711
# while also not biasing the sample space too much
1812
function rand_tangent(rng::AbstractRNG, x::T) where {T<:Number}
@@ -24,25 +18,11 @@ function rand_tangent(rng::AbstractRNG, x::ComplexF64)
2418
return ComplexF64(rand(rng, -9:0.1:9), rand(rng, -9:0.1:9))
2519
end
2620

27-
#BigFloat/MPFR is finicky about short numbers, this doesn't always work as well as it should
28-
21+
# BigFloat/MPFR is finicky about short numbers, this doesn't always work as well as it should
2922
# multiply by 9 to give a bigger range of values tested: no so tightly clustered around 0.
3023
rand_tangent(rng::AbstractRNG, ::BigFloat) = round(big(9 * randn(rng)), digits=5, base=2)
3124

32-
33-
rand_tangent(rng::AbstractRNG, x::Array{<:Any, 0}) = _compress_notangent(fill(rand_tangent(rng, x[])))
34-
rand_tangent(rng::AbstractRNG, x::Array) = _compress_notangent(rand_tangent.(Ref(rng), x))
35-
36-
# All other AbstractArray's can be handled using the ProjectTo mechanics.
37-
# and follow the same requirements
38-
function rand_tangent(rng::AbstractRNG, x::AbstractArray)
39-
return _compress_notangent(ProjectTo(x)(rand_tangent(rng, collect(x))))
40-
end
41-
42-
# TODO: arguably ProjectTo should handle this for us for AbstactArrays
43-
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/410
44-
_compress_notangent(::AbstractArray{NoTangent}) = NoTangent()
45-
_compress_notangent(x) = x
25+
rand_tangent(rng::AbstractRNG, x::AbstractArray) = ProjectTo(x)(rand_tangent.(Ref(rng), x))
4626

4727
function rand_tangent(rng::AbstractRNG, x::T) where {T}
4828
if !isstructtype(T)
@@ -65,5 +45,9 @@ function rand_tangent(rng::AbstractRNG, x::T) where {T}
6545
end
6646
end
6747

48+
rand_tangent(rng::AbstractRNG, x::Symbol) = NoTangent()
49+
rand_tangent(rng::AbstractRNG, x::AbstractChar) = NoTangent()
50+
rand_tangent(rng::AbstractRNG, x::AbstractString) = NoTangent()
51+
rand_tangent(rng::AbstractRNG, x::Integer) = NoTangent()
6852
rand_tangent(rng::AbstractRNG, ::Type) = NoTangent()
6953
rand_tangent(rng::AbstractRNG, ::Module) = NoTangent()

0 commit comments

Comments
 (0)