loicdescotte
8/21/2013 - 12:55 PM

Scala class to transform an HTTP GET stream into Enumerator[Whatever]

Scala class to transform an HTTP GET stream into Enumerator[Whatever]

package play.api.libs.ws

import play.api.libs.iteratee.{Enumeratee, Concurrent, Enumerator}

import play.api.libs.concurrent.Execution.Implicits._
import com.ning.http.client._
import com.ning.http.client.AsyncHandler.STATE
import play.api.Logger
import scala.concurrent.{Future, Promise}

object WSEnumerator {

  private class AbortOnIterateeDone() extends RuntimeException
  private val logger = Logger("WSEnumerator")

  def getStream[A](url: String, timeout: Int = -1)(f: Array[Byte] => A): Future[Enumerator[A]] = {
    getRawStream( url, timeout ).map( _.through( Enumeratee.map[Array[Byte]]( bytes => f( bytes ) ) ) )
  }

  def getRawStream(url: String, timeout: Int = -1): Future[Enumerator[Array[Byte]]] = {
    val promise = Promise[Enumerator[Array[Byte]]]()
    val promiseStatus = Promise[Int]()
    val promiseHeader = Promise[HttpResponseHeaders]()
    val config = new PerRequestConfig()
    config.setRequestTimeoutInMs(timeout)
    val client = WS.client.prepareGet(url).setPerRequestConfig(config)
    val (enumerator, channel) = Concurrent.broadcast[Array[Byte]]
    val listenableFuture = client.execute(new AsyncHandler[Unit]() {
      override def onThrowable(p1: Throwable) {
        p1 match {
          case _: AbortOnIterateeDone => logger.debug(s"WS call aborted on purpose : $p1")
          case _ => {
            logger.debug("Actual exception, closing enumerator channel and leaking exception")
            channel.eofAndEnd()
            throw p1
          }
        }
      }
      override def onBodyPartReceived(p1: HttpResponseBodyPart): STATE = {
        channel.push(p1.getBodyPartBytes)
        STATE.CONTINUE
      }
      override def onStatusReceived(p1: HttpResponseStatus): STATE = {
        if (p1.getStatusCode >= 300) {
          promiseStatus.failure(new IllegalStateException(s"HTTP status is ${p1.getStatusCode} for URL ${url}"))
        } else {
          promiseStatus.success(p1.getStatusCode)
        }
        STATE.CONTINUE
      }
      override def onHeadersReceived(p1: HttpResponseHeaders): STATE =  {
        promiseHeader.success(p1)
        STATE.CONTINUE
      }
      override def onCompleted() {
        logger.debug("Closing channel as WS call is completed")
        channel.eofAndEnd()
      }
    })
    promise.success(enumerator.through(Enumeratee.onIterateeDone[Array[Byte]]{ () =>
      logger.debug("Iteratee is done ...")
      if (!listenableFuture.isDone) {
        listenableFuture.abort(new AbortOnIterateeDone())
        channel.eofAndEnd()
        logger.debug("Aborting WS call")
      } else {
        logger.debug("WS Call already finished")
      }
    }))
    for {
      _ <- promiseStatus.future
      _ <- promiseHeader.future
      f <- promise.future
    } yield f
  }
}