abetkin
8/18/2016 - 9:33 AM

fixtures.py

import sys
import os
import logging

from pony.py23compat import PY2
from ponytest import with_cli_args, pony_fixtures, ValidationError, Fixture, provider_validators, provider

from functools import wraps, partial
import click
from contextlib import contextmanager, closing


from pony.orm.dbproviders.mysql import mysql_module
from pony.utils import cached_property, class_property

if not PY2:
    from contextlib import contextmanager, ContextDecorator, ExitStack
else:
    from contextlib2 import contextmanager, ContextDecorator, ExitStack

import unittest

from pony.orm import db_session, Database, rollback, delete

if not PY2:
    from io import StringIO
else:
    from StringIO import StringIO

from multiprocessing import Process

import threading


class DBContext(ContextDecorator):

    __fixture__ = 'db'

    enabled = False

    def __init__(self, Test):
        if not isinstance(Test, type):
            TestCls = type(Test)
            NewClass = type(TestCls.__name__, (TestCls,), {})
            NewClass.__module__ = TestCls.__module__
            NewClass.db = property(lambda t: self.db)
            Test.__class__ = NewClass
        else:
            Test.db = class_property(lambda cls: self.db)
        Test.db_provider = self.provider
        self.Test = Test

    @class_property
    def fixture_name(cls):
        return cls.provider

    @class_property
    def provider(cls):
        # is used in tests
        return cls.PROVIDER

    def init_db(self):
        raise NotImplementedError

    @cached_property
    def db(self):
        raise NotImplementedError

    def __enter__(self):
        self.init_db()
        try:
            self.Test.make_entities()
        except (AttributeError, TypeError):
            # No method make_entities with due signature
            pass
        else:
            self.db.generate_mapping(check_tables=True, create_tables=True)
        return self.db

    def __exit__(self, *exc_info):
        self.db.provider.disconnect()

    # @classmethod
    # @with_cli_args
    # @click.option('--db', '-d', 'database', multiple=True)
    # @click.option('--exclude-db', '-e', multiple=True)
    # def cli(cls, database, exclude_db):
    #     fixture = [
    #         MySqlContext, OracleContext, SqliteContext, PostgresContext,
    #         SqlServerContext,
    #     ]
    #     all_db = [ctx.provider for ctx in fixture]
    #     for db in database:
    #         if db == 'all':
    #             continue
    #         assert db in all_db, (
    #             "Unknown provider: %s. Use one of %s." % (db, ', '.join(all_db))
    #         )
    #     if 'all' in database:
    #         database = all_db
    #     elif exclude_db and not database:
    #         database = set(all_db) - set(exclude_db)
    #     elif not database:
    #         database = ['sqlite']
    #     for Ctx in fixture:
    #         if Ctx.provider in database:
    #             yield Ctx

    db_name = 'testdb'

# class DbFixture(Fixture):
#     __key__ = 'db'


# class GenerateMapping(Fixture):
#     __key__ = 'generate_mapping'


@provider()
class GenerateMapping(ContextDecorator):

    weight = 200
    scope = 'class'
    __fixture__ = 'generate_mapping'

    def __init__(self, Test):
        self.Test = Test

    def __enter__(self):
        db = getattr(self.Test, 'db', None)
        if not db or not db.entities:
            return
        for entity in db.entities.values():
            if entity._database_.schema is None:
                db.generate_mapping(check_tables=True, create_tables=True)
            break

    def __exit__(self, *exc_info):
        pass

@provider()
class MySqlContext(DBContext):
    PROVIDER  = 'mysql'

    def drop_db(self, cursor):
        cursor.execute('use sys')
        cursor.execute('drop database %s' % self.db_name)


    def init_db(self):
        with closing(mysql_module.connect(**self.CONN).cursor()) as c:
            try:
                self.drop_db(c)
            except mysql_module.DatabaseError as exc:
                print('Failed to drop db: %s' % exc)
            c.execute('create database %s' % self.db_name)
            c.execute('use %s' % self.db_name)

    CONN = {
        'host': "localhost",
        'user': "ponytest",
        'passwd': "ponytest",
    }

    @cached_property
    def db(self):
        CONN = dict(self.CONN, db=self.db_name)
        return Database('mysql', **CONN)

@provider()
class SqlServerContext(DBContext):

    PROVIDER = 'sqlserver'

    def get_conn_string(self, db=None):
        s = (
            'DSN=MSSQLdb;'
            'SERVER=mssql;'
            'UID=sa;'
            'PWD=pass;'
        )
        if db:
            s += 'DATABASE=%s' % db
        return s

    @cached_property
    def db(self):
        CONN = self.get_conn_string(self.db_name)
        return Database('mssqlserver', CONN)

    def init_db(self):
        import pyodbc
        cursor = pyodbc.connect(self.get_conn_string(), autocommit=True).cursor()
        with closing(cursor) as c:
            try:
                self.drop_db(c)
            except pyodbc.DatabaseError as exc:
                print('Failed to drop db: %s' % exc)
            c.execute('create database %s' % self.db_name)
            c.execute('use %s' % self.db_name)

    def drop_db(self, cursor):
        cursor.execute('use master')
        cursor.execute('drop database %s' % self.db_name)


