tags:

views:

174

answers:

2

Given n enumerables of the same type that return distinct elements in ascending order, for example:

IEnumerable<char> s1 = "adhjlstxyz";
IEnumerable<char> s2 = "bdeijmnpsz";
IEnumerable<char> s3 = "dejlnopsvw";

I want to efficiently find all values that are elements of all enumerables:

IEnumerable<char> sx = Intersect(new[] { s1, s2, s3 });

Debug.Assert(sx.SequenceEqual("djs"));

"Efficiently" here means that

  1. the input enumerables should each be enumerated only once,
  2. the elements of the input enumerables should be retrieved only when needed, and
  3. the algorithm should not recursively enumerate its own output.

I need some hints how to approach a solution.


Here is my (naive) attempt so far:

static IEnumerable<T> Intersect<T>(IEnumerable<T>[] enums)
{
    return enums[0].Intersect(
        enums.Length == 2 ? enums[1] : Intersect(enums.Skip(1).ToArray()));
}

Enumerable.Intersect collects the first enumerable into a HashSet, then enumerates the second enumerable and yields all matching elements. Intersect then recursively intersects the result with the next enumerable. This obviously isn't very efficient (it doesn't meet the constraints). And it doesn't exploit the fact that the elements are sorted at all.


Here is my attempt to intersect two enumerables. Maybe it can be generalized for n enumerables?

static IEnumerable<T> Intersect<T>(IEnumerable<T> first, IEnumerable<T> second)
{
    using (var left = first.GetEnumerator())
    using (var right = second.GetEnumerator())
    {
        var leftHasNext = left.MoveNext();
        var rightHasNext = right.MoveNext();

        var comparer = Comparer<T>.Default;

        while (leftHasNext && rightHasNext)
        {
            switch (Math.Sign(comparer.Compare(left.Current, right.Current)))
            {
            case -1:
                leftHasNext = left.MoveNext();
                break;
            case 0:
                yield return left.Current;
                leftHasNext = left.MoveNext();
                rightHasNext = right.MoveNext();
                break;
            case 1:
                rightHasNext = right.MoveNext();
                break;
            }
        }
    }
}
+2  A: 

You can use LINQ:

    public static IEnumerable<T> Intersect<T>(IEnumerable<IEnumerable<T>> enums) {
        using (var iter = enums.GetEnumerator()) {
            IEnumerable<T> result;
            if (iter.MoveNext()) {
                result = iter.Current;
                while (iter.MoveNext()) {
                    result = result.Intersect(iter.Current);
                }
            } else {
                result = Enumerable.Empty<T>();
            }
            return result;
        }
    }

This would be simple, although it does build the hash-set multiple times; advancing all n at once (to take advantage of sorted) would be hard, but you could also build a single hash-set and remove missing things?

Marc Gravell
I'm looking for a less simple solution :-) By question is basically: how do I approach a solution that advances all n at once (to take advantage of sorted).
dtb
Your second version, combined with my second attempt, looks pretty good. I'll grab some coffee and try to understand why it works.
dtb
D'oh. This is basically `enums.Aggregate(Enumerable.Empty<T>(), Enumerable.Intersect)` (modulo the little optimization if enums is non-empty).
dtb
You're right. It's `enums.DefaultIfEmpty(Enumerable.Empty<T>()).Aggregate(Enumerable.Intersect);` then :-)
dtb
+3  A: 

OK; more complex answer:

public static IEnumerable<T> Intersect<T>(params IEnumerable<T>[] enums) {
    return Intersect<T>(null, enums);
}
public static IEnumerable<T> Intersect<T>(IComparer<T> comparer, params IEnumerable<T>[] enums) {
    if(enums == null) throw new ArgumentNullException("enums");
    if(enums.Length == 0) return Enumerable.Empty<T>();
    if(enums.Length == 1) return enums[0];
    if(comparer == null) comparer = Comparer<T>.Default;
    return IntersectImpl(comparer, enums);
}
public static IEnumerable<T> IntersectImpl<T>(IComparer<T> comparer, IEnumerable<T>[] enums) {
    IEnumerator<T>[] iters = new IEnumerator<T>[enums.Length];
    try {
        // create iterators and move as far as the first item
        for (int i = 0; i < enums.Length; i++) {
            if(!(iters[i] = enums[i].GetEnumerator()).MoveNext()) {
                yield break; // no data for one of the iterators
            }
        }
        bool first = true;
        T lastValue = default(T);
        do { // get the next item from the first sequence
            T value = iters[0].Current;
            if (!first && comparer.Compare(value, lastValue) == 0) continue; // dup in first source
            bool allTrue = true;
            for (int i = 1; i < iters.Length; i++) {
                var iter = iters[i];
                // if any sequence isn't there yet, progress it; if any sequence
                // ends, we're all done
                while (comparer.Compare(iter.Current, value) < 0) {
                    if (!iter.MoveNext()) goto alldone; // nasty, but
                }
                // if any sequence is now **past** value, then short-circuit
                if (comparer.Compare(iter.Current, value) > 0) {
                    allTrue = false;
                    break;
                }
            }
            // so all sequences have this value
            if (allTrue) yield return value;
            first = false;
            lastValue = value;
        } while (iters[0].MoveNext());
    alldone:
        ;
    } finally { // clean up all iterators
        for (int i = 0; i < iters.Length; i++) {
            if (iters[i] != null) {
                try { iters[i].Dispose(); }
                catch { }
            }
        }
    }
}
Marc Gravell
Amazing. Thanks! Interestingly, my second attempt is faster than this solution for n=2, but this solution is faster my second attempt chained for any n!=0. Any solution involving Enumerable.Intersect is much slower than both.
dtb
Any rough estimate on the complexity of this algorithm? I'm tempted to say it's `O(n0+n1+..nn)` but I've got a feeling that's wrong...
dtb
It never rewinds anything; you could argue that it is O(m * min(n[1],n[2],...n[m])), (m = number of sequences, each of length n[i]); since it only runs until **any** sequence is exhausted, and iterates all the sequences at the same rate until then.
Marc Gravell
Right, thanks. Looks like the problem can't be solved more efficiently than this, complexity-wise :-)
dtb