CodyKochmann
8/20/2017 - 5:33 PM

restrict the calling of a function to a predetermined whitelist of functions in python

restrict the calling of a function to a predetermined whitelist of functions in python

# restrict the calling of a function to a predetermined whitelist of functions
# by: Cody Kochmann

In [52]: import traceback

In [53]: class authorized_callers:
    ...:     def __init__(self,*allowed_functions):
    ...:         assert all(callable(i) for i in allowed_functions)
    ...:         self.allowed_functions = allowed_functions
    ...:     def __call__(self,fn):
    ...:         def wrapper(*a,**k):
    ...:             caller = globals()[[ i.split('", line ')[1].split('\n')[0].split(' ')[-1] for i in traceback.format_stack() if (', in {}\n'.format(fn.__name__) not in i
    ...:  and ', in wrapper\n' not in i) ][-1]]
    ...:             assert caller in fn.authorized_callers, '{} was not in the authorized callers list, only these are allowed to call this function: {}'.format(caller, fn.
    ...: authorized_callers)
    ...:             return fn(*a,**k)
    ...:         fn.authorized_callers=self.allowed_functions
    ...:         wrapper.__name__ = fn.__name__
    ...:         return wrapper
    ...:

In [54]: def c():
    ...:     ''' calls a '''
    ...:     return a(6)
    ...:

In [55]: def b():
    ...:     ''' calls a '''
    ...:     return a(5)
    ...:

In [56]: @authorized_callers(b)
    ...: def a(i):
    ...:     return i*i
    ...:

In [57]: # at this point only b is authorized to call a

In [58]: b()
Out[58]: 25

In [59]: c()
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-59-1f2bdb17cf98> in <module>()
----> 1 c()

<ipython-input-54-8c3d65edf0d8> in c()
      1 def c():
      2     ''' calls a '''
----> 3     return a(6)

<ipython-input-53-a6e5bea086da> in wrapper(*a, **k)
      6         def wrapper(*a,**k):
      7             caller = globals()[[ i.split('", line ')[1].split('\n')[0].split(' ')[-1] for i in traceback.format_stack() if (', in {}\n'.format(fn.__name__) not in i and ', in wrapper\n' not in i) ][-1]]
----> 8             assert caller in fn.authorized_callers, '{} was not in the authorized callers list, only these are allowed to call this function: {}'.format(caller, fn.authorized_callers)
      9             return fn(*a,**k)
     10         fn.authorized_callers=self.allowed_functions

AssertionError: <function c at 0x10ba8a1e0> was not in the authorized callers list, only these are allowed to call this function: (<function b at 0x10b3adbf8>,)