casualjim
2/13/2012 - 7:42 PM

A Netty based WebSocket client and server in scala

A Netty based WebSocket client and server in scala

package io.backchat.minutes.river

import org.jboss.netty.bootstrap.ServerBootstrap
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory
import java.util.concurrent.{ TimeUnit, Executors }
import java.net.{ InetSocketAddress }
import org.jboss.netty.channel._
import group.{ChannelGroup, DefaultChannelGroup}
import org.elasticsearch.common.logging.{ESLogger, ESLoggerFactory}
import org.jboss.netty.handler.codec.http.{HttpRequest, HttpChunkAggregator, HttpRequestDecoder, HttpResponseEncoder}
import org.jboss.netty.handler.codec.http.websocketx._
import org.jboss.netty.handler.codec.http.HttpHeaders.Values
import org.jboss.netty.handler.codec.http.HttpHeaders.Names
import java.util.Locale.ENGLISH

trait WebSocketServerConfig {
  def listenOn: String
  def port: Int
}


/**
 * Netty based WebSocketServer
 * requires netty 3.3.x or later
 * 
 * Usage:
 * <pre>
 *   val conf = new WebSocketServerConfig {
 *     val port = 14567
 *     val listenOn = "0.0.0.0"
 *   }
 *   
 *   val server = WebSocketServer(conf) {
 *     case Connect(_) => println("got a client connection")
 *     case TextMessage(cl, text) => cl.write(new TextWebSocketFrame("ECHO: " + text))
 *     case Disconnected(_) => println("client disconnected")
 *   }
 *   server.start
 *   // time passes......
 *   server.stop
 * </pre>
 */
object WebSocketServer {
  
  type WebSocketHandler = PartialFunction[WebSocketMessage, Unit]
  
  sealed trait WebSocketMessage
  case class Connect(client: Channel) extends WebSocketMessage
  case class TextMessage(client: Channel, content: String) extends WebSocketMessage
  case class BinaryMessage(client: Channel, content: Array[Byte]) extends WebSocketMessage
  case class Error(client: Channel, cause: Option[Throwable]) extends WebSocketMessage
  case class Disconnected(client: Channel) extends WebSocketMessage


  def apply(config: WebSocketServerConfig)(handler: WebSocketServer.WebSocketHandler): WebSocketServer =
    new WebSocketServer(config, handler)

  private class ConnectionTracker(channels: ChannelGroup) extends SimpleChannelUpstreamHandler {
    override def channelClosed(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
      channels remove e.getChannel
      ctx.sendUpstream(e)
    }

    override def channelConnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
      channels add e.getChannel
      ctx.sendUpstream(e)
    }

