dotslash
4/4/2019 - 3:46 AM

SnapshottableArray.scala

import scala.collection.mutable.ListBuffer

trait SnapshottableArray { 
  // Regular array operations 
  def set(index: Int, value: Int): Unit 
  def get(index: Int): Int 

  // Creates a snapshot and returns a handle to the snapshot 
  def snapshot(): Int

  // Get the element at index from snapshot described at snapshot_handle 
  def getSnapValue(ind: Int, snapHandle: Int): Option[Int]
}; 

class SnapshottableArrayTreeSet(size: Int) extends SnapshottableArray {
    private val data: Array[java.util.TreeMap[Int, Int]] = Array.fill(size){
        val map = new java.util.TreeMap[Int, Int]()
        map.put(0, 0)
        map
    }
    private var curHandle = 0
    def get(ind: Int): Int = data(ind).lastEntry.getValue
    def set(ind: Int, v: Int): Unit = {
        data(ind).put(curHandle, v)
    }
    def snapshot(): Int = {
        curHandle += 1
        curHandle - 1
    }
    def getSnapValue(ind: Int, snapHandle: Int): Option[Int] = {
        if (snapHandle > curHandle) {
            return None
        }
        Option(data(ind).lowerEntry(snapHandle + 1).getValue)
    }
}


case class IndVal(ind: Int, value: Int)
// SnapshottableArrayBS via Binary Search
class SnapshottableArrayBS(size: Int) extends SnapshottableArray{
    private val data: Array[ListBuffer[IndVal]] = Array.fill(size){ 
        val lb = new ListBuffer[IndVal]()
        lb.append(IndVal(0,0))
        lb
    }
    private var curHandle = 0
    def get(ind: Int): Int = data(ind).last.value
    def set(ind: Int, v: Int): Unit = {
        val lb = data(ind)
        if (lb.last.ind == curHandle) {
            lb(lb.length - 1) = IndVal(curHandle, v)
        } else {
            lb.append(IndVal(curHandle, v))
        }
    }
    def snapshot(): Int = {
        curHandle += 1
        curHandle - 1
    }
    def getSnapValueIneff(i: Int, snapHandle: Int): Option[Int] = {
        if (snapHandle == curHandle) {
            return Option(get(i))
        }
        var ret: Option[Int] = None
        for (iv <- data(i)) {
            if (iv.ind > snapHandle) {
                return ret
            }
            ret = Option(iv.value)
        }
        None
    }
    
    def getSnapValue(i: Int, snapHandle: Int): Option[Int] = {
        if (snapHandle > curHandle) {
            return None
        }
        val arr = data(i)
        var start: Int = 0
        var end: Int = arr.length
        while (start + 1 < end) {
            val mid: Int = (start + end)/2
            val midInd = arr(mid).ind
            if (midInd == snapHandle) {
                return Some(arr(mid).value)
            } else if (midInd < snapHandle) {
                start = midInd
            } else {
                end = midInd
            }
        }
        return Some(arr(start).value)
    }
}

object HelloWorld {
    def main(args: Array[String]): Unit = {
        println("Hello, world!")
        // val x = new SnapshottableArray(10)
        val x = new SnapshottableArrayBS(10)
        println(x.get(0))
        
        x.set(0, 10)
        println(x.get(0))
        
        val snap = x.snapshot()
        x.set(0, 50)
        
        println(x.get(0))
        println(x.getSnapValue(0, snap))
        println(x.getSnapValue(0, snap+1))
        println(x.getSnapValue(0, snap+2))

    }
}