diff --git a/ext/NNlibCUDACUDNNExt/batchnorm.jl b/ext/NNlibCUDACUDNNExt/batchnorm.jl
index 2c38f009e..4b7793b91 100644
--- a/ext/NNlibCUDACUDNNExt/batchnorm.jl
+++ b/ext/NNlibCUDACUDNNExt/batchnorm.jl
@@ -3,6 +3,8 @@ using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward,
              cudnnBatchNormalizationForwardTraining
 import NNlib: batchnorm, ∇batchnorm
 
+using EnzymeCore
+
 # TODO: replace with new cudnn normalization interface
 # https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/normalization.jl
 
@@ -153,3 +155,144 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
         scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),
         xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
 end
+
+
+
+function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(cudnnBNForward!)}, ::Type{RT}, 
+                    y::OutType, 
+                    g,
+                    b,
+                    x,
+                    running_mean, running_var, momentum::EnzymeCore.Const{<:Real}; kws...) where {OutType, RT}
+
+    if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
+        func.val(y.val, b.val, x.val, running_mean.val, running_var.val, momentum.val; kws...)
+    end
+
+    primal = if EnzymeCore.EnzymeRules.needs_primal(config)
+        y.val
+    else
+        nothing
+    end
+    shadow = if EnzymeCore.EnzymeRules.needs_shadow(config)
+        y.dval
+    else
+        nothing
+    end
+
+    cache_g = nothing
+    cache_x = nothing
+    cache_running_mean = nothing
+    cache_running_var = nothing
+
+    if !(typeof(y) <: EnzymeCore.Const)
+      if !(typeof(x) <: EnzymeCore.Const) || !(typeof(g) <: EnzymeCore.Const) || !(typeof(b) <: EnzymeCore.Const)
+
+        if EnzymeCore.EnzymeRules.overwritten(config)[3]
+          cache_g = copy(g.val)
+        end
+        if EnzymeCore.EnzymeRules.overwritten(config)[5]
+          cache_x = copy(x.val)
+        end
+        if EnzymeCore.EnzymeRules.overwritten(config)[6]
+          cache_running_mean = copy(running_mean.val)
+        end
+        if EnzymeCore.EnzymeRules.overwritten(config)[7]
+          cache_running_var = copy(running_var.val)
+        end
+
+      end
+    end
+
+    cache = (cache_g, cache_x, cache_running_mean, cache_running_var)
+
+    return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache)
+end
+
+function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(cudnnBNForward!)}, ::Type{RT}, 
+                                        cache,
+                                        y::OutType, g, b, x, running_mean, running_var, momentum::EnzymeCore.Const{<:Real}; kws...) where {OutType, RT}
+
+    cache_g, cache_x, cache_running_mean, cache_running_var = cache
+    
+    if !(typeof(y) <: EnzymeCore.Const)
+      if !(typeof(x) <: EnzymeCore.Const) || !(typeof(g) <: EnzymeCore.Const) || !(typeof(b) <: EnzymeCore.Const)
+
+        if EnzymeCore.EnzymeRules.overwritten(config)[3]
+          cache_g = g.val
+        end
+        if EnzymeCore.EnzymeRules.overwritten(config)[5]
+          cache_x = x.val
+        end
+        if EnzymeCore.EnzymeRules.overwritten(config)[6]
+          cache_running_mean = running_mean.val
+        end
+        if EnzymeCore.EnzymeRules.overwritten(config)[7]
+          cache_running_var = running_var.val
+        end
+
+      end
+    end
+
+    dys = y.dval
+    dgs = (typeof(g) <: EnzymeCore.Const) ? dys : g.dval
+    dbs = (typeof(b) <: EnzymeCore.Const) ? dbs : b.dval
+    dxs = (typeof(x) <: EnzymeCore.Const) ? dxs : x.dval
+
+    if EnzymeCore.EnzymeRules.width(config) == 1
+        dys = (dys,)
+        dxs = (dxs,)
+        dgs = (dgs,)
+        dbs = (dbs,)
+    end
+
+    for (dy, dx, dg, db) in zip(dys, dxs, dgs, dbs)
+        if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val
+
+          if !((typeof(x) <: EnzymeCore.Const) || dx === x.val)
+             || !((typeof(g) <: EnzymeCore.Const) || dg === g.val)
+             || !((typeof(b) <: EnzymeCore.Const) || db === b.val)
+
+            # dx values
+            alpha = T(1)
+            beta = T(1)
+
+            # dx = alpha * newVal + beta old(dx)
+            # if x is constant, we can use zero for both
+            # otherwise we want to do dx += newVal, aka alpha=beta=1
+            if x <: EnzymeCore.Const
+              alpha = T(0)
+              beta = T(0)
+              dx = similar(x.val)
+            end
+
+            # dg / db values
+            alpha = T(1)
+            beta = T(1)
+
+            if g <: EnzymeCore.Const && b <: EnzymeCore.Const
+              dalpha = T(0)
+              dbeta = T(0)
+            end
+
+            if g <: EnzymeCore.Const
+              dg = similar(g.val)
+            end
+
+            if b <: EnzymeCore.Const
+              db = similar(b.val)
+            end
+
+            cudnnBNBackward!(dg, cache_g, db, dx, cache_x, dy,
+                             cache_running_mean, cache_running_var,
+                              momentum.val; alpha, beta, dalpha, dbeta; kw...)
+
+          end
+
+          dy .= 0
+
+        end
+    end
+
+    return (nothing, nothing, nothing, nothing, nothing, nothing, nothing)
+end