    override def channelDisconnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
      channels remove e.getChannel
      ctx.sendUpstream(e)
    }

  }

  private class WebSocketPartialFunctionHandler(handler: WebSocketHandler, logger: ESLogger) extends SimpleChannelUpstreamHandler {

    private[this] var collectedFrames: Seq[ContinuationWebSocketFrame] = Vector.empty[ContinuationWebSocketFrame]

    private[this] var handshaker: WebSocketServerHandshaker = _

    override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
      e.getMessage match {
        case httpRequest: HttpRequest if isWebSocketUpgrade(httpRequest) ⇒ handleUpgrade(ctx, httpRequest)
        case m: TextWebSocketFrame => handler lift TextMessage(e.getChannel, m.getText)
        case m: BinaryWebSocketFrame => handler lift BinaryMessage(e.getChannel, m.getBinaryData.array)
        case m: ContinuationWebSocketFrame => {
          if (m.isFinalFragment) {
            handler lift TextMessage(e.getChannel, collectedFrames map (_.getText) reduce (_ + _))
            collectedFrames = Nil
          } else {
            collectedFrames :+= m
          }
        }
        case f: CloseWebSocketFrame ⇒
          if (handshaker != null) handshaker.close(ctx.getChannel, f)
          handler lift Disconnected(e.getChannel)
        case _: PingWebSocketFrame ⇒ e.getChannel.write(new PongWebSocketFrame)
        case _ ⇒ ctx.sendUpstream(e)
      }
    }

    override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
      handler lift Error(e.getChannel, Option(e.getCause))
    }

    private def isWebSocketUpgrade(httpRequest: HttpRequest): Boolean = {
      val connHdr = httpRequest.getHeader(Names.CONNECTION)
      val upgrHdr = httpRequest.getHeader(Names.UPGRADE)
      (connHdr != null && connHdr.equalsIgnoreCase(Values.UPGRADE)) &&
        (upgrHdr != null && upgrHdr.equalsIgnoreCase(Values.WEBSOCKET))
    }

    private def handleUpgrade(ctx: ChannelHandlerContext, httpRequest: HttpRequest) {
      val handshakerFactory = new WebSocketServerHandshakerFactory(websocketLocation(httpRequest), null, false)
      handshaker = handshakerFactory.newHandshaker(httpRequest)
      if (handshaker == null) handshakerFactory.sendUnsupportedWebSocketVersionResponse(ctx.getChannel)
      else {
        handshaker.handshake(ctx.getChannel, httpRequest)
        handler.lift(Connect(ctx.getChannel))
      }
    }

    private def isHttps(req: HttpRequest) = {
      val h1 = Option(req.getHeader("REQUEST_URI")).filter(_.trim.nonEmpty)
      val h2 = Option(req.getHeader("REQUEST_URI")).filter(_.trim.nonEmpty)
      (h1.isDefined && h1.forall(_.toUpperCase(ENGLISH).startsWith("HTTPS"))) ||
        (h2.isDefined && h2.forall(_.toUpperCase(ENGLISH) startsWith "HTTPS"))
    }

    private def websocketLocation(req: HttpRequest) = {
      if (isHttps(req))
        "wss://" + req.getHeader(Names.HOST) + "/"
      else
        "ws://" + req.getHeader(Names.HOST) + "/"
    }
  }

}

class WebSocketServer(val config: WebSocketServerConfig, val handler: WebSocketServer.WebSocketHandler) {

  import WebSocketServer._
  private[this] val realHandler = handler orElse devNull
  private[this] val devNull: WebSocketHandler = {
    case WebSocketServer.Error(_, Some(ex)) =>
      System.err.println(ex.getMessage)
      ex.printStackTrace()
    case _ =>
  }
  protected val logger = ESLoggerFactory.getLogger(getClass.getName)
  private[this] val boss = Executors.newCachedThreadPool()
  private[this] val worker = Executors.newCachedThreadPool()
  private[this] val server = {
    val bs = new ServerBootstrap(new NioServerSocketChannelFactory(boss, worker))
    bs.setOption("soLinger", 0)
    bs.setOption("reuseAddress", true)
    bs.setOption("child.tcpNoDelay", true)
    bs
  }

  private[this] val allChannels = new DefaultChannelGroup

  protected def getPipeline = {
    val pipe = Channels.pipeline()
    pipe.addLast("connection-tracker", new ConnectionTracker(allChannels))
    pipe.addLast("decoder", new HttpRequestDecoder(4096, 8192, 8192))
    pipe.addLast("aggregator", new HttpChunkAggregator(64 * 1024))
    pipe.addLast("encoder", new HttpResponseEncoder)
    pipe.addLast("websocketmessages", new WebSocketPartialFunctionHandler(realHandler, logger))
    pipe
  }

  private[this] val servName = getClass.getSimpleName

  def start = synchronized {
    server.setPipeline(getPipeline)
    val addr = if (config.listenOn == null || config.listenOn.trim.isEmpty) new InetSocketAddress(config.port)
    else new InetSocketAddress(config.listenOn, config.port)
    val sc = server.bind(addr)
    allChannels add sc
    logger info "Started %s on [%s:%d]".format(servName, config.listenOn, config.port)
  }

