diff --git a/integration_test/cases/repo.exs b/integration_test/cases/repo.exs index 42cdaa8b43..75474a77bc 100644 --- a/integration_test/cases/repo.exs +++ b/integration_test/cases/repo.exs @@ -563,6 +563,7 @@ defmodule Ecto.Integration.RepoTest do test "get_by(!)" do post1 = TestRepo.insert!(%Post{title: "1", text: "hai"}) post2 = TestRepo.insert!(%Post{title: "2", text: "hello"}) + post3 = TestRepo.insert!(%Post{title: "3", text: nil}) assert post1 == TestRepo.get_by(Post, id: post1.id) assert post1 == TestRepo.get_by(Post, text: post1.text) @@ -570,11 +571,13 @@ defmodule Ecto.Integration.RepoTest do assert post2 == TestRepo.get_by(Post, id: to_string(post2.id)) # With casting assert nil == TestRepo.get_by(Post, text: "hey") assert nil == TestRepo.get_by(Post, id: post2.id, text: "hey") + assert post3 == TestRepo.get_by(Post, text: nil) assert post1 == TestRepo.get_by!(Post, id: post1.id) assert post1 == TestRepo.get_by!(Post, text: post1.text) assert post1 == TestRepo.get_by!(Post, id: post1.id, text: post1.text) assert post2 == TestRepo.get_by!(Post, id: to_string(post2.id)) # With casting + assert post3 == TestRepo.get_by!(Post, text: nil) assert post1 == TestRepo.get_by!(Post, %{id: post1.id}) diff --git a/lib/ecto/repo/queryable.ex b/lib/ecto/repo/queryable.ex index 38223dd957..cffe26393a 100644 --- a/lib/ecto/repo/queryable.ex +++ b/lib/ecto/repo/queryable.ex @@ -341,7 +341,14 @@ defmodule Ecto.Repo.Queryable do end defp query_for_get_by(_repo, queryable, clauses) do - Query.where(queryable, [], ^Enum.to_list(clauses)) + clauses + |> Enum.to_list + |> Enum.reduce(queryable, fn + ({key, nil}, query) -> + Query.where(query, [x], is_nil(field(x, ^key))) + ({key, value}, query) -> + Query.where(query, [x], field(x, ^key) == ^value) + end) end defp query_for_aggregate(queryable, aggregate, field) do diff --git a/test/ecto/repo_test.exs b/test/ecto/repo_test.exs index 953c87a768..66fe9a5d70 100644 --- a/test/ecto/repo_test.exs +++ b/test/ecto/repo_test.exs @@ -111,6 +111,7 @@ defmodule Ecto.RepoTest do test "validates get_by" do TestRepo.get_by(MySchema, id: 123) TestRepo.get_by(MySchema, %{id: 123}) + TestRepo.get_by(MySchema, id: nil) message = ~r"value `:atom` in `where` cannot be cast to type :id in query" assert_raise Ecto.Query.CastError, message, fn ->