From 6fe50d0af9585385cf01c39e77493304e624ac70 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 14 Oct 2024 08:12:47 +0200 Subject: [PATCH 1/4] cl/error --- src/interface.jl | 6 ++++++ test/runtests.jl | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/src/interface.jl b/src/interface.jl index ac9b90bc..8a436edb 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -67,6 +67,12 @@ function update(tree, model, grad, higher...) update!(t′, x′, grad, higher...) end +function update!(::AbstractRule, model, grad, higher...) + error("""update! must be called with an optimiser state, not a rule. + Call `opt_state = setup(rule, model)` first, then `update!(opt_state, model, grad)`. + """) +end + function update!(tree, model, grad, higher...) # First walk is to accumulate the gradient. This recursion visits every copy of # shared leaves, but stops when branches are absent from the gradient: diff --git a/test/runtests.jl b/test/runtests.jl index fc0fe57f..bf3056c3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -104,6 +104,11 @@ end @test isnan(m3n.γ[3]) end + @testset "friendly error when using rule instead of state" begin + @test_throws ErrorException Optimisers.update!(Adam(), m, gs) # friendly + @test_throws ErrorException Optimisers.update!(Adam(), m, gs[1]) # friendly + end + @testset "Dict support" begin @testset "simple dict" begin d = Dict(:a => [1.0,2.0], :b => [3.0,4.0], :c => 1) From 2bf9a09de3303b29848e8f4689ea02924c8f6168 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 14 Oct 2024 08:26:17 +0200 Subject: [PATCH 2/4] fix docs --- docs/src/api.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/api.md b/docs/src/api.md index 378bf72a..434ee70a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -85,4 +85,5 @@ It is defined in Functors.jl and re-exported by Optimisers.jl here for convenien Functors.KeyPath Functors.haskeypath Functors.getkeypath +Functors.setkeypath! ``` From e8dae73085fe82b4a51eee3304aa7f4275b6d03c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 14 Oct 2024 09:05:12 +0200 Subject: [PATCH 3/4] fix --- test/runtests.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index bf3056c3..bb5cbec9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -105,8 +105,7 @@ end end @testset "friendly error when using rule instead of state" begin - @test_throws ErrorException Optimisers.update!(Adam(), m, gs) # friendly - @test_throws ErrorException Optimisers.update!(Adam(), m, gs[1]) # friendly + @test_throws ErrorException Optimisers.update!(Adam(), rand(2), rand(2)) end @testset "Dict support" begin From 8ceedbe2a8633296aef7090440699e2e6877f586 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 14 Oct 2024 20:25:59 +0200 Subject: [PATCH 4/4] Update src/interface.jl Co-authored-by: Brian Chen --- src/interface.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 8a436edb..be0427ab 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -68,9 +68,9 @@ function update(tree, model, grad, higher...) end function update!(::AbstractRule, model, grad, higher...) - error("""update! must be called with an optimiser state, not a rule. + throw(ArgumentError("""update! must be called with an optimiser state tree, not a rule. Call `opt_state = setup(rule, model)` first, then `update!(opt_state, model, grad)`. - """) + """)) end function update!(tree, model, grad, higher...)