@provider()
class SqliteContext(DBContext):
    PROVIDER = 'sqlite'
    enabled = True

    def init_db(self):
        try:
            os.remove(self.db_path)
        except OSError as exc:
            print('Failed to drop db: %s' % exc)


    @cached_property
    def db_path(self):
        p = os.path.dirname(__file__)
        p = os.path.join(p, '%s.sqlite' % self.db_name)
        return os.path.abspath(p)

    @cached_property
    def db(self):
        return Database('sqlite', self.db_path, create_db=True)


@provider()
class PostgresContext(DBContext):
    PROVIDER = 'postgresql'

    def get_conn_dict(self, no_db=False):
        d = dict(
            user='ponytest', password='ponytest',
            host='localhost', database='postgres',
        )
        if not no_db:
            d.update(database=self.db_name)
        return d

    def init_db(self):
        import psycopg2
        conn = psycopg2.connect(
            **self.get_conn_dict(no_db=True)
        )
        conn.set_isolation_level(0)
        with closing(conn.cursor()) as cursor:
            try:
                self.drop_db(cursor)
            except psycopg2.DatabaseError as exc:
                print('Failed to drop db: %s' % exc)
            cursor.execute('create database %s' % self.db_name)

    def drop_db(self, cursor):
        cursor.execute('drop database %s' % self.db_name)


    @cached_property
    def db(self):
        return Database('postgres', **self.get_conn_dict())


@provider()
class OracleContext(DBContext):
    PROVIDER = 'oracle'

    def __enter__(self):
        os.environ.update(dict(
            ORACLE_BASE='/u01/app/oracle',
            ORACLE_HOME='/u01/app/oracle/product/12.1.0/dbhome_1',
            ORACLE_OWNR='oracle',
            ORACLE_SID='orcl',
        ))
        return super(OracleContext, self).__enter__()

    def init_db(self):

        import cx_Oracle
        with closing(self.connect_sys()) as conn:
            with closing(conn.cursor()) as cursor:
                try:
                    self._destroy_test_user(cursor)
                except cx_Oracle.DatabaseError as exc:
                    print('Failed to drop user: %s' % exc)
                try:
                    self._drop_tablespace(cursor)
                except cx_Oracle.DatabaseError as exc:
                    print('Failed to drop db: %s' % exc)
                cursor.execute(
                """CREATE TABLESPACE %(tblspace)s
                DATAFILE '%(datafile)s' SIZE 20M
                REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize)s
                """ % self.parameters)
                cursor.execute(
                """CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
                TEMPFILE '%(datafile_tmp)s' SIZE 20M
                REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize_tmp)s
                """ % self.parameters)
                self._create_test_user(cursor)


    def _drop_tablespace(self, cursor):
        cursor.execute(
            'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS'
        % self.parameters)
        cursor.execute(
            'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS'
        % self.parameters)


    parameters = {
        'tblspace': 'test_tblspace',
        'tblspace_temp': 'test_tblspace_temp',
        'datafile': 'test_datafile.dbf',
        'datafile_tmp': 'test_datafile_tmp.dbf',
        'user': 'ponytest',
        'password': 'ponytest',
        'maxsize': '100M',
        'maxsize_tmp': '100M',
    }

    def connect_sys(self):
        import cx_Oracle
        return cx_Oracle.connect('sys/the@localhost/ORCL', mode=cx_Oracle.SYSDBA)

    def connect_test(self):
        import cx_Oracle
        return cx_Oracle.connect('ponytest/ponytest@localhost/ORCL')


    @cached_property
    def db(self):
        return Database('oracle', 'ponytest/ponytest@localhost/ORCL')

    def _create_test_user(self, cursor):
        cursor.execute(
        """CREATE USER %(user)s
            IDENTIFIED BY %(password)s
            DEFAULT TABLESPACE %(tblspace)s
            TEMPORARY TABLESPACE %(tblspace_temp)s
            QUOTA UNLIMITED ON %(tblspace)s
        """ % self.parameters
        )
        cursor.execute(
        """GRANT CREATE SESSION,
                    CREATE TABLE,
                    CREATE SEQUENCE,
                    CREATE PROCEDURE,
                    CREATE TRIGGER
            TO %(user)s
        """ % self.parameters
        )

    def _destroy_test_user(self, cursor):
        cursor.execute('''
            DROP USER %(user)s CASCADE
        ''' % self.parameters)


