pmunin
2/28/2017 - 2:08 AM

CartesianProduct extensions

CartesianProduct extensions

//latest version is here: https://gist.github.com/7d99ff883fdf2ddfe968f6f9c57b1d4a.git
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using PriorityQueueUtils; //from https://gist.github.com/affc23dd89950e67ece9ca3b19b9508a.git
using CachingEnumeratorUtils; //from https://gist.github.com/b1059488d17c52da5a732a100ba6e09f.git

namespace CartesianProductUtils
{
    public class CartesianProduct
    {
        /// <summary>
        /// Generate sorted cartasian product using sorted countable enumerables using method similar to gradient descent
        /// </summary>
        /// <typeparam name="T1"></typeparam>
        /// <typeparam name="T2"></typeparam>
        /// <typeparam name="TResult"></typeparam>
        /// <param name="sortedItems1">sorted items</param>
        /// <param name="sortedItems2">sorted items</param>
        /// <param name="combine">combine function that generate result</param>
        /// <param name="resultComparer">result comparer</param>
        /// <param name="startPoint">position which should contain the first MINIMUM value to yield</param>
        /// <returns></returns>
        static IEnumerable<TResult> CartesianProductSorted<T1, T2, TResult>(
            IEnumerableCountable<T1> sortedItems1,
            IEnumerableCountable<T2> sortedItems2,
            Func<T1, T2, TResult> combine,
            IComparer<TResult> resultComparer = null,
            Tuple<int, int> startPoint = null//default[0,0]
            )
        {
            return IterateCartesianProduct<T1, T2, TResult>(sortedItems1, sortedItems2
                , combine, resultComparer, startPoint).AsCountable(sortedItems1.Count * sortedItems2.Count);
        }

        static IEnumerable<TResult> IterateCartesianProduct<T1, T2, TResult>(
            IEnumerableCountable<T1> sortedItems1,
            IEnumerableCountable<T2> sortedItems2,
            Func<T1, T2, TResult> combine,
            IComparer<TResult> resultComparer = null,
            Tuple<int, int> startPoint = null//default[0,0]
            )
        {
            //use algorithm similar to gradient descent: 
            // http://stackoverflow.com/questions/4279524/how-to-sort-a-m-x-n-matrix-which-has-all-its-m-rows-sorted-and-n-columns-sorted

            startPoint = startPoint ?? Tuple.Create(0, 0);
            var startNode = new { Position = startPoint, Result = combine(sortedItems1.ElementAt(startPoint.Item1), sortedItems2.ElementAt(startPoint.Item2)) };
            var visitedNodes = new HashSet<Tuple<int, int>>(new[] { startPoint });
            var queue = PriorityQueue.Create(new[] { startNode }, n => n.Result, resultComparer);

            while (!queue.IsEmpty)
            {
                var current = queue.Dequeue();
                yield return current.Result;
                var coords = current.Position;

                var newNeightborCoordsToQueue = new[] {
                    Tuple.Create(coords.Item1 - 1, coords.Item2),//Left
                    Tuple.Create(coords.Item1 + 1, coords.Item2),//Right
                    Tuple.Create(coords.Item1, coords.Item2-1), //Up
                    Tuple.Create(coords.Item1, coords.Item2 + 1)//Down
                }
                .Where(
                    n => n.Item1 >= 0 && n.Item1 < sortedItems1.Count // fit the range of first sorted array
                    && n.Item2 >= 0 && n.Item2 < sortedItems2.Count //fit the range of second sorted array
                    && !visitedNodes.Contains(n)
                ).ToArray();

                var neighborsToQueue =
                 newNeightborCoordsToQueue.Select(c => new { Position = c, Result = combine(sortedItems1.ElementAt(c.Item1), sortedItems2.ElementAt(c.Item2)) });
                queue.AddRange(neighborsToQueue);
                foreach (var neighbor in newNeightborCoordsToQueue)
                {
                    visitedNodes.Add(neighbor);
                }
            }
        }

        class Coords : IEquatable<Coords>
        {
            class CoordsComparer : IEqualityComparer<Coords>
            {
                public bool Equals(Coords x, Coords y)
                {
                    return x.Equals(y);
                }

                public int GetHashCode(Coords obj)
                {
                    return obj.GetHashCode();
                }
            }

            public static readonly IEqualityComparer<Coords> Comparer = new CoordsComparer();

            public Coords(int[] position)
            {
                this.Position = position;
            }
            public int[] Position { get; set; }

            public bool Equals(Coords other)
            {
                return this.Position.SequenceEqual(other.Position);
            }

            public override bool Equals(object obj)
            {
                var other = obj as Coords;
                if (other != null)
                    return this.Equals(other: other);
                return base.Equals(obj);
            }

            public override int GetHashCode()
            {
                //from here: http://stackoverflow.com/a/3404820/508797
                var array = Position;
                int hc = array.Length;
                for (int i = 0; i < array.Length; ++i)
                {
                    hc = unchecked(hc * 314159 + array[i]);
                }
                return hc;
            }


            public static IEnumerable<int[]> GenerateNeighbors(int[] position, int[] maxPosition)
            {
                for (int i = 0; i < (position?.Length ?? 0); i++)
                {
                    var beforePosition = position.ToArray();
                    beforePosition[i]--;
                    if(beforePosition[i]>=0 && beforePosition[i]<=maxPosition[i])
                        yield return beforePosition;
                    var afterPosition = position.ToArray();
                    afterPosition[i]++;
                    if (afterPosition[i] >= 0 && afterPosition[i] <= maxPosition[i])
                        yield return afterPosition;
                }
            }
        }

