6qat
6/7/2018 - 11:11 AM

Quick implementation of the CombineLatest RX operator in AkkaStreams

Quick implementation of the CombineLatest RX operator in AkkaStreams

import akka.actor.ActorSystem
import akka.stream._
import akka.stream.scaladsl.{Flow, GraphDSL, Keep, RunnableGraph, Sink, Source}
import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler}
import akka.testkit.TestKit
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike}

import scala.concurrent.Await

class CombineLatest[A, B]
  extends GraphStage[FanInShape2[A, B, (A, B)]] {
  val in0: Inlet[A] = Inlet[A]("CombineLatest.in1")
  val in1: Inlet[B] = Inlet[B]("CombineLatest.in2")
  val out: Outlet[(A, B)] = Outlet[(A, B)]("CombineLatest.out")

  override val shape = new FanInShape2[A, B, (A, B)](in0, in1, out)

  override def createLogic(attr: Attributes): GraphStageLogic =
    new GraphStageLogic(shape) {
      var aLast = Option.empty[A]
      var bLast = Option.empty[B]
      var waiting = false
      var lastSent = Option.empty[(A, B)]

      setHandler(out, new OutHandler {
        override def onPull(): Unit = {
          if(aLast.isEmpty || bLast.isEmpty) {
            waiting = true
          } else {
            if(lastSent.isEmpty || lastSent.get != (aLast.get, bLast.get)) {
              push(out, (aLast.get, bLast.get))
              lastSent = Option((aLast.get, bLast.get))
            } else {
              waiting = true
            }
          }
          pullBoth
        }
      })

      setHandler(in0, new InHandler {
        override def onPush() = {
          aLast = Option(grab(in0))
          sendIfWaiting
        }

        override def onUpstreamFinish(): Unit = super.onUpstreamFinish()
      })

      setHandler(in1, new InHandler {
        override def onPush() = {
          bLast = Option(grab(in1))
          sendIfWaiting
        }

        override def onUpstreamFinish(): Unit = super.onUpstreamFinish()
      })

      def sendIfWaiting(): Unit = {
        if(waiting) {
          if(aLast.isDefined && bLast.isDefined) {
            push(out, (aLast.get, bLast.get))
            lastSent = Option((aLast.get, bLast.get))
            waiting = false
            pullBoth()
          }
        }
      }

      def pullBoth(): Unit = {
        if(!hasBeenPulled(in0)) {
          pull(in0)
        }
        if(!hasBeenPulled(in1)) {
          pull(in1)
        }
      }
    }
}

class CombineLatestSpec extends TestKit(ActorSystem("CombineLatestSpec"))
  with WordSpecLike
  with Matchers
  with BeforeAndAfterAll {

  import scala.concurrent.duration._

  implicit val materializer = ActorMaterializer()

  override def afterAll {
    TestKit.shutdownActorSystem(system)
  }

  "CombineLatest" must {
    "work with happy case" in {
      val dataSource1 = Source(List(true, false, true, false)).throttle(1, 200.millisecond, 1, ThrottleMode.Shaping)
      val dataSource2 = Source(0 to 7).throttle(1, 100.millisecond, 1, ThrottleMode.Shaping).filter(_ % 2 == 1)
      // T:    0    100    200    300    400    500    600    700
      // S1:  true         false         true          false
      // S2:        1               3            5              7
      val sink = Flow[(Boolean, Int)].throttle(1, 50.milliseconds, 1, ThrottleMode.Shaping).toMat(Sink.seq)(Keep.right)

      val g = RunnableGraph.fromGraph(GraphDSL.create(sink) { implicit b =>
        (s1) =>
          import akka.stream.scaladsl.GraphDSL.Implicits._
          val clatest = b.add(new CombineLatest[Boolean, Int]())
          dataSource1 ~> clatest.in0
          dataSource2 ~> clatest.in1
          clatest.out ~> s1

          ClosedShape
      })
      val result = Await.result(g.run(), 10.seconds)
      result shouldBe Seq((true, 1), (false, 1), (false, 3), (true, 3), (true, 5), (false, 5))
    }
  }
}