  def stop = synchronized {
    allChannels.close().awaitUninterruptibly()
    val thread = new Thread {
      override def run = {
        server.releaseExternalResources()
        boss.awaitTermination(5, TimeUnit.SECONDS)
        worker.awaitTermination(5, TimeUnit.SECONDS)
      }
    }
    thread.setDaemon(false)
    thread.start()
    thread.join()
    logger info "Stopped %s".format(servName)
  }
}
package mojolly.io

import org.jboss.netty.bootstrap.ClientBootstrap
import org.jboss.netty.channel._
import socket.nio.NioClientSocketChannelFactory
import java.util.concurrent.Executors
import org.jboss.netty.handler.codec.http._
import collection.JavaConversions._
import websocketx._
import java.net.{InetSocketAddress, URI}
import java.nio.charset.Charset
import org.jboss.netty.buffer.ChannelBuffers
import org.jboss.netty.util.CharsetUtil
import akka.actor.ActorRef
import mojolly.LibraryConstants


/**
 * Usage of the simple websocket client:
 * <pre>
 *   WebSocketClient(new URI("ws://localhost:8080/thesocket")) {
 *     case Connected(client) => println("Connection has been established to: " + client.url.toASCIIString)
 *     case Disconnected(client, _) => println("The websocket to " + client.url.toASCIIString + " disconnected.")
 *     case TextMessage(client, message) => {
 *       println("RECV: " + message)
 *       client send ("ECHO: " + message)
 *     }
 *   }
 * </pre>
 */
object WebSocketClient {

  object Messages {
    sealed trait WebSocketClientMessage
    case object Connecting extends WebSocketClientMessage
    case class ConnectionFailed(client: WebSocketClient, reason: Option[Throwable] = None) extends WebSocketClientMessage
    case class Connected(client: WebSocketClient) extends WebSocketClientMessage
    case class TextMessage(client: WebSocketClient, text: String) extends WebSocketClientMessage
    case class WriteFailed(client: WebSocketClient, message: String, reason: Option[Throwable]) extends WebSocketClientMessage
    case object Disconnecting extends WebSocketClientMessage
    case class Disconnected(client: WebSocketClient, reason: Option[Throwable] = None) extends WebSocketClientMessage
    case class Error(client: WebSocketClient, th: Throwable) extends WebSocketClientMessage
  }

  type Handler = PartialFunction[Messages.WebSocketClientMessage, Unit]
  type FrameReader = WebSocketFrame => String
  
  val defaultFrameReader = (_: WebSocketFrame) match {
    case f: TextWebSocketFrame => f.getText
    case _ => throw new UnsupportedOperationException("Only single text frames are supported for now")
  }
  
  def apply(url: URI, version: WebSocketVersion = WebSocketVersion.V13, reader: FrameReader = defaultFrameReader)(handle: Handler): WebSocketClient = {
    require(url.getScheme.startsWith("ws"), "The scheme of the url should be 'ws' or 'wss'")
    new DefaultWebSocketClient(url, version, handle, reader)
  }
  
  def apply(url: URI, handle: ActorRef): WebSocketClient = {
    require(url.getScheme.startsWith("ws"), "The scheme of the url should be 'ws' or 'wss'")
    WebSocketClient(url) { case x => handle ! x }
  }
  
  private class WebSocketClientHandler(handshaker: WebSocketClientHandshaker, client: WebSocketClient) extends SimpleChannelUpstreamHandler {

    import Messages._
    override def channelClosed(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
      client.handler(Disconnected(client))
    }

    override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
      e.getMessage match {
        case resp: HttpResponse if handshaker.isHandshakeComplete => 
          throw new WebSocketException("Unexpected HttpResponse (status=" + resp.getStatus + ", content="
                              + resp.getContent.toString(CharsetUtil.UTF_8) + ")")
        case resp: HttpResponse =>
          handshaker.finishHandshake(ctx.getChannel, e.getMessage.asInstanceOf[HttpResponse])
          client.handler(Connected(client))

        case f: TextWebSocketFrame => client.handler(TextMessage(client, f.getText))
        case _: PongWebSocketFrame =>
        case _: CloseWebSocketFrame => ctx.getChannel.close()
      }
    }
    

