AndersonFirmino
1/30/2011 - 5:03 AM

Tail-Recursion helper in Python

Tail-Recursion helper in Python

#!/usr/bin/env python

"""
Tail-Recursion helper in Python.

Inspired by the trampoline function at 
http://jasonmbaker.com/tail-recursion-in-python-using-pysistence

Tail-recursive functions return calls to tail-recursive functions
(themselves, most of the time). For example, this is tail-recursive:

sum [] acc = acc
sum (x:xs) = sum xs (acc+x)

And this is not:
fib n | n == 0 || n == 1 = 1
      | otherwise = (fib (n-1)) + (fib (n-2))

because fib n returns an application of (+), not directly of fib. 

Suppose we wanted to write sum in Python like we could in Haskell:
"""

# iterator must have a has_next method...
def nontrampsum(iterator, accumulator):
    if not iterator.has_next():
        return accumulator
    else:
        head = iterator.next()
        accumulator += head
        return nontrampsum(iterator, accumulator)

"""
It looks elegant, but would blow up the stack pretty quickly.  
Python will fully evaluate the recurisive call before returning, unlike
lazier Haskell. 

We'll need some help:
"""

# Factory for consuming tail-recursive functions
# that return partially applied TR functions
def trampoline(f, *args, **kwargs):
    def trampolined_f(*args, **kwargs):
        result = f(*args, **kwargs) 
        while callable(result):
            result = result()
        return result
    return trampolined_f

# Creates a 'suspension' of f
# Rreturns a function of zero-arity
# functools.partial does more though... 
def partial(f, *args, **kwargs):
    def partial_f():
        return f(*args, **kwargs)
    return partial_f

"""
First, we can make our tail-recursive function not directly call itself, 
but instead return a closure in which it is applied.  Then we'll decorate
it with trampoline to call the suspensions it returns until the base case of
the recursion is reached.
"""

def trampsum_inner(iterator, acc):
    if not iterator.has_next():
        return acc
    else:
        head = iterator.next()
        acc += head
        return partial(trampsum_inner, iterator, acc)
trampsum = trampoline(trampsum_inner)

"""    
And a digression: we'll need to define an iterator with a has_next method.

I'd like to be able to pattern-match on iterators like lists in Haskell
sum [] acc = acc
sum (x:xs) = sum xs (acc+x)

We can just wrap an iterator and look ahead lazily.
"""

import collections

class LookAheadIterator(collections.Iterator):
    def __init__(self, wrapped):
        self._wrapped = iter(wrapped)
        self._need_to_advance = True
        self._has_next = False
        self._cache = None

    def has_next(self):
        if self._need_to_advance:
            self._advance()
        return self._has_next

    def _advance(self):
        try:
            self._cache = self._wrapped.next()
            self._has_next = True
        except StopIteration:
            self._has_next = False
        self._need_to_advance = False

    def next(self):
        if self._need_to_advance:
            self._advance()
        if self._has_next:
            self._need_to_advance = True
            return self._cache
        else:
            raise StopIteration()

    def __next__(self):
        self.next()

"""
Let's prove (sadly) that it's not the speediest:
"""

import cProfile

def test(f):
    iterator = LookAheadIterator(xrange(1000000))
    accumulator = 0
    print f(iterator, accumulator)

print "Summing with built-in sum" 
cProfile.run('test(sum)')

print "Summing with trampolined sum"
cProfile.run('test(trampsum)')

"""
499999500000
         2000009 function calls in 2.254 CPU seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    2.254    2.254 <string>:1(<module>)
        1    0.000    0.000    0.000    0.000 _abcoll.py:66(__iter__)
  1000001    0.909    0.000    1.747    0.000 cool.py:107(next)
        1    0.000    0.000    2.254    2.254 cool.py:125(test)
        1    0.000    0.000    0.000    0.000 cool.py:88(__init__)
  1000001    0.838    0.000    0.838    0.000 cool.py:99(_advance)
        1    0.000    0.000    0.000    0.000 {iter}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        1    0.506    0.506    2.254    2.254 {sum}


499999500000
         7000010 function calls in 6.139 CPU seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    6.139    6.139 <string>:1(<module>)
  1000000    0.588    0.000    0.588    0.000 cool.py:107(next)
        1    0.000    0.000    6.139    6.139 cool.py:125(test)
        1    0.761    0.761    6.139    6.139 cool.py:44(trampolined_f)
  1000000    0.358    0.000    0.358    0.000 cool.py:54(partial)
  1000000    0.784    0.000    5.242    0.000 cool.py:55(partial_f)
  1000001    1.763    0.000    4.457    0.000 cool.py:66(trampsum_inner)
        1    0.000    0.000    0.000    0.000 cool.py:88(__init__)
  1000001    0.736    0.000    1.748    0.000 cool.py:94(has_next)
  1000001    1.012    0.000    1.012    0.000 cool.py:99(_advance)
  1000001    0.136    0.000    0.136    0.000 {callable}
        1    0.000    0.000    0.000    0.000 {iter}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}


But it won't topple the stack:
"""

#print cProfile.run('test(nontrampsum)')
# RuntimeError: maximum recursion depth exceeded while calling a Python object

"""
We could fold with it:
"""

def foldl_inner(f, accumulator, iterator):
    if not iterator.has_next():
        return accumulator
    else:
        head = iterator.next()
        accumulator = f(accumulator, head)
        return partial(foldl_inner, f, accumulator, iterator)
foldl = trampoline(foldl_inner)

def add(a, b): return a + b

def foldlsum(iterator, accumulator):
    return foldl(add, accumulator, iterator)

print "Summing with trampolined foldl"
cProfile.run('test(foldlsum)')

"""
499999500000
         8000011 function calls in 6.874 CPU seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    6.874    6.874 <string>:1(<module>)
  1000000    0.601    0.000    0.601    0.000 cool.py:107(next)
        1    0.000    0.000    6.874    6.874 cool.py:125(test)
  1000001    2.159    0.000    5.109    0.000 cool.py:187(foldl_inner)
  1000000    0.206    0.000    0.206    0.000 cool.py:196(add)
        1    0.000    0.000    6.874    6.874 cool.py:198(foldlsum)
        1    0.824    0.824    6.874    6.874 cool.py:44(trampolined_f)
  1000000    0.389    0.000    0.389    0.000 cool.py:54(partial)
  1000000    0.819    0.000    5.927    0.000 cool.py:55(partial_f)
        1    0.000    0.000    0.000    0.000 cool.py:88(__init__)
  1000001    0.750    0.000    1.754    0.000 cool.py:94(has_next)
  1000001    1.004    0.000    1.004    0.000 cool.py:99(_advance)
  1000001    0.122    0.000    0.122    0.000 {callable}
        1    0.000    0.000    0.000    0.000 {iter}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}


Use it wisely, I guess.  (Or not at all.)  Look into functools.partial too.
"""