xfan001
8/25/2016 - 4:06 PM

IO多路复用接口封装,事件循环

IO多路复用接口封装,事件循环

# -*- coding: utf-8 -*-

# 对select和epoll接口的同一封装
import time
import functools
import heapq
import select
import traceback
import errno


POLL_NULL = 0x00
POLL_IN = 0x01
POLL_OUT = 0x04
POLL_ERR = 0x08
POLL_HUP = 0x10
POLL_NVAL = 0x20

NONE = POLL_NULL
READ = POLL_IN
WRITE = POLL_OUT
ERROR = POLL_ERR | POLL_HUP

MAX_POLLTIME = 0.01


class _SelectLoop(object):
    """所有系统通用"""
    def __init__(self):
        self.read_fds = set()
        self.write_fds = set()
        self.error_fds = set()

    def register(self, fd, events):
        if fd in self.read_fds or fd in self.write_fds or fd in self.error_fds:
            raise IOError("fd %s already registered" % fd)
        if events & READ:
            self.read_fds.add(fd)
        if events & WRITE:
            self.write_fds.add(fd)
        if events & ERROR:
            self.error_fds.add(fd)

    def unregister(self, fd):
        self.read_fds.discard(fd)
        self.write_fds.discard(fd)
        self.error_fds.discard(fd)

    def modify(self, fd, events):
        self.unregister(fd)
        self.register(fd, events)

    def poll(self, timeout):
        readable, writeable, errors = select.select(self.read_fds, self.write_fds, self.error_fds, timeout)
        events = {}
        for fd in readable:
            events[fd] = events.get(fd, 0) | READ
        for fd in writeable:
            events[fd] = events.get(fd, 0) | WRITE
        for fd in errors:
            events[fd] = events.get(fd, 0) | ERROR
        return events.items()


class EventLoop(object):

    def __init__(self):
        if hasattr(select, 'epoll'):
            self._impl = select.epoll()
        elif hasattr(select, 'select'):
            self._impl = _SelectLoop()
        else:
            raise TypeError("cannot find functions in select")
        self._handlers = {}
        self._timers = []
        self._stop = False

    def add_handler(self, fd, handler, events):
        fd, obj = self._split_fd(fd)
        self._handlers[fd] = (obj, handler)
        self._impl.register(fd, events)

    def remove_handler(self, fd):
        fd, obj = self._split_fd(fd)
        del self._handlers[fd]
        self._impl.unregister(fd)

    def update_handler(self, fd, events):
        fd, obj = self._split_fd(fd)
        self._impl.modify(fd, events)

    def add_timeout(self, timeval, is_period, callback, *args, **kwargs):
        timer = _Timer(functools.partial(callback, *args, **kwargs), timeval, is_period)
        heapq.heappush(self._timers, timer)
        return timer

    def remove_timeout(self, timeout):
        timeout.callback = None

    def start(self):
        assert self._stop == False
        events = []
        while not self._stop:
            #store ready timer into due_timers
            due_timers = []
            if self._timers:
                now = time.time()
                while self._timers:
                    if self._timers[0].callback is None:
                        heapq.heappop(self._timers)
                    elif self._timers[0].deadline < now:
                        due_timers.append(heapq.heappop(self._timers))
                    else:
                        break
            #handle due_timers
            for timer in due_timers:
                if timer.callback is not None:
                    timer.callback()
                    timer.update()
                    #add again if timer is period task
                    if timer.callback:
                        heapq.heappush(self._timers, timer)
            #get timeout
            if self._timers:
                poll_timeout = self._timers[0].deadline - time.time()
                poll_timeout = max(0, min(poll_timeout, MAX_POLLTIME))
            else:
                poll_timeout = MAX_POLLTIME
            #start poll loop
            try:
                events = self._impl.poll(poll_timeout)
            except select.error, e:
                if errno_from_exception(e) == errno.EPIPE:
                    #EPIPE: client close socket
                    print e
                    continue
                elif errno_from_exception(e) == errno.EINTR:
                    #EINTR: receive signal
                    print e
                    continue
                else:
                    traceback.print_exc()
                    continue

            for fd, event in events:
                sock, handler_func = self._handlers.get(fd, (None, None))
                if handler_func:
                    try:
                        handler_func(sock, event)
                    except (OSError, IOError) as e:
                        traceback.print_exc()
                        break

    def stop(self):
        self._stop = True

    def time(self):
        return time.time()

    def _split_fd(self, fd):
        try:
            return fd.fileno(), fd
        except AttributeError:
            return fd, fd


class _Timer(object):
    def __init__(self, callback, timeval, is_period=False):
        self.callback = callback
        self.is_period = is_period
        if is_period:
            self.period = timeval
            self.deadline = time.time() + timeval
        else:
            self.deadline = timeval
    def update(self):
        if self.is_period:
            self.deadline = time.time() + self.period
        else:
            self.deadline = 0
            self.callback = None
    def __lt__(self, other):
        return self.deadline < other.deadline
    def __gt__(self, other):
        return self.deadline > other.deadline


# from tornado
def errno_from_exception(e):
    """Provides the errno from an Exception object.

    There are cases that the errno attribute was not set so we pull
    the errno out of the args but if someone instatiates an Exception
    without any args you will get a tuple error. So this function
    abstracts all that behavior to give you a safe way to get the
    errno.
    """
    if hasattr(e, 'errno'):
        return e.errno
    elif e.args:
        return e.args[0]
    else:
        return None


if __name__ == '__main__':

    import socket, sys
    port = int(sys.argv[1])
    listen_fd = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
    listen_fd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    listen_fd.bind(('0.0.0.0', port))
    listen_fd.listen(10240)
    listen_fd.setblocking(0)

    loop = EventLoop()
    def handler(sock, event):
        conn, address = sock.accept()
        print conn.recv(10240)
    loop.add_handler(listen_fd, handler, POLL_IN | POLL_OUT | POLL_ERR)
    loop.start()