@provider(__fixture__='log', weight=100, enabled=False)
@contextmanager
def logging_context(test):
    level = logging.getLogger().level
    from pony.orm.core import debug, sql_debug
    logging.getLogger().setLevel(logging.INFO)
    sql_debug(True)
    yield
    logging.getLogger().setLevel(level)
    sql_debug(debug)

# @provider('log_all', scope='class', weight=-100, enabled=False)
# def log_all(Test):
#     return logging_context(Test)



# @with_cli_args
# @click.option('--log', 'scope', flag_value='test')
# @click.option('--log-all', 'scope', flag_value='all')
# def use_logging(scope):
#     if scope == 'test':
#         yield logging_context
#     elif scope =='all':
#         yield log_all




@provider()
class DBSessionProvider(object):

    __fixture__= 'db_session'

    weight = 30

    def __new__(cls, test):
        return db_session


@provider(__fixture__='rollback', weight=40)
@contextmanager
def do_rollback(test):
    try:
        yield
    finally:
        rollback()


@provider()
class SeparateProcess(object):

    # TODO read failures from sep process better

    __fixture__ = 'separate_process'

    enabled = False

    scope = 'class'

    def __init__(self, Test):
        self.Test = Test

    def __call__(self, func):
        def wrapper(Test):
            rnr = unittest.runner.TextTestRunner()
            TestCls = Test if isinstance(Test, type) else type(Test)
            def runTest(self):
                try:
                    func(Test)
                finally:
                    rnr.stream = unittest.runner._WritelnDecorator(StringIO())
            name = getattr(func, '__name__', 'runTest')
            Case = type(TestCls.__name__, (TestCls,), {name: runTest})
            Case.__module__ = TestCls.__module__
            case = Case(name)
            suite = unittest.suite.TestSuite([case])
            def run():
                result = rnr.run(suite)
                if not result.wasSuccessful():
                    sys.exit(1)
            p = Process(target=run, args=())
            p.start()
            p.join()
            case.assertEqual(p.exitcode, 0)
        return wrapper

    @classmethod
    def validate_chain(cls, fixtures, klass):
        for f in fixtures:
            if f.KEY in ('ipdb', 'ipdb_all'):
                return False
        for f in fixtures:
            if f.KEY == 'db' and f.PROVIDER in ('sqlserver', 'oracle'):
                return True

@provider()
class ClearTables(ContextDecorator):

    __fixture__ = 'clear_tables'

    def __init__(self, test):
        self.test = test

    def __enter__(self):
        pass

    @db_session
    def __exit__(self, *exc_info):
        db = self.test.db
        for entity in db.entities.values():
            if entity._database_.schema is None:
                break
            delete(i for i in entity)


@provider()
class NoJson1(ContextDecorator):

    __fixture__ = 'no_json1'

    def __init__(self, cls):
        self.Test = cls
        cls.no_json1 = True

    fixture_name = 'no_json1'

    def __enter__(self):
        self.json1_available = self.Test.db.provider.json1_available
        self.Test.db.provider.json1_available = False

    def __exit__(self, *exc_info):
        self.Test.db.provider.json1_available = self.json1_available

    scope = 'class'

    @classmethod
    def validate_chain(cls, fixtures, klass):
        for f in fixtures:
            if f.KEY in ('ipdb', 'ipdb_all'):
                return False
        for f in fixtures:
            if f.KEY == 'db' and f.PROVIDER in ('sqlserver', 'oracle'):
                return True


import signal

@provider()
class Timeout(object):

    __fixture__ = 'timeout'

    @with_cli_args
    @click.option('--timeout', type=int)
    def __init__(self, Test, timeout):
        self.Test = Test
        self.timeout = timeout if timeout else Test.TIMEOUT

    scope = 'class'
    enabled = False

    class Exception(Exception):
        pass

    class FailedInSubprocess(Exception):
        pass

    def __call__(self, func):
        def wrapper(test):
            p = Process(target=func, args=(test,))
            p.start()

            def on_expired():
                p.terminate()

            t = threading.Timer(self.timeout, on_expired)
            t.start()
            p.join()
            t.cancel()
            if p.exitcode == -signal.SIGTERM:
                raise self.Exception
            elif p.exitcode:
                raise self.FailedInSubprocess

        return wrapper

    @classmethod
    @with_cli_args
    @click.option('--timeout', type=int)
    def validate_chain(cls, fixtures, klass, timeout):
        if not getattr(klass, 'TIMEOUT', None) and not timeout:
            return False
        for f in fixtures:
            if f.KEY in ('ipdb', 'ipdb_all'):
                return False
        for f in fixtures:
            if f.KEY == 'db' and f.PROVIDER in ('sqlserver', 'oracle'):
                return True


pony_fixtures['test'].extend([
    'log',
    'clear_tables',
    'db_session',
])

pony_fixtures['class'].extend([
    'separate_process',
    'timeout',
    'db',
    'generate_mapping',
])

def db_is_present(providers, config):
    return providers

provider_validators.update({
    'db': db_is_present,
})