import sys
import os
import logging
from pony.py23compat import PY2
from ponytest import with_cli_args, pony_fixtures, ValidationError, Fixture, provider_validators, provider
from functools import wraps, partial
import click
from contextlib import contextmanager, closing
from pony.orm.dbproviders.mysql import mysql_module
from pony.utils import cached_property, class_property
if not PY2:
    from contextlib import contextmanager, ContextDecorator, ExitStack
else:
    from contextlib2 import contextmanager, ContextDecorator, ExitStack
import unittest
from pony.orm import db_session, Database, rollback, delete
if not PY2:
    from io import StringIO
else:
    from StringIO import StringIO
from multiprocessing import Process
import threading
class DBContext(ContextDecorator):
    __fixture__ = 'db'
    enabled = False
    def __init__(self, Test):
        if not isinstance(Test, type):
            TestCls = type(Test)
            NewClass = type(TestCls.__name__, (TestCls,), {})
            NewClass.__module__ = TestCls.__module__
            NewClass.db = property(lambda t: self.db)
            Test.__class__ = NewClass
        else:
            Test.db = class_property(lambda cls: self.db)
        Test.db_provider = self.provider
        self.Test = Test
    @class_property
    def fixture_name(cls):
        return cls.provider
    @class_property
    def provider(cls):
        # is used in tests
        return cls.PROVIDER
    def init_db(self):
        raise NotImplementedError
    @cached_property
    def db(self):
        raise NotImplementedError
    def __enter__(self):
        self.init_db()
        try:
            self.Test.make_entities()
        except (AttributeError, TypeError):
            # No method make_entities with due signature
            pass
        else:
            self.db.generate_mapping(check_tables=True, create_tables=True)
        return self.db
    def __exit__(self, *exc_info):
        self.db.provider.disconnect()
    # @classmethod
    # @with_cli_args
    # @click.option('--db', '-d', 'database', multiple=True)
    # @click.option('--exclude-db', '-e', multiple=True)
    # def cli(cls, database, exclude_db):
    #     fixture = [
    #         MySqlContext, OracleContext, SqliteContext, PostgresContext,
    #         SqlServerContext,
    #     ]
    #     all_db = [ctx.provider for ctx in fixture]
    #     for db in database:
    #         if db == 'all':
    #             continue
    #         assert db in all_db, (
    #             "Unknown provider: %s. Use one of %s." % (db, ', '.join(all_db))
    #         )
    #     if 'all' in database:
    #         database = all_db
    #     elif exclude_db and not database:
    #         database = set(all_db) - set(exclude_db)
    #     elif not database:
    #         database = ['sqlite']
    #     for Ctx in fixture:
    #         if Ctx.provider in database:
    #             yield Ctx
    db_name = 'testdb'
# class DbFixture(Fixture):
#     __key__ = 'db'
# class GenerateMapping(Fixture):
#     __key__ = 'generate_mapping'
@provider()
class GenerateMapping(ContextDecorator):
    weight = 200
    scope = 'class'
    __fixture__ = 'generate_mapping'
    def __init__(self, Test):
        self.Test = Test
    def __enter__(self):
        db = getattr(self.Test, 'db', None)
        if not db or not db.entities:
            return
        for entity in db.entities.values():
            if entity._database_.schema is None:
                db.generate_mapping(check_tables=True, create_tables=True)
            break
    def __exit__(self, *exc_info):
        pass
@provider()
class MySqlContext(DBContext):
    PROVIDER  = 'mysql'
    def drop_db(self, cursor):
        cursor.execute('use sys')
        cursor.execute('drop database %s' % self.db_name)
    def init_db(self):
        with closing(mysql_module.connect(**self.CONN).cursor()) as c:
            try:
                self.drop_db(c)
            except mysql_module.DatabaseError as exc:
                print('Failed to drop db: %s' % exc)
            c.execute('create database %s' % self.db_name)
            c.execute('use %s' % self.db_name)
    CONN = {
        'host': "localhost",
        'user': "ponytest",
        'passwd': "ponytest",
    }
    @cached_property
    def db(self):
        CONN = dict(self.CONN, db=self.db_name)
        return Database('mysql', **CONN)
@provider()
class SqlServerContext(DBContext):
    PROVIDER = 'sqlserver'
    def get_conn_string(self, db=None):
        s = (
            'DSN=MSSQLdb;'
            'SERVER=mssql;'
            'UID=sa;'
            'PWD=pass;'
        )
        if db:
            s += 'DATABASE=%s' % db
        return s
    @cached_property
    def db(self):
        CONN = self.get_conn_string(self.db_name)
        return Database('mssqlserver', CONN)
    def init_db(self):
        import pyodbc
        cursor = pyodbc.connect(self.get_conn_string(), autocommit=True).cursor()
        with closing(cursor) as c:
            try:
                self.drop_db(c)
            except pyodbc.DatabaseError as exc:
                print('Failed to drop db: %s' % exc)
            c.execute('create database %s' % self.db_name)
            c.execute('use %s' % self.db_name)
    def drop_db(self, cursor):
        cursor.execute('use master')
        cursor.execute('drop database %s' % self.db_name)
