From 09ca632af17d91a834c502b145ce97d5dc75fa1f Mon Sep 17 00:00:00 2001
From: Carlo Lucibello <carlo.lucibello@gmail.com>
Date: Mon, 31 Mar 2025 10:06:10 +0200
Subject: [PATCH 1/2] julia 1.10

---
 Project.toml                |  2 +-
 src/dim_helpers/ConvDims.jl |  2 +-
 src/gemm.jl                 |  2 +-
 src/utils.jl                | 18 ------------------
 test.jl                     | 11 +++++++++++
 test/conv.jl                |  2 +-
 test/dropout.jl             |  5 +----
 test/pooling.jl             |  2 +-
 test/runtests.jl            |  2 +-
 test/testsuite/gather.jl    |  2 +-
 test/testsuite/scatter.jl   |  2 +-
 11 files changed, 20 insertions(+), 30 deletions(-)
 create mode 100644 test.jl

diff --git a/Project.toml b/Project.toml
index c9647205c..02e14607b 100644
--- a/Project.toml
+++ b/Project.toml
@@ -48,4 +48,4 @@ ScopedValues = "1.3.0"
 SpecialFunctions = "2"
 Statistics = "1"
 cuDNN = "1"
-julia = "1.9"
+julia = "1.10"
diff --git a/src/dim_helpers/ConvDims.jl b/src/dim_helpers/ConvDims.jl
index e8bcc08f4..9e02010d3 100644
--- a/src/dim_helpers/ConvDims.jl
+++ b/src/dim_helpers/ConvDims.jl
@@ -73,7 +73,7 @@ function im2col_dims(c::ConvDims)
         # Size of single dotproduct within convolution
         prod(kernel_size(c))*channels_in(c),
         # One workspace per thread
-        VERSION > v"1.9.0-0" ? Threads.nthreads(:default) : Threads.nthreads(),
+        Threads.nthreads(:default),
     )
 end
 
diff --git a/src/gemm.jl b/src/gemm.jl
index 9a3c6cd57..e05174d17 100644
--- a/src/gemm.jl
+++ b/src/gemm.jl
@@ -95,7 +95,7 @@ for (gemm, elt) in gemm_datatype_mappings
             strC = Base.stride(C, 3)
 
             n_threads = min(
-                VERSION > v"1.9.0-0" ? Threads.nthreads(:default) : Threads.nthreads(),
+                Threads.nthreads(:default),
                 1 + max(length(A), length(B)) ÷ 8000)
             # In some tests, size (20,20,20) is worth splitting between two threads,
             # as is size (32,32,8).
diff --git a/src/utils.jl b/src/utils.jl
index baf95c8da..6d82a81ec 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -144,21 +144,3 @@ function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!), f:
     rrule_via_ad(cfg, broadcast, f, x, ys...)
 end
 
-# Could get this from Compat.jl instead
-# https://github.com/JuliaLang/julia/pull/39794
-if VERSION < v"1.7.0-DEV.793"
-    struct Returns{V} <: Function
-        value::V
-        Returns{V}(value) where {V} = new{V}(value)
-        Returns(value) = new{Core.Typeof(value)}(value)
-    end
-
-    (obj::Returns)(args...; kw...) = obj.value
-    function Base.show(io::IO, obj::Returns)
-        show(io, typeof(obj))
-        print(io, "(")
-        show(io, obj.value)
-        print(io, ")")
-    end
-end
-
diff --git a/test.jl b/test.jl
new file mode 100644
index 000000000..a20105f24
--- /dev/null
+++ b/test.jl
@@ -0,0 +1,11 @@
+import Metal, NNlib, Flux
+
+dev = Flux.get_device()
+
+src, idx = Int32[1 2 3 4; 5 6 7 8], Int32[2,1,1,5]
+srcd, idxd = dev(x), dev(idx)
+y = NNlib.scatter(+, src, idx)
+yd = dev(zero(y))
+NNlib.scatter!(+, yd, srcd, idxd)
+
+
diff --git a/test/conv.jl b/test/conv.jl
index cf3232778..8e52c846a 100644
--- a/test/conv.jl
+++ b/test/conv.jl
@@ -908,7 +908,7 @@ end
   gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
 end
 
-@static if Test_Enzyme
+if NNLIB_TEST_ENZYME
 
 @testset "EnzymeRules: conv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
   x = rand(rng, repeat([5], spatial_rank)..., 3, 2)