    override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
      client.handler(Error(client, e.getCause))
      e.getChannel.close()
    }

  }
  private class DefaultWebSocketClient(
      val url: URI,
      version: WebSocketVersion,
      private[this] val _handler: Handler,
      val reader: FrameReader = defaultFrameReader) extends WebSocketClient {
    val normalized = url.normalize()
    val tgt = if (normalized.getPath == null || normalized.getPath.trim().isEmpty) { 
      new URI(normalized.getScheme, normalized.getAuthority,"/", normalized.getQuery, normalized.getFragment)
    } else normalized
      
    val bootstrap = new ClientBootstrap(new NioClientSocketChannelFactory(Executors.newCachedThreadPool, Executors.newCachedThreadPool))
    val handshaker = new WebSocketClientHandshakerFactory().newHandshaker(tgt, version, null, false, Map.empty[String, String])
    val self = this
    var channel: Channel = _

    import Messages._
    val handler = _handler orElse defaultHandler

    private def defaultHandler: Handler = {
      case Error(_, ex) => ex.printStackTrace()
      case _: WebSocketClientMessage =>
    }


    bootstrap.setPipelineFactory(new ChannelPipelineFactory {
      def getPipeline = {
        val pipeline = Channels.pipeline()
        if (version == WebSocketVersion.V00)
          pipeline.addLast("decoder", new WebSocketHttpResponseDecoder)
        else
          pipeline.addLast("decoder", new HttpResponseDecoder)
        
        pipeline.addLast("encoder", new HttpRequestEncoder)
        pipeline.addLast("ws-handler", new WebSocketClientHandler(handshaker, self))
        pipeline
      }
    })

    import WebSocketClient.Messages._
    def connect = {
      if (channel == null || !channel.isConnected) {
        val listener = futureListener { future =>
                  if (future.isSuccess) {
                    synchronized { channel = future.getChannel }
                    handshaker.handshake(channel)
                  } else {
                    handler(ConnectionFailed(this, Option(future.getCause)))
                  }
                }
        handler(Connecting)
        val fut = bootstrap.connect(new InetSocketAddress(url.getHost, url.getPort))
        fut.addListener(listener)
        fut.await(5000L)
      }
    }

    def disconnect = {
      if (channel != null && channel.isConnected) {
        handler(Disconnecting)
        channel.write(new CloseWebSocketFrame())
      }
    }

    def send(message: String, charset: Charset = CharsetUtil.UTF_8) = {
      channel.write(new TextWebSocketFrame(ChannelBuffers.copiedBuffer(message, charset))).addListener(futureListener { fut =>
        if (!fut.isSuccess) {
          handler(WriteFailed(this, message, Option(fut.getCause)))
        } 
      })
    }
    
    def futureListener(handleWith: ChannelFuture => Unit) = new ChannelFutureListener {
      def operationComplete(future: ChannelFuture) {handleWith(future)}
    }
  }
  
  /**
   * Fix bug in standard HttpResponseDecoder for web socket clients. When status 101 is received for Hybi00, there are 16
   * bytes of contents expected
   */
  class WebSocketHttpResponseDecoder extends HttpResponseDecoder {

    val codes = List(101, 200, 204, 205, 304)

    protected override def isContentAlwaysEmpty(msg: HttpMessage) = {
      msg match {
        case res: HttpResponse => codes contains res.getStatus.getCode
        case _ => false
      }
    }
  }

  /**
   * A WebSocket related exception
   *
   * Copied from https://github.com/cgbystrom/netty-tools
   */
  class WebSocketException(s: String,  th: Throwable) extends java.io.IOException(s, th) {
    def this(s: String) = this(s, null)
  }
  
}
trait WebSocketClient {

  def url: URI
  def reader: WebSocketClient.FrameReader
  def handler: WebSocketClient.Handler

  def connect
  
  def disconnect
  
  def send(message: String, charset: Charset = CharsetUtil.UTF_8)
}