Skip to content

Fixes for EF issues #185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ public static Task<Dictionary<TKey, TSource>> ToDictionary<TSource, TKey>(this I
return source.ToDictionary(keySelector, x => x, EqualityComparer<TKey>.Default, cancellationToken);
}

public static Task<ILookup<TKey, TElement>> ToLookup<TSource, TKey, TElement>(this IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TSource, TElement> elementSelector, IEqualityComparer<TKey> comparer, CancellationToken cancellationToken)
public static async Task<ILookup<TKey, TElement>> ToLookup<TSource, TKey, TElement>(this IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TSource, TElement> elementSelector, IEqualityComparer<TKey> comparer, CancellationToken cancellationToken)
{
if (source == null)
throw new ArgumentNullException("source");
Expand All @@ -679,7 +679,9 @@ public static Task<ILookup<TKey, TElement>> ToLookup<TSource, TKey, TElement>(th
if (comparer == null)
throw new ArgumentNullException("comparer");

return source.Aggregate(new Lookup<TKey, TElement>(comparer), (lookup, x) => { lookup.Add(keySelector(x), elementSelector(x)); return lookup; }, lookup => (ILookup<TKey, TElement>)lookup, cancellationToken);
var lookup = await Internal.Lookup<TKey, TElement>.CreateAsync(source, keySelector, elementSelector, comparer, cancellationToken).ConfigureAwait(false);

return lookup;
}

public static Task<ILookup<TKey, TElement>> ToLookup<TSource, TKey, TElement>(this IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TSource, TElement> elementSelector, CancellationToken cancellationToken)
Expand Down Expand Up @@ -716,53 +718,6 @@ public static Task<ILookup<TKey, TSource>> ToLookup<TSource, TKey>(this IAsyncEn
return source.ToLookup(keySelector, x => x, EqualityComparer<TKey>.Default, cancellationToken);
}

class Lookup<TKey, TElement> : ILookup<TKey, TElement>
{
private readonly Dictionary<TKey, EnumerableGrouping<TKey, TElement>> map;

public Lookup(IEqualityComparer<TKey> comparer)
{
map = new Dictionary<TKey, EnumerableGrouping<TKey, TElement>>(comparer);
}

public void Add(TKey key, TElement element)
{
var g = default(EnumerableGrouping<TKey, TElement>);
if (!map.TryGetValue(key, out g))
{
g = new EnumerableGrouping<TKey, TElement>(key);
map.Add(key, g);
}

g.Add(element);
}

public bool Contains(TKey key)
{
return map.ContainsKey(key);
}

public int Count
{
get { return map.Keys.Count; }
}

public IEnumerable<TElement> this[TKey key]
{
get { return map[key]; }
}

public IEnumerator<IGrouping<TKey, TElement>> GetEnumerator()
{
return map.Values.Cast<IGrouping<TKey, TElement>>().GetEnumerator();
}

System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}

public static Task<double> Average(this IAsyncEnumerable<int> source, CancellationToken cancellationToken)
{
if (source == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public static partial class AsyncEnumerable
public static IAsyncEnumerable<TSource> ToAsyncEnumerable<TSource>(this IEnumerable<TSource> source)
{
if (source == null)
throw new ArgumentNullException("source");
throw new ArgumentNullException(nameof(source));

return Create(() =>
{
Expand Down Expand Up @@ -44,7 +44,7 @@ public static IAsyncEnumerable<TSource> ToAsyncEnumerable<TSource>(this IEnumera
public static IEnumerable<TSource> ToEnumerable<TSource>(this IAsyncEnumerable<TSource> source)
{
if (source == null)
throw new ArgumentNullException("source");
throw new ArgumentNullException(nameof(source));

return ToEnumerable_(source);
}
Expand All @@ -68,38 +68,28 @@ private static IEnumerable<TSource> ToEnumerable_<TSource>(IAsyncEnumerable<TSou
public static IAsyncEnumerable<TSource> ToAsyncEnumerable<TSource>(this Task<TSource> task)
{
if (task == null)
throw new ArgumentNullException("task");

throw new ArgumentNullException(nameof(task));
return Create(() =>
{
var called = 0;

var value = default(TSource);
return Create(
(ct, tcs) =>
async ct =>
{
if (Interlocked.CompareExchange(ref called, 1, 0) == 0)
{
task.Then(continuedTask =>
{
if (continuedTask.IsCanceled)
tcs.SetCanceled();
else if (continuedTask.IsFaulted)
tcs.SetException(continuedTask.Exception.InnerException);
else
tcs.SetResult(true);
});
value = await task.ConfigureAwait(false);
return true;
}
else
tcs.SetResult(false);

return tcs.Task;
return false;
},
() => task.Result,
() => value,
() => { });
});
}

#if !NO_RXINTERFACES
public static IAsyncEnumerable<TSource> ToAsyncEnumerable<TSource>(this IObservable<TSource> source)
{
if (source == null)
Expand Down Expand Up @@ -304,6 +294,5 @@ public IDisposable Subscribe(IObserver<T> observer)
return Disposable.Create(ctd, e);
}
}
#endif
}
}
210 changes: 130 additions & 80 deletions Ix.NET/Source/System.Interactive.Async/AsyncEnumerable.Multiple.cs
Original file line number Diff line number Diff line change
Expand Up @@ -356,87 +356,8 @@ public static IAsyncEnumerable<TResult> GroupJoin<TOuter, TInner, TKey, TResult>
if (comparer == null)
throw new ArgumentNullException("comparer");

return Create(() =>
{
var innerMap = default(Task<ILookup<TKey, TInner>>);
var getInnerMap = new Func<CancellationToken, Task<ILookup<TKey, TInner>>>(ct =>
{
if (innerMap == null)
innerMap = inner.ToLookup(innerKeySelector, comparer, ct);

return innerMap;
});

var outerE = outer.GetEnumerator();
var current = default(TResult);

var cts = new CancellationTokenDisposable();
var d = Disposable.Create(cts, outerE);

var f = default(Action<TaskCompletionSource<bool>, CancellationToken>);
f = (tcs, ct) =>
{
getInnerMap(ct).Then(ti =>
{
ti.Handle(tcs, map =>
{
outerE.MoveNext(ct).Then(to =>
{
to.Handle(tcs, res =>
{
if (res)
{
var element = outerE.Current;
var key = default(TKey);

try
{
key = outerKeySelector(element);
}
catch (Exception ex)
{
tcs.TrySetException(ex);
return;
}

var innerE = default(IAsyncEnumerable<TInner>);
if (!map.Contains(key))
innerE = AsyncEnumerable.Empty<TInner>();
else
innerE = map[key].ToAsyncEnumerable();

try
{
current = resultSelector(element, innerE);
}
catch (Exception ex)
{
tcs.TrySetException(ex);
return;
}

tcs.TrySetResult(true);
}
else
{
tcs.TrySetResult(false);
}
});
});
});
});
};

return Create(
(ct, tcs) =>
{
f(tcs, cts.Token);
return tcs.Task.UsingEnumerator(outerE);
},
() => current,
d.Dispose
);
});
return new GroupJoinAsyncEnumerable<TOuter, TInner, TKey, TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer);
}

public static IAsyncEnumerable<TResult> GroupJoin<TOuter, TInner, TKey, TResult>(this IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, TKey> outerKeySelector, Func<TInner, TKey> innerKeySelector, Func<TOuter, IAsyncEnumerable<TInner>, TResult> resultSelector)
Expand All @@ -455,6 +376,135 @@ public static IAsyncEnumerable<TResult> GroupJoin<TOuter, TInner, TKey, TResult>
return outer.GroupJoin(inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer<TKey>.Default);
}


private sealed class GroupJoinAsyncEnumerable<TOuter, TInner, TKey, TResult> : IAsyncEnumerable<TResult>
{
private readonly IAsyncEnumerable<TOuter> _outer;
private readonly IAsyncEnumerable<TInner> _inner;
private readonly Func<TOuter, TKey> _outerKeySelector;
private readonly Func<TInner, TKey> _innerKeySelector;
private readonly Func<TOuter, IAsyncEnumerable<TInner>, TResult> _resultSelector;
private readonly IEqualityComparer<TKey> _comparer;

public GroupJoinAsyncEnumerable(
IAsyncEnumerable<TOuter> outer,
IAsyncEnumerable<TInner> inner,
Func<TOuter, TKey> outerKeySelector,
Func<TInner, TKey> innerKeySelector,
Func<TOuter, IAsyncEnumerable<TInner>, TResult> resultSelector,
IEqualityComparer<TKey> comparer)
{
_outer = outer;
_inner = inner;
_outerKeySelector = outerKeySelector;
_innerKeySelector = innerKeySelector;
_resultSelector = resultSelector;
_comparer = comparer;
}

public IAsyncEnumerator<TResult> GetEnumerator()
=> new GroupJoinAsyncEnumerator(
_outer.GetEnumerator(),
_inner,
_outerKeySelector,
_innerKeySelector,
_resultSelector,
_comparer);

private sealed class GroupJoinAsyncEnumerator : IAsyncEnumerator<TResult>
{
private readonly IAsyncEnumerator<TOuter> _outer;
private readonly IAsyncEnumerable<TInner> _inner;
private readonly Func<TOuter, TKey> _outerKeySelector;
private readonly Func<TInner, TKey> _innerKeySelector;
private readonly Func<TOuter, IAsyncEnumerable<TInner>, TResult> _resultSelector;
private readonly IEqualityComparer<TKey> _comparer;

private Internal.Lookup<TKey, TInner> _lookup;

public GroupJoinAsyncEnumerator(
IAsyncEnumerator<TOuter> outer,
IAsyncEnumerable<TInner> inner,
Func<TOuter, TKey> outerKeySelector,
Func<TInner, TKey> innerKeySelector,
Func<TOuter, IAsyncEnumerable<TInner>, TResult> resultSelector,
IEqualityComparer<TKey> comparer)
{
_outer = outer;
_inner = inner;
_outerKeySelector = outerKeySelector;
_innerKeySelector = innerKeySelector;
_resultSelector = resultSelector;
_comparer = comparer;
}

public async Task<bool> MoveNext(CancellationToken cancellationToken)
{
// nothing to do
if (!await _outer.MoveNext(cancellationToken).ConfigureAwait(false))
{
return false;
}

if (_lookup == null)
{
_lookup = await Internal.Lookup<TKey, TInner>.CreateForJoinAsync(_inner, _innerKeySelector, _comparer, cancellationToken).ConfigureAwait(false);
}

var item = _outer.Current;
Current = _resultSelector(item, new AsyncEnumerableAdapter<TInner>(_lookup[_outerKeySelector(item)]));
return true;
}

public TResult Current { get; private set; }

public void Dispose()
{
_outer.Dispose();
}


}
}

private sealed class AsyncEnumerableAdapter<T> : IAsyncEnumerable<T>
{
private readonly IEnumerable<T> _source;

public AsyncEnumerableAdapter(IEnumerable<T> source)
{
_source = source;
}

public IAsyncEnumerator<T> GetEnumerator()
=> new AsyncEnumeratorAdapter(_source.GetEnumerator());

private sealed class AsyncEnumeratorAdapter : IAsyncEnumerator<T>
{
private readonly IEnumerator<T> _enumerator;

public AsyncEnumeratorAdapter(IEnumerator<T> enumerator)
{
_enumerator = enumerator;
}

public Task<bool> MoveNext(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

#if HAS_AWAIT
return Task.FromResult(_enumerator.MoveNext());
#else
return TaskEx.FromResult(_enumerator.MoveNext());
#endif
}

public T Current => _enumerator.Current;

public void Dispose() => _enumerator.Dispose();
}
}

public static IAsyncEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(this IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, TKey> outerKeySelector, Func<TInner, TKey> innerKeySelector, Func<TOuter, TInner, TResult> resultSelector, IEqualityComparer<TKey> comparer)
{
if (outer == null)
Expand Down
Loading