diff --git a/test/dropout.jl b/test/dropout.jl
index 0da70111e..65aac8b62 100644
--- a/test/dropout.jl
+++ b/test/dropout.jl
@@ -16,9 +16,6 @@ using Zygote, StableRNGs, ChainRulesCore, Enzyme
     @test size(@inferred dropout(rng, x1, 0.1; dims=2)) == (3, 4)
 
     x2 = Diagonal(randn(Float32, 10))  # Just to check it runs on weird matrices.
-    if VERSION > v"1.8-"  # on 1.6 this makes a sparse array.
-        @test dropout(x2, 0.3) isa Matrix{Float32}  # does not infer, but that's OK?
-    end
 
     # Values
     @test dropout(x1, 0) == x1
@@ -76,7 +73,7 @@ using Zygote, StableRNGs, ChainRulesCore, Enzyme
     @test_throws ArgumentError dropout!(y1, x1, 3)
 end
 
-@static if Test_Enzyme
+if NNLIB_TEST_ENZYME
 
 @testset "EnzymeRules: dropout " begin
     rng = Random.default_rng()
diff --git a/test/pooling.jl b/test/pooling.jl
index f9d57ade7..1b11a1aea 100644
--- a/test/pooling.jl
+++ b/test/pooling.jl
@@ -948,7 +948,7 @@ end
   gradtest(x -> sum(meanpool(x, k)), x)
 end
 
-@static if Test_Enzyme
+if NNLIB_TEST_ENZYME
 
 @testset "EnzymeRules: pooling! $pool spatial_rank=$spatial_rank " for spatial_rank in (1, 2),
                                                                                 (pool, pool!) in ((maxpool, maxpool!), (meanpool, meanpool!))
diff --git a/test/runtests.jl b/test/runtests.jl
index b8080b6ba..6805672e7 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -18,10 +18,10 @@ import ReverseDiff as RD        # used in `pooling.jl`
 import Pkg
 using SpecialFunctions
 
-const Test_Enzyme = VERSION <= v"1.10-"
 
 DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursive=true)
 
+const NNLIB_TEST_ENZYME = true
 # ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests
 # ENV["NNLIB_TEST_AMDGPU"] = "true" # uncomment to run AMDGPU tests
 # ENV["NNLIB_TEST_CPU"] = "false" # uncomment to skip CPU tests
diff --git a/test/testsuite/gather.jl b/test/testsuite/gather.jl
index 92e3bfb7d..189533385 100644
--- a/test/testsuite/gather.jl
+++ b/test/testsuite/gather.jl
@@ -154,7 +154,7 @@ function gather_testsuite(Backend)
             gradtest_fn((s, i) -> gather(s, i), src, idx)
     end
 
-    @static if Test_Enzyme
+    if NNLIB_TEST_ENZYME
 
     @testset "EnzymeRules: gather! gradient for scalar index" begin
         src = device(Float64[3, 4, 5, 6, 7])
diff --git a/test/testsuite/scatter.jl b/test/testsuite/scatter.jl
index aa0b1c41e..ddbf8eb67 100644
--- a/test/testsuite/scatter.jl
+++ b/test/testsuite/scatter.jl
@@ -208,7 +208,7 @@ function scatter_testsuite(Backend)
         end
 
 
-        @static if Test_Enzyme
+        if NNLIB_TEST_ENZYME
 
         @testset "EnzymeRules" begin
             idx = device([2, 2, 3, 4, 4])

From bcac8d6ab2e14f1a07c15b024b9db6245177519b Mon Sep 17 00:00:00 2001
From: Carlo Lucibello <carlo.lucibello@gmail.com>
Date: Tue, 8 Apr 2025 09:22:00 +0200
Subject: [PATCH 2/2] cleanup

---
 test.jl | 11 -----------
 1 file changed, 11 deletions(-)
 delete mode 100644 test.jl

diff --git a/test.jl b/test.jl
deleted file mode 100644
index a20105f24..000000000
--- a/test.jl
+++ /dev/null
@@ -1,11 +0,0 @@
-import Metal, NNlib, Flux
-
-dev = Flux.get_device()
-
-src, idx = Int32[1 2 3 4; 5 6 7 8], Int32[2,1,1,5]
-srcd, idxd = dev(x), dev(idx)
-y = NNlib.scatter(+, src, idx)
-yd = dev(zero(y))
-NNlib.scatter!(+, yd, srcd, idxd)
-
-