From c86950a2be236fdc497976882c9c4fc80b957a9b Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Sat, 24 Feb 2024 20:34:09 -0500 Subject: [PATCH 1/7] Make Grouped{Result}Enumerable derive from Iterator --- .../src/System/Linq/Grouping.SpeedOpt.cs | 71 ++-- .../System.Linq/src/System/Linq/Grouping.cs | 320 +++++++++++++----- .../System.Linq/src/System/Linq/Lookup.cs | 2 +- 3 files changed, 264 insertions(+), 129 deletions(-) diff --git a/src/libraries/System.Linq/src/System/Linq/Grouping.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Grouping.SpeedOpt.cs index 0a1d57df361bda..94a24fa8c112b6 100644 --- a/src/libraries/System.Linq/src/System/Linq/Grouping.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Grouping.SpeedOpt.cs @@ -5,51 +5,54 @@ namespace System.Linq { - internal sealed partial class GroupedResultEnumerable : IIListProvider + public static partial class Enumerable { - public TResult[] ToArray() => - Lookup.Create(_source, _keySelector, _elementSelector, _comparer).ToArray(_resultSelector); + internal sealed partial class GroupByResultIterator : IIListProvider + { + public TResult[] ToArray() => + Lookup.Create(_source, _keySelector, _elementSelector, _comparer).ToArray(_resultSelector); - public List ToList() => - Lookup.Create(_source, _keySelector, _elementSelector, _comparer).ToList(_resultSelector); + public List ToList() => + Lookup.Create(_source, _keySelector, _elementSelector, _comparer).ToList(_resultSelector); - public int GetCount(bool onlyIfCheap) => - onlyIfCheap ? -1 : Lookup.Create(_source, _keySelector, _elementSelector, _comparer).Count; - } + public int GetCount(bool onlyIfCheap) => + onlyIfCheap ? -1 : Lookup.Create(_source, _keySelector, _elementSelector, _comparer).Count; + } - internal sealed partial class GroupedResultEnumerable : IIListProvider - { - public TResult[] ToArray() => - Lookup.Create(_source, _keySelector, _comparer).ToArray(_resultSelector); + internal sealed partial class GroupByResultIterator : IIListProvider + { + public TResult[] ToArray() => + Lookup.Create(_source, _keySelector, _comparer).ToArray(_resultSelector); - public List ToList() => - Lookup.Create(_source, _keySelector, _comparer).ToList(_resultSelector); + public List ToList() => + Lookup.Create(_source, _keySelector, _comparer).ToList(_resultSelector); - public int GetCount(bool onlyIfCheap) => - onlyIfCheap ? -1 : Lookup.Create(_source, _keySelector, _comparer).Count; - } + public int GetCount(bool onlyIfCheap) => + onlyIfCheap ? -1 : Lookup.Create(_source, _keySelector, _comparer).Count; + } - internal sealed partial class GroupedEnumerable : IIListProvider> - { - public IGrouping[] ToArray() => - Lookup.Create(_source, _keySelector, _elementSelector, _comparer).ToArray(); + internal sealed partial class GroupByIterator : IIListProvider> + { + public IGrouping[] ToArray() => + Lookup.Create(_source, _keySelector, _elementSelector, _comparer).ToArray(); - public List> ToList() => - Lookup.Create(_source, _keySelector, _elementSelector, _comparer).ToList(); + public List> ToList() => + Lookup.Create(_source, _keySelector, _elementSelector, _comparer).ToList(); - public int GetCount(bool onlyIfCheap) => - onlyIfCheap ? -1 : Lookup.Create(_source, _keySelector, _elementSelector, _comparer).Count; - } + public int GetCount(bool onlyIfCheap) => + onlyIfCheap ? -1 : Lookup.Create(_source, _keySelector, _elementSelector, _comparer).Count; + } - internal sealed partial class GroupedEnumerable : IIListProvider> - { - public IGrouping[] ToArray() => - Lookup.Create(_source, _keySelector, _comparer).ToArray(); + internal sealed partial class GroupByIterator : IIListProvider> + { + public IGrouping[] ToArray() => + Lookup.Create(_source, _keySelector, _comparer).ToArray(); - public List> ToList() => - Lookup.Create(_source, _keySelector, _comparer).ToList(); + public List> ToList() => + Lookup.Create(_source, _keySelector, _comparer).ToList(); - public int GetCount(bool onlyIfCheap) => - onlyIfCheap ? -1 : Lookup.Create(_source, _keySelector, _comparer).Count; + public int GetCount(bool onlyIfCheap) => + onlyIfCheap ? -1 : Lookup.Create(_source, _keySelector, _comparer).Count; + } } } diff --git a/src/libraries/System.Linq/src/System/Linq/Grouping.cs b/src/libraries/System.Linq/src/System/Linq/Grouping.cs index f0517c953ebca3..c9b55b85cfad89 100644 --- a/src/libraries/System.Linq/src/System/Linq/Grouping.cs +++ b/src/libraries/System.Linq/src/System/Linq/Grouping.cs @@ -29,7 +29,7 @@ public static IEnumerable> GroupBy(this return []; } - return new GroupedEnumerable(source, keySelector, comparer); + return new GroupByIterator(source, keySelector, comparer); } public static IEnumerable> GroupBy(this IEnumerable source, Func keySelector, Func elementSelector) => @@ -57,7 +57,7 @@ public static IEnumerable> GroupBy(source, keySelector, elementSelector, comparer); + return new GroupByIterator(source, keySelector, elementSelector, comparer); } public static IEnumerable GroupBy(this IEnumerable source, Func keySelector, Func, TResult> resultSelector) => @@ -85,7 +85,7 @@ public static IEnumerable GroupBy(this IEnumera return []; } - return new GroupedResultEnumerable(source, keySelector, resultSelector, comparer); + return new GroupByResultIterator(source, keySelector, resultSelector, comparer); } public static IEnumerable GroupBy(this IEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector) => @@ -118,7 +118,229 @@ public static IEnumerable GroupBy(thi return []; } - return new GroupedResultEnumerable(source, keySelector, elementSelector, resultSelector, comparer); + return new GroupByResultIterator(source, keySelector, elementSelector, resultSelector, comparer); + } + + internal sealed partial class GroupByResultIterator : Iterator + { + private readonly IEnumerable _source; + private readonly Func _keySelector; + private readonly Func _elementSelector; + private readonly IEqualityComparer? _comparer; + private readonly Func, TResult> _resultSelector; + + private Lookup? _lookup; + private Grouping? _g; + + public GroupByResultIterator(IEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector, IEqualityComparer? comparer) + { + _source = source; + _keySelector = keySelector; + _elementSelector = elementSelector; + _comparer = comparer; + _resultSelector = resultSelector; + } + + public override Iterator Clone() => new GroupByResultIterator(_source, _keySelector, _elementSelector, _resultSelector, _comparer); + + public override bool MoveNext() + { + switch (_state) + { + case 1: + _lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); + _g = _lookup._lastGrouping; + if (_g is not null) + { + _state = 2; + goto ValidItem; + } + break; + + case 2: + Debug.Assert(_g is not null); + Debug.Assert(_lookup is not null); + if (_g != _lookup._lastGrouping) + { + goto ValidItem; + } + break; + } + + Dispose(); + return false; + + ValidItem: + _g = _g._next; + Debug.Assert(_g is not null); + _g.Trim(); + _current = _resultSelector(_g.Key, _g._elements); + return true; + } + } + + internal sealed partial class GroupByResultIterator : Iterator + { + private readonly IEnumerable _source; + private readonly Func _keySelector; + private readonly IEqualityComparer? _comparer; + private readonly Func, TResult> _resultSelector; + + private Lookup? _lookup; + private Grouping? _g; + + public GroupByResultIterator(IEnumerable source, Func keySelector, Func, TResult> resultSelector, IEqualityComparer? comparer) + { + _source = source; + _keySelector = keySelector; + _resultSelector = resultSelector; + _comparer = comparer; + } + + public override Iterator Clone() => new GroupByResultIterator(_source, _keySelector, _resultSelector, _comparer); + + public override bool MoveNext() + { + switch (_state) + { + case 1: + _lookup = Lookup.Create(_source, _keySelector, _comparer); + _g = _lookup._lastGrouping; + if (_g is not null) + { + _state = 2; + goto ValidItem; + } + break; + + case 2: + Debug.Assert(_g is not null); + Debug.Assert(_lookup is not null); + if (_g != _lookup._lastGrouping) + { + goto ValidItem; + } + break; + } + + Dispose(); + return false; + + ValidItem: + _g = _g._next; + Debug.Assert(_g is not null); + _g.Trim(); + _current = _resultSelector(_g.Key, _g._elements); + return true; + } + } + + internal sealed partial class GroupByIterator : Iterator> + { + private readonly IEnumerable _source; + private readonly Func _keySelector; + private readonly Func _elementSelector; + private readonly IEqualityComparer? _comparer; + + private Lookup? _lookup; + private Grouping? _g; + + public GroupByIterator(IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer? comparer) + { + _source = source; + _keySelector = keySelector; + _elementSelector = elementSelector; + _comparer = comparer; + } + + public override Iterator> Clone() => new GroupByIterator(_source, _keySelector, _elementSelector, _comparer); + + public override bool MoveNext() + { + switch (_state) + { + case 1: + _lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); + _g = _lookup._lastGrouping; + if (_g is not null) + { + _state = 2; + goto ValidItem; + } + break; + + case 2: + Debug.Assert(_g is not null); + Debug.Assert(_lookup is not null); + if (_g != _lookup._lastGrouping) + { + goto ValidItem; + } + break; + } + + Dispose(); + return false; + + ValidItem: + _g = _g._next; + Debug.Assert(_g is not null); + _current = _g; + return true; + } + } + + internal sealed partial class GroupByIterator : Iterator> + { + private readonly IEnumerable _source; + private readonly Func _keySelector; + private readonly IEqualityComparer? _comparer; + + private Lookup? _lookup; + private Grouping? _g; + + public GroupByIterator(IEnumerable source, Func keySelector, IEqualityComparer? comparer) + { + _source = source; + _keySelector = keySelector; + _comparer = comparer; + } + + public override Iterator> Clone() => new GroupByIterator(_source, _keySelector, _comparer); + + public override bool MoveNext() + { + switch (_state) + { + case 1: + _lookup = Lookup.Create(_source, _keySelector, _comparer); + _g = _lookup._lastGrouping; + if (_g is not null) + { + _state = 2; + goto ValidItem; + } + break; + + case 2: + Debug.Assert(_g is not null); + Debug.Assert(_lookup is not null); + if (_g != _lookup._lastGrouping) + { + goto ValidItem; + } + break; + } + + Dispose(); + return false; + + ValidItem: + _g = _g._next; + Debug.Assert(_g is not null); + _current = _g; + return true; + } } } @@ -219,94 +441,4 @@ TElement IList.this[int index] } } } - - internal sealed partial class GroupedResultEnumerable : IEnumerable - { - private readonly IEnumerable _source; - private readonly Func _keySelector; - private readonly Func _elementSelector; - private readonly IEqualityComparer? _comparer; - private readonly Func, TResult> _resultSelector; - - public GroupedResultEnumerable(IEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector, IEqualityComparer? comparer) - { - _source = source; - _keySelector = keySelector; - _elementSelector = elementSelector; - _comparer = comparer; - _resultSelector = resultSelector; - } - - public IEnumerator GetEnumerator() - { - Lookup lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); - return lookup.ApplyResultSelector(_resultSelector).GetEnumerator(); - } - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - } - - internal sealed partial class GroupedResultEnumerable : IEnumerable - { - private readonly IEnumerable _source; - private readonly Func _keySelector; - private readonly IEqualityComparer? _comparer; - private readonly Func, TResult> _resultSelector; - - public GroupedResultEnumerable(IEnumerable source, Func keySelector, Func, TResult> resultSelector, IEqualityComparer? comparer) - { - _source = source; - _keySelector = keySelector; - _resultSelector = resultSelector; - _comparer = comparer; - } - - public IEnumerator GetEnumerator() - { - Lookup lookup = Lookup.Create(_source, _keySelector, _comparer); - return lookup.ApplyResultSelector(_resultSelector).GetEnumerator(); - } - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - } - - internal sealed partial class GroupedEnumerable : IEnumerable> - { - private readonly IEnumerable _source; - private readonly Func _keySelector; - private readonly Func _elementSelector; - private readonly IEqualityComparer? _comparer; - - public GroupedEnumerable(IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer? comparer) - { - _source = source; - _keySelector = keySelector; - _elementSelector = elementSelector; - _comparer = comparer; - } - - public IEnumerator> GetEnumerator() => - Lookup.Create(_source, _keySelector, _elementSelector, _comparer).GetEnumerator(); - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - } - - internal sealed partial class GroupedEnumerable : IEnumerable> - { - private readonly IEnumerable _source; - private readonly Func _keySelector; - private readonly IEqualityComparer? _comparer; - - public GroupedEnumerable(IEnumerable source, Func keySelector, IEqualityComparer? comparer) - { - _source = source; - _keySelector = keySelector; - _comparer = comparer; - } - - public IEnumerator> GetEnumerator() => - Lookup.Create(_source, _keySelector, _comparer).GetEnumerator(); - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - } } diff --git a/src/libraries/System.Linq/src/System/Linq/Lookup.cs b/src/libraries/System.Linq/src/System/Linq/Lookup.cs index 7aca5c732e72d1..2b7fbb20b8594f 100644 --- a/src/libraries/System.Linq/src/System/Linq/Lookup.cs +++ b/src/libraries/System.Linq/src/System/Linq/Lookup.cs @@ -76,7 +76,7 @@ public partial class Lookup : ILookup { private readonly IEqualityComparer _comparer; private Grouping[] _groupings; - private protected Grouping? _lastGrouping; + internal Grouping? _lastGrouping; private int _count; internal static Lookup Create(IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer? comparer) From b5d64297d86cace30216b816e44a163738368d03 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Sun, 25 Feb 2024 21:29:13 -0500 Subject: [PATCH 2/7] Make OrderedPartition derive from Iterator --- .../System/Linq/OrderedEnumerable.SpeedOpt.cs | 108 ++++++++++++++++++ .../src/System/Linq/OrderedEnumerable.cs | 33 +----- .../src/System/Linq/Partition.SpeedOpt.cs | 59 ---------- 3 files changed, 110 insertions(+), 90 deletions(-) diff --git a/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.SpeedOpt.cs index 6dfde9efd62301..a666f9f1df5748 100644 --- a/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.SpeedOpt.cs @@ -434,5 +434,113 @@ public override List ToList() return default; } } + + internal sealed class OrderedPartition : Iterator, IPartition + { + private readonly OrderedEnumerable _source; + private readonly int _minIndexInclusive; + private readonly int _maxIndexInclusive; + + private TElement[]? _buffer; + private int[]? _map; + private int _maxIdx; + + public OrderedPartition(OrderedEnumerable source, int minIdxInclusive, int maxIdxInclusive) + { + _source = source; + _minIndexInclusive = minIdxInclusive; + _maxIndexInclusive = maxIdxInclusive; + } + + public override Iterator Clone() => new OrderedPartition(_source, _minIndexInclusive, _maxIndexInclusive); + + public override bool MoveNext() + { + int state = _state; + + Initialized: + if (state > 1) + { + Debug.Assert(_buffer is not null); + Debug.Assert(_map is not null); + + int[] map = _map; + int i = state - 2 + _minIndexInclusive; + if (i <= _maxIdx) + { + _current = _buffer[map[i]]; + _state++; + return true; + } + } + else if (state == 1) + { + TElement[] buffer = _source.ToArray(); + int count = buffer.Length; + if (count > _minIndexInclusive) + { + _maxIdx = _maxIndexInclusive; + if (count <= _maxIdx) + { + _maxIdx = count - 1; + } + + if (_minIndexInclusive == _maxIdx) + { + _current = _source.GetEnumerableSorter().ElementAt(buffer, count, _minIndexInclusive); + _state = -1; + return true; + } + + _map = _source.SortedMap(buffer, _minIndexInclusive, _maxIdx); + _buffer = buffer; + _state = state = 2; + goto Initialized; + } + } + + Dispose(); + return false; + } + + public IPartition? Skip(int count) + { + int minIndex = _minIndexInclusive + count; + return (uint)minIndex > (uint)_maxIndexInclusive ? null : new OrderedPartition(_source, minIndex, _maxIndexInclusive); + } + + public IPartition Take(int count) + { + int maxIndex = _minIndexInclusive + count - 1; + if ((uint)maxIndex >= (uint)_maxIndexInclusive) + { + return this; + } + + return new OrderedPartition(_source, _minIndexInclusive, maxIndex); + } + + public TElement? TryGetElementAt(int index, out bool found) + { + if ((uint)index <= (uint)(_maxIndexInclusive - _minIndexInclusive)) + { + return _source.TryGetElementAt(index + _minIndexInclusive, out found); + } + + found = false; + return default; + } + + public TElement? TryGetFirst(out bool found) => _source.TryGetElementAt(_minIndexInclusive, out found); + + public TElement? TryGetLast(out bool found) => + _source.TryGetLast(_minIndexInclusive, _maxIndexInclusive, out found); + + public TElement[] ToArray() => _source.ToArray(_minIndexInclusive, _maxIndexInclusive); + + public List ToList() => _source.ToList(_minIndexInclusive, _maxIndexInclusive); + + public int GetCount(bool onlyIfCheap) => _source.GetCount(_minIndexInclusive, _maxIndexInclusive, onlyIfCheap); + } } } diff --git a/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.cs b/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.cs index 1e888c974b3ada..b87f4fbc3dadc5 100644 --- a/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.cs +++ b/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.cs @@ -17,39 +17,10 @@ internal abstract partial class OrderedEnumerable : Iterator private protected int[] SortedMap(TElement[] buffer) => GetEnumerableSorter().Sort(buffer, buffer.Length); - private int[] SortedMap(TElement[] buffer, int minIdx, int maxIdx) => + internal int[] SortedMap(TElement[] buffer, int minIdx, int maxIdx) => GetEnumerableSorter().Sort(buffer, buffer.Length, minIdx, maxIdx); - internal IEnumerator GetEnumerator(int minIdx, int maxIdx) - { - TElement[] buffer = _source.ToArray(); - int count = buffer.Length; - if (count > minIdx) - { - if (count <= maxIdx) - { - maxIdx = count - 1; - } - - if (minIdx == maxIdx) - { - yield return GetEnumerableSorter().ElementAt(buffer, count, minIdx); - } - else - { - int[] map = SortedMap(buffer, minIdx, maxIdx); - while (minIdx <= maxIdx) - { - yield return buffer[map[minIdx]]; - ++minIdx; - } - } - } - } - - private EnumerableSorter GetEnumerableSorter() => GetEnumerableSorter(null); - - internal abstract EnumerableSorter GetEnumerableSorter(EnumerableSorter? next); + internal abstract EnumerableSorter GetEnumerableSorter(EnumerableSorter? next = null); internal abstract CachingComparer GetComparer(CachingComparer? childComparer = null); diff --git a/src/libraries/System.Linq/src/System/Linq/Partition.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Partition.SpeedOpt.cs index d7b04eb82b917e..b8e3d32c282747 100644 --- a/src/libraries/System.Linq/src/System/Linq/Partition.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Partition.SpeedOpt.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Collections; using System.Collections.Generic; using System.Diagnostics; @@ -9,64 +8,6 @@ namespace System.Linq { public static partial class Enumerable { - internal sealed class OrderedPartition : IPartition - { - private readonly OrderedEnumerable _source; - private readonly int _minIndexInclusive; - private readonly int _maxIndexInclusive; - - public OrderedPartition(OrderedEnumerable source, int minIdxInclusive, int maxIdxInclusive) - { - _source = source; - _minIndexInclusive = minIdxInclusive; - _maxIndexInclusive = maxIdxInclusive; - } - - public IEnumerator GetEnumerator() => _source.GetEnumerator(_minIndexInclusive, _maxIndexInclusive); - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - - public IPartition? Skip(int count) - { - int minIndex = _minIndexInclusive + count; - return (uint)minIndex > (uint)_maxIndexInclusive ? null : new OrderedPartition(_source, minIndex, _maxIndexInclusive); - } - - public IPartition Take(int count) - { - int maxIndex = _minIndexInclusive + count - 1; - if ((uint)maxIndex >= (uint)_maxIndexInclusive) - { - return this; - } - - return new OrderedPartition(_source, _minIndexInclusive, maxIndex); - } - - public TElement? TryGetElementAt(int index, out bool found) - { - if ((uint)index <= (uint)(_maxIndexInclusive - _minIndexInclusive)) - { - return _source.TryGetElementAt(index + _minIndexInclusive, out found); - } - - found = false; - return default; - } - - public TElement? TryGetFirst(out bool found) => _source.TryGetElementAt(_minIndexInclusive, out found); - - public TElement? TryGetLast(out bool found) => - _source.TryGetLast(_minIndexInclusive, _maxIndexInclusive, out found); - - public TElement[] ToArray() => _source.ToArray(_minIndexInclusive, _maxIndexInclusive); - - public List ToList() => _source.ToList(_minIndexInclusive, _maxIndexInclusive); - - public int GetCount(bool onlyIfCheap) => _source.GetCount(_minIndexInclusive, _maxIndexInclusive, onlyIfCheap); - } - - /// /// An iterator that yields the items of part of an . /// From 51b3adcd002875618a39ab277f5ce7aa63fdaf52 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 26 Feb 2024 13:09:19 -0500 Subject: [PATCH 3/7] Consolidate LINQ's internal IIListProvider/IPartition into base Iterator LINQ has an internal base Iterator class that's used when operators manually implement enumeration rather than having the compiler implement it with an iterator. That base class includes several abstract/virtual methods, including a Select and Where virtual method that have been present since the beginning of LINQ: those are used in a case of A().B(), where B is Select or Where and where A can then improve the processing of B by returning a customized implementation aware of some aspect of both A and B (e.g. the enumerable returned from .Where().Select() includes both the where and select functionality in that single object). Over the years, other specialization has been added to LINQ, in the form of additional internal interfaces: IIListProvider and IPartition. These interfaces similarly enable optimizing sequences A().B(), where B is other LINQ methods, e.g. IIListProvider enables optimizing ToArray/ToList/Count, and IPartition enables optimizing Skip/Take/First/Last/ElementAt. There was a complicated venn diagram of which types implemented which interfaces and base type. This PR merges IIListProvider/IPartition into the base Iterator class. Everything from IPartition is virtual, enabling derivations to specialize just a subset of the functionality, and deduplicating some implementations that were providing the same implementation instead of having it shared in a base. Code that was type testing for the interfaces now type tests for the base class, which means we can delete some type testing where both the interfaces and the base class were previously being tested for. We no longer have this strange split across multiple optimization-focused internal implementation details, and instead have everything consolidated in the one base class. This also means that all of the calls that were previously interface dispatch are now virtual dispatch. --- .../System.Linq/src/System.Linq.csproj | 4 +- .../src/System/Linq/AppendPrepend.SpeedOpt.cs | 66 ++++- .../src/System/Linq/Cast.SpeedOpt.cs | 21 +- .../src/System/Linq/Concat.SpeedOpt.cs | 18 +- .../System.Linq/src/System/Linq/Count.cs | 12 +- .../System/Linq/DefaultIfEmpty.SpeedOpt.cs | 50 +++- .../src/System/Linq/Distinct.SpeedOpt.cs | 10 +- .../System.Linq/src/System/Linq/ElementAt.cs | 72 +++-- .../System.Linq/src/System/Linq/First.cs | 14 +- .../src/System/Linq/Grouping.SpeedOpt.cs | 32 +-- .../src/System/Linq/IIListProvider.cs | 33 --- .../System.Linq/src/System/Linq/IPartition.cs | 47 ---- .../src/System/Linq/Iterator.SpeedOpt.cs | 71 +++++ .../System.Linq/src/System/Linq/Iterator.cs | 31 +-- .../System.Linq/src/System/Linq/Last.cs | 14 +- .../System.Linq/src/System/Linq/OrderBy.cs | 12 +- .../System/Linq/OrderedEnumerable.SpeedOpt.cs | 56 ++-- .../src/System/Linq/OrderedEnumerable.cs | 20 +- .../src/System/Linq/Partition.SpeedOpt.cs | 97 ++++--- .../src/System/Linq/Range.SpeedOpt.cs | 20 +- .../src/System/Linq/Repeat.SpeedOpt.cs | 21 +- .../src/System/Linq/Reverse.SpeedOpt.cs | 26 +- .../src/System/Linq/Select.SpeedOpt.cs | 245 +++++++++++------- .../System.Linq/src/System/Linq/Select.cs | 55 ++-- .../src/System/Linq/SelectMany.SpeedOpt.cs | 8 +- .../src/System/Linq/Skip.SpeedOpt.cs | 4 +- .../System.Linq/src/System/Linq/Skip.cs | 8 +- .../src/System/Linq/Take.SpeedOpt.cs | 18 +- .../src/System/Linq/ToCollection.cs | 12 +- .../src/System/Linq/Union.SpeedOpt.cs | 24 +- .../src/System/Linq/Where.SpeedOpt.cs | 108 +++----- .../System.Linq/src/System/Linq/Where.cs | 60 ++--- .../System.Linq/tests/AppendPrependTests.cs | 22 ++ .../System.Linq/tests/ElementAtTests.cs | 12 +- 34 files changed, 724 insertions(+), 599 deletions(-) delete mode 100644 src/libraries/System.Linq/src/System/Linq/IIListProvider.cs delete mode 100644 src/libraries/System.Linq/src/System/Linq/IPartition.cs create mode 100644 src/libraries/System.Linq/src/System/Linq/Iterator.SpeedOpt.cs diff --git a/src/libraries/System.Linq/src/System.Linq.csproj b/src/libraries/System.Linq/src/System.Linq.csproj index cbc128d961e880..c9bcf4ac0470a8 100644 --- a/src/libraries/System.Linq/src/System.Linq.csproj +++ b/src/libraries/System.Linq/src/System.Linq.csproj @@ -9,6 +9,7 @@ $([MSBuild]::GetTargetPlatformIdentifier('$(TargetFramework)')) true + $(DefineConstants);OPTIMIZE_FOR_SIZE @@ -23,6 +24,7 @@ + @@ -61,8 +63,6 @@ - - diff --git a/src/libraries/System.Linq/src/System/Linq/AppendPrepend.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/AppendPrepend.SpeedOpt.cs index 80ee23998603fa..cfd6ea79527438 100644 --- a/src/libraries/System.Linq/src/System/Linq/AppendPrepend.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/AppendPrepend.SpeedOpt.cs @@ -8,15 +8,6 @@ namespace System.Linq { public static partial class Enumerable { - private abstract partial class AppendPrependIterator : IIListProvider - { - public abstract TSource[] ToArray(); - - public abstract List ToList(); - - public abstract int GetCount(bool onlyIfCheap); - } - private sealed partial class AppendPrepend1Iterator { private TSource[] LazyToArray() @@ -130,14 +121,63 @@ public override List ToList() public override int GetCount(bool onlyIfCheap) { - if (_source is IIListProvider listProv) + if (_source is Iterator iterator) { - int count = listProv.GetCount(onlyIfCheap); + int count = iterator.GetCount(onlyIfCheap); return count == -1 ? -1 : count + 1; } return !onlyIfCheap || _source is ICollection ? _source.Count() + 1 : -1; } + + public override TSource? TryGetFirst(out bool found) + { + if (_appending) + { + TSource? first = _source.TryGetFirst(out found); + if (found) + { + return first; + } + } + + found = true; + return _item; + } + + public override TSource? TryGetLast(out bool found) + { + if (!_appending) + { + TSource? last = _source.TryGetLast(out found); + if (found) + { + return last; + } + } + + found = true; + return _item; + } + + public override TSource? TryGetElementAt(int index, out bool found) + { + if (!_appending) + { + if (index == 0) + { + found = true; + return _item; + } + + index--; + return + _source is Iterator iterator ? iterator.TryGetElementAt(index, out found) : + TryGetElementAtNonIterator(_source, index, out found); + } + + return base.TryGetElementAt(index, out found); + } } private sealed partial class AppendPrependN @@ -232,9 +272,9 @@ public override List ToList() public override int GetCount(bool onlyIfCheap) { - if (_source is IIListProvider listProv) + if (_source is Iterator iterator) { - int count = listProv.GetCount(onlyIfCheap); + int count = iterator.GetCount(onlyIfCheap); return count == -1 ? -1 : count + _appendCount + _prependCount; } diff --git a/src/libraries/System.Linq/src/System/Linq/Cast.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Cast.SpeedOpt.cs index 6ab50403c9c6d9..3cded1625e8df0 100644 --- a/src/libraries/System.Linq/src/System/Linq/Cast.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Cast.SpeedOpt.cs @@ -8,11 +8,11 @@ namespace System.Linq { public static partial class Enumerable { - private sealed partial class CastICollectionIterator : IPartition + private sealed partial class CastICollectionIterator { - public int GetCount(bool onlyIfCheap) => _source.Count; + public override int GetCount(bool onlyIfCheap) => _source.Count; - public TResult[] ToArray() + public override TResult[] ToArray() { TResult[] array = new TResult[_source.Count]; @@ -25,7 +25,7 @@ public TResult[] ToArray() return array; } - public List ToList() + public override List ToList() { List list = new(_source.Count); @@ -37,7 +37,7 @@ public List ToList() return list; } - public TResult? TryGetElementAt(int index, out bool found) + public override TResult? TryGetElementAt(int index, out bool found) { if (index >= 0) { @@ -65,7 +65,7 @@ public List ToList() return default; } - public TResult? TryGetFirst(out bool found) + public override TResult? TryGetFirst(out bool found) { IEnumerator e = _source.GetEnumerator(); try @@ -85,7 +85,7 @@ public List ToList() return default; } - public TResult? TryGetLast(out bool found) + public override TResult? TryGetLast(out bool found) { IEnumerator e = _source.GetEnumerator(); try @@ -110,13 +110,6 @@ public List ToList() (e as IDisposable)?.Dispose(); } } - - public override IEnumerable Select(Func selector) => - new SelectIPartitionIterator(this, selector); - - public IPartition? Skip(int count) => new EnumerablePartition(this, count, -1); - - public IPartition? Take(int count) => new EnumerablePartition(this, 0, count - 1); } } } diff --git a/src/libraries/System.Linq/src/System/Linq/Concat.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Concat.SpeedOpt.cs index a5ad64f78584f6..6ee051c00c2790 100644 --- a/src/libraries/System.Linq/src/System/Linq/Concat.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Concat.SpeedOpt.cs @@ -342,13 +342,9 @@ private TSource[] PreallocatingToArray() } } - private abstract partial class ConcatIterator : IPartition + private abstract partial class ConcatIterator { - public abstract int GetCount(bool onlyIfCheap); - - public abstract TSource[] ToArray(); - - public List ToList() + public override List ToList() { int count = GetCount(onlyIfCheap: true); var list = count != -1 ? new List(count) : new List(); @@ -367,16 +363,6 @@ public List ToList() return list; } - public abstract TSource? TryGetElementAt(int index, out bool found); - - public abstract TSource? TryGetFirst(out bool found); - - public abstract TSource? TryGetLast(out bool found); - - public IPartition? Skip(int count) => new EnumerablePartition(this, count, -1); - - public IPartition? Take(int count) => new EnumerablePartition(this, 0, count - 1); - } } } diff --git a/src/libraries/System.Linq/src/System/Linq/Count.cs b/src/libraries/System.Linq/src/System/Linq/Count.cs index 14f3d457f6ea35..048ebc0891e76c 100644 --- a/src/libraries/System.Linq/src/System/Linq/Count.cs +++ b/src/libraries/System.Linq/src/System/Linq/Count.cs @@ -20,10 +20,12 @@ public static int Count(this IEnumerable source) return collectionoft.Count; } - if (source is IIListProvider listProv) +#if !OPTIMIZE_FOR_SIZE + if (source is Iterator iterator) { - return listProv.GetCount(onlyIfCheap: false); + return iterator.GetCount(onlyIfCheap: false); } +#endif if (source is ICollection collection) { @@ -105,15 +107,17 @@ public static bool TryGetNonEnumeratedCount(this IEnumerable s return true; } - if (source is IIListProvider listProv) +#if !OPTIMIZE_FOR_SIZE + if (source is Iterator iterator) { - int c = listProv.GetCount(onlyIfCheap: true); + int c = iterator.GetCount(onlyIfCheap: true); if (c >= 0) { count = c; return true; } } +#endif if (source is ICollection collection) { diff --git a/src/libraries/System.Linq/src/System/Linq/DefaultIfEmpty.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/DefaultIfEmpty.SpeedOpt.cs index 24619cc438136b..e147870eb0d9f7 100644 --- a/src/libraries/System.Linq/src/System/Linq/DefaultIfEmpty.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/DefaultIfEmpty.SpeedOpt.cs @@ -8,15 +8,15 @@ namespace System.Linq { public static partial class Enumerable { - private sealed partial class DefaultIfEmptyIterator : IIListProvider + private sealed partial class DefaultIfEmptyIterator { - public TSource[] ToArray() + public override TSource[] ToArray() { TSource[] array = _source.ToArray(); return array.Length == 0 ? [_default] : array; } - public List ToList() + public override List ToList() { List list = _source.ToList(); if (list.Count == 0) @@ -27,7 +27,7 @@ public List ToList() return list; } - public int GetCount(bool onlyIfCheap) + public override int GetCount(bool onlyIfCheap) { int count; if (!onlyIfCheap || _source is ICollection || _source is ICollection) @@ -36,11 +36,51 @@ public int GetCount(bool onlyIfCheap) } else { - count = _source is IIListProvider listProv ? listProv.GetCount(onlyIfCheap: true) : -1; + count = _source is Iterator iterator ? iterator.GetCount(onlyIfCheap: true) : -1; } return count == 0 ? 1 : count; } + + public override TSource? TryGetFirst(out bool found) + { + TSource? first = _source.TryGetFirst(out found); + if (found) + { + return first; + } + + found = true; + return default; + } + + public override TSource? TryGetLast(out bool found) + { + TSource? last = _source.TryGetLast(out found); + if (found) + { + return last; + } + + found = true; + return default; + } + + public override TSource? TryGetElementAt(int index, out bool found) + { + TSource? item = _source.TryGetElementAt(index, out found); + if (found) + { + return item; + } + + if (index == 0) + { + found = true; + } + + return default; + } } } } diff --git a/src/libraries/System.Linq/src/System/Linq/Distinct.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Distinct.SpeedOpt.cs index 97ac750843e5bc..a3dbb6458969cf 100644 --- a/src/libraries/System.Linq/src/System/Linq/Distinct.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Distinct.SpeedOpt.cs @@ -7,13 +7,15 @@ namespace System.Linq { public static partial class Enumerable { - private sealed partial class DistinctIterator : IIListProvider + private sealed partial class DistinctIterator { - public TSource[] ToArray() => HashSetToArray(new HashSet(_source, _comparer)); + public override TSource[] ToArray() => HashSetToArray(new HashSet(_source, _comparer)); - public List ToList() => new List(new HashSet(_source, _comparer)); + public override List ToList() => new List(new HashSet(_source, _comparer)); - public int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : new HashSet(_source, _comparer).Count; + public override int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : new HashSet(_source, _comparer).Count; + + public override TSource? TryGetFirst(out bool found) => _source.TryGetFirst(out found); } } } diff --git a/src/libraries/System.Linq/src/System/Linq/ElementAt.cs b/src/libraries/System.Linq/src/System/Linq/ElementAt.cs index b33fcaddff924a..f2bec12be649aa 100644 --- a/src/libraries/System.Linq/src/System/Linq/ElementAt.cs +++ b/src/libraries/System.Linq/src/System/Linq/ElementAt.cs @@ -16,25 +16,13 @@ public static TSource ElementAt(this IEnumerable source, int i ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); } - if (source is IPartition partition) + TSource? element = TryGetElementAt(source, index, out bool found); + if (!found) { - TSource? element = partition.TryGetElementAt(index, out bool found); - if (found) - { - return element!; - } - } - else if (source is IList list) - { - return list[index]; - } - else if (TryGetElement(source, index, out TSource? element)) - { - return element; + ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.index); } - ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.index); - return default; + return element!; } /// Returns the element at a specified index in a sequence. @@ -80,18 +68,7 @@ public static TSource ElementAt(this IEnumerable source, Index ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); } - if (source is IPartition partition) - { - return partition.TryGetElementAt(index, out bool _); - } - - if (source is IList list) - { - return (uint)index < (uint)list.Count ? list[index] : default; - } - - TryGetElement(source, index, out TSource? element); - return element; + return TryGetElementAt(source, index, out _); } /// Returns the element at a specified index in a sequence or a default value if the index is out of range. @@ -125,27 +102,44 @@ public static TSource ElementAt(this IEnumerable source, Index return element; } - private static bool TryGetElement(IEnumerable source, int index, [MaybeNullWhen(false)] out TSource element) + private static TSource? TryGetElementAt(this IEnumerable source, int index, out bool found) => +#if !OPTIMIZE_FOR_SIZE + source is Iterator iterator ? iterator.TryGetElementAt(index, out found) : +#endif + TryGetElementAtNonIterator(source, index, out found); + + private static TSource? TryGetElementAtNonIterator(IEnumerable source, int index, out bool found) { Debug.Assert(source != null); - if (index >= 0) + if (source is IList list) { - using IEnumerator e = source.GetEnumerator(); - while (e.MoveNext()) + if ((uint)index < (uint)list.Count) + { + found = true; + return list[index]; + } + } + else + { + if (index >= 0) { - if (index == 0) + using IEnumerator e = source.GetEnumerator(); + while (e.MoveNext()) { - element = e.Current; - return true; - } + if (index == 0) + { + found = true; + return e.Current; + } - index--; + index--; + } } } - element = default; - return false; + found = false; + return default; } private static bool TryGetElementFromEnd(IEnumerable source, int indexFromEnd, [MaybeNullWhen(false)] out TSource element) diff --git a/src/libraries/System.Linq/src/System/Linq/First.cs b/src/libraries/System.Linq/src/System/Linq/First.cs index 1c62f547d9a03c..6879be5fc3c82e 100644 --- a/src/libraries/System.Linq/src/System/Linq/First.cs +++ b/src/libraries/System.Linq/src/System/Linq/First.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; +using System.Diagnostics; namespace System.Linq { @@ -69,11 +69,15 @@ public static TSource FirstOrDefault(this IEnumerable source, ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); } - if (source is IPartition partition) - { - return partition.TryGetFirst(out found); - } + return +#if !OPTIMIZE_FOR_SIZE + source is Iterator iterator ? iterator.TryGetFirst(out found) : +#endif + TryGetFirstNonIterator(source, out found); + } + private static TSource? TryGetFirstNonIterator(IEnumerable source, out bool found) + { if (source is IList list) { if (list.Count > 0) diff --git a/src/libraries/System.Linq/src/System/Linq/Grouping.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Grouping.SpeedOpt.cs index 94a24fa8c112b6..1153dc7ff4977a 100644 --- a/src/libraries/System.Linq/src/System/Linq/Grouping.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Grouping.SpeedOpt.cs @@ -7,51 +7,51 @@ namespace System.Linq { public static partial class Enumerable { - internal sealed partial class GroupByResultIterator : IIListProvider + internal sealed partial class GroupByResultIterator { - public TResult[] ToArray() => + public override TResult[] ToArray() => Lookup.Create(_source, _keySelector, _elementSelector, _comparer).ToArray(_resultSelector); - public List ToList() => + public override List ToList() => Lookup.Create(_source, _keySelector, _elementSelector, _comparer).ToList(_resultSelector); - public int GetCount(bool onlyIfCheap) => + public override int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : Lookup.Create(_source, _keySelector, _elementSelector, _comparer).Count; } - internal sealed partial class GroupByResultIterator : IIListProvider + internal sealed partial class GroupByResultIterator { - public TResult[] ToArray() => + public override TResult[] ToArray() => Lookup.Create(_source, _keySelector, _comparer).ToArray(_resultSelector); - public List ToList() => + public override List ToList() => Lookup.Create(_source, _keySelector, _comparer).ToList(_resultSelector); - public int GetCount(bool onlyIfCheap) => + public override int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : Lookup.Create(_source, _keySelector, _comparer).Count; } - internal sealed partial class GroupByIterator : IIListProvider> + internal sealed partial class GroupByIterator { - public IGrouping[] ToArray() => + public override IGrouping[] ToArray() => Lookup.Create(_source, _keySelector, _elementSelector, _comparer).ToArray(); - public List> ToList() => + public override List> ToList() => Lookup.Create(_source, _keySelector, _elementSelector, _comparer).ToList(); - public int GetCount(bool onlyIfCheap) => + public override int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : Lookup.Create(_source, _keySelector, _elementSelector, _comparer).Count; } - internal sealed partial class GroupByIterator : IIListProvider> + internal sealed partial class GroupByIterator { - public IGrouping[] ToArray() => + public override IGrouping[] ToArray() => Lookup.Create(_source, _keySelector, _comparer).ToArray(); - public List> ToList() => + public override List> ToList() => Lookup.Create(_source, _keySelector, _comparer).ToList(); - public int GetCount(bool onlyIfCheap) => + public override int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : Lookup.Create(_source, _keySelector, _comparer).Count; } } diff --git a/src/libraries/System.Linq/src/System/Linq/IIListProvider.cs b/src/libraries/System.Linq/src/System/Linq/IIListProvider.cs deleted file mode 100644 index 9eefc6e61e0ce4..00000000000000 --- a/src/libraries/System.Linq/src/System/Linq/IIListProvider.cs +++ /dev/null @@ -1,33 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections.Generic; - -namespace System.Linq -{ - /// - /// An iterator that can produce an array or through an optimized path. - /// - internal interface IIListProvider : IEnumerable - { - /// - /// Produce an array of the sequence through an optimized path. - /// - /// The array. - TElement[] ToArray(); - - /// - /// Produce a of the sequence through an optimized path. - /// - /// The . - List ToList(); - - /// - /// Returns the count of elements in the sequence. - /// - /// If true then the count should only be calculated if doing - /// so is quick (sure or likely to be constant time), otherwise -1 should be returned. - /// The number of elements. - int GetCount(bool onlyIfCheap); - } -} diff --git a/src/libraries/System.Linq/src/System/Linq/IPartition.cs b/src/libraries/System.Linq/src/System/Linq/IPartition.cs deleted file mode 100644 index 86db1921b12f61..00000000000000 --- a/src/libraries/System.Linq/src/System/Linq/IPartition.cs +++ /dev/null @@ -1,47 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace System.Linq -{ - /// - /// An iterator that supports random access and can produce a partial sequence of its items through an optimized path. - /// - internal interface IPartition : IIListProvider - { - /// - /// Creates a new partition that skips the specified number of elements from this sequence. - /// - /// The number of elements to skip. - /// An with the first items removed, or null if known empty. - IPartition? Skip(int count); - - /// - /// Creates a new partition that takes the specified number of elements from this sequence. - /// - /// The number of elements to take. - /// An with only the first items, or null if known empty. - IPartition? Take(int count); - - /// - /// Gets the item associated with a 0-based index in this sequence. - /// - /// The 0-based index to access. - /// true if the sequence contains an element at that index, false otherwise. - /// The element if is true, otherwise, the default value of . - TElement? TryGetElementAt(int index, out bool found); - - /// - /// Gets the first item in this sequence. - /// - /// true if the sequence contains an element, false otherwise. - /// The element if is true, otherwise, the default value of . - TElement? TryGetFirst(out bool found); - - /// - /// Gets the last item in this sequence. - /// - /// true if the sequence contains an element, false otherwise. - /// The element if is true, otherwise, the default value of . - TElement? TryGetLast(out bool found); - } -} diff --git a/src/libraries/System.Linq/src/System/Linq/Iterator.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Iterator.SpeedOpt.cs new file mode 100644 index 00000000000000..d641faa93e3a1f --- /dev/null +++ b/src/libraries/System.Linq/src/System/Linq/Iterator.SpeedOpt.cs @@ -0,0 +1,71 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + internal abstract partial class Iterator + { + /// + /// Produce an array of the sequence through an optimized path. + /// + /// The array. + public abstract TSource[] ToArray(); + + /// + /// Produce a of the sequence through an optimized path. + /// + /// The . + public abstract List ToList(); + + /// + /// Returns the count of elements in the sequence. + /// + /// If true then the count should only be calculated if doing + /// so is quick (sure or likely to be constant time), otherwise -1 should be returned. + /// The number of elements. + public abstract int GetCount(bool onlyIfCheap); + + /// + /// Creates a new iterator that skips the specified number of elements from this sequence. + /// + /// The number of elements to skip. + /// An with the first items removed, or null if known empty. + public virtual Iterator? Skip(int count) => new IEnumerableSkipTakeIterator(this, count, -1); + + /// + /// Creates a new iterator that takes the specified number of elements from this sequence. + /// + /// The number of elements to take. + /// An with only the first items, or null if known empty. + public virtual Iterator? Take(int count) => new IEnumerableSkipTakeIterator(this, 0, count - 1); + + /// + /// Gets the item associated with a 0-based index in this sequence. + /// + /// The 0-based index to access. + /// true if the sequence contains an element at that index, false otherwise. + /// The element if is true, otherwise, the default value of . + public virtual TSource? TryGetElementAt(int index, out bool found) => + index == 0 ? TryGetFirst(out found) : + TryGetElementAtNonIterator(this, index, out found); + + /// + /// Gets the first item in this sequence. + /// + /// true if the sequence contains an element, false otherwise. + /// The element if is true, otherwise, the default value of . + public virtual TSource? TryGetFirst(out bool found) => TryGetFirstNonIterator(this, out found); + + /// + /// Gets the last item in this sequence. + /// + /// true if the sequence contains an element, false otherwise. + /// The element if is true, otherwise, the default value of . + public virtual TSource? TryGetLast(out bool found) => TryGetLastNonIterator(this, out found); + } + } +} diff --git a/src/libraries/System.Linq/src/System/Linq/Iterator.cs b/src/libraries/System.Linq/src/System/Linq/Iterator.cs index b9e8c7b58c0548..933868212428bf 100644 --- a/src/libraries/System.Linq/src/System/Linq/Iterator.cs +++ b/src/libraries/System.Linq/src/System/Linq/Iterator.cs @@ -28,20 +28,13 @@ public static partial class Enumerable /// /// /// - internal abstract class Iterator : IEnumerable, IEnumerator + internal abstract partial class Iterator : IEnumerable, IEnumerator { - private readonly int _threadId; + private readonly int _threadId = Environment.CurrentManagedThreadId; + internal int _state; internal TSource _current = default!; - /// - /// Initializes a new instance of the class. - /// - protected Iterator() - { - _threadId = Environment.CurrentManagedThreadId; - } - /// /// The item currently yielded by this iterator. /// @@ -94,19 +87,21 @@ public IEnumerator GetEnumerator() /// /// The type of the mapped items. /// The selector used to map each item. - public virtual IEnumerable Select(Func selector) - { - return new SelectEnumerableIterator(this, selector); - } + public virtual IEnumerable Select(Func selector) => + new +#if OPTIMIZE_FOR_SIZE + IEnumerableSelectIterator +#else + IteratorSelectIterator +#endif + (this, selector); /// /// Returns an enumerable that filters each item in this iterator based on a predicate. /// /// The predicate used to filter each item. - public virtual IEnumerable Where(Func predicate) - { - return new WhereEnumerableIterator(this, predicate); - } + public virtual IEnumerable Where(Func predicate) => + new IEnumerableWhereIterator(this, predicate); object? IEnumerator.Current => Current; diff --git a/src/libraries/System.Linq/src/System/Linq/Last.cs b/src/libraries/System.Linq/src/System/Linq/Last.cs index 568f0d8670faf8..e7052c2b48dd6d 100644 --- a/src/libraries/System.Linq/src/System/Linq/Last.cs +++ b/src/libraries/System.Linq/src/System/Linq/Last.cs @@ -68,11 +68,15 @@ public static TSource LastOrDefault(this IEnumerable source, F ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); } - if (source is IPartition partition) - { - return partition.TryGetLast(out found); - } + return +#if !OPTIMIZE_FOR_SIZE + source is Iterator iterator ? iterator.TryGetLast(out found) : +#endif + TryGetLastNonIterator(source, out found); + } + private static TSource? TryGetLastNonIterator(IEnumerable source, out bool found) + { if (source is IList list) { int count = list.Count; @@ -117,7 +121,7 @@ public static TSource LastOrDefault(this IEnumerable source, F ThrowHelper.ThrowArgumentNullException(ExceptionArgument.predicate); } - if (source is OrderedEnumerable ordered) + if (source is OrderedIterator ordered) { return ordered.TryGetLast(predicate, out found); } diff --git a/src/libraries/System.Linq/src/System/Linq/OrderBy.cs b/src/libraries/System.Linq/src/System/Linq/OrderBy.cs index aa7a08ee81c9f6..700410fe579f3f 100644 --- a/src/libraries/System.Linq/src/System/Linq/OrderBy.cs +++ b/src/libraries/System.Linq/src/System/Linq/OrderBy.cs @@ -44,14 +44,14 @@ public static IOrderedEnumerable Order(this IEnumerable source) => /// public static IOrderedEnumerable Order(this IEnumerable source, IComparer? comparer) => TypeIsImplicitlyStable() && (comparer is null || comparer == Comparer.Default) ? - new OrderedImplicitlyStableEnumerable(source, descending: false) : + new ImplicitlyStableOrderedIterator(source, descending: false) : OrderBy(source, EnumerableSorter.IdentityFunc, comparer); public static IOrderedEnumerable OrderBy(this IEnumerable source, Func keySelector) - => new OrderedEnumerable(source, keySelector, null, false, null); + => new OrderedIterator(source, keySelector, null, false, null); public static IOrderedEnumerable OrderBy(this IEnumerable source, Func keySelector, IComparer? comparer) - => new OrderedEnumerable(source, keySelector, comparer, false, null); + => new OrderedIterator(source, keySelector, comparer, false, null); /// /// Sorts the elements of a sequence in descending order. @@ -89,14 +89,14 @@ public static IOrderedEnumerable OrderDescending(this IEnumerable sourc /// public static IOrderedEnumerable OrderDescending(this IEnumerable source, IComparer? comparer) => TypeIsImplicitlyStable() && (comparer is null || comparer == Comparer.Default) ? - new OrderedImplicitlyStableEnumerable(source, descending: true) : + new ImplicitlyStableOrderedIterator(source, descending: true) : OrderByDescending(source, EnumerableSorter.IdentityFunc, comparer); public static IOrderedEnumerable OrderByDescending(this IEnumerable source, Func keySelector) => - new OrderedEnumerable(source, keySelector, null, true, null); + new OrderedIterator(source, keySelector, null, true, null); public static IOrderedEnumerable OrderByDescending(this IEnumerable source, Func keySelector, IComparer? comparer) => - new OrderedEnumerable(source, keySelector, comparer, true, null); + new OrderedIterator(source, keySelector, comparer, true, null); public static IOrderedEnumerable ThenBy(this IOrderedEnumerable source, Func keySelector) { diff --git a/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.SpeedOpt.cs index a666f9f1df5748..ffa533dd84e6f3 100644 --- a/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.SpeedOpt.cs @@ -10,9 +10,9 @@ namespace System.Linq { public static partial class Enumerable { - internal abstract partial class OrderedEnumerable : IPartition + internal abstract partial class OrderedIterator { - public virtual TElement[] ToArray() + public override TElement[] ToArray() { TElement[] buffer = _source.ToArray(); if (buffer.Length == 0) @@ -25,7 +25,7 @@ public virtual TElement[] ToArray() return array; } - public virtual List ToList() + public override List ToList() { TElement[] buffer = _source.ToArray(); @@ -47,11 +47,11 @@ private void Fill(TElement[] buffer, Span destination) } } - public int GetCount(bool onlyIfCheap) + public override int GetCount(bool onlyIfCheap) { - if (_source is IIListProvider listProv) + if (_source is Iterator iterator) { - return listProv.GetCount(onlyIfCheap); + return iterator.GetCount(onlyIfCheap); } return !onlyIfCheap || _source is ICollection || _source is ICollection ? _source.Count() : -1; @@ -133,11 +133,11 @@ internal int GetCount(int minIdx, int maxIdx, bool onlyIfCheap) return (count <= maxIdx ? count : maxIdx + 1) - minIdx; } - public IPartition Skip(int count) => new OrderedPartition(this, count, int.MaxValue); + public override Iterator Skip(int count) => new SkipTakeOrderedIterator(this, count, int.MaxValue); - public IPartition Take(int count) => new OrderedPartition(this, 0, count - 1); + public override Iterator Take(int count) => new SkipTakeOrderedIterator(this, 0, count - 1); - public TElement? TryGetElementAt(int index, out bool found) + public override TElement? TryGetElementAt(int index, out bool found) { if (index == 0) { @@ -158,7 +158,7 @@ internal int GetCount(int minIdx, int maxIdx, bool onlyIfCheap) return default; } - public virtual TElement? TryGetFirst(out bool found) + public override TElement? TryGetFirst(out bool found) { CachingComparer comparer = GetComparer(); using (IEnumerator e = _source.GetEnumerator()) @@ -185,7 +185,7 @@ internal int GetCount(int minIdx, int maxIdx, bool onlyIfCheap) } } - public virtual TElement? TryGetLast(out bool found) + public override TElement? TryGetLast(out bool found) { using (IEnumerator e = _source.GetEnumerator()) { @@ -247,7 +247,7 @@ private TElement Last(TElement[] items) } } - internal sealed partial class OrderedEnumerable : OrderedEnumerable + internal sealed partial class OrderedIterator : OrderedIterator { // For complicated cases, rely on the base implementation that's more comprehensive. // For the simple case of OrderBy(...).First() or OrderByDescending(...).First() (i.e. where @@ -358,7 +358,7 @@ internal sealed partial class OrderedEnumerable : OrderedEnumera } } - internal sealed partial class OrderedImplicitlyStableEnumerable : OrderedEnumerable + internal sealed partial class ImplicitlyStableOrderedIterator : OrderedIterator { public override TElement[] ToArray() { @@ -435,9 +435,9 @@ public override List ToList() } } - internal sealed class OrderedPartition : Iterator, IPartition + internal sealed class SkipTakeOrderedIterator : Iterator { - private readonly OrderedEnumerable _source; + private readonly OrderedIterator _source; private readonly int _minIndexInclusive; private readonly int _maxIndexInclusive; @@ -445,20 +445,20 @@ internal sealed class OrderedPartition : Iterator, IPartitio private int[]? _map; private int _maxIdx; - public OrderedPartition(OrderedEnumerable source, int minIdxInclusive, int maxIdxInclusive) + public SkipTakeOrderedIterator(OrderedIterator source, int minIdxInclusive, int maxIdxInclusive) { _source = source; _minIndexInclusive = minIdxInclusive; _maxIndexInclusive = maxIdxInclusive; } - public override Iterator Clone() => new OrderedPartition(_source, _minIndexInclusive, _maxIndexInclusive); + public override Iterator Clone() => new SkipTakeOrderedIterator(_source, _minIndexInclusive, _maxIndexInclusive); public override bool MoveNext() { int state = _state; - Initialized: + Initialized: if (state > 1) { Debug.Assert(_buffer is not null); @@ -503,13 +503,13 @@ public override bool MoveNext() return false; } - public IPartition? Skip(int count) + public override Iterator? Skip(int count) { int minIndex = _minIndexInclusive + count; - return (uint)minIndex > (uint)_maxIndexInclusive ? null : new OrderedPartition(_source, minIndex, _maxIndexInclusive); + return (uint)minIndex > (uint)_maxIndexInclusive ? null : new SkipTakeOrderedIterator(_source, minIndex, _maxIndexInclusive); } - public IPartition Take(int count) + public override Iterator Take(int count) { int maxIndex = _minIndexInclusive + count - 1; if ((uint)maxIndex >= (uint)_maxIndexInclusive) @@ -517,10 +517,10 @@ public IPartition Take(int count) return this; } - return new OrderedPartition(_source, _minIndexInclusive, maxIndex); + return new SkipTakeOrderedIterator(_source, _minIndexInclusive, maxIndex); } - public TElement? TryGetElementAt(int index, out bool found) + public override TElement? TryGetElementAt(int index, out bool found) { if ((uint)index <= (uint)(_maxIndexInclusive - _minIndexInclusive)) { @@ -531,16 +531,16 @@ public IPartition Take(int count) return default; } - public TElement? TryGetFirst(out bool found) => _source.TryGetElementAt(_minIndexInclusive, out found); + public override TElement? TryGetFirst(out bool found) => _source.TryGetElementAt(_minIndexInclusive, out found); - public TElement? TryGetLast(out bool found) => + public override TElement? TryGetLast(out bool found) => _source.TryGetLast(_minIndexInclusive, _maxIndexInclusive, out found); - public TElement[] ToArray() => _source.ToArray(_minIndexInclusive, _maxIndexInclusive); + public override TElement[] ToArray() => _source.ToArray(_minIndexInclusive, _maxIndexInclusive); - public List ToList() => _source.ToList(_minIndexInclusive, _maxIndexInclusive); + public override List ToList() => _source.ToList(_minIndexInclusive, _maxIndexInclusive); - public int GetCount(bool onlyIfCheap) => _source.GetCount(_minIndexInclusive, _maxIndexInclusive, onlyIfCheap); + public override int GetCount(bool onlyIfCheap) => _source.GetCount(_minIndexInclusive, _maxIndexInclusive, onlyIfCheap); } } } diff --git a/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.cs b/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.cs index b87f4fbc3dadc5..40ee075a9b3141 100644 --- a/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.cs +++ b/src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.cs @@ -9,11 +9,11 @@ namespace System.Linq { public static partial class Enumerable { - internal abstract partial class OrderedEnumerable : Iterator, IOrderedEnumerable + internal abstract partial class OrderedIterator : Iterator, IOrderedEnumerable { internal readonly IEnumerable _source; - protected OrderedEnumerable(IEnumerable source) => _source = source; + protected OrderedIterator(IEnumerable source) => _source = source; private protected int[] SortedMap(TElement[] buffer) => GetEnumerableSorter().Sort(buffer, buffer.Length); @@ -27,7 +27,7 @@ internal int[] SortedMap(TElement[] buffer, int minIdx, int maxIdx) => IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); IOrderedEnumerable IOrderedEnumerable.CreateOrderedEnumerable(Func keySelector, IComparer? comparer, bool descending) => - new OrderedEnumerable(_source, keySelector, comparer, @descending, this); + new OrderedIterator(_source, keySelector, comparer, @descending, this); public TElement? TryGetLast(Func predicate, out bool found) { @@ -63,16 +63,16 @@ IOrderedEnumerable IOrderedEnumerable.CreateOrderedEnumerabl } } - internal sealed partial class OrderedEnumerable : OrderedEnumerable + internal sealed partial class OrderedIterator : OrderedIterator { - private readonly OrderedEnumerable? _parent; + private readonly OrderedIterator? _parent; private readonly Func _keySelector; private readonly IComparer _comparer; private readonly bool _descending; private TElement[]? _buffer; private int[]? _map; - internal OrderedEnumerable(IEnumerable source, Func keySelector, IComparer? comparer, bool descending, OrderedEnumerable? parent) : + internal OrderedIterator(IEnumerable source, Func keySelector, IComparer? comparer, bool descending, OrderedIterator? parent) : base(source) { if (source is null) @@ -90,7 +90,7 @@ internal OrderedEnumerable(IEnumerable source, Func ke _descending = descending; } - public override Iterator Clone() => new OrderedEnumerable(_source, _keySelector, _comparer, _descending, _parent); + public override Iterator Clone() => new OrderedIterator(_source, _keySelector, _comparer, _descending, _parent); internal override EnumerableSorter GetEnumerableSorter(EnumerableSorter? next) { @@ -165,12 +165,12 @@ public override void Dispose() } /// An ordered enumerable used by Order/OrderDescending for Ts that are bitwise indistinguishable for any considered equal. - internal sealed partial class OrderedImplicitlyStableEnumerable : OrderedEnumerable + internal sealed partial class ImplicitlyStableOrderedIterator : OrderedIterator { private readonly bool _descending; private TElement[]? _buffer; - public OrderedImplicitlyStableEnumerable(IEnumerable source, bool descending) : base(source) + public ImplicitlyStableOrderedIterator(IEnumerable source, bool descending) : base(source) { Debug.Assert(TypeIsImplicitlyStable()); @@ -182,7 +182,7 @@ public OrderedImplicitlyStableEnumerable(IEnumerable source, bool desc _descending = descending; } - public override Iterator Clone() => new OrderedImplicitlyStableEnumerable(_source, _descending); + public override Iterator Clone() => new ImplicitlyStableOrderedIterator(_source, _descending); internal override CachingComparer GetComparer(CachingComparer? childComparer) => childComparer == null ? diff --git a/src/libraries/System.Linq/src/System/Linq/Partition.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Partition.SpeedOpt.cs index b8e3d32c282747..021097de87b42a 100644 --- a/src/libraries/System.Linq/src/System/Linq/Partition.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Partition.SpeedOpt.cs @@ -13,13 +13,13 @@ public static partial class Enumerable /// /// The type of the source list. [DebuggerDisplay("Count = {Count}")] - private sealed class ListPartition : Iterator, IPartition, IList, IReadOnlyList + private sealed class IListSkipTakeIterator : Iterator, IList, IReadOnlyList { private readonly IList _source; private readonly int _minIndexInclusive; private readonly int _maxIndexInclusive; - public ListPartition(IList source, int minIndexInclusive, int maxIndexInclusive) + public IListSkipTakeIterator(IList source, int minIndexInclusive, int maxIndexInclusive) { Debug.Assert(source != null); Debug.Assert(minIndexInclusive >= 0); @@ -30,7 +30,7 @@ public ListPartition(IList source, int minIndexInclusive, int maxIndexI } public override Iterator Clone() => - new ListPartition(_source, _minIndexInclusive, _maxIndexInclusive); + new IListSkipTakeIterator(_source, _minIndexInclusive, _maxIndexInclusive); public override bool MoveNext() { @@ -50,21 +50,21 @@ public override bool MoveNext() } public override IEnumerable Select(Func selector) => - new SelectListPartitionIterator(_source, selector, _minIndexInclusive, _maxIndexInclusive); + new IListSkipTakeSelectIterator(_source, selector, _minIndexInclusive, _maxIndexInclusive); - public IPartition? Skip(int count) + public override Iterator? Skip(int count) { int minIndex = _minIndexInclusive + count; - return (uint)minIndex > (uint)_maxIndexInclusive ? null : new ListPartition(_source, minIndex, _maxIndexInclusive); + return (uint)minIndex > (uint)_maxIndexInclusive ? null : new IListSkipTakeIterator(_source, minIndex, _maxIndexInclusive); } - public IPartition Take(int count) + public override Iterator Take(int count) { int maxIndex = _minIndexInclusive + count - 1; - return (uint)maxIndex >= (uint)_maxIndexInclusive ? this : new ListPartition(_source, _minIndexInclusive, maxIndex); + return (uint)maxIndex >= (uint)_maxIndexInclusive ? this : new IListSkipTakeIterator(_source, _minIndexInclusive, maxIndex); } - public TSource? TryGetElementAt(int index, out bool found) + public override TSource? TryGetElementAt(int index, out bool found) { if ((uint)index <= (uint)(_maxIndexInclusive - _minIndexInclusive) && index < _source.Count - _minIndexInclusive) { @@ -76,7 +76,7 @@ public IPartition Take(int count) return default; } - public TSource? TryGetFirst(out bool found) + public override TSource? TryGetFirst(out bool found) { if (_source.Count > _minIndexInclusive) { @@ -88,7 +88,7 @@ public IPartition Take(int count) return default; } - public TSource? TryGetLast(out bool found) + public override TSource? TryGetLast(out bool found) { int lastIndex = _source.Count - 1; if (lastIndex >= _minIndexInclusive) @@ -115,9 +115,9 @@ public int Count } } - public int GetCount(bool onlyIfCheap) => Count; + public override int GetCount(bool onlyIfCheap) => Count; - public TSource[] ToArray() + public override TSource[] ToArray() { int count = Count; if (count == 0) @@ -130,16 +130,16 @@ public TSource[] ToArray() return array; } - public List ToList() + public override List ToList() { int count = Count; - if (count == 0) + + List list = []; + if (count != 0) { - return new List(); + Fill(_source, SetCountAndGetSpan(list, count), _minIndexInclusive); } - List list = new List(count); - Fill(_source, SetCountAndGetSpan(list, count), _minIndexInclusive); return list; } @@ -199,7 +199,7 @@ public TSource this[int index] /// An iterator that yields the items of part of an . /// /// The type of the source enumerable. - private sealed class EnumerablePartition : Iterator, IPartition + private sealed class IEnumerableSkipTakeIterator : Iterator { private readonly IEnumerable _source; private readonly int _minIndexInclusive; @@ -207,7 +207,7 @@ private sealed class EnumerablePartition : Iterator, IPartitio // If this is -1, it's impossible to set a limit on the count. private IEnumerator? _enumerator; - internal EnumerablePartition(IEnumerable source, int minIndexInclusive, int maxIndexInclusive) + internal IEnumerableSkipTakeIterator(IEnumerable source, int minIndexInclusive, int maxIndexInclusive) { Debug.Assert(source != null); Debug.Assert(!(source is IList), $"The caller needs to check for {nameof(IList)}."); @@ -231,7 +231,7 @@ internal EnumerablePartition(IEnumerable source, int minIndexInclusive, private int Limit => _maxIndexInclusive + 1 - _minIndexInclusive; // This is that upper bound. public override Iterator Clone() => - new EnumerablePartition(_source, _minIndexInclusive, _maxIndexInclusive); + new IEnumerableSkipTakeIterator(_source, _minIndexInclusive, _maxIndexInclusive); public override void Dispose() { @@ -244,7 +244,7 @@ public override void Dispose() base.Dispose(); } - public int GetCount(bool onlyIfCheap) + public override int GetCount(bool onlyIfCheap) { if (onlyIfCheap) { @@ -261,7 +261,7 @@ public int GetCount(bool onlyIfCheap) using (IEnumerator en = _source.GetEnumerator()) { // We only want to iterate up to _maxIndexInclusive + 1. - // Past that, we know the enumerable will be able to fit this partition, + // Past that, we know the enumerable will be able to fit this subset, // so the count will just be _maxIndexInclusive + 1 - _minIndexInclusive. // Note that it is possible for _maxIndexInclusive to be int.MaxValue here, @@ -273,7 +273,6 @@ public int GetCount(bool onlyIfCheap) Debug.Assert(count != (uint)int.MaxValue + 1 || _minIndexInclusive > 0, "Our return value will be incorrect."); return Math.Max((int)count - _minIndexInclusive, 0); } - } public override bool MoveNext() @@ -325,10 +324,7 @@ public override bool MoveNext() return false; } - public override IEnumerable Select(Func selector) => - new SelectIPartitionIterator(this, selector); - - public IPartition? Skip(int count) + public override Iterator? Skip(int count) { int minIndex = _minIndexInclusive + count; @@ -339,7 +335,7 @@ public override IEnumerable Select(Func sele // If we don't know our max count and minIndex can no longer fit in a positive int, // then we will need to wrap ourselves in another iterator. // This can happen, for example, during e.Skip(int.MaxValue).Skip(int.MaxValue). - return new EnumerablePartition(this, count, -1); + return new IEnumerableSkipTakeIterator(this, count, -1); } } else if ((uint)minIndex > (uint)_maxIndexInclusive) @@ -351,10 +347,10 @@ public override IEnumerable Select(Func sele } Debug.Assert(minIndex >= 0, $"We should have taken care of all cases when {nameof(minIndex)} overflows."); - return new EnumerablePartition(_source, minIndex, _maxIndexInclusive); + return new IEnumerableSkipTakeIterator(_source, minIndex, _maxIndexInclusive); } - public IPartition Take(int count) + public override Iterator Take(int count) { int maxIndex = _minIndexInclusive + count - 1; if (!HasLimit) @@ -367,7 +363,7 @@ public IPartition Take(int count) // _minIndexInclusive (which is count - 1) must fit in an int. // Example: e.Skip(50).Take(int.MaxValue). - return new EnumerablePartition(this, 0, count - 1); + return new IEnumerableSkipTakeIterator(this, 0, count - 1); } } else if ((uint)maxIndex >= (uint)_maxIndexInclusive) @@ -379,18 +375,23 @@ public IPartition Take(int count) } Debug.Assert(maxIndex >= 0, $"We should have taken care of all cases when {nameof(maxIndex)} overflows."); - return new EnumerablePartition(_source, _minIndexInclusive, maxIndex); + return new IEnumerableSkipTakeIterator(_source, _minIndexInclusive, maxIndex); } - public TSource? TryGetElementAt(int index, out bool found) + public override TSource? TryGetElementAt(int index, out bool found) { // If the index is negative or >= our max count, return early. if (index >= 0 && (!HasLimit || index < Limit)) { - using (IEnumerator en = _source.GetEnumerator()) + Debug.Assert(_minIndexInclusive + index >= 0, $"Adding {nameof(index)} caused {nameof(_minIndexInclusive)} to overflow."); + + if (_source is Iterator iterator) { - Debug.Assert(_minIndexInclusive + index >= 0, $"Adding {nameof(index)} caused {nameof(_minIndexInclusive)} to overflow."); + return iterator.TryGetElementAt(_minIndexInclusive + index, out found); + } + using (IEnumerator en = _source.GetEnumerator()) + { if (SkipBefore(_minIndexInclusive + index, en) && en.MoveNext()) { found = true; @@ -403,8 +404,15 @@ public IPartition Take(int count) return default; } - public TSource? TryGetFirst(out bool found) + public override TSource? TryGetFirst(out bool found) { + Debug.Assert(!HasLimit || Limit > 0); + + if (_source is Iterator iterator) + { + return iterator.TryGetElementAt(_minIndexInclusive, out found); + } + using (IEnumerator en = _source.GetEnumerator()) { if (SkipBeforeFirst(en) && en.MoveNext()) @@ -418,8 +426,17 @@ public IPartition Take(int count) return default; } - public TSource? TryGetLast(out bool found) + public override TSource? TryGetLast(out bool found) { + if (_source is Iterator iterator && + iterator.GetCount(onlyIfCheap: true) is int count && + count >= _minIndexInclusive) + { + return !HasLimit ? + iterator.TryGetLast(out found) : + iterator.TryGetElementAt(_maxIndexInclusive, out found); + } + using (IEnumerator en = _source.GetEnumerator()) { if (SkipBeforeFirst(en) && en.MoveNext()) @@ -444,7 +461,7 @@ public IPartition Take(int count) return default; } - public TSource[] ToArray() + public override TSource[] ToArray() { using (IEnumerator en = _source.GetEnumerator()) { @@ -472,7 +489,7 @@ public TSource[] ToArray() return []; } - public List ToList() + public override List ToList() { var list = new List(); diff --git a/src/libraries/System.Linq/src/System/Linq/Range.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Range.SpeedOpt.cs index c125673e16d36e..bde2cede4b1b9a 100644 --- a/src/libraries/System.Linq/src/System/Linq/Range.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Range.SpeedOpt.cs @@ -10,14 +10,14 @@ namespace System.Linq { public static partial class Enumerable { - private sealed partial class RangeIterator : IPartition, IList, IReadOnlyList + private sealed partial class RangeIterator : IList, IReadOnlyList { public override IEnumerable Select(Func selector) { - return new SelectRangeIterator(_start, _end, selector); + return new RangeSelectIterator(_start, _end, selector); } - public int[] ToArray() + public override int[] ToArray() { int start = _start; int[] array = new int[_end - start]; @@ -25,7 +25,7 @@ public int[] ToArray() return array; } - public List ToList() + public override List ToList() { (int start, int end) = (_start, _end); List list = new List(end - start); @@ -67,11 +67,11 @@ private static void Fill(Span destination, int value) } } - public int GetCount(bool onlyIfCheap) => _end - _start; + public override int GetCount(bool onlyIfCheap) => _end - _start; public int Count => _end - _start; - public IPartition? Skip(int count) + public override Iterator? Skip(int count) { if (count >= _end - _start) { @@ -81,7 +81,7 @@ private static void Fill(Span destination, int value) return new RangeIterator(_start + count, _end - _start - count); } - public IPartition Take(int count) + public override Iterator Take(int count) { int curCount = _end - _start; if (count >= curCount) @@ -92,7 +92,7 @@ public IPartition Take(int count) return new RangeIterator(_start, count); } - public int TryGetElementAt(int index, out bool found) + public override int TryGetElementAt(int index, out bool found) { if ((uint)index < (uint)(_end - _start)) { @@ -104,13 +104,13 @@ public int TryGetElementAt(int index, out bool found) return 0; } - public int TryGetFirst(out bool found) + public override int TryGetFirst(out bool found) { found = true; return _start; } - public int TryGetLast(out bool found) + public override int TryGetLast(out bool found) { found = true; return _end - 1; diff --git a/src/libraries/System.Linq/src/System/Linq/Repeat.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Repeat.SpeedOpt.cs index 3c25ee20ba5f2c..5ca5b2625b1078 100644 --- a/src/libraries/System.Linq/src/System/Linq/Repeat.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Repeat.SpeedOpt.cs @@ -8,12 +8,9 @@ namespace System.Linq { public static partial class Enumerable { - private sealed partial class RepeatIterator : IPartition, IList, IReadOnlyList + private sealed partial class RepeatIterator : IList, IReadOnlyList { - public override IEnumerable Select(Func selector) => - new SelectIPartitionIterator(this, selector); - - public TResult[] ToArray() + public override TResult[] ToArray() { TResult[] array = new TResult[_count]; if (_current != null) @@ -24,7 +21,7 @@ public TResult[] ToArray() return array; } - public List ToList() + public override List ToList() { List list = new List(_count); SetCountAndGetSpan(list, _count).Fill(_current); @@ -32,11 +29,11 @@ public List ToList() return list; } - public int GetCount(bool onlyIfCheap) => _count; + public override int GetCount(bool onlyIfCheap) => _count; public int Count => _count; - public IPartition? Skip(int count) + public override Iterator? Skip(int count) { Debug.Assert(count > 0); @@ -48,7 +45,7 @@ public List ToList() return new RepeatIterator(_current, _count - count); } - public IPartition Take(int count) + public override Iterator Take(int count) { Debug.Assert(count > 0); @@ -60,7 +57,7 @@ public IPartition Take(int count) return new RepeatIterator(_current, count); } - public TResult? TryGetElementAt(int index, out bool found) + public override TResult? TryGetElementAt(int index, out bool found) { if ((uint)index < (uint)_count) { @@ -72,13 +69,13 @@ public IPartition Take(int count) return default; } - public TResult TryGetFirst(out bool found) + public override TResult TryGetFirst(out bool found) { found = true; return _current; } - public TResult TryGetLast(out bool found) + public override TResult TryGetLast(out bool found) { found = true; return _current; diff --git a/src/libraries/System.Linq/src/System/Linq/Reverse.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Reverse.SpeedOpt.cs index bb301cc3084899..d1ec26de879afd 100644 --- a/src/libraries/System.Linq/src/System/Linq/Reverse.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Reverse.SpeedOpt.cs @@ -7,28 +7,28 @@ namespace System.Linq { public static partial class Enumerable { - private sealed partial class ReverseIterator : IPartition + private sealed partial class ReverseIterator { - public TSource[] ToArray() + public override TSource[] ToArray() { TSource[] array = _source.ToArray(); Array.Reverse(array); return array; } - public List ToList() + public override List ToList() { List list = _source.ToList(); list.Reverse(); return list; } - public int GetCount(bool onlyIfCheap) => + public override int GetCount(bool onlyIfCheap) => !onlyIfCheap ? _source.Count() : TryGetNonEnumeratedCount(_source, out int count) ? count : -1; - public TSource? TryGetElementAt(int index, out bool found) + public override TSource? TryGetElementAt(int index, out bool found) { if (_source is IList list) { @@ -53,11 +53,11 @@ public int GetCount(bool onlyIfCheap) => return default; } - public TSource? TryGetFirst(out bool found) + public override TSource? TryGetFirst(out bool found) { - if (_source is IPartition partition) + if (_source is Iterator iterator) { - return partition.TryGetLast(out found); + return iterator.TryGetLast(out found); } else if (_source is IList list) { @@ -89,11 +89,11 @@ public int GetCount(bool onlyIfCheap) => return default; } - public TSource? TryGetLast(out bool found) + public override TSource? TryGetLast(out bool found) { - if (_source is IPartition partition) + if (_source is Iterator iterator) { - return partition.TryGetFirst(out found); + return iterator.TryGetFirst(out found); } else if (_source is IList list) { @@ -116,10 +116,6 @@ public int GetCount(bool onlyIfCheap) => found = false; return default; } - - public IPartition? Skip(int count) => new EnumerablePartition(this, count, -1); - - public IPartition? Take(int count) => new EnumerablePartition(this, 0, count - 1); } } } diff --git a/src/libraries/System.Linq/src/System/Linq/Select.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Select.SpeedOpt.cs index 06f87db9c0f786..f55b656033e3ab 100644 --- a/src/libraries/System.Linq/src/System/Linq/Select.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Select.SpeedOpt.cs @@ -10,15 +10,9 @@ namespace System.Linq { public static partial class Enumerable { - static partial void CreateSelectIPartitionIterator( - Func selector, IPartition partition, ref IEnumerable? result) + private sealed partial class IEnumerableSelectIterator { - result = new SelectIPartitionIterator(partition, selector); - } - - private sealed partial class SelectEnumerableIterator : IIListProvider - { - public TResult[] ToArray() + public override TResult[] ToArray() { SegmentedArrayBuilder.ScratchBuffer scratch = default; SegmentedArrayBuilder builder = new(scratch); @@ -35,7 +29,7 @@ public TResult[] ToArray() return result; } - public List ToList() + public override List ToList() { var list = new List(); @@ -48,7 +42,7 @@ public List ToList() return list; } - public int GetCount(bool onlyIfCheap) + public override int GetCount(bool onlyIfCheap) { // In case someone uses Count() to force evaluation of // the selector, run it provided `onlyIfCheap` is false. @@ -71,11 +65,73 @@ public int GetCount(bool onlyIfCheap) return count; } + + public override TResult? TryGetElementAt(int index, out bool found) + { + if (index >= 0) + { + IEnumerator e = _source.GetEnumerator(); + try + { + while (e.MoveNext()) + { + if (index == 0) + { + found = true; + return _selector(e.Current); + } + + index--; + } + } + finally + { + (e as IDisposable)?.Dispose(); + } + } + + found = false; + return default; + } + + public override TResult? TryGetFirst(out bool found) + { + using IEnumerator e = _source.GetEnumerator(); + if (e.MoveNext()) + { + found = true; + return _selector(e.Current); + } + + found = false; + return default; + } + + public override TResult? TryGetLast(out bool found) + { + using IEnumerator e = _source.GetEnumerator(); + + if (e.MoveNext()) + { + found = true; + TSource last = e.Current; + + while (e.MoveNext()) + { + last = e.Current; + } + + return _selector(last); + } + + found = false; + return default; + } } - private sealed partial class SelectArrayIterator : IPartition + private sealed partial class ArraySelectIterator { - public TResult[] ToArray() + public override TResult[] ToArray() { // See assert in constructor. // Since _source should never be empty, we don't check for 0/return Array.Empty. @@ -88,7 +144,7 @@ public TResult[] ToArray() return results; } - public List ToList() + public override List ToList() { TSource[] source = _source; Debug.Assert(source.Length > 0); @@ -107,7 +163,7 @@ private static void Fill(ReadOnlySpan source, Span destination } } - public int GetCount(bool onlyIfCheap) + public override int GetCount(bool onlyIfCheap) { // In case someone uses Count() to force evaluation of // the selector, run it provided `onlyIfCheap` is false. @@ -123,7 +179,7 @@ public int GetCount(bool onlyIfCheap) return _source.Length; } - public IPartition? Skip(int count) + public override Iterator? Skip(int count) { Debug.Assert(count > 0); if (count >= _source.Length) @@ -131,30 +187,31 @@ public int GetCount(bool onlyIfCheap) return null; } - return new SelectListPartitionIterator(_source, _selector, count, int.MaxValue); + return new IListSkipTakeSelectIterator(_source, _selector, count, int.MaxValue); } - public IPartition Take(int count) + public override Iterator Take(int count) { Debug.Assert(count > 0); return count >= _source.Length ? this : - new SelectListPartitionIterator(_source, _selector, 0, count - 1); + new IListSkipTakeSelectIterator(_source, _selector, 0, count - 1); } - public TResult? TryGetElementAt(int index, out bool found) + public override TResult? TryGetElementAt(int index, out bool found) { - if ((uint)index < (uint)_source.Length) + TSource[] source = _source; + if ((uint)index < (uint)source.Length) { found = true; - return _selector(_source[index]); + return _selector(source[index]); } found = false; return default; } - public TResult TryGetFirst(out bool found) + public override TResult TryGetFirst(out bool found) { Debug.Assert(_source.Length > 0); // See assert in constructor @@ -162,22 +219,22 @@ public TResult TryGetFirst(out bool found) return _selector(_source[0]); } - public TResult TryGetLast(out bool found) + public override TResult TryGetLast(out bool found) { Debug.Assert(_source.Length > 0); // See assert in constructor found = true; - return _selector(_source[_source.Length - 1]); + return _selector(_source[^1]); } } - private sealed partial class SelectRangeIterator : Iterator, IPartition + private sealed partial class RangeSelectIterator : Iterator { private readonly int _start; private readonly int _end; private readonly Func _selector; - public SelectRangeIterator(int start, int end, Func selector) + public RangeSelectIterator(int start, int end, Func selector) { Debug.Assert(start < end); Debug.Assert((uint)(end - start) <= (uint)int.MaxValue); @@ -189,7 +246,7 @@ public SelectRangeIterator(int start, int end, Func selector) } public override Iterator Clone() => - new SelectRangeIterator(_start, _end, _selector); + new RangeSelectIterator(_start, _end, _selector); public override bool MoveNext() { @@ -206,9 +263,9 @@ public override bool MoveNext() } public override IEnumerable Select(Func selector) => - new SelectRangeIterator(_start, _end, CombineSelectors(_selector, selector)); + new RangeSelectIterator(_start, _end, CombineSelectors(_selector, selector)); - public TResult[] ToArray() + public override TResult[] ToArray() { var results = new TResult[_end - _start]; Fill(results, _start, _selector); @@ -216,7 +273,7 @@ public TResult[] ToArray() return results; } - public List ToList() + public override List ToList() { var results = new List(_end - _start); Fill(SetCountAndGetSpan(results, _end - _start), _start, _selector); @@ -232,7 +289,7 @@ private static void Fill(Span results, int start, Func fu } } - public int GetCount(bool onlyIfCheap) + public override int GetCount(bool onlyIfCheap) { // In case someone uses Count() to force evaluation of the selector, // run it provided `onlyIfCheap` is false. @@ -247,7 +304,7 @@ public int GetCount(bool onlyIfCheap) return _end - _start; } - public IPartition? Skip(int count) + public override Iterator? Skip(int count) { Debug.Assert(count > 0); @@ -256,10 +313,10 @@ public int GetCount(bool onlyIfCheap) return null; } - return new SelectRangeIterator(_start + count, _end, _selector); + return new RangeSelectIterator(_start + count, _end, _selector); } - public IPartition Take(int count) + public override Iterator Take(int count) { Debug.Assert(count > 0); @@ -268,10 +325,10 @@ public IPartition Take(int count) return this; } - return new SelectRangeIterator(_start, _start + count, _selector); + return new RangeSelectIterator(_start, _start + count, _selector); } - public TResult? TryGetElementAt(int index, out bool found) + public override TResult? TryGetElementAt(int index, out bool found) { if ((uint)index < (uint)(_end - _start)) { @@ -283,14 +340,14 @@ public IPartition Take(int count) return default; } - public TResult TryGetFirst(out bool found) + public override TResult TryGetFirst(out bool found) { Debug.Assert(_end > _start); found = true; return _selector(_start); } - public TResult TryGetLast(out bool found) + public override TResult TryGetLast(out bool found) { Debug.Assert(_end > _start); found = true; @@ -298,9 +355,9 @@ public TResult TryGetLast(out bool found) } } - private sealed partial class SelectListIterator : IPartition + private sealed partial class ListSelectIterator { - public TResult[] ToArray() + public override TResult[] ToArray() { ReadOnlySpan source = CollectionsMarshal.AsSpan(_source); if (source.Length == 0) @@ -314,7 +371,7 @@ public TResult[] ToArray() return results; } - public List ToList() + public override List ToList() { ReadOnlySpan source = CollectionsMarshal.AsSpan(_source); @@ -332,7 +389,7 @@ private static void Fill(ReadOnlySpan source, Span destination } } - public int GetCount(bool onlyIfCheap) + public override int GetCount(bool onlyIfCheap) { // In case someone uses Count() to force evaluation of // the selector, run it provided `onlyIfCheap` is false. @@ -350,19 +407,19 @@ public int GetCount(bool onlyIfCheap) return count; } - public IPartition Skip(int count) + public override Iterator Skip(int count) { Debug.Assert(count > 0); - return new SelectListPartitionIterator(_source, _selector, count, int.MaxValue); + return new IListSkipTakeSelectIterator(_source, _selector, count, int.MaxValue); } - public IPartition Take(int count) + public override Iterator Take(int count) { Debug.Assert(count > 0); - return new SelectListPartitionIterator(_source, _selector, 0, count - 1); + return new IListSkipTakeSelectIterator(_source, _selector, 0, count - 1); } - public TResult? TryGetElementAt(int index, out bool found) + public override TResult? TryGetElementAt(int index, out bool found) { if ((uint)index < (uint)_source.Count) { @@ -374,7 +431,7 @@ public IPartition Take(int count) return default; } - public TResult? TryGetFirst(out bool found) + public override TResult? TryGetFirst(out bool found) { if (_source.Count != 0) { @@ -386,7 +443,7 @@ public IPartition Take(int count) return default; } - public TResult? TryGetLast(out bool found) + public override TResult? TryGetLast(out bool found) { int len = _source.Count; if (len != 0) @@ -400,9 +457,9 @@ public IPartition Take(int count) } } - private sealed partial class SelectIListIterator : IPartition + private sealed partial class IListSelectIterator { - public TResult[] ToArray() + public override TResult[] ToArray() { int count = _source.Count; if (count == 0) @@ -416,7 +473,7 @@ public TResult[] ToArray() return results; } - public List ToList() + public override List ToList() { IList source = _source; int count = _source.Count; @@ -435,7 +492,7 @@ private static void Fill(IList source, Span results, Func Skip(int count) + public override Iterator Skip(int count) { Debug.Assert(count > 0); - return new SelectListPartitionIterator(_source, _selector, count, int.MaxValue); + return new IListSkipTakeSelectIterator(_source, _selector, count, int.MaxValue); } - public IPartition Take(int count) + public override Iterator Take(int count) { Debug.Assert(count > 0); - return new SelectListPartitionIterator(_source, _selector, 0, count - 1); + return new IListSkipTakeSelectIterator(_source, _selector, 0, count - 1); } - public TResult? TryGetElementAt(int index, out bool found) + public override TResult? TryGetElementAt(int index, out bool found) { if ((uint)index < (uint)_source.Count) { @@ -477,7 +534,7 @@ public IPartition Take(int count) return default; } - public TResult? TryGetFirst(out bool found) + public override TResult? TryGetFirst(out bool found) { if (_source.Count != 0) { @@ -489,7 +546,7 @@ public IPartition Take(int count) return default; } - public TResult? TryGetLast(out bool found) + public override TResult? TryGetLast(out bool found) { int len = _source.Count; if (len != 0) @@ -504,17 +561,17 @@ public IPartition Take(int count) } /// - /// An iterator that maps each item of an . + /// An iterator that maps each item of an . /// - /// The type of the source partition. + /// The type of the source elements. /// The type of the mapped items. - private sealed class SelectIPartitionIterator : Iterator, IPartition + private sealed class IteratorSelectIterator : Iterator { - private readonly IPartition _source; + private readonly Iterator _source; private readonly Func _selector; private IEnumerator? _enumerator; - public SelectIPartitionIterator(IPartition source, Func selector) + public IteratorSelectIterator(Iterator source, Func selector) { Debug.Assert(source != null); Debug.Assert(selector != null); @@ -523,7 +580,7 @@ public SelectIPartitionIterator(IPartition source, Func Clone() => - new SelectIPartitionIterator(_source, _selector); + new IteratorSelectIterator(_source, _selector); public override bool MoveNext() { @@ -560,23 +617,23 @@ public override void Dispose() } public override IEnumerable Select(Func selector) => - new SelectIPartitionIterator(_source, CombineSelectors(_selector, selector)); + new IteratorSelectIterator(_source, CombineSelectors(_selector, selector)); - public IPartition? Skip(int count) + public override Iterator? Skip(int count) { Debug.Assert(count > 0); - IPartition? source = _source.Skip(count); - return source is null ? null : new SelectIPartitionIterator(source, _selector); + Iterator? source = _source.Skip(count); + return source is null ? null : new IteratorSelectIterator(source, _selector); } - public IPartition? Take(int count) + public override Iterator? Take(int count) { Debug.Assert(count > 0); - IPartition? source = _source.Take(count); - return source is null ? null : new SelectIPartitionIterator(source, _selector); + Iterator? source = _source.Take(count); + return source is null ? null : new IteratorSelectIterator(source, _selector); } - public TResult? TryGetElementAt(int index, out bool found) + public override TResult? TryGetElementAt(int index, out bool found) { bool sourceFound; TSource? input = _source.TryGetElementAt(index, out sourceFound); @@ -584,7 +641,7 @@ public override IEnumerable Select(Func s return sourceFound ? _selector(input!) : default!; } - public TResult? TryGetFirst(out bool found) + public override TResult? TryGetFirst(out bool found) { bool sourceFound; TSource? input = _source.TryGetFirst(out sourceFound); @@ -592,7 +649,7 @@ public override IEnumerable Select(Func s return sourceFound ? _selector(input!) : default!; } - public TResult? TryGetLast(out bool found) + public override TResult? TryGetLast(out bool found) { bool sourceFound; TSource? input = _source.TryGetLast(out sourceFound); @@ -629,7 +686,7 @@ private TResult[] PreallocatingToArray(int count) return array; } - public TResult[] ToArray() + public override TResult[] ToArray() { int count = _source.GetCount(onlyIfCheap: true); return count switch @@ -640,7 +697,7 @@ public TResult[] ToArray() }; } - public List ToList() + public override List ToList() { int count = _source.GetCount(onlyIfCheap: true); List list; @@ -665,7 +722,7 @@ public List ToList() return list; } - private static void Fill(IPartition source, Span results, Func func) + private static void Fill(Iterator source, Span results, Func func) { int index = 0; foreach (TSource item in source) @@ -677,7 +734,7 @@ private static void Fill(IPartition source, Span results, Func Debug.Assert(index == results.Length, "All list elements were not initialized."); } - public int GetCount(bool onlyIfCheap) + public override int GetCount(bool onlyIfCheap) { if (!onlyIfCheap) { @@ -705,14 +762,14 @@ public int GetCount(bool onlyIfCheap) /// The type of the source list. /// The type of the mapped items. [DebuggerDisplay("Count = {Count}")] - private sealed class SelectListPartitionIterator : Iterator, IPartition + private sealed class IListSkipTakeSelectIterator : Iterator { private readonly IList _source; private readonly Func _selector; private readonly int _minIndexInclusive; private readonly int _maxIndexInclusive; - public SelectListPartitionIterator(IList source, Func selector, int minIndexInclusive, int maxIndexInclusive) + public IListSkipTakeSelectIterator(IList source, Func selector, int minIndexInclusive, int maxIndexInclusive) { Debug.Assert(source != null); Debug.Assert(selector != null); @@ -725,7 +782,7 @@ public SelectListPartitionIterator(IList source, Func } public override Iterator Clone() => - new SelectListPartitionIterator(_source, _selector, _minIndexInclusive, _maxIndexInclusive); + new IListSkipTakeSelectIterator(_source, _selector, _minIndexInclusive, _maxIndexInclusive); public override bool MoveNext() { @@ -745,23 +802,23 @@ public override bool MoveNext() } public override IEnumerable Select(Func selector) => - new SelectListPartitionIterator(_source, CombineSelectors(_selector, selector), _minIndexInclusive, _maxIndexInclusive); + new IListSkipTakeSelectIterator(_source, CombineSelectors(_selector, selector), _minIndexInclusive, _maxIndexInclusive); - public IPartition? Skip(int count) + public override Iterator? Skip(int count) { Debug.Assert(count > 0); int minIndex = _minIndexInclusive + count; - return (uint)minIndex > (uint)_maxIndexInclusive ? null : new SelectListPartitionIterator(_source, _selector, minIndex, _maxIndexInclusive); + return (uint)minIndex > (uint)_maxIndexInclusive ? null : new IListSkipTakeSelectIterator(_source, _selector, minIndex, _maxIndexInclusive); } - public IPartition Take(int count) + public override Iterator Take(int count) { Debug.Assert(count > 0); int maxIndex = _minIndexInclusive + count - 1; - return (uint)maxIndex >= (uint)_maxIndexInclusive ? this : new SelectListPartitionIterator(_source, _selector, _minIndexInclusive, maxIndex); + return (uint)maxIndex >= (uint)_maxIndexInclusive ? this : new IListSkipTakeSelectIterator(_source, _selector, _minIndexInclusive, maxIndex); } - public TResult? TryGetElementAt(int index, out bool found) + public override TResult? TryGetElementAt(int index, out bool found) { if ((uint)index <= (uint)(_maxIndexInclusive - _minIndexInclusive) && index < _source.Count - _minIndexInclusive) { @@ -773,7 +830,7 @@ public IPartition Take(int count) return default; } - public TResult? TryGetFirst(out bool found) + public override TResult? TryGetFirst(out bool found) { if (_source.Count > _minIndexInclusive) { @@ -785,7 +842,7 @@ public IPartition Take(int count) return default; } - public TResult? TryGetLast(out bool found) + public override TResult? TryGetLast(out bool found) { int lastIndex = _source.Count - 1; if (lastIndex >= _minIndexInclusive) @@ -812,7 +869,7 @@ private int Count } } - public TResult[] ToArray() + public override TResult[] ToArray() { int count = Count; if (count == 0) @@ -826,7 +883,7 @@ public TResult[] ToArray() return array; } - public List ToList() + public override List ToList() { int count = Count; if (count == 0) @@ -848,7 +905,7 @@ private static void Fill(IList source, Span destination, Func< } } - public int GetCount(bool onlyIfCheap) + public override int GetCount(bool onlyIfCheap) { // In case someone uses Count() to force evaluation of // the selector, run it provided `onlyIfCheap` is false. diff --git a/src/libraries/System.Linq/src/System/Linq/Select.cs b/src/libraries/System.Linq/src/System/Linq/Select.cs index 3059fa7f0a8e02..916267c1ecaf05 100644 --- a/src/libraries/System.Linq/src/System/Linq/Select.cs +++ b/src/libraries/System.Linq/src/System/Linq/Select.cs @@ -37,35 +37,20 @@ public static IEnumerable Select( return []; } - return new SelectArrayIterator(array, selector); + return new ArraySelectIterator(array, selector); } if (source is List list) { - return new SelectListIterator(list, selector); + return new ListSelectIterator(list, selector); } - return new SelectIListIterator(ilist, selector); + return new IListSelectIterator(ilist, selector); } - if (source is IPartition partition) - { - IEnumerable? result = null; - CreateSelectIPartitionIterator(selector, partition, ref result); - if (result != null) - { - return result; - } - } - - return new SelectEnumerableIterator(source, selector); + return new IEnumerableSelectIterator(source, selector); } -#pragma warning disable IDE0060 // https://github.com/dotnet/roslyn-analyzers/issues/6177 - static partial void CreateSelectIPartitionIterator( - Func selector, IPartition partition, [NotNull] ref IEnumerable? result); -#pragma warning restore IDE0060 - public static IEnumerable Select(this IEnumerable source, Func selector) { if (source == null) @@ -105,13 +90,13 @@ private static IEnumerable SelectIterator(IEnumerable /// /// The type of the source enumerable. /// The type of the mapped items. - private sealed partial class SelectEnumerableIterator : Iterator + private sealed partial class IEnumerableSelectIterator : Iterator { private readonly IEnumerable _source; private readonly Func _selector; private IEnumerator? _enumerator; - public SelectEnumerableIterator(IEnumerable source, Func selector) + public IEnumerableSelectIterator(IEnumerable source, Func selector) { Debug.Assert(source != null); Debug.Assert(selector != null); @@ -120,7 +105,7 @@ public SelectEnumerableIterator(IEnumerable source, Func Clone() => - new SelectEnumerableIterator(_source, _selector); + new IEnumerableSelectIterator(_source, _selector); public override void Dispose() { @@ -157,7 +142,7 @@ public override bool MoveNext() } public override IEnumerable Select(Func selector) => - new SelectEnumerableIterator(_source, CombineSelectors(_selector, selector)); + new IEnumerableSelectIterator(_source, CombineSelectors(_selector, selector)); } /// @@ -166,12 +151,12 @@ public override IEnumerable Select(Func s /// The type of the source array. /// The type of the mapped items. [DebuggerDisplay("Count = {CountForDebugger}")] - private sealed partial class SelectArrayIterator : Iterator + private sealed partial class ArraySelectIterator : Iterator { private readonly TSource[] _source; private readonly Func _selector; - public SelectArrayIterator(TSource[] source, Func selector) + public ArraySelectIterator(TSource[] source, Func selector) { Debug.Assert(source != null); Debug.Assert(selector != null); @@ -182,7 +167,7 @@ public SelectArrayIterator(TSource[] source, Func selector) private int CountForDebugger => _source.Length; - public override Iterator Clone() => new SelectArrayIterator(_source, _selector); + public override Iterator Clone() => new ArraySelectIterator(_source, _selector); public override bool MoveNext() { @@ -200,7 +185,7 @@ public override bool MoveNext() } public override IEnumerable Select(Func selector) => - new SelectArrayIterator(_source, CombineSelectors(_selector, selector)); + new ArraySelectIterator(_source, CombineSelectors(_selector, selector)); } /// @@ -209,13 +194,13 @@ public override IEnumerable Select(Func s /// The type of the source list. /// The type of the mapped items. [DebuggerDisplay("Count = {CountForDebugger}")] - private sealed partial class SelectListIterator : Iterator + private sealed partial class ListSelectIterator : Iterator { private readonly List _source; private readonly Func _selector; private List.Enumerator _enumerator; - public SelectListIterator(List source, Func selector) + public ListSelectIterator(List source, Func selector) { Debug.Assert(source != null); Debug.Assert(selector != null); @@ -225,7 +210,7 @@ public SelectListIterator(List source, Func selector) private int CountForDebugger => _source.Count; - public override Iterator Clone() => new SelectListIterator(_source, _selector); + public override Iterator Clone() => new ListSelectIterator(_source, _selector); public override bool MoveNext() { @@ -250,7 +235,7 @@ public override bool MoveNext() } public override IEnumerable Select(Func selector) => - new SelectListIterator(_source, CombineSelectors(_selector, selector)); + new ListSelectIterator(_source, CombineSelectors(_selector, selector)); } /// @@ -259,13 +244,13 @@ public override IEnumerable Select(Func s /// The type of the source list. /// The type of the mapped items. [DebuggerDisplay("Count = {CountForDebugger}")] - private sealed partial class SelectIListIterator : Iterator + private sealed partial class IListSelectIterator : Iterator { private readonly IList _source; private readonly Func _selector; private IEnumerator? _enumerator; - public SelectIListIterator(IList source, Func selector) + public IListSelectIterator(IList source, Func selector) { Debug.Assert(source != null); Debug.Assert(selector != null); @@ -275,7 +260,7 @@ public SelectIListIterator(IList source, Func selecto private int CountForDebugger => _source.Count; - public override Iterator Clone() => new SelectIListIterator(_source, _selector); + public override Iterator Clone() => new IListSelectIterator(_source, _selector); public override bool MoveNext() { @@ -312,7 +297,7 @@ public override void Dispose() } public override IEnumerable Select(Func selector) => - new SelectIListIterator(_source, CombineSelectors(_selector, selector)); + new IListSelectIterator(_source, CombineSelectors(_selector, selector)); } } } diff --git a/src/libraries/System.Linq/src/System/Linq/SelectMany.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/SelectMany.SpeedOpt.cs index 050ae6a4e06b61..ae0bf35ef8f1af 100644 --- a/src/libraries/System.Linq/src/System/Linq/SelectMany.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/SelectMany.SpeedOpt.cs @@ -7,9 +7,9 @@ namespace System.Linq { public static partial class Enumerable { - private sealed partial class SelectManySingleSelectorIterator : IIListProvider + private sealed partial class SelectManySingleSelectorIterator { - public int GetCount(bool onlyIfCheap) + public override int GetCount(bool onlyIfCheap) { if (onlyIfCheap) { @@ -29,7 +29,7 @@ public int GetCount(bool onlyIfCheap) return count; } - public TResult[] ToArray() + public override TResult[] ToArray() { SegmentedArrayBuilder.ScratchBuffer scratch = default; SegmentedArrayBuilder builder = new(scratch); @@ -46,7 +46,7 @@ public TResult[] ToArray() return result; } - public List ToList() + public override List ToList() { var list = new List(); diff --git a/src/libraries/System.Linq/src/System/Linq/Skip.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Skip.SpeedOpt.cs index 1596dc0cc7cfe0..74ff73a068242c 100644 --- a/src/libraries/System.Linq/src/System/Linq/Skip.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Skip.SpeedOpt.cs @@ -9,7 +9,7 @@ public static partial class Enumerable { private static IEnumerable SkipIterator(IEnumerable source, int count) => source is IList sourceList ? - (IEnumerable)new ListPartition(sourceList, count, int.MaxValue) : - new EnumerablePartition(source, count, -1); + (IEnumerable)new IListSkipTakeIterator(sourceList, count, int.MaxValue) : + new IEnumerableSkipTakeIterator(source, count, -1); } } diff --git a/src/libraries/System.Linq/src/System/Linq/Skip.cs b/src/libraries/System.Linq/src/System/Linq/Skip.cs index 3652d1da3e79e8..582fb14f12a47f 100644 --- a/src/libraries/System.Linq/src/System/Linq/Skip.cs +++ b/src/libraries/System.Linq/src/System/Linq/Skip.cs @@ -23,17 +23,19 @@ public static IEnumerable Skip(this IEnumerable sourc { // Return source if not actually skipping, but only if it's a type from here, to avoid // issues if collections are used as keys or otherwise must not be aliased. - if (source is Iterator || source is IPartition) + if (source is Iterator) { return source; } count = 0; } - else if (source is IPartition partition) +#if !OPTIMIZE_FOR_SIZE + else if (source is Iterator iterator) { - return partition.Skip(count) ?? Empty(); + return iterator.Skip(count) ?? Empty(); } +#endif return SkipIterator(source, count); } diff --git a/src/libraries/System.Linq/src/System/Linq/Take.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Take.SpeedOpt.cs index f97c5295f75a78..f761d4a8ab07be 100644 --- a/src/libraries/System.Linq/src/System/Linq/Take.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Take.SpeedOpt.cs @@ -14,9 +14,9 @@ private static IEnumerable TakeIterator(IEnumerable s Debug.Assert(count > 0); return - source is IPartition partition ? (partition.Take(count) ?? Empty()) : - source is IList sourceList ? new ListPartition(sourceList, 0, count - 1) : - new EnumerablePartition(source, 0, count - 1); + source is Iterator iterator ? (iterator.Take(count) ?? Empty()) : + source is IList sourceList ? new IListSkipTakeIterator(sourceList, 0, count - 1) : + new IEnumerableSkipTakeIterator(source, 0, count - 1); } private static IEnumerable TakeRangeIterator(IEnumerable source, int startIndex, int endIndex) @@ -25,15 +25,15 @@ private static IEnumerable TakeRangeIterator(IEnumerable= 0 && startIndex < endIndex); return - source is IPartition partition ? TakePartitionRange(partition, startIndex, endIndex) : - source is IList sourceList ? new ListPartition(sourceList, startIndex, endIndex - 1) : - new EnumerablePartition(source, startIndex, endIndex - 1); + source is Iterator iterator ? TakeIteratorRange(iterator, startIndex, endIndex) : + source is IList sourceList ? new IListSkipTakeIterator(sourceList, startIndex, endIndex - 1) : + new IEnumerableSkipTakeIterator(source, startIndex, endIndex - 1); - static IEnumerable TakePartitionRange(IPartition partition, int startIndex, int endIndex) + static IEnumerable TakeIteratorRange(Iterator iterator, int startIndex, int endIndex) { - IPartition? source; + Iterator? source; if (endIndex != 0 && - (source = partition.Take(endIndex)) is not null && + (source = iterator.Take(endIndex)) is not null && (startIndex == 0 || (source = source!.Skip(startIndex)) is not null)) { return source; diff --git a/src/libraries/System.Linq/src/System/Linq/ToCollection.cs b/src/libraries/System.Linq/src/System/Linq/ToCollection.cs index 043cac8f0038b0..afbf1d4e158a9b 100644 --- a/src/libraries/System.Linq/src/System/Linq/ToCollection.cs +++ b/src/libraries/System.Linq/src/System/Linq/ToCollection.cs @@ -16,10 +16,12 @@ public static TSource[] ToArray(this IEnumerable source) ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); } - if (source is IIListProvider arrayProvider) +#if !OPTIMIZE_FOR_SIZE + if (source is Iterator iterator) { - return arrayProvider.ToArray(); + return iterator.ToArray(); } +#endif if (source is ICollection collection) { @@ -57,10 +59,12 @@ public static List ToList(this IEnumerable source) ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); } - if (source is IIListProvider listProvider) +#if !OPTIMIZE_FOR_SIZE + if (source is Iterator iterator) { - return listProvider.ToList(); + return iterator.ToList(); } +#endif return new List(source); } diff --git a/src/libraries/System.Linq/src/System/Linq/Union.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Union.SpeedOpt.cs index 1868e57e817390..0efa8248405b32 100644 --- a/src/libraries/System.Linq/src/System/Linq/Union.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Union.SpeedOpt.cs @@ -7,7 +7,7 @@ namespace System.Linq { public static partial class Enumerable { - private abstract partial class UnionIterator : IIListProvider + private abstract partial class UnionIterator { private HashSet FillSet() { @@ -24,11 +24,27 @@ private HashSet FillSet() } } - public TSource[] ToArray() => HashSetToArray(FillSet()); + public override TSource[] ToArray() => HashSetToArray(FillSet()); - public List ToList() => new List(FillSet()); + public override List ToList() => new List(FillSet()); - public int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : FillSet().Count; + public override int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : FillSet().Count; + + public override TSource? TryGetFirst(out bool found) + { + IEnumerable? source; + for (int i = 0; (source = GetEnumerable(i)) is not null; i++) + { + TSource? result = source.TryGetFirst(out found); + if (found) + { + return result; + } + } + + found = false; + return default; + } } } } diff --git a/src/libraries/System.Linq/src/System/Linq/Where.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Where.SpeedOpt.cs index 5bf0a718088386..d7adb1bd99170e 100644 --- a/src/libraries/System.Linq/src/System/Linq/Where.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Where.SpeedOpt.cs @@ -8,9 +8,9 @@ namespace System.Linq { public static partial class Enumerable { - private sealed partial class WhereEnumerableIterator : IPartition + private sealed partial class IEnumerableWhereIterator { - public int GetCount(bool onlyIfCheap) + public override int GetCount(bool onlyIfCheap) { if (onlyIfCheap) { @@ -33,7 +33,7 @@ public int GetCount(bool onlyIfCheap) return count; } - public TSource[] ToArray() + public override TSource[] ToArray() { SegmentedArrayBuilder.ScratchBuffer scratch = default; SegmentedArrayBuilder builder = new(scratch); @@ -53,7 +53,7 @@ public TSource[] ToArray() return result; } - public List ToList() + public override List ToList() { var list = new List(); @@ -69,7 +69,7 @@ public List ToList() return list; } - public TSource? TryGetFirst(out bool found) + public override TSource? TryGetFirst(out bool found) { Func predicate = _predicate; @@ -86,7 +86,7 @@ public List ToList() return default; } - public TSource? TryGetLast(out bool found) + public override TSource? TryGetLast(out bool found) { using IEnumerator e = _source.GetEnumerator(); @@ -121,7 +121,7 @@ public List ToList() return default; } - public TSource? TryGetElementAt(int index, out bool found) + public override TSource? TryGetElementAt(int index, out bool found) { if (index >= 0) { @@ -145,15 +145,11 @@ public List ToList() found = false; return default; } - - public IPartition? Skip(int count) => new EnumerablePartition(this, count, -1); - - public IPartition? Take(int count) => new EnumerablePartition(this, 0, count - 1); } - internal sealed partial class WhereArrayIterator : IPartition + internal sealed partial class ArrayWhereIterator { - public int GetCount(bool onlyIfCheap) => GetCount(onlyIfCheap, _source, _predicate); + public override int GetCount(bool onlyIfCheap) => GetCount(onlyIfCheap, _source, _predicate); public static int GetCount(bool onlyIfCheap, ReadOnlySpan source, Func predicate) { @@ -178,7 +174,7 @@ public static int GetCount(bool onlyIfCheap, ReadOnlySpan source, Func< return count; } - public TSource[] ToArray() => ToArray(_source, _predicate); + public override TSource[] ToArray() => ToArray(_source, _predicate); public static TSource[] ToArray(ReadOnlySpan source, Func predicate) { @@ -199,7 +195,7 @@ public static TSource[] ToArray(ReadOnlySpan source, Func ToList() => ToList(_source, _predicate); + public override List ToList() => ToList(_source, _predicate); public static List ToList(ReadOnlySpan source, Func predicate) { @@ -216,7 +212,7 @@ public static List ToList(ReadOnlySpan source, Func predicate = _predicate; @@ -233,7 +229,7 @@ public static List ToList(ReadOnlySpan source, Func predicate = _predicate; @@ -251,7 +247,7 @@ public static List ToList(ReadOnlySpan source, Func= 0) { @@ -275,21 +271,17 @@ public static List ToList(ReadOnlySpan source, Func? Skip(int count) => new EnumerablePartition(this, count, -1); - - public IPartition? Take(int count) => new EnumerablePartition(this, 0, count - 1); } - private sealed partial class WhereListIterator : Iterator, IPartition + private sealed partial class ListWhereIterator : Iterator { - public int GetCount(bool onlyIfCheap) => WhereArrayIterator.GetCount(onlyIfCheap, CollectionsMarshal.AsSpan(_source), _predicate); + public override int GetCount(bool onlyIfCheap) => ArrayWhereIterator.GetCount(onlyIfCheap, CollectionsMarshal.AsSpan(_source), _predicate); - public TSource[] ToArray() => WhereArrayIterator.ToArray(CollectionsMarshal.AsSpan(_source), _predicate); + public override TSource[] ToArray() => ArrayWhereIterator.ToArray(CollectionsMarshal.AsSpan(_source), _predicate); - public List ToList() => WhereArrayIterator.ToList(CollectionsMarshal.AsSpan(_source), _predicate); + public override List ToList() => ArrayWhereIterator.ToList(CollectionsMarshal.AsSpan(_source), _predicate); - public TSource? TryGetFirst(out bool found) + public override TSource? TryGetFirst(out bool found) { Func predicate = _predicate; @@ -306,7 +298,7 @@ private sealed partial class WhereListIterator : Iterator, IPa return default; } - public TSource? TryGetLast(out bool found) + public override TSource? TryGetLast(out bool found) { ReadOnlySpan source = CollectionsMarshal.AsSpan(_source); Func predicate = _predicate; @@ -324,7 +316,7 @@ private sealed partial class WhereListIterator : Iterator, IPa return default; } - public TSource? TryGetElementAt(int index, out bool found) + public override TSource? TryGetElementAt(int index, out bool found) { if (index >= 0) { @@ -348,15 +340,11 @@ private sealed partial class WhereListIterator : Iterator, IPa found = false; return default; } - - public IPartition? Skip(int count) => new EnumerablePartition(this, count, -1); - - public IPartition? Take(int count) => new EnumerablePartition(this, 0, count - 1); } - private sealed partial class WhereSelectArrayIterator : IPartition + private sealed partial class ArrayWhereSelectIterator { - public int GetCount(bool onlyIfCheap) => GetCount(onlyIfCheap, _source, _predicate, _selector); + public override int GetCount(bool onlyIfCheap) => GetCount(onlyIfCheap, _source, _predicate, _selector); public static int GetCount(bool onlyIfCheap, ReadOnlySpan source, Func predicate, Func selector) { @@ -385,7 +373,7 @@ public static int GetCount(bool onlyIfCheap, ReadOnlySpan source, Func< return count; } - public TResult[] ToArray() => ToArray(_source, _predicate, _selector); + public override TResult[] ToArray() => ToArray(_source, _predicate, _selector); public static TResult[] ToArray(ReadOnlySpan source, Func predicate, Func selector) { @@ -406,7 +394,7 @@ public static TResult[] ToArray(ReadOnlySpan source, Func ToList() => ToList(_source, _predicate, _selector); + public override List ToList() => ToList(_source, _predicate, _selector); public static List ToList(ReadOnlySpan source, Func predicate, Func selector) { @@ -423,7 +411,7 @@ public static List ToList(ReadOnlySpan source, Func TryGetFirst(_source, _predicate, _selector, out found); + public override TResult? TryGetFirst(out bool found) => TryGetFirst(_source, _predicate, _selector, out found); public static TResult? TryGetFirst(ReadOnlySpan source, Func predicate, Func selector, out bool found) { @@ -440,7 +428,7 @@ public static List ToList(ReadOnlySpan source, Func TryGetLast(_source, _predicate, _selector, out found); + public override TResult? TryGetLast(out bool found) => TryGetLast(_source, _predicate, _selector, out found); public static TResult? TryGetLast(ReadOnlySpan source, Func predicate, Func selector, out bool found) { @@ -457,7 +445,7 @@ public static List ToList(ReadOnlySpan source, Func TryGetElementAt(_source, _predicate, _selector, index, out found); + public override TResult? TryGetElementAt(int index, out bool found) => TryGetElementAt(_source, _predicate, _selector, index, out found); public static TResult? TryGetElementAt(ReadOnlySpan source, Func predicate, Func selector, int index, out bool found) { @@ -481,34 +469,26 @@ public static List ToList(ReadOnlySpan source, Func? Skip(int count) => new EnumerablePartition(this, count, -1); - - public IPartition? Take(int count) => new EnumerablePartition(this, 0, count - 1); } - private sealed partial class WhereSelectListIterator : IPartition + private sealed partial class ListWhereSelectIterator { - public int GetCount(bool onlyIfCheap) => WhereSelectArrayIterator.GetCount(onlyIfCheap, CollectionsMarshal.AsSpan(_source), _predicate, _selector); - - public TResult[] ToArray() => WhereSelectArrayIterator.ToArray(CollectionsMarshal.AsSpan(_source), _predicate, _selector); + public override int GetCount(bool onlyIfCheap) => ArrayWhereSelectIterator.GetCount(onlyIfCheap, CollectionsMarshal.AsSpan(_source), _predicate, _selector); - public List ToList() => WhereSelectArrayIterator.ToList(CollectionsMarshal.AsSpan(_source), _predicate, _selector); + public override TResult[] ToArray() => ArrayWhereSelectIterator.ToArray(CollectionsMarshal.AsSpan(_source), _predicate, _selector); - public TResult? TryGetElementAt(int index, out bool found) => WhereSelectArrayIterator.TryGetElementAt(CollectionsMarshal.AsSpan(_source), _predicate, _selector, index, out found); + public override List ToList() => ArrayWhereSelectIterator.ToList(CollectionsMarshal.AsSpan(_source), _predicate, _selector); - public TResult? TryGetFirst(out bool found) => WhereSelectArrayIterator.TryGetFirst(CollectionsMarshal.AsSpan(_source), _predicate, _selector, out found); + public override TResult? TryGetElementAt(int index, out bool found) => ArrayWhereSelectIterator.TryGetElementAt(CollectionsMarshal.AsSpan(_source), _predicate, _selector, index, out found); - public TResult? TryGetLast(out bool found) => WhereSelectArrayIterator.TryGetLast(CollectionsMarshal.AsSpan(_source), _predicate, _selector, out found); + public override TResult? TryGetFirst(out bool found) => ArrayWhereSelectIterator.TryGetFirst(CollectionsMarshal.AsSpan(_source), _predicate, _selector, out found); - public IPartition? Skip(int count) => new EnumerablePartition(this, count, -1); - - public IPartition? Take(int count) => new EnumerablePartition(this, 0, count - 1); + public override TResult? TryGetLast(out bool found) => ArrayWhereSelectIterator.TryGetLast(CollectionsMarshal.AsSpan(_source), _predicate, _selector, out found); } - private sealed partial class WhereSelectEnumerableIterator : IPartition + private sealed partial class IEnumerableWhereSelectIterator { - public int GetCount(bool onlyIfCheap) + public override int GetCount(bool onlyIfCheap) { // In case someone uses Count() to force evaluation of // the selector, run it provided `onlyIfCheap` is false. @@ -535,7 +515,7 @@ public int GetCount(bool onlyIfCheap) return count; } - public TResult[] ToArray() + public override TResult[] ToArray() { SegmentedArrayBuilder.ScratchBuffer scratch = default; SegmentedArrayBuilder builder = new(scratch); @@ -556,7 +536,7 @@ public TResult[] ToArray() return result; } - public List ToList() + public override List ToList() { var list = new List(); @@ -573,7 +553,7 @@ public List ToList() return list; } - public TResult? TryGetFirst(out bool found) + public override TResult? TryGetFirst(out bool found) { Func predicate = _predicate; @@ -590,7 +570,7 @@ public List ToList() return default; } - public TResult? TryGetLast(out bool found) + public override TResult? TryGetLast(out bool found) { using IEnumerator e = _source.GetEnumerator(); @@ -625,7 +605,7 @@ public List ToList() return default; } - public TResult? TryGetElementAt(int index, out bool found) + public override TResult? TryGetElementAt(int index, out bool found) { if (index >= 0) { @@ -649,10 +629,6 @@ public List ToList() found = false; return default; } - - public IPartition? Skip(int count) => new EnumerablePartition(this, count, -1); - - public IPartition? Take(int count) => new EnumerablePartition(this, 0, count - 1); } } } diff --git a/src/libraries/System.Linq/src/System/Linq/Where.cs b/src/libraries/System.Linq/src/System/Linq/Where.cs index aec6370a330f8e..a3a6656c43bafc 100644 --- a/src/libraries/System.Linq/src/System/Linq/Where.cs +++ b/src/libraries/System.Linq/src/System/Linq/Where.cs @@ -33,15 +33,15 @@ public static IEnumerable Where(this IEnumerable sour return []; } - return new WhereArrayIterator(array, predicate); + return new ArrayWhereIterator(array, predicate); } if (source is List list) { - return new WhereListIterator(list, predicate); + return new ListWhereIterator(list, predicate); } - return new WhereEnumerableIterator(source, predicate); + return new IEnumerableWhereIterator(source, predicate); } public static IEnumerable Where(this IEnumerable source, Func predicate) @@ -85,13 +85,13 @@ private static IEnumerable WhereIterator(IEnumerable /// An iterator that filters each item of an . /// /// The type of the source enumerable. - private sealed partial class WhereEnumerableIterator : Iterator + private sealed partial class IEnumerableWhereIterator : Iterator { private readonly IEnumerable _source; private readonly Func _predicate; private IEnumerator? _enumerator; - public WhereEnumerableIterator(IEnumerable source, Func predicate) + public IEnumerableWhereIterator(IEnumerable source, Func predicate) { Debug.Assert(source != null); Debug.Assert(predicate != null); @@ -99,7 +99,7 @@ public WhereEnumerableIterator(IEnumerable source, Func _predicate = predicate; } - public override Iterator Clone() => new WhereEnumerableIterator(_source, _predicate); + public override Iterator Clone() => new IEnumerableWhereIterator(_source, _predicate); public override void Dispose() { @@ -140,22 +140,22 @@ public override bool MoveNext() } public override IEnumerable Select(Func selector) => - new WhereSelectEnumerableIterator(_source, _predicate, selector); + new IEnumerableWhereSelectIterator(_source, _predicate, selector); public override IEnumerable Where(Func predicate) => - new WhereEnumerableIterator(_source, CombinePredicates(_predicate, predicate)); + new IEnumerableWhereIterator(_source, CombinePredicates(_predicate, predicate)); } /// /// An iterator that filters each item of an array. /// /// The type of the source array. - internal sealed partial class WhereArrayIterator : Iterator + internal sealed partial class ArrayWhereIterator : Iterator { private readonly TSource[] _source; private readonly Func _predicate; - public WhereArrayIterator(TSource[] source, Func predicate) + public ArrayWhereIterator(TSource[] source, Func predicate) { Debug.Assert(source != null && source.Length > 0); Debug.Assert(predicate != null); @@ -164,7 +164,7 @@ public WhereArrayIterator(TSource[] source, Func predicate) } public override Iterator Clone() => - new WhereArrayIterator(_source, _predicate); + new ArrayWhereIterator(_source, _predicate); public override bool MoveNext() { @@ -187,23 +187,23 @@ public override bool MoveNext() } public override IEnumerable Select(Func selector) => - new WhereSelectArrayIterator(_source, _predicate, selector); + new ArrayWhereSelectIterator(_source, _predicate, selector); public override IEnumerable Where(Func predicate) => - new WhereArrayIterator(_source, CombinePredicates(_predicate, predicate)); + new ArrayWhereIterator(_source, CombinePredicates(_predicate, predicate)); } /// /// An iterator that filters each item of a . /// /// The type of the source list. - private sealed partial class WhereListIterator : Iterator + private sealed partial class ListWhereIterator : Iterator { private readonly List _source; private readonly Func _predicate; private List.Enumerator _enumerator; - public WhereListIterator(List source, Func predicate) + public ListWhereIterator(List source, Func predicate) { Debug.Assert(source != null); Debug.Assert(predicate != null); @@ -212,7 +212,7 @@ public WhereListIterator(List source, Func predicate) } public override Iterator Clone() => - new WhereListIterator(_source, _predicate); + new ListWhereIterator(_source, _predicate); public override bool MoveNext() { @@ -241,10 +241,10 @@ public override bool MoveNext() } public override IEnumerable Select(Func selector) => - new WhereSelectListIterator(_source, _predicate, selector); + new ListWhereSelectIterator(_source, _predicate, selector); public override IEnumerable Where(Func predicate) => - new WhereListIterator(_source, CombinePredicates(_predicate, predicate)); + new ListWhereIterator(_source, CombinePredicates(_predicate, predicate)); } /// @@ -252,13 +252,13 @@ public override IEnumerable Where(Func predicate) => /// /// The type of the source array. /// The type of the mapped items. - private sealed partial class WhereSelectArrayIterator : Iterator + private sealed partial class ArrayWhereSelectIterator : Iterator { private readonly TSource[] _source; private readonly Func _predicate; private readonly Func _selector; - public WhereSelectArrayIterator(TSource[] source, Func predicate, Func selector) + public ArrayWhereSelectIterator(TSource[] source, Func predicate, Func selector) { Debug.Assert(source != null && source.Length > 0); Debug.Assert(predicate != null); @@ -269,7 +269,7 @@ public WhereSelectArrayIterator(TSource[] source, Func predicate, } public override Iterator Clone() => - new WhereSelectArrayIterator(_source, _predicate, _selector); + new ArrayWhereSelectIterator(_source, _predicate, _selector); public override bool MoveNext() { @@ -292,7 +292,7 @@ public override bool MoveNext() } public override IEnumerable Select(Func selector) => - new WhereSelectArrayIterator(_source, _predicate, CombineSelectors(_selector, selector)); + new ArrayWhereSelectIterator(_source, _predicate, CombineSelectors(_selector, selector)); } /// @@ -300,14 +300,14 @@ public override IEnumerable Select(Func s /// /// The type of the source list. /// The type of the mapped items. - private sealed partial class WhereSelectListIterator : Iterator + private sealed partial class ListWhereSelectIterator : Iterator { private readonly List _source; private readonly Func _predicate; private readonly Func _selector; private List.Enumerator _enumerator; - public WhereSelectListIterator(List source, Func predicate, Func selector) + public ListWhereSelectIterator(List source, Func predicate, Func selector) { Debug.Assert(source != null); Debug.Assert(predicate != null); @@ -318,7 +318,7 @@ public WhereSelectListIterator(List source, Func predica } public override Iterator Clone() => - new WhereSelectListIterator(_source, _predicate, _selector); + new ListWhereSelectIterator(_source, _predicate, _selector); public override bool MoveNext() { @@ -347,7 +347,7 @@ public override bool MoveNext() } public override IEnumerable Select(Func selector) => - new WhereSelectListIterator(_source, _predicate, CombineSelectors(_selector, selector)); + new ListWhereSelectIterator(_source, _predicate, CombineSelectors(_selector, selector)); } /// @@ -355,14 +355,14 @@ public override IEnumerable Select(Func s /// /// The type of the source enumerable. /// The type of the mapped items. - private sealed partial class WhereSelectEnumerableIterator : Iterator + private sealed partial class IEnumerableWhereSelectIterator : Iterator { private readonly IEnumerable _source; private readonly Func _predicate; private readonly Func _selector; private IEnumerator? _enumerator; - public WhereSelectEnumerableIterator(IEnumerable source, Func predicate, Func selector) + public IEnumerableWhereSelectIterator(IEnumerable source, Func predicate, Func selector) { Debug.Assert(source != null); Debug.Assert(predicate != null); @@ -373,7 +373,7 @@ public WhereSelectEnumerableIterator(IEnumerable source, Func Clone() => - new WhereSelectEnumerableIterator(_source, _predicate, _selector); + new IEnumerableWhereSelectIterator(_source, _predicate, _selector); public override void Dispose() { @@ -414,7 +414,7 @@ public override bool MoveNext() } public override IEnumerable Select(Func selector) => - new WhereSelectEnumerableIterator(_source, _predicate, CombineSelectors(_selector, selector)); + new IEnumerableWhereSelectIterator(_source, _predicate, CombineSelectors(_selector, selector)); } } } diff --git a/src/libraries/System.Linq/tests/AppendPrependTests.cs b/src/libraries/System.Linq/tests/AppendPrependTests.cs index 9df154f46ff47a..a42f1297bf0930 100644 --- a/src/libraries/System.Linq/tests/AppendPrependTests.cs +++ b/src/libraries/System.Linq/tests/AppendPrependTests.cs @@ -263,5 +263,27 @@ public void AppendPrependRunOnce() source = NumberRangeGuaranteedNotCollectionType(2, 2).Prepend(1).Prepend(0).Append(4).Append(5).RunOnce(); Assert.Equal(Enumerable.Range(0, 6), source.ToList()); } + + [Fact] + public void AppendPrepend_First_Last_ElementAt() + { + Assert.Equal(42, new int[] { 42 }.Append(84).First()); + Assert.Equal(42, new int[] { 84 }.Prepend(42).First()); + Assert.Equal(84, new int[] { 42 }.Append(84).Last()); + Assert.Equal(84, new int[] { 84 }.Prepend(42).Last()); + Assert.Equal(42, new int[] { 42 }.Append(84).ElementAt(0)); + Assert.Equal(42, new int[] { 84 }.Prepend(42).ElementAt(0)); + Assert.Equal(84, new int[] { 42 }.Append(84).ElementAt(1)); + Assert.Equal(84, new int[] { 84 }.Prepend(42).ElementAt(1)); + + Assert.Equal(42, NumberRangeGuaranteedNotCollectionType(42, 1).Append(84).First()); + Assert.Equal(42, NumberRangeGuaranteedNotCollectionType(84, 1).Prepend(42).First()); + Assert.Equal(84, NumberRangeGuaranteedNotCollectionType(42, 1).Append(84).Last()); + Assert.Equal(84, NumberRangeGuaranteedNotCollectionType(84, 1).Prepend(42).Last()); + Assert.Equal(42, NumberRangeGuaranteedNotCollectionType(42, 1).Append(84).ElementAt(0)); + Assert.Equal(42, NumberRangeGuaranteedNotCollectionType(84, 1).Prepend(42).ElementAt(0)); + Assert.Equal(84, NumberRangeGuaranteedNotCollectionType(42, 1).Append(84).ElementAt(1)); + Assert.Equal(84, NumberRangeGuaranteedNotCollectionType(84, 1).Prepend(42).ElementAt(1)); + } } } diff --git a/src/libraries/System.Linq/tests/ElementAtTests.cs b/src/libraries/System.Linq/tests/ElementAtTests.cs index 8644a7df0d1e9b..303e44d1411d5a 100644 --- a/src/libraries/System.Linq/tests/ElementAtTests.cs +++ b/src/libraries/System.Linq/tests/ElementAtTests.cs @@ -229,14 +229,14 @@ public void NonEmptySource_Consistency_ThrowsIListIndexerException() Assert.Throws("index", () => source.ElementAt(^11)); // ImmutableArray implements IList. ElementAt calls ImmutableArray's indexer, which throws IndexOutOfRangeException instead of ArgumentOutOfRangeException. Assert.Throws(() => ImmutableArray.Create(source).ElementAt(-1)); - Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^11)); + Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^11)); Assert.Throws("index", () => source.ElementAt(10)); Assert.Throws("index", () => source.ElementAt(new Index(10))); Assert.Throws("index", () => source.ElementAt(^0)); Assert.Throws(() => ImmutableArray.Create(source).ElementAt(10)); - Assert.Throws(() => ImmutableArray.Create(source).ElementAt(new Index(10))); - Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^0)); + Assert.Throws(() => ImmutableArray.Create(source).ElementAt(new Index(10))); + Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^0)); } [Fact] @@ -425,17 +425,17 @@ public void EmptySource_Consistency_ThrowsIListIndexerException() Assert.Throws("index", () => source.ElementAt(^1)); // ImmutableArray implements IList. ElementAt calls ImmutableArray's indexer, which throws IndexOutOfRangeException instead of ArgumentOutOfRangeException. Assert.Throws(() => ImmutableArray.Create(source).ElementAt(-1)); - Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^1)); + Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^1)); Assert.Throws("index", () => source.ElementAt(0)); Assert.Throws("index", () => source.ElementAt(^0)); Assert.Throws(() => ImmutableArray.Create(source).ElementAt(0)); - Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^0)); + Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^0)); Assert.Throws("index", () => source.ElementAt(1)); Assert.Throws("index", () => source.ElementAt(new Index(1))); Assert.Throws(() => ImmutableArray.Create(source).ElementAt(1)); - Assert.Throws(() => ImmutableArray.Create(source).ElementAt(new Index(1))); + Assert.Throws(() => ImmutableArray.Create(source).ElementAt(new Index(1))); } [Fact] From 2fc74f1a5d7d5dde8ba17361e8ae852a145a8075 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 26 Feb 2024 14:23:42 -0500 Subject: [PATCH 4/7] Rename Partition.SpeedOpt.cs to SkipTake.SpeedOpt.cs --- src/libraries/System.Linq/src/System.Linq.csproj | 2 +- .../System/Linq/{Partition.SpeedOpt.cs => SkipTake.SpeedOpt.cs} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/libraries/System.Linq/src/System/Linq/{Partition.SpeedOpt.cs => SkipTake.SpeedOpt.cs} (100%) diff --git a/src/libraries/System.Linq/src/System.Linq.csproj b/src/libraries/System.Linq/src/System.Linq.csproj index c9bcf4ac0470a8..7c656e5b50d097 100644 --- a/src/libraries/System.Linq/src/System.Linq.csproj +++ b/src/libraries/System.Linq/src/System.Linq.csproj @@ -27,13 +27,13 @@ - + diff --git a/src/libraries/System.Linq/src/System/Linq/Partition.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/SkipTake.SpeedOpt.cs similarity index 100% rename from src/libraries/System.Linq/src/System/Linq/Partition.SpeedOpt.cs rename to src/libraries/System.Linq/src/System/Linq/SkipTake.SpeedOpt.cs From 7a2baa6050876e44688762932f3317945989be76 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 27 Feb 2024 09:36:09 -0500 Subject: [PATCH 5/7] Address PR feedback, and revert ElementAt bounds check on IList --- .../System/Linq/DefaultIfEmpty.SpeedOpt.cs | 6 ++-- .../System.Linq/src/System/Linq/ElementAt.cs | 35 ++++++++++--------- .../src/System/Linq/Grouping.SpeedOpt.cs | 8 ++--- .../System.Linq/src/System/Linq/Grouping.cs | 8 ++--- .../System.Linq/src/System/Linq/Iterator.cs | 7 ++-- .../System.Linq/tests/DefaultIfEmptyTests.cs | 19 ++++++++++ .../System.Linq/tests/ElementAtTests.cs | 12 +++---- 7 files changed, 58 insertions(+), 37 deletions(-) diff --git a/src/libraries/System.Linq/src/System/Linq/DefaultIfEmpty.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/DefaultIfEmpty.SpeedOpt.cs index e147870eb0d9f7..c89d6797581e9b 100644 --- a/src/libraries/System.Linq/src/System/Linq/DefaultIfEmpty.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/DefaultIfEmpty.SpeedOpt.cs @@ -51,7 +51,7 @@ public override int GetCount(bool onlyIfCheap) } found = true; - return default; + return _default; } public override TSource? TryGetLast(out bool found) @@ -63,7 +63,7 @@ public override int GetCount(bool onlyIfCheap) } found = true; - return default; + return _default; } public override TSource? TryGetElementAt(int index, out bool found) @@ -79,7 +79,7 @@ public override int GetCount(bool onlyIfCheap) found = true; } - return default; + return _default; } } } diff --git a/src/libraries/System.Linq/src/System/Linq/ElementAt.cs b/src/libraries/System.Linq/src/System/Linq/ElementAt.cs index f2bec12be649aa..0466f7d9694bc5 100644 --- a/src/libraries/System.Linq/src/System/Linq/ElementAt.cs +++ b/src/libraries/System.Linq/src/System/Linq/ElementAt.cs @@ -16,7 +16,7 @@ public static TSource ElementAt(this IEnumerable source, int i ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); } - TSource? element = TryGetElementAt(source, index, out bool found); + TSource? element = TryGetElementAt(source, index, out bool found, guardIListLength: false); if (!found) { ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.index); @@ -102,39 +102,42 @@ public static TSource ElementAt(this IEnumerable source, Index return element; } - private static TSource? TryGetElementAt(this IEnumerable source, int index, out bool found) => + private static TSource? TryGetElementAt(this IEnumerable source, int index, out bool found, bool guardIListLength = true) => #if !OPTIMIZE_FOR_SIZE source is Iterator iterator ? iterator.TryGetElementAt(index, out found) : #endif - TryGetElementAtNonIterator(source, index, out found); + TryGetElementAtNonIterator(source, index, out found, guardIListLength); - private static TSource? TryGetElementAtNonIterator(IEnumerable source, int index, out bool found) + private static TSource? TryGetElementAtNonIterator(IEnumerable source, int index, out bool found, bool guardIListLength = true) { Debug.Assert(source != null); if (source is IList list) { - if ((uint)index < (uint)list.Count) + // Historically, ElementAt would simply delegate to IList[int] without first checking the bounds. + // That in turn meant that whatever exception the IList[int] throws for out-of-bounds access would + // propagate, e.g. ImmutableArray throws IndexOutOfRangeException whereas List throws ArgumentOutOfRangeException. + // Other uses of this, though, do need to guard, such as ElementAtOrDefault and all the various + // internal TryGetElementAt helpers. So, we have a guardIListLength parameter to allow the caller + // to specify whether to guard or not. + if (!guardIListLength || (uint)index < (uint)list.Count) { found = true; return list[index]; } } - else + else if (index >= 0) { - if (index >= 0) + using IEnumerator e = source.GetEnumerator(); + while (e.MoveNext()) { - using IEnumerator e = source.GetEnumerator(); - while (e.MoveNext()) + if (index == 0) { - if (index == 0) - { - found = true; - return e.Current; - } - - index--; + found = true; + return e.Current; } + + index--; } } diff --git a/src/libraries/System.Linq/src/System/Linq/Grouping.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Grouping.SpeedOpt.cs index 1153dc7ff4977a..d081a09380f956 100644 --- a/src/libraries/System.Linq/src/System/Linq/Grouping.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Grouping.SpeedOpt.cs @@ -7,7 +7,7 @@ namespace System.Linq { public static partial class Enumerable { - internal sealed partial class GroupByResultIterator + private sealed partial class GroupByResultIterator { public override TResult[] ToArray() => Lookup.Create(_source, _keySelector, _elementSelector, _comparer).ToArray(_resultSelector); @@ -19,7 +19,7 @@ public override int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : Lookup.Create(_source, _keySelector, _elementSelector, _comparer).Count; } - internal sealed partial class GroupByResultIterator + private sealed partial class GroupByResultIterator { public override TResult[] ToArray() => Lookup.Create(_source, _keySelector, _comparer).ToArray(_resultSelector); @@ -31,7 +31,7 @@ public override int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : Lookup.Create(_source, _keySelector, _comparer).Count; } - internal sealed partial class GroupByIterator + private sealed partial class GroupByIterator { public override IGrouping[] ToArray() => Lookup.Create(_source, _keySelector, _elementSelector, _comparer).ToArray(); @@ -43,7 +43,7 @@ public override int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : Lookup.Create(_source, _keySelector, _elementSelector, _comparer).Count; } - internal sealed partial class GroupByIterator + private sealed partial class GroupByIterator { public override IGrouping[] ToArray() => Lookup.Create(_source, _keySelector, _comparer).ToArray(); diff --git a/src/libraries/System.Linq/src/System/Linq/Grouping.cs b/src/libraries/System.Linq/src/System/Linq/Grouping.cs index c9b55b85cfad89..b81b7909d4b844 100644 --- a/src/libraries/System.Linq/src/System/Linq/Grouping.cs +++ b/src/libraries/System.Linq/src/System/Linq/Grouping.cs @@ -121,7 +121,7 @@ public static IEnumerable GroupBy(thi return new GroupByResultIterator(source, keySelector, elementSelector, resultSelector, comparer); } - internal sealed partial class GroupByResultIterator : Iterator + private sealed partial class GroupByResultIterator : Iterator { private readonly IEnumerable _source; private readonly Func _keySelector; @@ -179,7 +179,7 @@ public override bool MoveNext() } } - internal sealed partial class GroupByResultIterator : Iterator + private sealed partial class GroupByResultIterator : Iterator { private readonly IEnumerable _source; private readonly Func _keySelector; @@ -235,7 +235,7 @@ public override bool MoveNext() } } - internal sealed partial class GroupByIterator : Iterator> + private sealed partial class GroupByIterator : Iterator> { private readonly IEnumerable _source; private readonly Func _keySelector; @@ -290,7 +290,7 @@ public override bool MoveNext() } } - internal sealed partial class GroupByIterator : Iterator> + private sealed partial class GroupByIterator : Iterator> { private readonly IEnumerable _source; private readonly Func _keySelector; diff --git a/src/libraries/System.Linq/src/System/Linq/Iterator.cs b/src/libraries/System.Linq/src/System/Linq/Iterator.cs index 933868212428bf..4738d2343069eb 100644 --- a/src/libraries/System.Linq/src/System/Linq/Iterator.cs +++ b/src/libraries/System.Linq/src/System/Linq/Iterator.cs @@ -88,13 +88,12 @@ public IEnumerator GetEnumerator() /// The type of the mapped items. /// The selector used to map each item. public virtual IEnumerable Select(Func selector) => - new #if OPTIMIZE_FOR_SIZE - IEnumerableSelectIterator + new IEnumerableSelectIterator(this, selector); #else - IteratorSelectIterator + new IteratorSelectIterator(this, selector); #endif - (this, selector); + /// /// Returns an enumerable that filters each item in this iterator based on a predicate. diff --git a/src/libraries/System.Linq/tests/DefaultIfEmptyTests.cs b/src/libraries/System.Linq/tests/DefaultIfEmptyTests.cs index b4e9e13f804240..843d2a7d9e425f 100644 --- a/src/libraries/System.Linq/tests/DefaultIfEmptyTests.cs +++ b/src/libraries/System.Linq/tests/DefaultIfEmptyTests.cs @@ -105,5 +105,24 @@ public void ForcedToEnumeratorDoesntEnumerate() var en = iterator as IEnumerator; Assert.False(en != null && en.MoveNext()); } + + [Fact] + public void First_Last_ElementAt() + { + IEnumerable nonEmpty = Enumerable.Range(1, 3); + Assert.Equal(1, nonEmpty.First()); + Assert.Equal(3, nonEmpty.Last()); + Assert.Equal(1, nonEmpty.ElementAt(0)); + Assert.Equal(2, nonEmpty.ElementAt(1)); + Assert.Equal(3, nonEmpty.ElementAt(2)); + Assert.Throws(() => nonEmpty.ElementAt(-1)); + Assert.Throws(() => nonEmpty.ElementAt(4)); + + IEnumerable empty = Enumerable.Empty(); + Assert.Equal(42, empty.DefaultIfEmpty(42).First()); + Assert.Equal(42, empty.DefaultIfEmpty(42).Last()); + Assert.Equal(42, empty.DefaultIfEmpty(42).ElementAt(0)); + Assert.Throws(() => empty.DefaultIfEmpty(42).ElementAt(1)); + } } } diff --git a/src/libraries/System.Linq/tests/ElementAtTests.cs b/src/libraries/System.Linq/tests/ElementAtTests.cs index 303e44d1411d5a..8644a7df0d1e9b 100644 --- a/src/libraries/System.Linq/tests/ElementAtTests.cs +++ b/src/libraries/System.Linq/tests/ElementAtTests.cs @@ -229,14 +229,14 @@ public void NonEmptySource_Consistency_ThrowsIListIndexerException() Assert.Throws("index", () => source.ElementAt(^11)); // ImmutableArray implements IList. ElementAt calls ImmutableArray's indexer, which throws IndexOutOfRangeException instead of ArgumentOutOfRangeException. Assert.Throws(() => ImmutableArray.Create(source).ElementAt(-1)); - Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^11)); + Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^11)); Assert.Throws("index", () => source.ElementAt(10)); Assert.Throws("index", () => source.ElementAt(new Index(10))); Assert.Throws("index", () => source.ElementAt(^0)); Assert.Throws(() => ImmutableArray.Create(source).ElementAt(10)); - Assert.Throws(() => ImmutableArray.Create(source).ElementAt(new Index(10))); - Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^0)); + Assert.Throws(() => ImmutableArray.Create(source).ElementAt(new Index(10))); + Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^0)); } [Fact] @@ -425,17 +425,17 @@ public void EmptySource_Consistency_ThrowsIListIndexerException() Assert.Throws("index", () => source.ElementAt(^1)); // ImmutableArray implements IList. ElementAt calls ImmutableArray's indexer, which throws IndexOutOfRangeException instead of ArgumentOutOfRangeException. Assert.Throws(() => ImmutableArray.Create(source).ElementAt(-1)); - Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^1)); + Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^1)); Assert.Throws("index", () => source.ElementAt(0)); Assert.Throws("index", () => source.ElementAt(^0)); Assert.Throws(() => ImmutableArray.Create(source).ElementAt(0)); - Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^0)); + Assert.Throws(() => ImmutableArray.Create(source).ElementAt(^0)); Assert.Throws("index", () => source.ElementAt(1)); Assert.Throws("index", () => source.ElementAt(new Index(1))); Assert.Throws(() => ImmutableArray.Create(source).ElementAt(1)); - Assert.Throws(() => ImmutableArray.Create(source).ElementAt(new Index(1))); + Assert.Throws(() => ImmutableArray.Create(source).ElementAt(new Index(1))); } [Fact] From 1b25bee5f6e688e2b61fc8060c2ee2cd058a6ade Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 27 Feb 2024 14:02:42 -0500 Subject: [PATCH 6/7] Fill some code coverage test holes --- .../System.Linq/tests/AggregateByTests.cs | 22 +++++++-- src/libraries/System.Linq/tests/ChunkTests.cs | 6 +++ .../System.Linq/tests/ConcatTests.cs | 49 +++++++++++++++++++ .../System.Linq/tests/DistinctTests.cs | 6 +++ .../System.Linq/tests/GroupByTests.cs | 25 +++++++++- src/libraries/System.Linq/tests/IndexTests.cs | 6 +++ src/libraries/System.Linq/tests/MaxTests.cs | 9 ++++ src/libraries/System.Linq/tests/MinTests.cs | 13 +++++ src/libraries/System.Linq/tests/OrderTests.cs | 16 ++++++ src/libraries/System.Linq/tests/RangeTests.cs | 3 ++ .../System.Linq/tests/RepeatTests.cs | 3 ++ .../System.Linq/tests/SequenceEqualTests.cs | 29 ++++++++++- .../System.Linq/tests/SkipLastTests.cs | 6 +++ .../System.Linq/tests/SkipWhileTests.cs | 9 +++- .../System.Linq/tests/TakeLastTests.cs | 6 +++ src/libraries/System.Linq/tests/TakeTests.cs | 33 +++++++++++++ .../System.Linq/tests/TakeWhileTests.cs | 8 ++- .../System.Linq/tests/ToLookupTests.cs | 33 +++++++++++-- src/libraries/System.Linq/tests/WhereTests.cs | 44 +++++++++++++++++ 19 files changed, 314 insertions(+), 12 deletions(-) diff --git a/src/libraries/System.Linq/tests/AggregateByTests.cs b/src/libraries/System.Linq/tests/AggregateByTests.cs index daae145f775508..6232ce24a6df49 100644 --- a/src/libraries/System.Linq/tests/AggregateByTests.cs +++ b/src/libraries/System.Linq/tests/AggregateByTests.cs @@ -8,6 +8,16 @@ namespace System.Linq.Tests { public class AggregateByTests : EnumerableTests { + [Fact] + public void Empty() + { + Assert.All(IdentityTransforms(), transform => + { + Assert.Equal(Enumerable.Empty>(), transform(Enumerable.Empty()).AggregateBy(i => i, i => i, (a, i) => a + i)); + Assert.Equal(Enumerable.Empty>(), transform(Enumerable.Empty()).AggregateBy(i => i, 0, (a, i) => a + i)); + }); + } + [Fact] public void AggregateBy_SourceNull_ThrowsArgumentNullException() { @@ -15,22 +25,26 @@ public void AggregateBy_SourceNull_ThrowsArgumentNullException() AssertExtensions.Throws("source", () => first.AggregateBy(x => x, string.Empty, (x, y) => x + y)); AssertExtensions.Throws("source", () => first.AggregateBy(x => x, string.Empty, (x, y) => x + y, new AnagramEqualityComparer())); + AssertExtensions.Throws("source", () => first.AggregateBy(x => x, x => x, (x, y) => x + y)); + AssertExtensions.Throws("source", () => first.AggregateBy(x => x, x => x, (x, y) => x + y, new AnagramEqualityComparer())); } [Fact] public void AggregateBy_KeySelectorNull_ThrowsArgumentNullException() { - string[] source = { }; + string[] source = ["test"]; Func keySelector = null; AssertExtensions.Throws("keySelector", () => source.AggregateBy(keySelector, string.Empty, (x, y) => x + y)); AssertExtensions.Throws("keySelector", () => source.AggregateBy(keySelector, string.Empty, (x, y) => x + y, new AnagramEqualityComparer())); + AssertExtensions.Throws("keySelector", () => source.AggregateBy(keySelector, x => x, (x, y) => x + y)); + AssertExtensions.Throws("keySelector", () => source.AggregateBy(keySelector, x => x, (x, y) => x + y, new AnagramEqualityComparer())); } [Fact] public void AggregateBy_SeedSelectorNull_ThrowsArgumentNullException() { - string[] source = { }; + string[] source = ["test"]; Func seedSelector = null; AssertExtensions.Throws("seedSelector", () => source.AggregateBy(x => x, seedSelector, (x, y) => x + y)); @@ -40,11 +54,13 @@ public void AggregateBy_SeedSelectorNull_ThrowsArgumentNullException() [Fact] public void AggregateBy_FuncNull_ThrowsArgumentNullException() { - string[] source = { }; + string[] source = ["test"]; Func func = null; AssertExtensions.Throws("func", () => source.AggregateBy(x => x, string.Empty, func)); AssertExtensions.Throws("func", () => source.AggregateBy(x => x, string.Empty, func, new AnagramEqualityComparer())); + AssertExtensions.Throws("func", () => source.AggregateBy(x => x, x => x, func)); + AssertExtensions.Throws("func", () => source.AggregateBy(x => x, x => x, func, new AnagramEqualityComparer())); } [Fact] diff --git a/src/libraries/System.Linq/tests/ChunkTests.cs b/src/libraries/System.Linq/tests/ChunkTests.cs index ee348604192792..f8cfc4de6b854e 100644 --- a/src/libraries/System.Linq/tests/ChunkTests.cs +++ b/src/libraries/System.Linq/tests/ChunkTests.cs @@ -7,6 +7,12 @@ namespace System.Linq.Tests { public class ChunkTests : EnumerableTests { + [Fact] + public void Empty() + { + Assert.Equal(Enumerable.Empty(), Enumerable.Empty().Chunk(4)); + } + [Fact] public void ThrowsOnNullSource() { diff --git a/src/libraries/System.Linq/tests/ConcatTests.cs b/src/libraries/System.Linq/tests/ConcatTests.cs index 6209d846ff2438..0435bba5d7f05f 100644 --- a/src/libraries/System.Linq/tests/ConcatTests.cs +++ b/src/libraries/System.Linq/tests/ConcatTests.cs @@ -83,6 +83,55 @@ public void VerifyEquals(IEnumerable expected, IEnumerable actual) VerifyEqualsWorker(expected, actual); } + [Theory] + [MemberData(nameof(ArraySourcesData))] + [MemberData(nameof(SelectArraySourcesData))] + [MemberData(nameof(EnumerableSourcesData))] + [MemberData(nameof(NonCollectionSourcesData))] + [MemberData(nameof(ListSourcesData))] + [MemberData(nameof(ConcatOfConcatsData))] + [MemberData(nameof(ConcatWithSelfData))] + [MemberData(nameof(ChainedCollectionConcatData))] + [MemberData(nameof(AppendedPrependedConcatAlternationsData))] + public void First_Last_ElementAt(IEnumerable _, IEnumerable actual) + { + int count = actual.Count(); + if (count == 0) + { + Assert.Throws(() => actual.First()); + Assert.Throws(() => actual.Last()); + Assert.Throws(() => actual.ElementAt(0)); + } + else + { + int first = actual.First(); + int last = actual.Last(); + int elementAt = actual.ElementAt(count / 2); + + int enumeratedFirst = 0, enumeratedLast = 0, enumeratedElementAt = 0; + int i = 0; + foreach (int item in actual) + { + if (i == 0) + { + enumeratedFirst = item; + } + + if (i == count / 2) + { + enumeratedElementAt = item; + } + + enumeratedLast = item; + i++; + } + + Assert.Equal(enumeratedFirst, first); + Assert.Equal(enumeratedLast, last); + Assert.Equal(enumeratedElementAt, elementAt); + } + } + private static void VerifyEqualsWorker(IEnumerable expected, IEnumerable actual) { // Returns a list of functions that, when applied to enumerable, should return diff --git a/src/libraries/System.Linq/tests/DistinctTests.cs b/src/libraries/System.Linq/tests/DistinctTests.cs index 7408e96ddb38ce..e9de987ee0481d 100644 --- a/src/libraries/System.Linq/tests/DistinctTests.cs +++ b/src/libraries/System.Linq/tests/DistinctTests.cs @@ -303,6 +303,12 @@ public static void DistinctBy_RunOnce_HasExpectedOutput(IEnumerab public static IEnumerable DistinctBy_TestData() { + yield return WrapArgs( + source: Array.Empty(), + keySelector: x => x, + comparer: null, + expected: Enumerable.Empty()); + yield return WrapArgs( source: Enumerable.Range(0, 10), keySelector: x => x, diff --git a/src/libraries/System.Linq/tests/GroupByTests.cs b/src/libraries/System.Linq/tests/GroupByTests.cs index 4b8967a28a82e7..5036ebe9b02698 100644 --- a/src/libraries/System.Linq/tests/GroupByTests.cs +++ b/src/libraries/System.Linq/tests/GroupByTests.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; -using System.Diagnostics; using System.Reflection; using Xunit; @@ -864,5 +863,29 @@ public static void GroupingKeyIsPublic() PropertyInfo key = grouptype.GetProperty("Key", BindingFlags.Instance | BindingFlags.Public); Assert.NotNull(key); } + + [Fact] + public void MultipleIterationsOfSameEnumerable() + { + foreach (IEnumerable> e1 in new[] { Enumerable.Range(0, 10).GroupBy(i => i), Enumerable.Range(0, 10).GroupBy(i => i, i => i) }) + { + for (int trial = 0; trial < 3; trial++) + { + int count = 0; + foreach (IGrouping g in e1) count++; + Assert.Equal(10, count); + } + } + + foreach (IEnumerable e2 in new[] { Enumerable.Range(0, 10).GroupBy(i => i, (i, e) => i), Enumerable.Range(0, 10).GroupBy(i => i, i => i, (i, e) => i) }) + { + for (int trial = 0; trial < 3; trial++) + { + int count = 0; + foreach (int i in e2) count++; + Assert.Equal(10, count); + } + } + } } } diff --git a/src/libraries/System.Linq/tests/IndexTests.cs b/src/libraries/System.Linq/tests/IndexTests.cs index 4b08820fe0e30c..0742569f787d40 100644 --- a/src/libraries/System.Linq/tests/IndexTests.cs +++ b/src/libraries/System.Linq/tests/IndexTests.cs @@ -8,6 +8,12 @@ namespace System.Linq.Tests { public class IndexTests : EnumerableTests { + [Fact] + public void Empty() + { + Assert.Empty(Enumerable.Empty().Index()); + } + [Fact] public void Index_SourceIsNull_ArgumentNullExceptionThrown() { diff --git a/src/libraries/System.Linq/tests/MaxTests.cs b/src/libraries/System.Linq/tests/MaxTests.cs index a1509855091d11..bb70a14d684c72 100644 --- a/src/libraries/System.Linq/tests/MaxTests.cs +++ b/src/libraries/System.Linq/tests/MaxTests.cs @@ -251,6 +251,8 @@ public void Max_Float_EmptySource_ThrowsInvalidOperationException() { Assert.Throws(() => Enumerable.Empty().Max()); Assert.Throws(() => Enumerable.Empty().Max(x => x)); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Max()); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Max(x => x)); Assert.Throws(() => Array.Empty().Max()); Assert.Throws(() => new List().Max()); } @@ -331,6 +333,8 @@ public void Max_Double_EmptySource_ThrowsInvalidOperationException() { Assert.Throws(() => Enumerable.Empty().Max()); Assert.Throws(() => Enumerable.Empty().Max(x => x)); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Max()); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Max(x => x)); Assert.Throws(() => Array.Empty().Max()); Assert.Throws(() => new List().Max()); } @@ -397,6 +401,8 @@ public void Max_Decimal_EmptySource_ThrowsInvalidOperationException() { Assert.Throws(() => Enumerable.Empty().Max()); Assert.Throws(() => Enumerable.Empty().Max(x => x)); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Max()); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Max(x => x)); Assert.Throws(() => Array.Empty().Max()); Assert.Throws(() => new List().Max(x => x)); } @@ -622,6 +628,8 @@ public void Max_DateTime_EmptySource_ThrowsInvalidOperationException() { Assert.Throws(() => Enumerable.Empty().Max()); Assert.Throws(() => Enumerable.Empty().Max(i => i)); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Max()); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Max(i => i)); } public static IEnumerable Max_String_TestData() @@ -888,6 +896,7 @@ public void Max_String_WithSelectorAccessingProperty() public void Max_Boolean_EmptySource_ThrowsInvalidOperationException() { Assert.Throws(() => Enumerable.Empty().Max()); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Max()); } [Fact] diff --git a/src/libraries/System.Linq/tests/MinTests.cs b/src/libraries/System.Linq/tests/MinTests.cs index feca6994d066d6..0cc72fa43a1001 100644 --- a/src/libraries/System.Linq/tests/MinTests.cs +++ b/src/libraries/System.Linq/tests/MinTests.cs @@ -136,6 +136,8 @@ public void Min_Int_EmptySource_ThrowsInvalidOperationException() { Assert.Throws(() => Enumerable.Empty().Min()); Assert.Throws(() => Enumerable.Empty().Min(x => x)); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Min()); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Min(x => x)); Assert.Throws(() => Array.Empty().Min()); Assert.Throws(() => new List().Min()); } @@ -182,6 +184,8 @@ public void Min_Long_EmptySource_ThrowsInvalidOperationException() { Assert.Throws(() => Enumerable.Empty().Min()); Assert.Throws(() => Enumerable.Empty().Min(x => x)); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Min()); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Min(x => x)); Assert.Throws(() => Array.Empty().Min()); Assert.Throws(() => new List().Min()); } @@ -250,6 +254,8 @@ public void Min_Float_EmptySource_ThrowsInvalidOperationException() { Assert.Throws(() => Enumerable.Empty().Min()); Assert.Throws(() => Enumerable.Empty().Min(x => x)); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Min()); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Min(x => x)); Assert.Throws(() => Array.Empty().Min()); Assert.Throws(() => new List().Min()); } @@ -316,6 +322,8 @@ public void Min_Double_EmptySource_ThrowsInvalidOperationException() { Assert.Throws(() => Enumerable.Empty().Min()); Assert.Throws(() => Enumerable.Empty().Min(x => x)); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Min()); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Min(x => x)); Assert.Throws(() => Array.Empty().Min()); Assert.Throws(() => new List().Min()); } @@ -355,6 +363,8 @@ public void Min_Decimal_EmptySource_ThrowsInvalidOperationException() { Assert.Throws(() => Enumerable.Empty().Min()); Assert.Throws(() => Enumerable.Empty().Min(x => x)); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Min()); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Min(x => x)); Assert.Throws(() => Array.Empty().Min()); Assert.Throws(() => new List().Min()); } @@ -595,6 +605,8 @@ public void Min_DateTime_EmptySource_ThrowsInvalidOperationException() { Assert.Throws(() => Enumerable.Empty().Min()); Assert.Throws(() => Enumerable.Empty().Min(x => x)); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Min()); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Min(x => x)); Assert.Throws(() => Array.Empty().Min()); Assert.Throws(() => new List().Min()); } @@ -858,6 +870,7 @@ public void Min_String_NullSelector_ThrowsArgumentNullException() public void Min_Bool_EmptySource_ThrowsInvalodOperationException() { Assert.Throws(() => Enumerable.Empty().Min()); + Assert.Throws(() => ForceNotCollection(Enumerable.Empty()).Min()); } [Fact] diff --git a/src/libraries/System.Linq/tests/OrderTests.cs b/src/libraries/System.Linq/tests/OrderTests.cs index ed2dd9bfc8765e..dee76efe73823d 100644 --- a/src/libraries/System.Linq/tests/OrderTests.cs +++ b/src/libraries/System.Linq/tests/OrderTests.cs @@ -196,6 +196,9 @@ public void FirstOnOrdered() { Assert.Equal(0, Enumerable.Range(0, 10).Shuffle().Order().First()); Assert.Equal(9, Enumerable.Range(0, 10).Shuffle().OrderDescending().First()); + + Assert.Equal(0, ForceNotCollection(Enumerable.Range(0, 10).Shuffle()).Order().First()); + Assert.Equal(9, ForceNotCollection(Enumerable.Range(0, 10).Shuffle()).OrderDescending().First()); } [Fact] @@ -281,6 +284,9 @@ public void LastOnOrdered() { Assert.Equal(9, Enumerable.Range(0, 10).Shuffle().Order().Last()); Assert.Equal(0, Enumerable.Range(0, 10).Shuffle().OrderDescending().Last()); + + Assert.Equal(9, ForceNotCollection(Enumerable.Range(0, 10).Shuffle()).Order().Last()); + Assert.Equal(0, ForceNotCollection(Enumerable.Range(0, 10).Shuffle()).OrderDescending().Last()); } [Fact] @@ -307,6 +313,16 @@ public void LastOrDefaultOnOrdered() Assert.Equal(0, Enumerable.Empty().Order().LastOrDefault()); } + [Fact] + public void ElementAtOnOrdered() + { + Assert.Equal(4, Enumerable.Range(0, 10).Shuffle().Order().ElementAt(4)); + Assert.Equal(5, Enumerable.Range(0, 10).Shuffle().OrderDescending().ElementAt(4)); + + Assert.Equal(4, ForceNotCollection(Enumerable.Range(0, 10).Shuffle()).Order().ElementAt(4)); + Assert.Equal(5, ForceNotCollection(Enumerable.Range(0, 10).Shuffle()).OrderDescending().ElementAt(4)); + } + [Fact] public void EnumeratorDoesntContinue() { diff --git a/src/libraries/System.Linq/tests/RangeTests.cs b/src/libraries/System.Linq/tests/RangeTests.cs index 8421a66ba890c9..2e331cfee7ecd7 100644 --- a/src/libraries/System.Linq/tests/RangeTests.cs +++ b/src/libraries/System.Linq/tests/RangeTests.cs @@ -243,6 +243,7 @@ static void Validate(IEnumerable e, int[] expected) Assert.Throws(() => list.Insert(0, 42)); Assert.Throws(() => list.Clear()); Assert.Throws(() => list.Remove(42)); + Assert.Throws(() => list.RemoveAt(0)); Assert.Throws(() => list[0] = 42); AssertExtensions.Throws("index", () => list[-1]); AssertExtensions.Throws("index", () => list[expected.Length]); @@ -255,6 +256,8 @@ static void Validate(IEnumerable e, int[] expected) Assert.False(list.Contains(expected[0] - 1)); Assert.False(list.Contains(expected[^1] + 1)); + Assert.Equal(-1, list.IndexOf(expected[0] - 1)); + Assert.Equal(-1, list.IndexOf(expected[^1] + 1)); Assert.All(expected, i => Assert.True(list.Contains(i))); Assert.All(expected, i => Assert.Equal(Array.IndexOf(expected, i), list.IndexOf(i))); for (int i = 0; i < expected.Length; i++) diff --git a/src/libraries/System.Linq/tests/RepeatTests.cs b/src/libraries/System.Linq/tests/RepeatTests.cs index 625dff376de311..df8eebda35691e 100644 --- a/src/libraries/System.Linq/tests/RepeatTests.cs +++ b/src/libraries/System.Linq/tests/RepeatTests.cs @@ -255,6 +255,7 @@ static void Validate(IEnumerable e, int[] expected) Assert.Throws(() => list.Insert(0, 42)); Assert.Throws(() => list.Clear()); Assert.Throws(() => list.Remove(42)); + Assert.Throws(() => list.RemoveAt(0)); Assert.Throws(() => list[0] = 42); AssertExtensions.Throws("index", () => list[-1]); AssertExtensions.Throws("index", () => list[expected.Length]); @@ -267,6 +268,8 @@ static void Validate(IEnumerable e, int[] expected) Assert.False(list.Contains(expected[0] - 1)); Assert.False(list.Contains(expected[^1] + 1)); + Assert.Equal(-1, list.IndexOf(expected[0] - 1)); + Assert.Equal(-1, list.IndexOf(expected[^1] + 1)); Assert.All(expected, i => Assert.True(list.Contains(i))); Assert.All(expected, i => Assert.Equal(Array.IndexOf(expected, i), list.IndexOf(i))); for (int i = 0; i < expected.Length; i++) diff --git a/src/libraries/System.Linq/tests/SequenceEqualTests.cs b/src/libraries/System.Linq/tests/SequenceEqualTests.cs index 380916550efd60..7393d18947aa8f 100644 --- a/src/libraries/System.Linq/tests/SequenceEqualTests.cs +++ b/src/libraries/System.Linq/tests/SequenceEqualTests.cs @@ -1,8 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests @@ -246,5 +245,31 @@ public void ByteArrays_SpecialCasedButExpectedBehavior() } } } + + [Fact] + public void ICollectionsCompareCorrectly() + { + Assert.True(new TestCollection([]).SequenceEqual(new TestCollection([]))); + Assert.True(new TestCollection([1]).SequenceEqual(new TestCollection([1]))); + Assert.True(new TestCollection([1, 2, 3]).SequenceEqual(new TestCollection([1, 2, 3]))); + + Assert.False(new TestCollection([1, 2, 3, 4]).SequenceEqual(new TestCollection([1, 2, 3]))); + Assert.False(new TestCollection([1, 2, 3]).SequenceEqual(new TestCollection([1, 2, 3, 4]))); + Assert.False(new TestCollection([1, 2, 3]).SequenceEqual(new TestCollection([1, 2, 4]))); + Assert.False(new TestCollection([-1, 2, 3]).SequenceEqual(new TestCollection([-2, 2, 3]))); + } + + [Fact] + public void IListsCompareCorrectly() + { + Assert.True(new ReadOnlyCollection([]).SequenceEqual(new ReadOnlyCollection([]))); + Assert.True(new ReadOnlyCollection([1]).SequenceEqual(new ReadOnlyCollection([1]))); + Assert.True(new ReadOnlyCollection([1, 2, 3]).SequenceEqual(new ReadOnlyCollection([1, 2, 3]))); + + Assert.False(new ReadOnlyCollection([1, 2, 3, 4]).SequenceEqual(new ReadOnlyCollection([1, 2, 3]))); + Assert.False(new ReadOnlyCollection([1, 2, 3]).SequenceEqual(new ReadOnlyCollection([1, 2, 3, 4]))); + Assert.False(new ReadOnlyCollection([1, 2, 3]).SequenceEqual(new ReadOnlyCollection([1, 2, 4]))); + Assert.False(new ReadOnlyCollection([-1, 2, 3]).SequenceEqual(new ReadOnlyCollection([-2, 2, 3]))); + } } } diff --git a/src/libraries/System.Linq/tests/SkipLastTests.cs b/src/libraries/System.Linq/tests/SkipLastTests.cs index fe9652a875e9d9..c4770410870d47 100644 --- a/src/libraries/System.Linq/tests/SkipLastTests.cs +++ b/src/libraries/System.Linq/tests/SkipLastTests.cs @@ -9,6 +9,12 @@ namespace System.Linq.Tests { public class SkipLastTests : EnumerableTests { + [Fact] + public void SkipLastThrowsOnNull() + { + AssertExtensions.Throws("source", () => ((IEnumerable)null).SkipLast(10)); + } + [Theory] [MemberData(nameof(EnumerableData), MemberType = typeof(SkipTakeData))] public void SkipLast(IEnumerable source, int count) diff --git a/src/libraries/System.Linq/tests/SkipWhileTests.cs b/src/libraries/System.Linq/tests/SkipWhileTests.cs index 26281efc5f862a..75a90967042887 100644 --- a/src/libraries/System.Linq/tests/SkipWhileTests.cs +++ b/src/libraries/System.Linq/tests/SkipWhileTests.cs @@ -1,15 +1,20 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Collections; using System.Collections.Generic; using Xunit; -using Xunit.Abstractions; namespace System.Linq.Tests { public class SkipWhileTests : EnumerableTests { + [Fact] + public void Empty() + { + Assert.Equal(Enumerable.Empty(), Enumerable.Empty().SkipWhile(i => i < 40)); + Assert.Equal(Enumerable.Empty(), Enumerable.Empty().SkipWhile((i, index) => i < 40)); + } + [Fact] public void SkipWhileAllTrue() { diff --git a/src/libraries/System.Linq/tests/TakeLastTests.cs b/src/libraries/System.Linq/tests/TakeLastTests.cs index 31b58d5bf017ae..b39d59e94263a6 100644 --- a/src/libraries/System.Linq/tests/TakeLastTests.cs +++ b/src/libraries/System.Linq/tests/TakeLastTests.cs @@ -9,6 +9,12 @@ namespace System.Linq.Tests { public class TakeLastTests : EnumerableTests { + [Fact] + public void SkipLastThrowsOnNull() + { + AssertExtensions.Throws("source", () => ((IEnumerable)null).TakeLast(10)); + } + [Theory] [MemberData(nameof(EnumerableData), MemberType = typeof(SkipTakeData))] public void TakeLast(IEnumerable source, int count) diff --git a/src/libraries/System.Linq/tests/TakeTests.cs b/src/libraries/System.Linq/tests/TakeTests.cs index 93a0405bfaf312..9dfd829e5d9efa 100644 --- a/src/libraries/System.Linq/tests/TakeTests.cs +++ b/src/libraries/System.Linq/tests/TakeTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests @@ -2031,5 +2032,37 @@ public void EmptySource_DoNotThrowException_EnumerablePartition() Assert.Empty(EnumerablePartitionOrEmpty(source).Take(3..^8)); Assert.Empty(EnumerablePartitionOrEmpty(source).Take(^6..^7)); } + + [Fact] + public void SkipTakeOnIListIsIList() + { + IList list = new ReadOnlyCollection(Enumerable.Range(0, 100).ToList()); + IList skipTake = Assert.IsAssignableFrom>(list.Skip(10).Take(20)); + + Assert.True(skipTake.IsReadOnly); + Assert.Equal(20, skipTake.Count); + int[] results = new int[20]; + skipTake.CopyTo(results, 0); + for (int i = 0; i < 20; i++) + { + Assert.Equal(i + 10, skipTake[i]); + Assert.Equal(i + 10, results[i]); + Assert.True(skipTake.Contains(i + 10)); + Assert.True(skipTake.IndexOf(i + 10) == i); + } + + Assert.False(skipTake.Contains(9)); + Assert.False(skipTake.Contains(30)); + + Assert.Throws(() => skipTake[-1]); + Assert.Throws(() => skipTake[20]); + + Assert.Throws(() => skipTake.Add(42)); + Assert.Throws(() => skipTake.Clear()); + Assert.Throws(() => skipTake.Insert(0, 42)); + Assert.Throws(() => skipTake.Remove(42)); + Assert.Throws(() => skipTake.RemoveAt(0)); + Assert.Throws(() => skipTake[0] = 42); + } } } diff --git a/src/libraries/System.Linq/tests/TakeWhileTests.cs b/src/libraries/System.Linq/tests/TakeWhileTests.cs index 55f02459978a00..18589d97192b0f 100644 --- a/src/libraries/System.Linq/tests/TakeWhileTests.cs +++ b/src/libraries/System.Linq/tests/TakeWhileTests.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Collections; using System.Collections.Generic; using Xunit; @@ -9,6 +8,13 @@ namespace System.Linq.Tests { public class TakeWhileTests : EnumerableTests { + [Fact] + public void Empty() + { + Assert.Equal(Enumerable.Empty(), Enumerable.Empty().TakeWhile(i => i < 40)); + Assert.Equal(Enumerable.Empty(), Enumerable.Empty().TakeWhile((i, index) => i < 40)); + } + [Fact] public void SameResultsRepeatCallsIntQuery() { diff --git a/src/libraries/System.Linq/tests/ToLookupTests.cs b/src/libraries/System.Linq/tests/ToLookupTests.cs index 8d1f1f5fd00822..d2f81e9a662d5c 100644 --- a/src/libraries/System.Linq/tests/ToLookupTests.cs +++ b/src/libraries/System.Linq/tests/ToLookupTests.cs @@ -1,10 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; +using System.Collections; using System.Collections.Generic; -using System.Diagnostics; -using System.Reflection; using Xunit; namespace System.Linq.Tests @@ -53,6 +51,14 @@ from x4 in q2 Assert.Equal(q.ToLookup(e => e.a1), q.ToLookup(e => e.a1)); } + [Fact] + public void Empty() + { + AssertMatches(Enumerable.Empty(), Enumerable.Empty(), Enumerable.Empty().ToLookup(i => i)); + Assert.False(Enumerable.Empty().ToLookup(i => i).Contains(0)); + Assert.Empty(Enumerable.Empty().ToLookup(i => i)[0]); + } + [Fact] public void NullKeyIncluded() { @@ -289,6 +295,18 @@ public void ApplyResultSelectorForGroup(int enumType) Assert.Equal(expected, result); } + [Fact] + public void ApplyResultSelector() + { + Lookup lookup = (Lookup)new int[] { 1, 2, 2, 3, 3, 3 }.ToLookup(i => i); + IEnumerable sums = lookup.ApplyResultSelector((key, elements) => + { + Assert.Equal(key, elements.Count()); + return elements.Sum(); + }); + Assert.Equal([1, 4, 9], sums); + } + [Theory] [InlineData(0)] [InlineData(1)] @@ -302,6 +320,7 @@ public void LookupImplementsICollection(int count) var collection = (ICollection>)Enumerable.Range(0, count).ToLookup(i => i.ToString()); Assert.Equal(count, collection.Count); + Assert.True(collection.IsReadOnly); Assert.Throws(() => collection.Add(null)); Assert.Throws(() => collection.Remove(null)); Assert.Throws(() => collection.Clear()); @@ -313,6 +332,7 @@ public void LookupImplementsICollection(int count) Assert.True(collection.Contains(first)); Assert.True(collection.Contains(last)); } + Assert.False(collection.Contains(new NopGrouping())); IGrouping[] items = new IGrouping[count]; collection.CopyTo(items, 0); @@ -321,6 +341,13 @@ public void LookupImplementsICollection(int count) Assert.Equal(items, Enumerable.Range(0, count).ToLookup(i => i.ToString()).ToList()); } + private sealed class NopGrouping : IGrouping + { + public string Key => ""; + public IEnumerator GetEnumerator() => ((IList)Array.Empty()).GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + public class Membership { public int Id { get; set; } diff --git a/src/libraries/System.Linq/tests/WhereTests.cs b/src/libraries/System.Linq/tests/WhereTests.cs index a6bc625e9e4936..a25b3f1abf9c31 100644 --- a/src/libraries/System.Linq/tests/WhereTests.cs +++ b/src/libraries/System.Linq/tests/WhereTests.cs @@ -1094,6 +1094,50 @@ public void ToCollection(IEnumerable source) } } + [Fact] + public void WhereFirstLast() + { + Assert.All(IdentityTransforms(), transform => + { + IEnumerable data = transform(Enumerable.Range(0, 10)); + + Assert.Equal(3, data.Where(i => i == 3).First()); + Assert.Equal(0, data.Where(i => i % 2 == 0).First()); + + Assert.Equal(3, data.Where(i => i == 3).Last()); + Assert.Equal(8, data.Where(i => i % 2 == 0).Last()); + + Assert.Equal(3, data.Where(i => i == 3).ElementAt(0)); + Assert.Equal(8, data.Where(i => i % 2 == 0).ElementAt(4)); + + Assert.Throws(() => data.Where(i => i == 10).First()); + Assert.Throws(() => data.Where(i => i == 10).Last()); + Assert.Throws(() => data.Where(i => i == 10).ElementAt(0)); + }); + } + + [Fact] + public void WhereSelectFirstLast() + { + Assert.All(IdentityTransforms(), transform => + { + IEnumerable data = transform(Enumerable.Range(0, 10)); + + Assert.Equal(6, data.Where(i => i == 3).Select(i => i * 2).First()); + Assert.Equal(0, data.Where(i => i % 2 == 0).Select(i => i * 2).First()); + + Assert.Equal(6, data.Where(i => i == 3).Select(i => i * 2).Last()); + Assert.Equal(16, data.Where(i => i % 2 == 0).Select(i => i * 2).Last()); + + Assert.Equal(6, data.Where(i => i == 3).Select(i => i * 2).ElementAt(0)); + Assert.Equal(16, data.Where(i => i % 2 == 0).Select(i => i * 2).ElementAt(4)); + + Assert.Throws(() => data.Where(i => i == 10).Select(i => i * 2).First()); + Assert.Throws(() => data.Where(i => i == 10).Select(i => i * 2).Last()); + Assert.Throws(() => data.Where(i => i == 10).Select(i => i * 2).ElementAt(0)); + }); + } + public static IEnumerable ToCollectionData() { IEnumerable seq = GenerateRandomSequnce(seed: 0xdeadbeef, count: 10); From 8a09e9f1532a208788a3aea3e0d8346482ef7ae8 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 27 Feb 2024 21:21:06 -0500 Subject: [PATCH 7/7] Disable speed-optimized test on wasm --- src/libraries/System.Linq/tests/TakeTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Linq/tests/TakeTests.cs b/src/libraries/System.Linq/tests/TakeTests.cs index 9dfd829e5d9efa..b19c69ab0fd9af 100644 --- a/src/libraries/System.Linq/tests/TakeTests.cs +++ b/src/libraries/System.Linq/tests/TakeTests.cs @@ -2033,7 +2033,7 @@ public void EmptySource_DoNotThrowException_EnumerablePartition() Assert.Empty(EnumerablePartitionOrEmpty(source).Take(^6..^7)); } - [Fact] + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsSpeedOptimized))] public void SkipTakeOnIListIsIList() { IList list = new ReadOnlyCollection(Enumerable.Range(0, 100).ToList());