Given a collection, is there a way to get the last N elements of that collection? If there isn't a method in the framework, what would be the best way to write an extension method to do this?
coll.Reverse().Take(N).Reverse().ToList();
public static IEnumerable<T> TakeLast(this IEnumerable<T> coll, int N)
{
return coll.Reverse().Take(N).Reverse();
}
Note: I missed your question title which said Using Linq, so my answer does not in fact use Linq.
If you want to avoid caching a non-lazy copy of the entire collection, you could write a simple method that does it using a linked list.
The following method will add each value it finds in the original collection into a linked list, and trim the linked list down to the number of items required. Since it keeps the linked list trimmed to this number of items the entire time through iterating through the collection, it will only keep a copy of at most N items from the original collection.
It does not require you to know the number of items in the original collection, nor iterate over it more than once.
Usage:
IEnumerable<int> sequence = Enumerable.Range(1, 10000);
IEnumerable<int> last10 = sequence.Last(10);
...
Extension method:
public static class Extensions
{
public static IEnumerable<T> Last<T>(this IEnumerable<T> collection, int n)
{
if (collection == null)
throw new ArgumentNullException("collection");
if (n < 0)
throw new ArgumentOutOfRangeException("n", "n must be 0 or greater");
LinkedList<T> temp = new LinkedList<T>();
foreach (var value in collection)
{
temp.AddLast(value);
if (temp.Count > n)
temp.RemoveFirst();
}
return temp;
}
}
Here's a method that works on any enumerable but uses only O(N) temporary storage:
public static class TakeLastExtension
{
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> source, int n)
{
if (source == null) { throw new ArgumentNullException("source"); }
if (n < 0) { throw new ArgumentOutOfRangeException("must not be negative", "n"); }
if (n == 0) { yield break; }
T[] result = new T[n];
int i = 0;
int count = 0;
foreach (T t in source)
{
result[i] = t;
i = (i + 1) % n;
count++;
}
if (count < n)
{
n = count;
i = 0;
}
for (int j = 0; j < n; ++j)
{
yield return result[(i + j) % n];
}
}
}
Usage:
List<int> l = new List<int> {4, 6, 3, 6, 2, 5, 7};
List<int> lastElements = l.TakeLast(3).ToList();
It works by using a ring buffer of size N to store the elements as it sees them, overwriting old elements with new ones. When the end of the enumerable is reached the ring buffer contains the last N elements.
If you don't mind dipping into Rx as part of the monad, you can use TakeLast
:
IEnumerable<int> source = Enumerable.Range(1, 10000);
IEnumerable<int> lastThree = source.AsObservable().TakeLast(3).AsEnumerable();