LSTANCZYK
10/17/2014 - 2:53 PM

TestEnvironment.cs

using System.Configuration;
using System.Data.Common;
using System.Data.SqlClient;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;
using FluentNHibernate.Cfg;
using FluentNHibernate.Cfg.Db;
using Ionic.Zip;
using Inflector;
using NHibernate;
using NHibernate.Connection;
using NSubstitute;
using TechTalk.SpecFlow;
using roundhouse;

[Binding]
public class TestEnvironment
{
  static string TempDir;
  static ISessionFactory SessionFactory;
  
  [BeforeTestRun]
  public static void BeforeTestRun()
  {
    TempDir = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName());
    Directory.CreateDirectory(TempDir);

    ExtractMigrations();

    var csb = GetConnectionStringBuilder();
    csb.InitialCatalog = InitialCatalogForCurrentProcess(csb);
    
    CreateTemplateDatabase(csb);
    CreateSessionFactory(csb.ConnectionString);
    
    var databaseName = csb.InitialCatalog;

    string mdfPath = GetMdfPath(csb);
    DetachDatabase(csb, databaseName);
    
    var templateDatabaseName = string.Format("Template{0}", Path.GetFileName(mdfPath));
    File.Copy(mdfPath, Path.Combine(TempDir, templateDatabaseName));

    AttachDatabase(csb, databaseName, mdfPath);
  }
  
  [BeforeScenario]
  public void BeforeScenario()
  {
    var csb = GetConnectionStringBuilder();
    var sourceFileName = Path.Combine(TempDir, string.Format("Template{0}.mdf", InitialCatalogForCurrentProcess(csb)));
    var destFileName = Path.Combine(TempDir, string.Format("{0}.mdf", DatabaseName));
    File.Copy(sourceFileName, destFileName, true);

    csb.InitialCatalog = "master";

    CreateDatabase(csb, destFileName);
  }
  
  [AfterScenario]
  public void AfterScenario()
  {
    var csb = GetConnectionStringBuilder();
    csb.InitialCatalog = "master";
    using (var cn = new SqlConnection(csb.ConnectionString))
    {
      cn.Open();
      using (var cmd = cn.CreateCommand())
      {
        cmd.CommandText = string.Format(@"DROP DATABASE {0}", DatabaseName);
        cmd.ExecuteNonQuery();
      }
      cn.Close();
    }
  }
  
  [AfterTestRun]
  public static void AfterTestRun()
  {
    SessionFactory.Dispose();

    var csb = GetConnectionStringBuilder();
    var databaseName = InitialCatalogForCurrentProcess(csb);
    csb.InitialCatalog = "master";
    using (var cn = new SqlConnection(csb.ConnectionString))
    {
      cn.Open();
      using (var cmd = cn.CreateCommand())
      {
        cmd.CommandText = string.Format(@"DROP DATABASE {0}", databaseName);
        cmd.ExecuteNonQuery();
      }
      cn.Close();
    }
    Directory.Delete(TempDir, true);
  }
  
  static void ExtractMigrations()
  {
    var assembly = Assembly.GetExecutingAssembly();
    var directory = Path.Combine(TempDir, "Database");
    Directory.CreateDirectory(directory);
    using(var stream = assembly.GetManifestResourceStream(string.Format("{0}.Migrations.zip", assembly.GetName().Name)))
    using (var zipFile = ZipFile.Read(stream))
      zipFile.ExtractAll(directory);
  }
  
  static SqlConnectionStringBuilder GetConnectionStringBuilder()
  {
    var csb =
      new SqlConnectionStringBuilder(
          ConfigurationManager.ConnectionStrings["TestData"].ConnectionString);
    return csb;
  }
  
  static string InitialCatalogForCurrentProcess(SqlConnectionStringBuilder csb)
  {
    return string.Format("P{0}_{1}", Process.GetCurrentProcess().Id, csb.InitialCatalog);
  }
  
  static string GetMdfPath(SqlConnectionStringBuilder csb)
  {
    string mdfPath;
    using (var cn = new SqlConnection(csb.ConnectionString))
    {
      cn.Open();
      using (var cmd = cn.CreateCommand())
      {
        cmd.CommandText = @"select physical_name from sys.database_files where type = 0";
        mdfPath = cmd.ExecuteScalar().ToString();
      }
      cn.Close();
    }
    return mdfPath;
  }
  
  static void DetachDatabase(SqlConnectionStringBuilder csb, string databaseName)
  {
    csb.InitialCatalog = "master";
    using (var cn = new SqlConnection(csb.ConnectionString))
    {
      cn.Open();
      using (var cmd = cn.CreateCommand())
      {
        cmd.CommandText = string.Format(@"EXEC master.dbo.sp_detach_db @dbname = N'{0}',@keepfulltextindexfile = N'false';", databaseName);
        cmd.ExecuteNonQuery();
      }
      cn.Close();
    }
  }
  
  static void AttachDatabase(SqlConnectionStringBuilder csb, string databaseName, string mdfPath)
  {
    using (var cn = new SqlConnection(csb.ConnectionString))
    {
      cn.Open();
      using (var cmd = cn.CreateCommand())
      {
        cmd.CommandText = string.Format(@"EXEC master.dbo.sp_attach_db @dbname = N'{0}', @filename1 = N'{1}'", databaseName, mdfPath);
        cmd.ExecuteNonQuery();
      }
      cn.Close();
    }
  }
  
  static void CreateDatabase(DbConnectionStringBuilder csb, string destFileName)
  {
    using (var cn = new SqlConnection(csb.ConnectionString))
    {
      cn.Open();
      using (var cmd = cn.CreateCommand())
      {
        cmd.CommandText = string.Format(@"CREATE DATABASE {0} ON (FILENAME = N'{1}') FOR ATTACH;", DatabaseName, destFileName);
        cmd.ExecuteNonQuery();
      }
      cn.Close();
    }
  }
  
  static string DatabaseName
  {
    get
    {
      var databaseName = ScenarioContext.Current.ScenarioInfo.Title;
      databaseName = Regex.Replace(databaseName, @"[^\w\s]", string.Empty);
      databaseName = Path.GetInvalidFileNameChars().Aggregate(databaseName, (current, c) => current.Replace(c, '_'));
      databaseName = string.Format("P{0}_{1}", Process.GetCurrentProcess().Id, databaseName.Underscore());
      return databaseName;
    }
  }
  
  static void CreateSessionFactory(string connectionString)
  {
    var fluentConfiguration = GetFluentConfiguration(connectionString);
    var nhibernateConfiguration = fluentConfiguration.BuildConfiguration();
    SessionFactory = nhibernateConfiguration.BuildSessionFactory();
  }

  static FluentConfiguration GetFluentConfiguration(string connectionString)
  {
    var configuration = Fluently.Configure()
        .Database(
          MsSqlConfiguration.MsSql2008
          .Provider<DynamicConnectionProvider>()
          .ConnectionString(c => c.Is(connectionString))
          .AdoNetBatchSize(20)
        )
        .Mappings(m =>
        {
          m.HbmMappings.AddFromAssemblyOf<Command>();
          m.FluentMappings.AddFromAssemblyOf<Command>().Conventions.AddFromAssemblyOf<Command>();
        })
        ;
    return configuration;
  }
  
  public class DynamicConnectionProvider : DriverConnectionProvider
  {
    protected override string ConnectionString
    {
      get
      {
        var csb = GetConnectionStringBuilder();
        csb.InitialCatalog = ScenarioContext.Current != null
                            ? DatabaseName
                            : InitialCatalogForCurrentProcess(csb);
        return csb.ConnectionString;
      }
    }
  }
}