@provider()
class SqliteContext(DBContext):
    PROVIDER = 'sqlite'
    enabled = True
    def init_db(self):
        try:
            os.remove(self.db_path)
        except OSError as exc:
            print('Failed to drop db: %s' % exc)
    @cached_property
    def db_path(self):
        p = os.path.dirname(__file__)
        p = os.path.join(p, '%s.sqlite' % self.db_name)
        return os.path.abspath(p)
    @cached_property
    def db(self):
        return Database('sqlite', self.db_path, create_db=True)
@provider()
class PostgresContext(DBContext):
    PROVIDER = 'postgresql'
    def get_conn_dict(self, no_db=False):
        d = dict(
            user='ponytest', password='ponytest',
            host='localhost', database='postgres',
        )
        if not no_db:
            d.update(database=self.db_name)
        return d
    def init_db(self):
        import psycopg2
        conn = psycopg2.connect(
            **self.get_conn_dict(no_db=True)
        )
        conn.set_isolation_level(0)
        with closing(conn.cursor()) as cursor:
            try:
                self.drop_db(cursor)
            except psycopg2.DatabaseError as exc:
                print('Failed to drop db: %s' % exc)
            cursor.execute('create database %s' % self.db_name)
    def drop_db(self, cursor):
        cursor.execute('drop database %s' % self.db_name)
    @cached_property
    def db(self):
        return Database('postgres', **self.get_conn_dict())
@provider()
class OracleContext(DBContext):
    PROVIDER = 'oracle'
    def __enter__(self):
        os.environ.update(dict(
            ORACLE_BASE='/u01/app/oracle',
            ORACLE_HOME='/u01/app/oracle/product/12.1.0/dbhome_1',
            ORACLE_OWNR='oracle',
            ORACLE_SID='orcl',
        ))
        return super(OracleContext, self).__enter__()
    def init_db(self):
        import cx_Oracle
        with closing(self.connect_sys()) as conn:
            with closing(conn.cursor()) as cursor:
                try:
                    self._destroy_test_user(cursor)
                except cx_Oracle.DatabaseError as exc:
                    print('Failed to drop user: %s' % exc)
                try:
                    self._drop_tablespace(cursor)
                except cx_Oracle.DatabaseError as exc:
                    print('Failed to drop db: %s' % exc)
                cursor.execute(
                """CREATE TABLESPACE %(tblspace)s
                DATAFILE '%(datafile)s' SIZE 20M
                REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize)s
                """ % self.parameters)
                cursor.execute(
                """CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
                TEMPFILE '%(datafile_tmp)s' SIZE 20M
                REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize_tmp)s
                """ % self.parameters)
                self._create_test_user(cursor)
    def _drop_tablespace(self, cursor):
        cursor.execute(
            'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS'
        % self.parameters)
        cursor.execute(
            'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS'
        % self.parameters)
    parameters = {
        'tblspace': 'test_tblspace',
        'tblspace_temp': 'test_tblspace_temp',
        'datafile': 'test_datafile.dbf',
        'datafile_tmp': 'test_datafile_tmp.dbf',
        'user': 'ponytest',
        'password': 'ponytest',
        'maxsize': '100M',
        'maxsize_tmp': '100M',
    }
    def connect_sys(self):
        import cx_Oracle
        return cx_Oracle.connect('sys/the@localhost/ORCL', mode=cx_Oracle.SYSDBA)
    def connect_test(self):
        import cx_Oracle
        return cx_Oracle.connect('ponytest/ponytest@localhost/ORCL')
    @cached_property
    def db(self):
        return Database('oracle', 'ponytest/ponytest@localhost/ORCL')
    def _create_test_user(self, cursor):
        cursor.execute(
        """CREATE USER %(user)s
            IDENTIFIED BY %(password)s
            DEFAULT TABLESPACE %(tblspace)s
            TEMPORARY TABLESPACE %(tblspace_temp)s
            QUOTA UNLIMITED ON %(tblspace)s
        """ % self.parameters
        )
        cursor.execute(
        """GRANT CREATE SESSION,
                    CREATE TABLE,
                    CREATE SEQUENCE,
                    CREATE PROCEDURE,
                    CREATE TRIGGER
            TO %(user)s
        """ % self.parameters
        )
    def _destroy_test_user(self, cursor):
        cursor.execute('''
            DROP USER %(user)s CASCADE
        ''' % self.parameters)
@provider(__fixture__='log', weight=100, enabled=False)
@contextmanager
def logging_context(test):
    level = logging.getLogger().level
    from pony.orm.core import debug, sql_debug
    logging.getLogger().setLevel(logging.INFO)
    sql_debug(True)
    yield
    logging.getLogger().setLevel(level)
    sql_debug(debug)
