casualjim
3/27/2011 - 9:31 PM

variance.scala

package experiments

object Variance extends Enumeration {
  val Co, Contra, No = Value
}

class Def[ToMatch : Manifest](variance: Variance.Value*) {
  import Variance._

  def unapply[Candidate](candidate: Candidate)(implicit mf: Manifest[Candidate]): Option[Candidate] = {
    val toMatch = implicitly[Manifest[ToMatch]]
    val typeArgsTriplet = toMatch.typeArguments.zip(mf.typeArguments).zipWithIndex
    
    def sameArgs = typeArgsTriplet forall {
      case ((desired,actual),index) if(getVariance(index) == Contra) => desired <:< actual
      case ((desired,actual),index) if(getVariance(index) == No) => desired == actual
      case ((desired,actual),index)  => desired >:> actual
    }

    val isAssignable = toMatch.erasure.isAssignableFrom(mf.erasure) || (toMatch >:> mf)
    if (isAssignable && sameArgs) Some(candidate.asInstanceOf[Candidate]) else None
  }

  def getVariance(index: Int) = {
    if(variance.length > index) variance(index) else No
  }
}

class A(val a: Int)
sealed trait AA
trait BB extends AA
trait CC extends AA
class B(override val a: Int) extends A(a) with BB
class C(override val a: Int) extends B(a) with CC

object TestIt {
  
  val IntList = new Def[List[Int]](Variance.No)
  val Covariant = new Def[List[B]](Variance.Co)
  val Contravariant = new Def[List[B]](Variance.Contra)
  val Invariant = new Def[List[B]](Variance.No)
  
  def matchIt[T](a: T)(implicit mf: Manifest[T]) = a match {
    case IntList(il) => println("This was an int list: " + il.asInstanceOf[List[Int]].mkString(", "))
    case Invariant(_) => println("this was an invariant match")
    case Covariant(col) => println("This was a covariant match")
    case Contravariant(col) => println("This was a contravariant match")
    case _ => println("NO int list: " + a)
  }
  
  def now = {
    println("matching int list")
    matchIt(List(1, 2, 3, 4, 5, 6, 7))
    println("matching list[A]")
    matchIt(List(new A(1), new A(2), new A(3)))
    println("matching list[B]")
    matchIt(List(new B(1), new B(2), new B(3)))
    println("matching list[C]")
    matchIt(List(new C(1), new C(2), new C(3)))
  }
}