        /// <summary>
        /// Generate sorted multidimensional cartesian product using algo similar to gradient descent
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <typeparam name="TResult"></typeparam>
        /// <param name="sortedItems">Multidimentional countable enum.
        /// Uses caching enumerable, that execute enumerator ONLY once and then caches the pregenerated items
        /// </param>
        /// <param name="combine">combine function</param>
        /// <param name="resultComparer">how results should be sorted</param>
        /// <param name="startPoint">start position which must have a minimum (first element)</param>
        /// <returns></returns>
        public static IEnumerable<TResult> CartesianProductSorted<T, TResult>(
            IEnumerableCountable<IEnumerableCountable<T>> sortedItems,
            Func<IEnumerable<T>, TResult> combine,
            IComparer<TResult> resultComparer = null,
            int[] startPoint = null//default[0,0,0...]
            )
        {
            return IterateCartesianProduct<T, TResult>(sortedItems
                , combine, resultComparer, startPoint)
                //.ToCountable(sortedItems.Aggregate(1, (s, a) => s * a.Count))
                ;
        }


        static IEnumerable<T> GetItemsByPosition<T>(IEnumerable<IEnumerable<T>> sortedItems, int[] position)
        {
            return sortedItems.Select((x, i) => 
            {
                var res = x.ElementAtOrDefault(position[i]);
                return res;
            });
        }


        private static IEnumerable<TResult> IterateCartesianProduct<T, TResult>(
            IEnumerable<IEnumerableCountable<T>> sortedItems
            , Func<IEnumerable<T>, TResult> combine
            , IComparer<TResult> resultComparer
            , int[] startPoint = null
            )
        {
            //use algorithm similar to gradient descent: 
            // http://stackoverflow.com/questions/4279524/how-to-sort-a-m-x-n-matrix-which-has-all-its-m-rows-sorted-and-n-columns-sorted

            if (startPoint == null)
                startPoint = sortedItems.Select(s => 0).ToArray();
            var maxPosition = sortedItems.Select(s => s.Count-1).ToArray();
            var cachedItems = sortedItems.Select(s => s.AsCaching()).ToArray();

            var startNode = new { Position = startPoint, Result = combine(GetItemsByPosition(cachedItems, startPoint) ) };
            var visitedNodes = new HashSet<Coords>(new[] { new Coords(startPoint) });
            var queue = PriorityQueue.Create(new[] { startNode }, n => n.Result, resultComparer);

            while (!queue.IsEmpty)
            {
                var current = queue.Dequeue();
                //var isValid = combineValidation?.Invoke(current.Result)??true;
                //if(isValid)
                    yield return current.Result;

                var coords = current.Position;
                var newNeightborCoordsToQueue = Coords.GenerateNeighbors(current.Position, maxPosition)
                    .Select(p=>new Coords(p))
                    .Where(c=>!visitedNodes.Contains(c))
                    .ToArray();

                var neighborsToQueue = newNeightborCoordsToQueue
                    .Select(c => new { Position = c.Position, Result = combine(GetItemsByPosition(cachedItems, c.Position)) })
                    ;
                //if (combineValidation != null)
                //    neighborsToQueue = neighborsToQueue.Where(c => combineValidation(c.Result));

                queue.AddRange(neighborsToQueue);
                foreach (var neighbor in newNeightborCoordsToQueue)
                {
                    visitedNodes.Add(neighbor);
                }
            }

        }
    }

    /// <summary>
    /// This is required in order to avoid iterating all items, when the amount is known in advance like it is in case of Cartesian
    /// </summary>
    /// <typeparam name="T"></typeparam>
    public interface IEnumerableCountable<out T> : IEnumerable<T>
    {
        int Count { get; }
    }
    public static class EnumerableCountable
    {
        class Wrapper<T> : IEnumerableCountable<T>
        {
            private IEnumerable<T> Items;
            private Func<int> getCount;

            static int GetKnownCount(IEnumerable<T> items)
            {
                switch (items)
                {
                    case IEnumerableCountable<T> countableT:
                        return countableT.Count;
                    case T[] array:
                        return array.Length;
                    case List<T> list:
                        return list.Count;
                    case LinkedList<T> linkedList:
                        return linkedList.Count;
                    case ICollection collection:
                        return collection.Count;
                    case ICollection<T> collectionT:
                        return collectionT.Count;
                    case IReadOnlyCollection<T> rocollectionT:
                        return rocollectionT.Count;
                    default:
                        return items.Count();
                }
            }

            public Wrapper(IEnumerable<T> items, Func<int> getCount)
            {
                this.getCount = getCount;
                this.Items = items;
            }

            public Wrapper(IEnumerable<T> items, int? knownCount)
            {
                this.Items = items;
                var count = knownCount ?? GetKnownCount(items);
                this.getCount = ()=>count;
            }
            public int Count { get { return getCount(); } }

            public IEnumerator<T> GetEnumerator()
            {
                return Items.GetEnumerator();
            }

            IEnumerator IEnumerable.GetEnumerator()
            {
                return Items.GetEnumerator();
            }
        }

        public static IEnumerableCountable<T> AsCountable<T>(this IEnumerable<T> items, int? knownCount = null)
        {
            if (items is IEnumerableCountable<T> res) return res;
            return new Wrapper<T>(items, knownCount);
        }

        public static IEnumerableCountable<T> AsCountable<T>(this IEnumerable<T> items, Func<int> getCount)
        {
            if (items is IEnumerableCountable<T> res) return res;
            return new Wrapper<T>(items, getCount);
        }
    }


}