# @provider('log_all', scope='class', weight=-100, enabled=False)
# def log_all(Test):
#     return logging_context(Test)
# @with_cli_args
# @click.option('--log', 'scope', flag_value='test')
# @click.option('--log-all', 'scope', flag_value='all')
# def use_logging(scope):
#     if scope == 'test':
#         yield logging_context
#     elif scope =='all':
#         yield log_all
@provider()
class DBSessionProvider(object):
    __fixture__= 'db_session'
    weight = 30
    def __new__(cls, test):
        return db_session
@provider(__fixture__='rollback', weight=40)
@contextmanager
def do_rollback(test):
    try:
        yield
    finally:
        rollback()
@provider()
class SeparateProcess(object):
    # TODO read failures from sep process better
    __fixture__ = 'separate_process'
    enabled = False
    scope = 'class'
    def __init__(self, Test):
        self.Test = Test
    def __call__(self, func):
        def wrapper(Test):
            rnr = unittest.runner.TextTestRunner()
            TestCls = Test if isinstance(Test, type) else type(Test)
            def runTest(self):
                try:
                    func(Test)
                finally:
                    rnr.stream = unittest.runner._WritelnDecorator(StringIO())
            name = getattr(func, '__name__', 'runTest')
            Case = type(TestCls.__name__, (TestCls,), {name: runTest})
            Case.__module__ = TestCls.__module__
            case = Case(name)
            suite = unittest.suite.TestSuite([case])
            def run():
                result = rnr.run(suite)
                if not result.wasSuccessful():
                    sys.exit(1)
            p = Process(target=run, args=())
            p.start()
            p.join()
            case.assertEqual(p.exitcode, 0)
        return wrapper
    @classmethod
    def validate_chain(cls, fixtures, klass):
        for f in fixtures:
            if f.KEY in ('ipdb', 'ipdb_all'):
                return False
        for f in fixtures:
            if f.KEY == 'db' and f.PROVIDER in ('sqlserver', 'oracle'):
                return True
@provider()
class ClearTables(ContextDecorator):
    __fixture__ = 'clear_tables'
    def __init__(self, test):
        self.test = test
    def __enter__(self):
        pass
    @db_session
    def __exit__(self, *exc_info):
        db = self.test.db
        for entity in db.entities.values():
            if entity._database_.schema is None:
                break
            delete(i for i in entity)
@provider()
class NoJson1(ContextDecorator):
    __fixture__ = 'no_json1'
    def __init__(self, cls):
        self.Test = cls
        cls.no_json1 = True
    fixture_name = 'no_json1'
    def __enter__(self):
        self.json1_available = self.Test.db.provider.json1_available
        self.Test.db.provider.json1_available = False
    def __exit__(self, *exc_info):
        self.Test.db.provider.json1_available = self.json1_available
    scope = 'class'
    @classmethod
    def validate_chain(cls, fixtures, klass):
        for f in fixtures:
            if f.KEY in ('ipdb', 'ipdb_all'):
                return False
        for f in fixtures:
            if f.KEY == 'db' and f.PROVIDER in ('sqlserver', 'oracle'):
                return True
import signal
@provider()
class Timeout(object):
    __fixture__ = 'timeout'
    @with_cli_args
    @click.option('--timeout', type=int)
    def __init__(self, Test, timeout):
        self.Test = Test
        self.timeout = timeout if timeout else Test.TIMEOUT
    scope = 'class'
    enabled = False
    class Exception(Exception):
        pass
    class FailedInSubprocess(Exception):
        pass
    def __call__(self, func):
        def wrapper(test):
            p = Process(target=func, args=(test,))
            p.start()
            def on_expired():
                p.terminate()
            t = threading.Timer(self.timeout, on_expired)
            t.start()
            p.join()
            t.cancel()
            if p.exitcode == -signal.SIGTERM:
                raise self.Exception
            elif p.exitcode:
                raise self.FailedInSubprocess
        return wrapper
    @classmethod
    @with_cli_args
    @click.option('--timeout', type=int)
    def validate_chain(cls, fixtures, klass, timeout):
        if not getattr(klass, 'TIMEOUT', None) and not timeout:
            return False
        for f in fixtures:
            if f.KEY in ('ipdb', 'ipdb_all'):
                return False
        for f in fixtures:
            if f.KEY == 'db' and f.PROVIDER in ('sqlserver', 'oracle'):
                return True
pony_fixtures['test'].extend([
    'log',
    'clear_tables',
    'db_session',
])
pony_fixtures['class'].extend([
    'separate_process',
    'timeout',
    'db',
    'generate_mapping',
])
def db_is_present(providers, config):
    return providers
provider_validators.update({
    'db': db_is_present,
})