From b261537485c11cd4a8f57651648ed702e6498087 Mon Sep 17 00:00:00 2001 From: Foxtrek_64 Date: Sun, 28 Feb 2021 14:43:30 -0600 Subject: [PATCH 01/14] Fix #20064 --- .../src/System/Linq/CachedReflection.cs | 42 ++++++++++ .../src/System/Linq/Queryable.cs | 82 +++++++++++++++++++ .../System.Linq/src/System/Linq/First.cs | 24 ++++-- .../System.Linq/src/System/Linq/Last.cs | 20 +++-- .../System.Linq/src/System/Linq/Single.cs | 12 ++- 5 files changed, 165 insertions(+), 15 deletions(-) diff --git a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs index f0eb0f169bf3c0..2d37fe2f53016a 100644 --- a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs +++ b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs @@ -284,6 +284,20 @@ public static MethodInfo FirstOrDefault_TSource_2(Type TSource) => (s_FirstOrDefault_TSource_2 ??= new Func, Expression>, object?>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition()) .MakeGenericMethod(TSource); + private static MethodInfo? s_FirstOrDefault_TSource_3; + + public static MethodInfo FirstOrDefault_TSource_3(Type TSource) => + (s_FirstOrDefault_TSource_3 ?? + (s_FirstOrDefault_TSource_3 = new Func, object?, object?>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + .MakeGenericMethod(); + + private static MethodInfo? s_FirstOrDefault_TSource_4; + + public static MethodInfo FirstOrDefault_TSource_4(Type TSource) => + (s_FirstOrDefault_TSource_4 ?? + (s_FirstOrDefault_TSource_4 = new Func, Expression>, object?, object?>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + .MakeGenericMethod(TSource); + private static MethodInfo? s_GroupBy_TSource_TKey_2; public static MethodInfo GroupBy_TSource_TKey_2(Type TSource, Type TKey) => @@ -392,6 +406,20 @@ public static MethodInfo LastOrDefault_TSource_2(Type TSource) => (s_LastOrDefault_TSource_2 ??= new Func, Expression>, object?>(Queryable.LastOrDefault).GetMethodInfo().GetGenericMethodDefinition()) .MakeGenericMethod(TSource); + private static MethodInfo? s_LastOrDefault_TSource_3; + + public static MethodInfo LastOrDefault_TSource_3(Type TSource) => + (s_LastOrDefault_TSource_3 ?? + (s_LastOrDefault_TSource_3 = new Func, object?, object?>(Queryable.LastOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + .MakeGenericMethod(TSource); + + private static MethodInfo? s_LastOrDefault_TSource_4; + + public static MethodInfo LastOrDefault_TSource_4(Type TSource) => + (s_LastOrDefault_TSource_4 ?? + (s_LastOrDefault_TSource_4 = new Func, Expression>, object?, object?>(Queryable.LastOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + .MakeGenericMethod(TSource); + private static MethodInfo? s_LongCount_TSource_1; public static MethodInfo LongCount_TSource_1(Type TSource) => @@ -536,6 +564,20 @@ public static MethodInfo SingleOrDefault_TSource_2(Type TSource) => (s_SingleOrDefault_TSource_2 ??= new Func, Expression>, object?>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition()) .MakeGenericMethod(TSource); + private static MethodInfo? s_SingleOrDefault_TSource_3; + + public static MethodInfo SingleOrDefault_TSource_3(Type TSource) => + (s_SingleOrDefault_TSource_3 ?? + (s_SingleOrDefault_TSource_3 = new Func, object?, object?>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + .MakeGenericMethod(TSource); + + private static MethodInfo? s_SingleOrDefault_TSource_4; + + public static MethodInfo SingleOrDefault_TSource_4(Type TSource) => + (s_SingleOrDefault_TSource_4 ?? + (s_SingleOrDefault_TSource_4 = new Func, Expression>, object?, object?>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + .MakeGenericMethod(); + private static MethodInfo? s_Skip_TSource_2; public static MethodInfo Skip_TSource_2(Type TSource) => diff --git a/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs b/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs index c08bcab79af79c..065218f1ba12ce 100644 --- a/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs +++ b/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs @@ -864,6 +864,18 @@ public static TSource First(this IQueryable source, Expression CachedReflectionInfo.FirstOrDefault_TSource_1(typeof(TSource)), source.Expression)); } + [DynamicDependency("FirstOrDefault`1", typeof(Enumerable))] + public static TSource? FirstOrDefault(this IQueryable source, TSource? defaultValue) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.FirstOrDefault_TSource_3(typeof(TSource)), + source.Expression, Expression.Constant(defaultValue, typeof(TSource)))); + } + [DynamicDependency("FirstOrDefault`1", typeof(Enumerable))] public static TSource? FirstOrDefault(this IQueryable source, Expression> predicate) { @@ -879,6 +891,21 @@ public static TSource First(this IQueryable source, Expression )); } + [DynamicDependency("FirstOrDefault`1", typeof(Enumerable))] + public static TSource? FirstOrDefault(this IQueryable source, Expression> predicate, TSource? defaultValue) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + if (predicate == null) + throw Error.ArgumentNull(nameof(predicate)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.FirstOrDefault_TSource_4(typeof(TSource)), + source.Expression, Expression.Quote(predicate), Expression.Constant(defaultValue, typeof(TSource)) + )); + } + [DynamicDependency("Last`1", typeof(Enumerable))] public static TSource Last(this IQueryable source) { @@ -916,6 +943,18 @@ public static TSource Last(this IQueryable source, Expression< CachedReflectionInfo.LastOrDefault_TSource_1(typeof(TSource)), source.Expression)); } + [DynamicDependency("LastOrDefault`1", typeof(Enumerable))] + public static TSource? LastOrDefault(this IQueryable source, TSource? defaultValue) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.LastOrDefault_TSource_3(typeof(TSource)), + source.Expression, Expression.Constant(defaultValue, typeof(TSource)))); + } + [DynamicDependency("LastOrDefault`1", typeof(Enumerable))] public static TSource? LastOrDefault(this IQueryable source, Expression> predicate) { @@ -931,6 +970,21 @@ public static TSource Last(this IQueryable source, Expression< )); } + [DynamicDependency("LastOrDefault`1", typeof(Enumerable))] + public static TSource? LastOrDefault(this IQueryable source, Expression> predicate, TSource? defaultValue) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + if (predicate == null) + throw Error.ArgumentNull(nameof(predicate)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.LastOrDefault_TSource_4(typeof(TSource)), + source.Expression, Expression.Quote(predicate), Expression.Constant(defaultValue, typeof(TSource)) + )); + } + [DynamicDependency("Single`1", typeof(Enumerable))] public static TSource Single(this IQueryable source) { @@ -968,6 +1022,19 @@ public static TSource Single(this IQueryable source, Expressio CachedReflectionInfo.SingleOrDefault_TSource_1(typeof(TSource)), source.Expression)); } + [DynamicDependency("SingleOrDefault`1", typeof(Enumerable))] + public static TSource? SingleOrDefault(this IQueryable source, TSource? defaultValue) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.SingleOrDefault_TSource_3(typeof(TSource)), + source.Expression, Expression.Constant(defaultValue, typeof(TSource)))); + + } + [DynamicDependency("SingleOrDefault`1", typeof(Enumerable))] public static TSource? SingleOrDefault(this IQueryable source, Expression> predicate) { @@ -983,6 +1050,21 @@ public static TSource Single(this IQueryable source, Expressio )); } + [DynamicDependency("SingleOrDefault`1", typeof(Enumerable))] + public static TSource? SingleOrDefault(this IQueryable source, Expression> predicate, TSource? defaultValue) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + if (predicate == null) + throw Error.ArgumentNull(nameof(predicate)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.SingleOrDefault_TSource_4(typeof(TSource)), + source.Expression, Expression.Quote(predicate), Expression.Constant(defaultValue, typeof(TSource)) + )); + } + [DynamicDependency("ElementAt`1", typeof(Enumerable))] public static TSource ElementAt(this IQueryable source, int index) { diff --git a/src/libraries/System.Linq/src/System/Linq/First.cs b/src/libraries/System.Linq/src/System/Linq/First.cs index 64b89915c259f8..6751a14e9f1550 100644 --- a/src/libraries/System.Linq/src/System/Linq/First.cs +++ b/src/libraries/System.Linq/src/System/Linq/First.cs @@ -31,12 +31,21 @@ public static TSource First(this IEnumerable source, Func(this IEnumerable source) => - source.TryGetFirst(out bool _); + source.TryGetFirst(out _); + + public static TSource? FirstOrDefault(this IEnumerable source, TSource? defaultValue) => + source.TryGetFirst(defaultValue, out _); public static TSource? FirstOrDefault(this IEnumerable source, Func predicate) => - source.TryGetFirst(predicate, out bool _); + source.TryGetFirst(predicate, out _); + + public static TSource? FirstOrDefault(this IEnumerable source, Func predicate, TSource? defaultValue) => + source.TryGetFirst(predicate, defaultValue, out _); - private static TSource? TryGetFirst(this IEnumerable source, out bool found) + private static TSource? TryGetFirst(this IEnumerable source, out bool found) => + source.TryGetFirst(default(TSource), out found); + + private static TSource? TryGetFirst(this IEnumerable source, TSource? defaultValue, out bool found) { if (source == null) { @@ -69,10 +78,13 @@ public static TSource First(this IEnumerable source, Func(this IEnumerable source, Func predicate, out bool found) + private static TSource? TryGetFirst(this IEnumerable source, Func predicate, out bool found) => + source.TryGetFirst(predicate, default(TSource), out found); + + private static TSource? TryGetFirst(this IEnumerable source, Func predicate, TSource? defaultValue, out bool found) { if (source == null) { @@ -94,7 +106,7 @@ public static TSource First(this IEnumerable source, Func(this IEnumerable source, Func(this IEnumerable source) => - source.TryGetLast(out bool _); + public static TSource? LastOrDefault(this IEnumerable source) + => source.TryGetLast(out _); + public static TSource? LastOrDefault(this IEnumerable source, TSource? defaultValue) + => source.TryGetLast(defaultValue, out _); - public static TSource? LastOrDefault(this IEnumerable source, Func predicate) => - source.TryGetLast(predicate, out bool _); + public static TSource? LastOrDefault(this IEnumerable source, Func predicate) + => source.TryGetLast(predicate, out bool _); + public static TSource? LastOrDefault(this IEnumerable source, Func predicate, TSource? defaultValue) + => source.TryGetLast(predicate, defaultValue, out bool _); private static TSource? TryGetLast(this IEnumerable source, out bool found) + => source.TryGetLast(default(TSource?), out found); + private static TSource? TryGetLast(this IEnumerable source, TSource? defaultValue, out bool found) { if (source == null) { @@ -77,10 +83,12 @@ public static TSource Last(this IEnumerable source, Func(this IEnumerable source, Func predicate, out bool found) + => source.TryGetLast(predicate, default(TSource?), out found); + private static TSource? TryGetLast(this IEnumerable source, Func predicate, TSource? defaultValue, out bool found) { if (source == null) { @@ -135,7 +143,7 @@ public static TSource Last(this IEnumerable source, Func(this IEnumerable source, Func(this IEnumerable source) + => source.SingleOrDefault(default(TSource)); + + public static TSource? SingleOrDefault(this IEnumerable source, TSource? defaultValue) { if (source == null) { @@ -95,7 +98,7 @@ public static TSource Single(this IEnumerable source, Func(this IEnumerable source, Func(this IEnumerable source, Func(this IEnumerable source, Func predicate) + => source.SingleOrDefault(predicate, default); + + public static TSource? SingleOrDefault(this IEnumerable source, Func predicate, TSource? defaultValue) { if (source == null) { @@ -153,7 +159,7 @@ public static TSource Single(this IEnumerable source, Func Date: Sun, 28 Feb 2021 16:45:20 -0600 Subject: [PATCH 02/14] Add API to ref assembly --- .../System.Linq.Queryable/ref/System.Linq.Queryable.cs | 6 ++++++ src/libraries/System.Linq/ref/System.Linq.cs | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs b/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs index 692f4bc23ac5df..3635765940eeee 100644 --- a/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs +++ b/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs @@ -86,7 +86,9 @@ public static partial class Queryable public static System.Linq.IQueryable Except(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2) { throw null; } public static System.Linq.IQueryable Except(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? FirstOrDefault(this System.Linq.IQueryable source) { throw null; } + public static TSource? FirstOrDefault(this System.Linq.IQueryable source, TSource? defaultValue) { throw null; } public static TSource? FirstOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } + public static TSource? FirstOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate, TSource? defaultValue) { throw null; } public static TSource First(this System.Linq.IQueryable source) { throw null; } public static TSource First(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } public static System.Linq.IQueryable> GroupBy(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector) { throw null; } @@ -104,7 +106,9 @@ public static partial class Queryable public static System.Linq.IQueryable Join(this System.Linq.IQueryable outer, System.Collections.Generic.IEnumerable inner, System.Linq.Expressions.Expression> outerKeySelector, System.Linq.Expressions.Expression> innerKeySelector, System.Linq.Expressions.Expression> resultSelector) { throw null; } public static System.Linq.IQueryable Join(this System.Linq.IQueryable outer, System.Collections.Generic.IEnumerable inner, System.Linq.Expressions.Expression> outerKeySelector, System.Linq.Expressions.Expression> innerKeySelector, System.Linq.Expressions.Expression> resultSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? LastOrDefault(this System.Linq.IQueryable source) { throw null; } + public static TSource? LastOrDefault(this System.Linq.IQueryable source, TSource? defaultValue) { throw null; } public static TSource? LastOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } + public static TSource? LastOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate, TSource? defaultValue) { throw null; } public static TSource Last(this System.Linq.IQueryable source) { throw null; } public static TSource Last(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } public static long LongCount(this System.Linq.IQueryable source) { throw null; } @@ -129,7 +133,9 @@ public static partial class Queryable public static bool SequenceEqual(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2) { throw null; } public static bool SequenceEqual(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? SingleOrDefault(this System.Linq.IQueryable source) { throw null; } + public static TSource? SingleOrDefault(this System.Linq.IQueryable source, TSource? defaultValue) { throw null; } public static TSource? SingleOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } + public static TSource? SingleOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate, TSource? defaultValue) { throw null; } public static TSource Single(this System.Linq.IQueryable source) { throw null; } public static TSource Single(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } public static System.Linq.IQueryable SkipLast(this System.Linq.IQueryable source, int count) { throw null; } diff --git a/src/libraries/System.Linq/ref/System.Linq.cs b/src/libraries/System.Linq/ref/System.Linq.cs index 7a86417711aa74..716fb0ba8ed9e1 100644 --- a/src/libraries/System.Linq/ref/System.Linq.cs +++ b/src/libraries/System.Linq/ref/System.Linq.cs @@ -59,7 +59,9 @@ public static System.Collections.Generic.IEnumerable< public static System.Collections.Generic.IEnumerable Except(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second) { throw null; } public static System.Collections.Generic.IEnumerable Except(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? FirstOrDefault(this System.Collections.Generic.IEnumerable source) { throw null; } + public static TSource? FirstOrDefault(this System.Collections.Generic.IEnumerable source, TSource? defaultValue) { throw null; } public static TSource? FirstOrDefault(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } + public static TSource? FirstOrDefault(this System.Collections.Generic.IEnumerable source, System.Func predicate, TSource? defaultValue) { throw null; } public static TSource First(this System.Collections.Generic.IEnumerable source) { throw null; } public static TSource First(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } public static System.Collections.Generic.IEnumerable> GroupBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector) { throw null; } @@ -77,7 +79,9 @@ public static System.Collections.Generic.IEnumerable< public static System.Collections.Generic.IEnumerable Join(this System.Collections.Generic.IEnumerable outer, System.Collections.Generic.IEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Func resultSelector) { throw null; } public static System.Collections.Generic.IEnumerable Join(this System.Collections.Generic.IEnumerable outer, System.Collections.Generic.IEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Func resultSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? LastOrDefault(this System.Collections.Generic.IEnumerable source) { throw null; } + public static TSource? LastOrDefault(this System.Collections.Generic.IEnumerable source, TSource? defaultValue) { throw null; } public static TSource? LastOrDefault(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } + public static TSource? LastOrDefault(this System.Collections.Generic.IEnumerable source, System.Func predicate, TSource? defaultValue) { throw null; } public static TSource Last(this System.Collections.Generic.IEnumerable source) { throw null; } public static TSource Last(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } public static long LongCount(this System.Collections.Generic.IEnumerable source) { throw null; } @@ -144,7 +148,9 @@ public static System.Collections.Generic.IEnumerable< public static bool SequenceEqual(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second) { throw null; } public static bool SequenceEqual(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? SingleOrDefault(this System.Collections.Generic.IEnumerable source) { throw null; } + public static TSource? SingleOrDefault(this System.Collections.Generic.IEnumerable source, TSource? defaultValue) { throw null; } public static TSource? SingleOrDefault(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } + public static TSource? SingleOrDefault(this System.Collections.Generic.IEnumerable source, System.Func predicate, TSource? defaultValue) { throw null; } public static TSource Single(this System.Collections.Generic.IEnumerable source) { throw null; } public static TSource Single(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } public static System.Collections.Generic.IEnumerable SkipLast(this System.Collections.Generic.IEnumerable source, int count) { throw null; } From d7e120bdb21e3912f232bf0f7739f4f8f744d591 Mon Sep 17 00:00:00 2001 From: Foxtrek_64 Date: Mon, 1 Mar 2021 19:05:08 -0600 Subject: [PATCH 03/14] Make overloads with defaultValue not nullable --- .../ref/System.Linq.Queryable.cs | 12 ++++++------ .../src/System/Linq/CachedReflection.cs | 6 +++--- .../src/System/Linq/Queryable.cs | 12 ++++++------ src/libraries/System.Linq/ref/System.Linq.cs | 12 ++++++------ src/libraries/System.Linq/src/System/Linq/First.cs | 12 ++++++------ src/libraries/System.Linq/src/System/Linq/Last.cs | 14 +++++++------- .../System.Linq/src/System/Linq/Single.cs | 2 +- 7 files changed, 35 insertions(+), 35 deletions(-) diff --git a/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs b/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs index 3635765940eeee..bd3497d53c6580 100644 --- a/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs +++ b/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs @@ -86,9 +86,9 @@ public static partial class Queryable public static System.Linq.IQueryable Except(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2) { throw null; } public static System.Linq.IQueryable Except(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? FirstOrDefault(this System.Linq.IQueryable source) { throw null; } - public static TSource? FirstOrDefault(this System.Linq.IQueryable source, TSource? defaultValue) { throw null; } + public static TSource FirstOrDefault(this System.Linq.IQueryable source, TSource defaultValue) { throw null; } public static TSource? FirstOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } - public static TSource? FirstOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate, TSource? defaultValue) { throw null; } + public static TSource FirstOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate, TSource defaultValue) { throw null; } public static TSource First(this System.Linq.IQueryable source) { throw null; } public static TSource First(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } public static System.Linq.IQueryable> GroupBy(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector) { throw null; } @@ -106,9 +106,9 @@ public static partial class Queryable public static System.Linq.IQueryable Join(this System.Linq.IQueryable outer, System.Collections.Generic.IEnumerable inner, System.Linq.Expressions.Expression> outerKeySelector, System.Linq.Expressions.Expression> innerKeySelector, System.Linq.Expressions.Expression> resultSelector) { throw null; } public static System.Linq.IQueryable Join(this System.Linq.IQueryable outer, System.Collections.Generic.IEnumerable inner, System.Linq.Expressions.Expression> outerKeySelector, System.Linq.Expressions.Expression> innerKeySelector, System.Linq.Expressions.Expression> resultSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? LastOrDefault(this System.Linq.IQueryable source) { throw null; } - public static TSource? LastOrDefault(this System.Linq.IQueryable source, TSource? defaultValue) { throw null; } + public static TSource LastOrDefault(this System.Linq.IQueryable source, TSource defaultValue) { throw null; } public static TSource? LastOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } - public static TSource? LastOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate, TSource? defaultValue) { throw null; } + public static TSource LastOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate, TSource defaultValue) { throw null; } public static TSource Last(this System.Linq.IQueryable source) { throw null; } public static TSource Last(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } public static long LongCount(this System.Linq.IQueryable source) { throw null; } @@ -133,9 +133,9 @@ public static partial class Queryable public static bool SequenceEqual(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2) { throw null; } public static bool SequenceEqual(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? SingleOrDefault(this System.Linq.IQueryable source) { throw null; } - public static TSource? SingleOrDefault(this System.Linq.IQueryable source, TSource? defaultValue) { throw null; } + public static TSource SingleOrDefault(this System.Linq.IQueryable source, TSource defaultValue) { throw null; } public static TSource? SingleOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } - public static TSource? SingleOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate, TSource? defaultValue) { throw null; } + public static TSource SingleOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate, TSource defaultValue) { throw null; } public static TSource Single(this System.Linq.IQueryable source) { throw null; } public static TSource Single(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } public static System.Linq.IQueryable SkipLast(this System.Linq.IQueryable source, int count) { throw null; } diff --git a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs index 2d37fe2f53016a..37ac202c0f45bc 100644 --- a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs +++ b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs @@ -295,7 +295,7 @@ public static MethodInfo FirstOrDefault_TSource_3(Type TSource) => public static MethodInfo FirstOrDefault_TSource_4(Type TSource) => (s_FirstOrDefault_TSource_4 ?? - (s_FirstOrDefault_TSource_4 = new Func, Expression>, object?, object?>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + (s_FirstOrDefault_TSource_4 = new Func, Expression>, object, object?>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition())) .MakeGenericMethod(TSource); private static MethodInfo? s_GroupBy_TSource_TKey_2; @@ -417,7 +417,7 @@ public static MethodInfo LastOrDefault_TSource_3(Type TSource) => public static MethodInfo LastOrDefault_TSource_4(Type TSource) => (s_LastOrDefault_TSource_4 ?? - (s_LastOrDefault_TSource_4 = new Func, Expression>, object?, object?>(Queryable.LastOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + (s_LastOrDefault_TSource_4 = new Func, Expression>, object, object?>(Queryable.LastOrDefault).GetMethodInfo().GetGenericMethodDefinition())) .MakeGenericMethod(TSource); private static MethodInfo? s_LongCount_TSource_1; @@ -575,7 +575,7 @@ public static MethodInfo SingleOrDefault_TSource_3(Type TSource) => public static MethodInfo SingleOrDefault_TSource_4(Type TSource) => (s_SingleOrDefault_TSource_4 ?? - (s_SingleOrDefault_TSource_4 = new Func, Expression>, object?, object?>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + (s_SingleOrDefault_TSource_4 = new Func, Expression>, object, object?>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition())) .MakeGenericMethod(); private static MethodInfo? s_Skip_TSource_2; diff --git a/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs b/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs index 065218f1ba12ce..36a7789d464f84 100644 --- a/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs +++ b/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs @@ -865,7 +865,7 @@ public static TSource First(this IQueryable source, Expression } [DynamicDependency("FirstOrDefault`1", typeof(Enumerable))] - public static TSource? FirstOrDefault(this IQueryable source, TSource? defaultValue) + public static TSource FirstOrDefault(this IQueryable source, TSource defaultValue) { if (source == null) throw Error.ArgumentNull(nameof(source)); @@ -892,7 +892,7 @@ public static TSource First(this IQueryable source, Expression } [DynamicDependency("FirstOrDefault`1", typeof(Enumerable))] - public static TSource? FirstOrDefault(this IQueryable source, Expression> predicate, TSource? defaultValue) + public static TSource FirstOrDefault(this IQueryable source, Expression> predicate, TSource defaultValue) { if (source == null) throw Error.ArgumentNull(nameof(source)); @@ -944,7 +944,7 @@ public static TSource Last(this IQueryable source, Expression< } [DynamicDependency("LastOrDefault`1", typeof(Enumerable))] - public static TSource? LastOrDefault(this IQueryable source, TSource? defaultValue) + public static TSource LastOrDefault(this IQueryable source, TSource defaultValue) { if (source == null) throw Error.ArgumentNull(nameof(source)); @@ -971,7 +971,7 @@ public static TSource Last(this IQueryable source, Expression< } [DynamicDependency("LastOrDefault`1", typeof(Enumerable))] - public static TSource? LastOrDefault(this IQueryable source, Expression> predicate, TSource? defaultValue) + public static TSource LastOrDefault(this IQueryable source, Expression> predicate, TSource defaultValue) { if (source == null) throw Error.ArgumentNull(nameof(source)); @@ -1023,7 +1023,7 @@ public static TSource Single(this IQueryable source, Expressio } [DynamicDependency("SingleOrDefault`1", typeof(Enumerable))] - public static TSource? SingleOrDefault(this IQueryable source, TSource? defaultValue) + public static TSource SingleOrDefault(this IQueryable source, TSource defaultValue) { if (source == null) throw Error.ArgumentNull(nameof(source)); @@ -1051,7 +1051,7 @@ public static TSource Single(this IQueryable source, Expressio } [DynamicDependency("SingleOrDefault`1", typeof(Enumerable))] - public static TSource? SingleOrDefault(this IQueryable source, Expression> predicate, TSource? defaultValue) + public static TSource SingleOrDefault(this IQueryable source, Expression> predicate, TSource defaultValue) { if (source == null) throw Error.ArgumentNull(nameof(source)); diff --git a/src/libraries/System.Linq/ref/System.Linq.cs b/src/libraries/System.Linq/ref/System.Linq.cs index 716fb0ba8ed9e1..0f09240a0249de 100644 --- a/src/libraries/System.Linq/ref/System.Linq.cs +++ b/src/libraries/System.Linq/ref/System.Linq.cs @@ -59,9 +59,9 @@ public static System.Collections.Generic.IEnumerable< public static System.Collections.Generic.IEnumerable Except(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second) { throw null; } public static System.Collections.Generic.IEnumerable Except(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? FirstOrDefault(this System.Collections.Generic.IEnumerable source) { throw null; } - public static TSource? FirstOrDefault(this System.Collections.Generic.IEnumerable source, TSource? defaultValue) { throw null; } + public static TSource FirstOrDefault(this System.Collections.Generic.IEnumerable source, TSource defaultValue) { throw null; } public static TSource? FirstOrDefault(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } - public static TSource? FirstOrDefault(this System.Collections.Generic.IEnumerable source, System.Func predicate, TSource? defaultValue) { throw null; } + public static TSource FirstOrDefault(this System.Collections.Generic.IEnumerable source, System.Func predicate, TSource defaultValue) { throw null; } public static TSource First(this System.Collections.Generic.IEnumerable source) { throw null; } public static TSource First(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } public static System.Collections.Generic.IEnumerable> GroupBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector) { throw null; } @@ -79,9 +79,9 @@ public static System.Collections.Generic.IEnumerable< public static System.Collections.Generic.IEnumerable Join(this System.Collections.Generic.IEnumerable outer, System.Collections.Generic.IEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Func resultSelector) { throw null; } public static System.Collections.Generic.IEnumerable Join(this System.Collections.Generic.IEnumerable outer, System.Collections.Generic.IEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Func resultSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? LastOrDefault(this System.Collections.Generic.IEnumerable source) { throw null; } - public static TSource? LastOrDefault(this System.Collections.Generic.IEnumerable source, TSource? defaultValue) { throw null; } + public static TSource LastOrDefault(this System.Collections.Generic.IEnumerable source, TSource defaultValue) { throw null; } public static TSource? LastOrDefault(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } - public static TSource? LastOrDefault(this System.Collections.Generic.IEnumerable source, System.Func predicate, TSource? defaultValue) { throw null; } + public static TSource LastOrDefault(this System.Collections.Generic.IEnumerable source, System.Func predicate, TSource defaultValue) { throw null; } public static TSource Last(this System.Collections.Generic.IEnumerable source) { throw null; } public static TSource Last(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } public static long LongCount(this System.Collections.Generic.IEnumerable source) { throw null; } @@ -148,9 +148,9 @@ public static System.Collections.Generic.IEnumerable< public static bool SequenceEqual(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second) { throw null; } public static bool SequenceEqual(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? SingleOrDefault(this System.Collections.Generic.IEnumerable source) { throw null; } - public static TSource? SingleOrDefault(this System.Collections.Generic.IEnumerable source, TSource? defaultValue) { throw null; } + public static TSource SingleOrDefault(this System.Collections.Generic.IEnumerable source, TSource defaultValue) { throw null; } public static TSource? SingleOrDefault(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } - public static TSource? SingleOrDefault(this System.Collections.Generic.IEnumerable source, System.Func predicate, TSource? defaultValue) { throw null; } + public static TSource SingleOrDefault(this System.Collections.Generic.IEnumerable source, System.Func predicate, TSource defaultValue) { throw null; } public static TSource Single(this System.Collections.Generic.IEnumerable source) { throw null; } public static TSource Single(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } public static System.Collections.Generic.IEnumerable SkipLast(this System.Collections.Generic.IEnumerable source, int count) { throw null; } diff --git a/src/libraries/System.Linq/src/System/Linq/First.cs b/src/libraries/System.Linq/src/System/Linq/First.cs index 6751a14e9f1550..dcb410c9e54300 100644 --- a/src/libraries/System.Linq/src/System/Linq/First.cs +++ b/src/libraries/System.Linq/src/System/Linq/First.cs @@ -33,19 +33,19 @@ public static TSource First(this IEnumerable source, Func(this IEnumerable source) => source.TryGetFirst(out _); - public static TSource? FirstOrDefault(this IEnumerable source, TSource? defaultValue) => - source.TryGetFirst(defaultValue, out _); + public static TSource FirstOrDefault(this IEnumerable source, TSource defaultValue) => + source.TryGetFirst(defaultValue, out _)!; public static TSource? FirstOrDefault(this IEnumerable source, Func predicate) => source.TryGetFirst(predicate, out _); - public static TSource? FirstOrDefault(this IEnumerable source, Func predicate, TSource? defaultValue) => - source.TryGetFirst(predicate, defaultValue, out _); + public static TSource FirstOrDefault(this IEnumerable source, Func predicate, TSource defaultValue) => + source.TryGetFirst(predicate, defaultValue, out _)!; private static TSource? TryGetFirst(this IEnumerable source, out bool found) => source.TryGetFirst(default(TSource), out found); - private static TSource? TryGetFirst(this IEnumerable source, TSource? defaultValue, out bool found) + private static TSource? TryGetFirst(this IEnumerable source, TSource defaultValue, out bool found) { if (source == null) { @@ -82,7 +82,7 @@ public static TSource First(this IEnumerable source, Func(this IEnumerable source, Func predicate, out bool found) => - source.TryGetFirst(predicate, default(TSource), out found); + source.TryGetFirst(predicate, default, out found); private static TSource? TryGetFirst(this IEnumerable source, Func predicate, TSource? defaultValue, out bool found) { diff --git a/src/libraries/System.Linq/src/System/Linq/Last.cs b/src/libraries/System.Linq/src/System/Linq/Last.cs index 9d0b67e0a131a4..9a83a0a59bd0e3 100644 --- a/src/libraries/System.Linq/src/System/Linq/Last.cs +++ b/src/libraries/System.Linq/src/System/Linq/Last.cs @@ -32,17 +32,17 @@ public static TSource Last(this IEnumerable source, Func(this IEnumerable source) => source.TryGetLast(out _); - public static TSource? LastOrDefault(this IEnumerable source, TSource? defaultValue) - => source.TryGetLast(defaultValue, out _); + public static TSource LastOrDefault(this IEnumerable source, TSource defaultValue) + => source.TryGetLast(defaultValue, out _)!; public static TSource? LastOrDefault(this IEnumerable source, Func predicate) => source.TryGetLast(predicate, out bool _); - public static TSource? LastOrDefault(this IEnumerable source, Func predicate, TSource? defaultValue) - => source.TryGetLast(predicate, defaultValue, out bool _); + public static TSource LastOrDefault(this IEnumerable source, Func predicate, TSource defaultValue) + => source.TryGetLast(predicate, defaultValue, out bool _)!; private static TSource? TryGetLast(this IEnumerable source, out bool found) - => source.TryGetLast(default(TSource?), out found); - private static TSource? TryGetLast(this IEnumerable source, TSource? defaultValue, out bool found) + => source.TryGetLast(default(TSource), out found); + private static TSource? TryGetLast(this IEnumerable source, TSource defaultValue, out bool found) { if (source == null) { @@ -87,7 +87,7 @@ public static TSource Last(this IEnumerable source, Func(this IEnumerable source, Func predicate, out bool found) - => source.TryGetLast(predicate, default(TSource?), out found); + => source.TryGetLast(predicate, default, out found); private static TSource? TryGetLast(this IEnumerable source, Func predicate, TSource? defaultValue, out bool found) { if (source == null) diff --git a/src/libraries/System.Linq/src/System/Linq/Single.cs b/src/libraries/System.Linq/src/System/Linq/Single.cs index 396c1f1aeaee3a..d8f5f236b44614 100644 --- a/src/libraries/System.Linq/src/System/Linq/Single.cs +++ b/src/libraries/System.Linq/src/System/Linq/Single.cs @@ -86,7 +86,7 @@ public static TSource Single(this IEnumerable source, Func(this IEnumerable source) => source.SingleOrDefault(default(TSource)); - public static TSource? SingleOrDefault(this IEnumerable source, TSource? defaultValue) + public static TSource SingleOrDefault(this IEnumerable source, TSource defaultValue) { if (source == null) { From 2ff34cd65d749072d53ac428d6a03778d85ce8c2 Mon Sep 17 00:00:00 2001 From: Foxtrek_64 Date: Wed, 10 Mar 2021 18:29:28 -0600 Subject: [PATCH 04/14] Add unit tests, simplify implementation --- .../System.Linq/src/System/Linq/First.cs | 14 ++----- .../System.Linq/tests/FirstOrDefaultTests.cs | 42 +++++++++++++++++++ .../System.Linq/tests/LastOrDefaultTests.cs | 17 ++++++++ .../System.Linq/tests/SingleOrDefaultTests.cs | 18 ++++++++ 4 files changed, 81 insertions(+), 10 deletions(-) diff --git a/src/libraries/System.Linq/src/System/Linq/First.cs b/src/libraries/System.Linq/src/System/Linq/First.cs index dcb410c9e54300..45c06ac64d3396 100644 --- a/src/libraries/System.Linq/src/System/Linq/First.cs +++ b/src/libraries/System.Linq/src/System/Linq/First.cs @@ -10,7 +10,7 @@ public static partial class Enumerable { public static TSource First(this IEnumerable source) { - TSource? first = source.TryGetFirst(out bool found); + TSource? first = source.TryGetFirst(default, out bool found); if (!found) { ThrowHelper.ThrowNoElementsException(); @@ -21,7 +21,7 @@ public static TSource First(this IEnumerable source) public static TSource First(this IEnumerable source, Func predicate) { - TSource? first = source.TryGetFirst(predicate, out bool found); + TSource? first = source.TryGetFirst(predicate, default, out bool found); if (!found) { ThrowHelper.ThrowNoMatchException(); @@ -31,20 +31,17 @@ public static TSource First(this IEnumerable source, Func(this IEnumerable source) => - source.TryGetFirst(out _); + source.TryGetFirst(default, out _); public static TSource FirstOrDefault(this IEnumerable source, TSource defaultValue) => source.TryGetFirst(defaultValue, out _)!; public static TSource? FirstOrDefault(this IEnumerable source, Func predicate) => - source.TryGetFirst(predicate, out _); + source.TryGetFirst(predicate, default, out _); public static TSource FirstOrDefault(this IEnumerable source, Func predicate, TSource defaultValue) => source.TryGetFirst(predicate, defaultValue, out _)!; - private static TSource? TryGetFirst(this IEnumerable source, out bool found) => - source.TryGetFirst(default(TSource), out found); - private static TSource? TryGetFirst(this IEnumerable source, TSource defaultValue, out bool found) { if (source == null) @@ -81,9 +78,6 @@ public static TSource FirstOrDefault(this IEnumerable source, return defaultValue; } - private static TSource? TryGetFirst(this IEnumerable source, Func predicate, out bool found) => - source.TryGetFirst(predicate, default, out found); - private static TSource? TryGetFirst(this IEnumerable source, Func predicate, TSource? defaultValue, out bool found) { if (source == null) diff --git a/src/libraries/System.Linq/tests/FirstOrDefaultTests.cs b/src/libraries/System.Linq/tests/FirstOrDefaultTests.cs index f0642c3e8165ab..0e6f27dc9f515b 100644 --- a/src/libraries/System.Linq/tests/FirstOrDefaultTests.cs +++ b/src/libraries/System.Linq/tests/FirstOrDefaultTests.cs @@ -39,6 +39,15 @@ private static void TestEmptyIList() Assert.Equal(expected, source.RunOnce().FirstOrDefault()); } + private static void TestEmptyIListDefault(T defaultValue) + { + T[] source = { }; + + Assert.IsAssignableFrom>(source); + + Assert.Equal(defaultValue, source.RunOnce().FirstOrDefault(defaultValue)); + } + [Fact] public void EmptyIListT() { @@ -48,6 +57,14 @@ public void EmptyIListT() TestEmptyIList(); } + [Fact] + public void EmptyIListDefault() + { + TestEmptyIListDefault(5); // int + TestEmptyIListDefault("Hello"); // string + TestEmptyIListDefault(DateTime.UnixEpoch); //DateTime + } + [Fact] public void IListTOneElement() { @@ -59,6 +76,17 @@ public void IListTOneElement() Assert.Equal(expected, source.FirstOrDefault()); } + [Fact] + public void IListOneElementDefault() + { + int[] source = { 5 }; + int expected = 5; + + Assert.IsAssignableFrom>(source); + + Assert.Equal(expected, source.FirstOrDefault(3)); + } + [Fact] public void IListTManyElementsFirstIsDefault() { @@ -96,6 +124,20 @@ static IEnumerable EmptySource() Assert.Equal(expected, source.RunOnce().FirstOrDefault()); } + private static void TestEmptyNotIListDefault(T defaultValue) + { + static IEnumerable EmptySource() + { + yield break; + } + + var source = EmptySource(); + + Assert.Null(source as IList); + + Assert.Equal(defaultValue, source.RunOnce().FirstOrDefault(defaultValue)); + } + [Fact] public void EmptyNotIListT() { diff --git a/src/libraries/System.Linq/tests/LastOrDefaultTests.cs b/src/libraries/System.Linq/tests/LastOrDefaultTests.cs index ec8d6c40b62448..5d5fe5e49d6190 100644 --- a/src/libraries/System.Linq/tests/LastOrDefaultTests.cs +++ b/src/libraries/System.Linq/tests/LastOrDefaultTests.cs @@ -39,6 +39,15 @@ private static void TestEmptyIList() Assert.Equal(expected, source.RunOnce().LastOrDefault()); } + private static void TestEmptyIListDefault(T defaultValue) + { + T[] source = { }; + + Assert.IsAssignableFrom>(source); + + Assert.Equal(defaultValue, source.RunOnce().LastOrDefault(defaultValue)); + } + [Fact] public void EmptyIListT() { @@ -48,6 +57,14 @@ public void EmptyIListT() TestEmptyIList(); } + [Fact] + public void EmptyIList() + { + TestEmptyIListDefault(5); // int + TestEmptyIListDefault("Hello"); // string + TestEmptyIListDefault(DateTime.UnixEpoch); + } + [Fact] public void IListTOneElement() { diff --git a/src/libraries/System.Linq/tests/SingleOrDefaultTests.cs b/src/libraries/System.Linq/tests/SingleOrDefaultTests.cs index 1dfca02713808e..8660edb944828c 100644 --- a/src/libraries/System.Linq/tests/SingleOrDefaultTests.cs +++ b/src/libraries/System.Linq/tests/SingleOrDefaultTests.cs @@ -36,6 +36,15 @@ public void EmptyIList() Assert.Equal(expected, source.SingleOrDefault()); } + [Fact] + public void EmptyIListDefault() + { + int?[] source = { }; + int expected = 5; + + Assert.Equal(expected, source.SingleOrDefault(5)); + } + [Fact] public void SingleElementIList() { @@ -45,6 +54,15 @@ public void SingleElementIList() Assert.Equal(expected, source.SingleOrDefault()); } + [Fact] + public void SingleElementIListDefault() + { + int[] source = { 4 }; + int expected = 4; + + Assert.Equal(expected, source.SingleOrDefault(5)); + } + [Fact] public void ManyElementIList() { From b54edb40e3d20572d709725cd2012cb5222bead3 Mon Sep 17 00:00:00 2001 From: Foxtrek_64 Date: Wed, 10 Mar 2021 18:30:01 -0600 Subject: [PATCH 05/14] Add LastOrDefault tests --- src/libraries/System.Linq/tests/LastOrDefaultTests.cs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/libraries/System.Linq/tests/LastOrDefaultTests.cs b/src/libraries/System.Linq/tests/LastOrDefaultTests.cs index 5d5fe5e49d6190..dff8fb5843df17 100644 --- a/src/libraries/System.Linq/tests/LastOrDefaultTests.cs +++ b/src/libraries/System.Linq/tests/LastOrDefaultTests.cs @@ -76,6 +76,17 @@ public void IListTOneElement() Assert.Equal(expected, source.LastOrDefault()); } + [Fact] + public void IListTOneElementDefault() + { + int[] source = { 5 }; + int expected = 5; + + Assert.IsAssignableFrom>(source); + + Assert.Equal(expected, source.LastOrDefault(4)); + } + [Fact] public void IListTManyElementsLastIsDefault() From ed55a24b58af76787abda4e998dd75e5392cb5a1 Mon Sep 17 00:00:00 2001 From: Foxtrek_64 Date: Thu, 11 Mar 2021 08:31:28 -0600 Subject: [PATCH 06/14] Add Queryable tests --- .../tests/FirstOrDefaultTests.cs | 16 ++++++++++++++++ .../tests/LastOrDefaultTests.cs | 8 ++++++++ .../tests/SingleOrDefaultTests.cs | 14 ++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs b/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs index 95093c0e8596af..ff13e3bcfe04cb 100644 --- a/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs @@ -16,6 +16,15 @@ public void Empty() Assert.Equal(0, source.AsQueryable().FirstOrDefault()); } + [Fact] + public void EmptyDefault() + { + int[] source = { }; + int defaultValue = 5; + + Assert.Equal(defaultValue, source.AsQueryable().FirstOrDefault(defaultValue)); + } + [Fact] public void ManyElementsFirstIsDefault() { @@ -37,6 +46,13 @@ public void OneElementTruePredicate() Assert.Equal(4, source.AsQueryable().FirstOrDefault(i => i % 2 == 0)); } + [Fact] + public void OneElementFalsePredicate() + { + int[] source = { 3 }; + Assert.Equal(5, source.AsQueryable().FirstOrDefault(i => i % 2 == 0, 5)); + } + [Fact] public void ManyElementsPredicateFalseForAll() { diff --git a/src/libraries/System.Linq.Queryable/tests/LastOrDefaultTests.cs b/src/libraries/System.Linq.Queryable/tests/LastOrDefaultTests.cs index 62bfe0a9d29f2c..8083c77738bb01 100644 --- a/src/libraries/System.Linq.Queryable/tests/LastOrDefaultTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/LastOrDefaultTests.cs @@ -14,6 +14,14 @@ public void Empty() Assert.Null(Enumerable.Empty().AsQueryable().LastOrDefault()); } + [Fact] + public void EmptyDefault() + { + int[] source = { }; + int defaultValue = 5; + Assert.Equal(defaultValue, source.AsQueryable().LastOrDefault(defaultValue)); + } + [Fact] public void OneElement() { diff --git a/src/libraries/System.Linq.Queryable/tests/SingleOrDefaultTests.cs b/src/libraries/System.Linq.Queryable/tests/SingleOrDefaultTests.cs index 032ca44ba83ce5..f34edae8b9f84d 100644 --- a/src/libraries/System.Linq.Queryable/tests/SingleOrDefaultTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/SingleOrDefaultTests.cs @@ -22,12 +22,26 @@ public void Empty() Assert.Null(Enumerable.Empty().AsQueryable().SingleOrDefault()); } + [Fact] + public void EmptyDefault() + { + int[] source = { }; + int defaultValue = 5; + Assert.Equal(defaultValue, source.AsQueryable().SingleOrDefault(5)); + } + [Fact] public void EmptySourceWithPredicate() { Assert.Null(Enumerable.Empty().AsQueryable().SingleOrDefault(i => i % 2 == 0)); } + [Fact] + public void EmptySourceWithPredicateDefault() + { + Assert.Equal(5, Enumerable.Empty().AsQueryable().SingleOrDefault(i => i % 2 == 0, 5)); + } + [Theory] [InlineData(1, 100)] [InlineData(42, 100)] From fa5bc806f022f2e2f092685fc747ae95c8cdbdc2 Mon Sep 17 00:00:00 2001 From: Foxtrek_64 Date: Thu, 11 Mar 2021 17:59:16 -0600 Subject: [PATCH 07/14] Additional tests. Reformatting TryGet Methods. --- .../tests/FirstOrDefaultTests.cs | 30 +++++ .../tests/LastOrDefaultTests.cs | 7 ++ .../System.Linq/src/System/Linq/First.cs | 31 +++-- .../System.Linq/src/System/Linq/Last.cs | 19 +-- .../System.Linq/src/System/Linq/Single.cs | 108 ++++++------------ .../System.Linq/tests/FirstOrDefaultTests.cs | 43 +++++++ .../System.Linq/tests/LastOrDefaultTests.cs | 42 +++++++ .../System.Linq/tests/SingleOrDefaultTests.cs | 77 +++++++++++++ 8 files changed, 264 insertions(+), 93 deletions(-) diff --git a/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs b/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs index ff13e3bcfe04cb..3b02a7ac8f2241 100644 --- a/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs @@ -46,8 +46,22 @@ public void OneElementTruePredicate() Assert.Equal(4, source.AsQueryable().FirstOrDefault(i => i % 2 == 0)); } + [Fact] + public void OneElementTruePredicateDefault() + { + int[] source = { 4 }; + Assert.Equal(4, source.AsQueryable().FirstOrDefault(i => i % 2 == 0, 5)); + } + [Fact] public void OneElementFalsePredicate() + { + int[] source = { 3 }; + Assert.Equal(0, source.AsQueryable().FirstOrDefault(i => i % 2 == 0)); + } + + [Fact] + public void OneElementFalsePredicateDefault() { int[] source = { 3 }; Assert.Equal(5, source.AsQueryable().FirstOrDefault(i => i % 2 == 0, 5)); @@ -60,23 +74,38 @@ public void ManyElementsPredicateFalseForAll() Assert.Equal(0, source.AsQueryable().FirstOrDefault(i => i % 2 == 0)); } + [Fact] + public void ManyElementsPredicateFalseForAllDefault() + { + int[] source = { 9, 5, 1, 3, 17, 21 }; + Assert.Equal(2, source.AsQueryable().FirstOrDefault(i => i % 2 == 0), 2); + } + [Fact] public void PredicateTrueForSome() { int[] source = { 3, 7, 10, 7, 9, 2, 11, 17, 13, 8 }; Assert.Equal(10, source.AsQueryable().FirstOrDefault(i => i % 2 == 0)); } + [Fact] + public void PredicateTrueForSomeDefault() + { + int[] source = { 3, 7, 10, 7, 9, 2, 11, 17, 13, 8 }; + Assert.Equal(10, source.AsQueryable().FirstOrDefault(i => i % 2 == 0, 5)); + } [Fact] public void NullSource() { AssertExtensions.Throws("source", () => ((IQueryable)null).FirstOrDefault()); + AssertExtensions.Throws("source", () => ((IQueryable)null).FirstOrDefault(5)); } [Fact] public void NullSourcePredicateUsed() { AssertExtensions.Throws("source", () => ((IQueryable)null).FirstOrDefault(i => i != 2)); + AssertExtensions.Throws("source", () => ((IQueryable)null).FirstOrDefault(i => i != 2, 5)); } [Fact] @@ -84,6 +113,7 @@ public void NullPredicate() { Expression> predicate = null; AssertExtensions.Throws("predicate", () => Enumerable.Range(0, 3).AsQueryable().FirstOrDefault(predicate)); + AssertExtensions.Throws("predicate", () => Enumerable.Range(0, 3).AsQueryable().FirstOrDefault(predicate, 5)); } [Fact] diff --git a/src/libraries/System.Linq.Queryable/tests/LastOrDefaultTests.cs b/src/libraries/System.Linq.Queryable/tests/LastOrDefaultTests.cs index 8083c77738bb01..c66c42fce4a4fe 100644 --- a/src/libraries/System.Linq.Queryable/tests/LastOrDefaultTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/LastOrDefaultTests.cs @@ -29,6 +29,13 @@ public void OneElement() Assert.Equal(5, source.AsQueryable().LastOrDefault()); } + [Fact] + public void OneElementFalsePredicate() + { + int[] source = { 3 }; + Assert.Equal(5, source.AsQueryable().LastOrDefault(i => i % 2 == 0, 5)); + } + [Fact] public void ManyElementsLastIsDefault() { diff --git a/src/libraries/System.Linq/src/System/Linq/First.cs b/src/libraries/System.Linq/src/System/Linq/First.cs index 45c06ac64d3396..89e56f8803f141 100644 --- a/src/libraries/System.Linq/src/System/Linq/First.cs +++ b/src/libraries/System.Linq/src/System/Linq/First.cs @@ -10,7 +10,7 @@ public static partial class Enumerable { public static TSource First(this IEnumerable source) { - TSource? first = source.TryGetFirst(default, out bool found); + TSource? first = source.TryGetFirst(out bool found); if (!found) { ThrowHelper.ThrowNoElementsException(); @@ -21,7 +21,7 @@ public static TSource First(this IEnumerable source) public static TSource First(this IEnumerable source, Func predicate) { - TSource? first = source.TryGetFirst(predicate, default, out bool found); + TSource? first = source.TryGetFirst(predicate, out bool found); if (!found) { ThrowHelper.ThrowNoMatchException(); @@ -31,18 +31,25 @@ public static TSource First(this IEnumerable source, Func(this IEnumerable source) => - source.TryGetFirst(default, out _); + source.TryGetFirst(out _); - public static TSource FirstOrDefault(this IEnumerable source, TSource defaultValue) => - source.TryGetFirst(defaultValue, out _)!; + public static TSource FirstOrDefault(this IEnumerable source, TSource defaultValue) + { + var first = source.TryGetFirst(out bool found); + return found ? first! : defaultValue; + } public static TSource? FirstOrDefault(this IEnumerable source, Func predicate) => - source.TryGetFirst(predicate, default, out _); + source.TryGetFirst(predicate, out _); + + public static TSource FirstOrDefault(this IEnumerable source, Func predicate, TSource defaultValue) + { + var first = source.TryGetFirst(predicate, out bool found); + return found ? first! : defaultValue; + } - public static TSource FirstOrDefault(this IEnumerable source, Func predicate, TSource defaultValue) => - source.TryGetFirst(predicate, defaultValue, out _)!; - private static TSource? TryGetFirst(this IEnumerable source, TSource defaultValue, out bool found) + private static TSource? TryGetFirst(this IEnumerable source, out bool found) { if (source == null) { @@ -75,10 +82,10 @@ public static TSource FirstOrDefault(this IEnumerable source, } found = false; - return defaultValue; + return default; } - private static TSource? TryGetFirst(this IEnumerable source, Func predicate, TSource? defaultValue, out bool found) + private static TSource? TryGetFirst(this IEnumerable source, Func predicate, out bool found) { if (source == null) { @@ -100,7 +107,7 @@ public static TSource FirstOrDefault(this IEnumerable source, } found = false; - return defaultValue; + return default; } } } diff --git a/src/libraries/System.Linq/src/System/Linq/Last.cs b/src/libraries/System.Linq/src/System/Linq/Last.cs index 9a83a0a59bd0e3..300838d48c7c9f 100644 --- a/src/libraries/System.Linq/src/System/Linq/Last.cs +++ b/src/libraries/System.Linq/src/System/Linq/Last.cs @@ -33,16 +33,20 @@ public static TSource Last(this IEnumerable source, Func(this IEnumerable source) => source.TryGetLast(out _); public static TSource LastOrDefault(this IEnumerable source, TSource defaultValue) - => source.TryGetLast(defaultValue, out _)!; + { + var last = source.TryGetLast(out bool found); + return found ? last! : defaultValue; + } public static TSource? LastOrDefault(this IEnumerable source, Func predicate) => source.TryGetLast(predicate, out bool _); public static TSource LastOrDefault(this IEnumerable source, Func predicate, TSource defaultValue) - => source.TryGetLast(predicate, defaultValue, out bool _)!; + { + var last = source.TryGetLast(out bool found); + return found ? last! : defaultValue; + } private static TSource? TryGetLast(this IEnumerable source, out bool found) - => source.TryGetLast(default(TSource), out found); - private static TSource? TryGetLast(this IEnumerable source, TSource defaultValue, out bool found) { if (source == null) { @@ -83,12 +87,9 @@ public static TSource LastOrDefault(this IEnumerable source, F } found = false; - return defaultValue; + return default; } - private static TSource? TryGetLast(this IEnumerable source, Func predicate, out bool found) - => source.TryGetLast(predicate, default, out found); - private static TSource? TryGetLast(this IEnumerable source, Func predicate, TSource? defaultValue, out bool found) { if (source == null) { @@ -143,7 +144,7 @@ public static TSource LastOrDefault(this IEnumerable source, F } found = false; - return defaultValue; + return default; } } } diff --git a/src/libraries/System.Linq/src/System/Linq/Single.cs b/src/libraries/System.Linq/src/System/Linq/Single.cs index d8f5f236b44614..e84c0f9c9cbffd 100644 --- a/src/libraries/System.Linq/src/System/Linq/Single.cs +++ b/src/libraries/System.Linq/src/System/Linq/Single.cs @@ -10,85 +10,46 @@ public static partial class Enumerable { public static TSource Single(this IEnumerable source) { - if (source == null) + var single = source.TryGetSingle(out bool found); + if (!found) { - ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); + ThrowHelper.ThrowNoElementsException(); } - if (source is IList list) - { - switch (list.Count) - { - case 0: - ThrowHelper.ThrowNoElementsException(); - return default; - case 1: - return list[0]; - } - } - else - { - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) - { - ThrowHelper.ThrowNoElementsException(); - } - - TSource result = e.Current; - if (!e.MoveNext()) - { - return result; - } - } - } - - ThrowHelper.ThrowMoreThanOneElementException(); - return default; + return single!; } - public static TSource Single(this IEnumerable source, Func predicate) { - if (source == null) - { - ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); - } - - if (predicate == null) + var single = source.TryGetSingle(predicate, out bool found); + if (!found) { - ThrowHelper.ThrowArgumentNullException(ExceptionArgument.predicate); + ThrowHelper.ThrowNoElementsException(); } - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - TSource result = e.Current; - if (predicate(result)) - { - while (e.MoveNext()) - { - if (predicate(e.Current)) - { - ThrowHelper.ThrowMoreThanOneMatchException(); - } - } - - return result; - } - } - } - - ThrowHelper.ThrowNoMatchException(); - return default; + return single!; } public static TSource? SingleOrDefault(this IEnumerable source) - => source.SingleOrDefault(default(TSource)); + => source.TryGetSingle(out _); public static TSource SingleOrDefault(this IEnumerable source, TSource defaultValue) { - if (source == null) + var single = source.TryGetSingle(out bool found); + return found ? single! : defaultValue; + } + + public static TSource? SingleOrDefault(this IEnumerable source, Func predicate) + => source.TryGetSingle(predicate, out _); + + public static TSource SingleOrDefault(this IEnumerable source, Func predicate, TSource defaultValue) + { + var single = source.TryGetSingle(predicate, out bool found); + return found ? single! : defaultValue; + } + + private static TSource? TryGetSingle(this IEnumerable source, out bool found) + { + if (source is null) { ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); } @@ -98,8 +59,10 @@ public static TSource SingleOrDefault(this IEnumerable source, switch (list.Count) { case 0: - return defaultValue; + found = false; + return default; case 1: + found = true; return list[0]; } } @@ -109,25 +72,25 @@ public static TSource SingleOrDefault(this IEnumerable source, { if (!e.MoveNext()) { - return defaultValue; + found = false; + return default; } TSource result = e.Current; if (!e.MoveNext()) { + found = true; return result; } } } + found = false; ThrowHelper.ThrowMoreThanOneElementException(); return default; } - public static TSource? SingleOrDefault(this IEnumerable source, Func predicate) - => source.SingleOrDefault(predicate, default); - - public static TSource? SingleOrDefault(this IEnumerable source, Func predicate, TSource? defaultValue) + private static TSource? TryGetSingle(this IEnumerable source, Func predicate, out bool found) { if (source == null) { @@ -153,13 +116,14 @@ public static TSource SingleOrDefault(this IEnumerable source, ThrowHelper.ThrowMoreThanOneMatchException(); } } - + found = true; return result; } } } - return defaultValue; + found = false; + return default; } } } diff --git a/src/libraries/System.Linq/tests/FirstOrDefaultTests.cs b/src/libraries/System.Linq/tests/FirstOrDefaultTests.cs index 0e6f27dc9f515b..b0fe384389dcba 100644 --- a/src/libraries/System.Linq/tests/FirstOrDefaultTests.cs +++ b/src/libraries/System.Linq/tests/FirstOrDefaultTests.cs @@ -188,6 +188,16 @@ public void OneElementTruePredicate() Assert.Equal(expected, source.FirstOrDefault(predicate)); } + [Fact] + public void OneElementTruePredicateDefault() + { + int[] source = { 4 }; + Func predicate = IsEven; + int expected = 4; + + Assert.Equal(expected, source.FirstOrDefault(predicate, 5)); + } + [Fact] public void ManyElementsPredicateFalseForAll() { @@ -198,6 +208,16 @@ public void ManyElementsPredicateFalseForAll() Assert.Equal(expected, source.FirstOrDefault(predicate)); } + [Fact] + public void ManyElementsPredicateFalseForAllDefault() + { + int[] source = { 9, 5, 1, 3, 17, 21 }; + Func predicate = IsEven; + int expected = 5; + + Assert.Equal(expected, source.FirstOrDefault(predicate, 5)); + } + [Fact] public void PredicateTrueOnlyForLast() { @@ -208,6 +228,16 @@ public void PredicateTrueOnlyForLast() Assert.Equal(expected, source.FirstOrDefault(predicate)); } + [Fact] + public void PredicateTrueOnlyForLastDefault() + { + int[] source = { 9, 5, 1, 3, 17, 21, 50 }; + Func predicate = IsEven; + int expected = 50; + + Assert.Equal(expected, source.FirstOrDefault(predicate, 5)); + } + [Fact] public void PredicateTrueForSome() { @@ -218,6 +248,16 @@ public void PredicateTrueForSome() Assert.Equal(expected, source.FirstOrDefault(predicate)); } + [Fact] + public void PredicateTrueForSomeDefault() + { + int[] source = { 3, 7, 10, 7, 9, 2, 11, 17, 13, 8 }; + Func predicate = IsEven; + int expected = 10; + + Assert.Equal(expected, source.FirstOrDefault(predicate, 5)); + } + [Fact] public void PredicateTrueForSomeRunOnce() { @@ -232,12 +272,14 @@ public void PredicateTrueForSomeRunOnce() public void NullSource() { AssertExtensions.Throws("source", () => ((IEnumerable)null).FirstOrDefault()); + AssertExtensions.Throws("source", () => ((IEnumerable)null).FirstOrDefault(5)); } [Fact] public void NullSourcePredicateUsed() { AssertExtensions.Throws("source", () => ((IEnumerable)null).FirstOrDefault(i => i != 2)); + AssertExtensions.Throws("source", () => ((IEnumerable)null).FirstOrDefault(i => i != 2, 5)); } [Fact] @@ -245,6 +287,7 @@ public void NullPredicate() { Func predicate = null; AssertExtensions.Throws("predicate", () => Enumerable.Range(0, 3).FirstOrDefault(predicate)); + AssertExtensions.Throws("predicate", () => Enumerable.Range(0, 3).FirstOrDefault(predicate, 5)); } } } diff --git a/src/libraries/System.Linq/tests/LastOrDefaultTests.cs b/src/libraries/System.Linq/tests/LastOrDefaultTests.cs index dff8fb5843df17..ab3a3fa1ad8a1b 100644 --- a/src/libraries/System.Linq/tests/LastOrDefaultTests.cs +++ b/src/libraries/System.Linq/tests/LastOrDefaultTests.cs @@ -110,6 +110,28 @@ public void IListTManyElementsLastIsNotDefault() Assert.Equal(expected, source.LastOrDefault()); } + [Fact] + public void IListTManyElementsLastHasDefault() + { + int?[] source = { -10, 2, 4, 3, 0, 2, null }; + int? expected = null; + + Assert.IsAssignableFrom>(source); + + Assert.Equal(expected, source.LastOrDefault(5)); + } + + [Fact] + public void IListTManyElementsLastIsHasDefault() + { + int?[] source = { -10, 2, 4, 3, 0, 2, null, 19 }; + int? expected = 19; + + Assert.IsAssignableFrom>(source); + + Assert.Equal(expected, source.LastOrDefault(5)); + } + private static IEnumerable EmptySource() { yield break; @@ -175,6 +197,16 @@ public void OneElementIListTruePredicate() Assert.Equal(expected, source.LastOrDefault(predicate)); } + [Fact] + public void OneElementIListTruePredicateDefault() + { + int[] source = { 4 }; + Func predicate = IsEven; + int expected = 4; + + Assert.Equal(expected, source.LastOrDefault(predicate, 5)); + } + [Fact] public void ManyElementsIListPredicateFalseForAll() { @@ -185,6 +217,16 @@ public void ManyElementsIListPredicateFalseForAll() Assert.Equal(expected, source.LastOrDefault(predicate)); } + [Fact] + public void ManyElementsIListPredicateFalseForAllDefault() + { + int[] source = { 9, 5, 1, 3, 17, 21 }; + Func predicate = IsEven; + int expected = 5; + + Assert.Equal(expected, source.LastOrDefault(predicate, 5)); + } + [Fact] public void IListPredicateTrueOnlyForLast() { diff --git a/src/libraries/System.Linq/tests/SingleOrDefaultTests.cs b/src/libraries/System.Linq/tests/SingleOrDefaultTests.cs index 8660edb944828c..9ac26e76a8fa23 100644 --- a/src/libraries/System.Linq/tests/SingleOrDefaultTests.cs +++ b/src/libraries/System.Linq/tests/SingleOrDefaultTests.cs @@ -71,6 +71,14 @@ public void ManyElementIList() Assert.Throws(() => source.SingleOrDefault()); } + [Fact] + public void ManyElementIListDefault() + { + int[] source = { 4, 4, 4, 4, 4 }; + + Assert.Throws(() => source.SingleOrDefault(5)); + } + [Fact] public void EmptyNotIList() { @@ -106,6 +114,15 @@ public void EmptySourceWithPredicate() Assert.Equal(expected, source.SingleOrDefault(i => i % 2 == 0)); } + [Fact] + public void EmptySourceWithPredicateDefault() + { + int[] source = { }; + int expected = 5; + + Assert.Equal(expected, source.SingleOrDefault(i => i % 2 == 0, 5)); + } + [Fact] public void SingleElementPredicateTrue() { @@ -115,6 +132,15 @@ public void SingleElementPredicateTrue() Assert.Equal(expected, source.SingleOrDefault(i => i % 2 == 0)); } + [Fact] + public void SingleElementPredicateTrueDefault() + { + int[] source = { 4 }; + int expected = 4; + + Assert.Equal(expected, source.SingleOrDefault(i => i % 2 == 0, 5)); + } + [Fact] public void SingleElementPredicateFalse() { @@ -124,6 +150,15 @@ public void SingleElementPredicateFalse() Assert.Equal(expected, source.SingleOrDefault(i => i % 2 == 0)); } + [Fact] + public void SingleElementPredicateFalseDefault() + { + int[] source = { 3 }; + int expected = 5; + + Assert.Equal(expected, source.SingleOrDefault(i => i % 2 == 0, 5)); + } + [Fact] public void ManyElementsPredicateFalseForAll() { @@ -133,6 +168,15 @@ public void ManyElementsPredicateFalseForAll() Assert.Equal(expected, source.SingleOrDefault(i => i % 2 == 0)); } + [Fact] + public void ManyElementsPredicateFalseForAllDefault() + { + int[] source = { 3, 1, 7, 9, 13, 19 }; + int expected = 5; + + Assert.Equal(expected, source.SingleOrDefault(i => i % 2 == 0, 5)); + } + [Fact] public void ManyElementsPredicateTrueForLast() { @@ -142,6 +186,15 @@ public void ManyElementsPredicateTrueForLast() Assert.Equal(expected, source.SingleOrDefault(i => i % 2 == 0)); } + [Fact] + public void ManyElementsPredicateTrueForLastDefault() + { + int[] source = { 3, 1, 7, 9, 13, 19, 20 }; + int expected = 20; + + Assert.Equal(expected, source.SingleOrDefault(i => i % 2 == 0, 5)); + } + [Fact] public void ManyElementsPredicateTrueForFirstAndFifth() { @@ -150,6 +203,14 @@ public void ManyElementsPredicateTrueForFirstAndFifth() Assert.Throws(() => source.SingleOrDefault(i => i % 2 == 0)); } + [Fact] + public void ManyElementsPredicateTrueForFirstAndFifthDefault() + { + int[] source = { 2, 3, 1, 7, 10, 13, 19, 9 }; + + Assert.Throws(() => source.SingleOrDefault(i => i % 2 == 0, 5)); + } + [Theory] [InlineData(1, 100)] [InlineData(42, 100)] @@ -174,6 +235,14 @@ public void ThrowsOnNullSource() AssertExtensions.Throws("source", () => source.SingleOrDefault(i => i % 2 == 0)); } + [Fact] + public void ThrowsOnNullSourceDefault() + { + int[] source = null; + AssertExtensions.Throws("source", () => source.SingleOrDefault(5)); + AssertExtensions.Throws("source", () => source.SingleOrDefault(i => i % 2 == 0, 5)); + } + [Fact] public void ThrowsOnNullPredicate() { @@ -181,5 +250,13 @@ public void ThrowsOnNullPredicate() Func nullPredicate = null; AssertExtensions.Throws("predicate", () => source.SingleOrDefault(nullPredicate)); } + + [Fact] + public void ThrowsOnNullPredicateDefault() + { + int[] source = { }; + Func nullPredicate = null; + AssertExtensions.Throws("predicate", () => source.SingleOrDefault(nullPredicate, 5)); + } } } From 808245921b756177502efb4f96357dd52cd985c9 Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Fri, 12 Mar 2021 13:02:13 +0000 Subject: [PATCH 08/14] Update src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs --- .../System.Linq.Queryable/tests/FirstOrDefaultTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs b/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs index 3b02a7ac8f2241..ac74ca5bd68365 100644 --- a/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs @@ -78,7 +78,7 @@ public void ManyElementsPredicateFalseForAll() public void ManyElementsPredicateFalseForAllDefault() { int[] source = { 9, 5, 1, 3, 17, 21 }; - Assert.Equal(2, source.AsQueryable().FirstOrDefault(i => i % 2 == 0), 2); + Assert.Equal(2, source.AsQueryable().FirstOrDefault(i => i % 2 == 0, 2)); } [Fact] From 24cf263ad2845bcdc3f89489a522c1582e4b64f8 Mon Sep 17 00:00:00 2001 From: Foxtrek_64 Date: Sun, 14 Mar 2021 00:04:10 +0000 Subject: [PATCH 09/14] Apply suggestions from code review Co-authored-by: Eirik Tsarpalis --- src/libraries/System.Linq/src/System/Linq/First.cs | 4 ++-- src/libraries/System.Linq/src/System/Linq/Last.cs | 8 ++++++-- src/libraries/System.Linq/src/System/Linq/Single.cs | 6 +++--- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/libraries/System.Linq/src/System/Linq/First.cs b/src/libraries/System.Linq/src/System/Linq/First.cs index 89e56f8803f141..7f59d1c63b5872 100644 --- a/src/libraries/System.Linq/src/System/Linq/First.cs +++ b/src/libraries/System.Linq/src/System/Linq/First.cs @@ -35,7 +35,7 @@ public static TSource First(this IEnumerable source, Func(this IEnumerable source, TSource defaultValue) { - var first = source.TryGetFirst(out bool found); + TSource? first = source.TryGetFirst(out bool found); return found ? first! : defaultValue; } @@ -44,7 +44,7 @@ public static TSource FirstOrDefault(this IEnumerable source, public static TSource FirstOrDefault(this IEnumerable source, Func predicate, TSource defaultValue) { - var first = source.TryGetFirst(predicate, out bool found); + TSource? first = source.TryGetFirst(predicate, out bool found); return found ? first! : defaultValue; } diff --git a/src/libraries/System.Linq/src/System/Linq/Last.cs b/src/libraries/System.Linq/src/System/Linq/Last.cs index 300838d48c7c9f..de730f1c30d26a 100644 --- a/src/libraries/System.Linq/src/System/Linq/Last.cs +++ b/src/libraries/System.Linq/src/System/Linq/Last.cs @@ -32,14 +32,17 @@ public static TSource Last(this IEnumerable source, Func(this IEnumerable source) => source.TryGetLast(out _); + + public static TSource LastOrDefault(this IEnumerable source, TSource defaultValue) { - var last = source.TryGetLast(out bool found); + TSource? last = source.TryGetLast(out bool found); return found ? last! : defaultValue; } public static TSource? LastOrDefault(this IEnumerable source, Func predicate) - => source.TryGetLast(predicate, out bool _); + => source.TryGetLast(predicate, out _); + public static TSource LastOrDefault(this IEnumerable source, Func predicate, TSource defaultValue) { var last = source.TryGetLast(out bool found); @@ -89,6 +92,7 @@ public static TSource LastOrDefault(this IEnumerable source, F found = false; return default; } + private static TSource? TryGetLast(this IEnumerable source, Func predicate, out bool found) { if (source == null) diff --git a/src/libraries/System.Linq/src/System/Linq/Single.cs b/src/libraries/System.Linq/src/System/Linq/Single.cs index e84c0f9c9cbffd..6455a841c9c46f 100644 --- a/src/libraries/System.Linq/src/System/Linq/Single.cs +++ b/src/libraries/System.Linq/src/System/Linq/Single.cs @@ -10,7 +10,7 @@ public static partial class Enumerable { public static TSource Single(this IEnumerable source) { - var single = source.TryGetSingle(out bool found); + TSource? single = source.TryGetSingle(out bool found); if (!found) { ThrowHelper.ThrowNoElementsException(); @@ -20,10 +20,10 @@ public static TSource Single(this IEnumerable source) } public static TSource Single(this IEnumerable source, Func predicate) { - var single = source.TryGetSingle(predicate, out bool found); + TSource? single = source.TryGetSingle(predicate, out bool found); if (!found) { - ThrowHelper.ThrowNoElementsException(); + ThrowHelper.ThrowNoMatchException(); } return single!; From ed9728dbdb0e87fdc0b1da1d5f602fb2e840ee6b Mon Sep 17 00:00:00 2001 From: Foxtrek_64 Date: Mon, 15 Mar 2021 17:57:14 -0500 Subject: [PATCH 10/14] Fix ref methods --- .../System.Linq.Queryable/src/System/Linq/CachedReflection.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs index 37ac202c0f45bc..e7bfc2f28b5100 100644 --- a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs +++ b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs @@ -568,7 +568,7 @@ public static MethodInfo SingleOrDefault_TSource_2(Type TSource) => public static MethodInfo SingleOrDefault_TSource_3(Type TSource) => (s_SingleOrDefault_TSource_3 ?? - (s_SingleOrDefault_TSource_3 = new Func, object?, object?>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + (s_SingleOrDefault_TSource_3 = new Func, object, object?>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition())) .MakeGenericMethod(TSource); private static MethodInfo? s_SingleOrDefault_TSource_4; @@ -576,7 +576,7 @@ public static MethodInfo SingleOrDefault_TSource_3(Type TSource) => public static MethodInfo SingleOrDefault_TSource_4(Type TSource) => (s_SingleOrDefault_TSource_4 ?? (s_SingleOrDefault_TSource_4 = new Func, Expression>, object, object?>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition())) - .MakeGenericMethod(); + .MakeGenericMethod(TSource); private static MethodInfo? s_Skip_TSource_2; From 604ee5983e1a42aca4265752647b7ecee6dd78aa Mon Sep 17 00:00:00 2001 From: Foxtrek_64 Date: Mon, 15 Mar 2021 22:36:35 -0500 Subject: [PATCH 11/14] Further adjust nullability --- .../System.Linq.Queryable/src/System/Linq/CachedReflection.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs index e7bfc2f28b5100..1fe10e814fc13f 100644 --- a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs +++ b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs @@ -568,14 +568,14 @@ public static MethodInfo SingleOrDefault_TSource_2(Type TSource) => public static MethodInfo SingleOrDefault_TSource_3(Type TSource) => (s_SingleOrDefault_TSource_3 ?? - (s_SingleOrDefault_TSource_3 = new Func, object, object?>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + (s_SingleOrDefault_TSource_3 = new Func, object, object>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition())) .MakeGenericMethod(TSource); private static MethodInfo? s_SingleOrDefault_TSource_4; public static MethodInfo SingleOrDefault_TSource_4(Type TSource) => (s_SingleOrDefault_TSource_4 ?? - (s_SingleOrDefault_TSource_4 = new Func, Expression>, object, object?>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + (s_SingleOrDefault_TSource_4 = new Func, Expression>, object, object>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition())) .MakeGenericMethod(TSource); private static MethodInfo? s_Skip_TSource_2; From 5a05802504542b655c83d5df7f1c29825b9c2332 Mon Sep 17 00:00:00 2001 From: Foxtrek_64 Date: Tue, 16 Mar 2021 00:56:10 -0500 Subject: [PATCH 12/14] Fix more nullables --- .../src/System/Linq/CachedReflection.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs index 1fe10e814fc13f..2d11a128a33b08 100644 --- a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs +++ b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs @@ -288,14 +288,14 @@ public static MethodInfo FirstOrDefault_TSource_2(Type TSource) => public static MethodInfo FirstOrDefault_TSource_3(Type TSource) => (s_FirstOrDefault_TSource_3 ?? - (s_FirstOrDefault_TSource_3 = new Func, object?, object?>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + (s_FirstOrDefault_TSource_3 = new Func, object, object>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition())) .MakeGenericMethod(); private static MethodInfo? s_FirstOrDefault_TSource_4; public static MethodInfo FirstOrDefault_TSource_4(Type TSource) => (s_FirstOrDefault_TSource_4 ?? - (s_FirstOrDefault_TSource_4 = new Func, Expression>, object, object?>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + (s_FirstOrDefault_TSource_4 = new Func, Expression>, object, object>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition())) .MakeGenericMethod(TSource); private static MethodInfo? s_GroupBy_TSource_TKey_2; @@ -410,14 +410,14 @@ public static MethodInfo LastOrDefault_TSource_2(Type TSource) => public static MethodInfo LastOrDefault_TSource_3(Type TSource) => (s_LastOrDefault_TSource_3 ?? - (s_LastOrDefault_TSource_3 = new Func, object?, object?>(Queryable.LastOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + (s_LastOrDefault_TSource_3 = new Func, object, object>(Queryable.LastOrDefault).GetMethodInfo().GetGenericMethodDefinition())) .MakeGenericMethod(TSource); private static MethodInfo? s_LastOrDefault_TSource_4; public static MethodInfo LastOrDefault_TSource_4(Type TSource) => (s_LastOrDefault_TSource_4 ?? - (s_LastOrDefault_TSource_4 = new Func, Expression>, object, object?>(Queryable.LastOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + (s_LastOrDefault_TSource_4 = new Func, Expression>, object, object>(Queryable.LastOrDefault).GetMethodInfo().GetGenericMethodDefinition())) .MakeGenericMethod(TSource); private static MethodInfo? s_LongCount_TSource_1; From 9e8699e5ec06ff368581c0b346703703a56e51ad Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Wed, 17 Mar 2021 15:53:54 +0000 Subject: [PATCH 13/14] fix failing tests --- .../System.Linq.Queryable/src/System/Linq/CachedReflection.cs | 2 +- .../System.Linq.Queryable/tests/TrimCompatibilityTests.cs | 2 +- src/libraries/System.Linq/src/System/Linq/Last.cs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs index 2d11a128a33b08..be11c326812cb5 100644 --- a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs +++ b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs @@ -289,7 +289,7 @@ public static MethodInfo FirstOrDefault_TSource_2(Type TSource) => public static MethodInfo FirstOrDefault_TSource_3(Type TSource) => (s_FirstOrDefault_TSource_3 ?? (s_FirstOrDefault_TSource_3 = new Func, object, object>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition())) - .MakeGenericMethod(); + .MakeGenericMethod(TSource); private static MethodInfo? s_FirstOrDefault_TSource_4; diff --git a/src/libraries/System.Linq.Queryable/tests/TrimCompatibilityTests.cs b/src/libraries/System.Linq.Queryable/tests/TrimCompatibilityTests.cs index 3d02b20ae0ec13..b1463804c69086 100644 --- a/src/libraries/System.Linq.Queryable/tests/TrimCompatibilityTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/TrimCompatibilityTests.cs @@ -61,7 +61,7 @@ public static void CachedReflectionInfoMethodsNoAnnotations() .Where(m => m.GetParameters().Length > 0); // If you are adding a new method to this class, ensure the method meets these requirements - Assert.Equal(111, methods.Count()); + Assert.Equal(117, methods.Count()); foreach (MethodInfo method in methods) { ParameterInfo[] parameters = method.GetParameters(); diff --git a/src/libraries/System.Linq/src/System/Linq/Last.cs b/src/libraries/System.Linq/src/System/Linq/Last.cs index de730f1c30d26a..b28235bd48b336 100644 --- a/src/libraries/System.Linq/src/System/Linq/Last.cs +++ b/src/libraries/System.Linq/src/System/Linq/Last.cs @@ -45,7 +45,7 @@ public static TSource LastOrDefault(this IEnumerable source, T public static TSource LastOrDefault(this IEnumerable source, Func predicate, TSource defaultValue) { - var last = source.TryGetLast(out bool found); + var last = source.TryGetLast(predicate, out bool found); return found ? last! : defaultValue; } From 9499df695ac73d299422a1dc6e1aa4a12d216c18 Mon Sep 17 00:00:00 2001 From: Foxtrek_64 Date: Wed, 17 Mar 2021 16:07:10 +0000 Subject: [PATCH 14/14] Restore coding style --- src/libraries/System.Linq/src/System/Linq/Last.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libraries/System.Linq/src/System/Linq/Last.cs b/src/libraries/System.Linq/src/System/Linq/Last.cs index b28235bd48b336..16318de2574ccc 100644 --- a/src/libraries/System.Linq/src/System/Linq/Last.cs +++ b/src/libraries/System.Linq/src/System/Linq/Last.cs @@ -30,8 +30,8 @@ public static TSource Last(this IEnumerable source, Func(this IEnumerable source) - => source.TryGetLast(out _); + public static TSource? LastOrDefault(this IEnumerable source) => + source.TryGetLast(out _); public static TSource LastOrDefault(this IEnumerable source, TSource defaultValue)