casualjim
12/18/2010 - 9:15 PM

WebSocket support for scalatra

WebSocket support for scalatra

class WebSocketApp extends ScalatraServlet with WebSocketSupport {

  // overriden because otherwise you need a trailing slash for the root url
  // prefer the freedom with or without root.
  override def requestPath = {
    val p = (Option(request.getPathInfo) getOrElse "").trim
    if(p.isEmpty) "/" else p
  }

  ws("/?") { websocket =>
    websocket.onMessage { (soc, msg) =>
      log debug "RECV: %s".format(msg)
      soc.sendMessage("ECHO: %s" format msg)
    }
  }

  get("/?") {
    log debug "executing get index"
    <html>
      <head><title>WebSocket connection</title></head>
      <body>
        <h1>Hello</h1>
        <p>In a javascript console</p>
        <pre>
          var ws = new WebSocket("ws://localhost:8888/suske")
          ws.onmessage = { "function(m) { console.log(m.data); };" }
          ws.send("hello scalatra");
          // Some time passes
          ws.close()
        </pre>
      </body>
    </html>
  }

}
package com.mojolly.websocket

import org.scalatra._
import javax.servlet.http.{HttpServletResponse, HttpServletRequest}
import org.eclipse.jetty.websocket.{WebSocket => ServletWebSocket, WebSocketFactory}
import akka.util.Logging
import java.io.UnsupportedEncodingException
import org.eclipse.jetty.websocket.WebSocket.Outbound
import collection.mutable.{ HashSet, SynchronizedSet }

object WebSocket {
  var encoding = "UTF-8"
}

trait WebSocket {

  private type MessageHandler = (ScalatraWebSocket, String) => Unit
  private type ConnectingHandler = (ScalatraWebSocket) => Boolean
  private type DisconnectedHandler = (ScalatraWebSocket) => Unit
  
  private val _messageHandlers = new HashSet[MessageHandler] with SynchronizedSet[MessageHandler]
  private val _connectingHandlers = new HashSet[ConnectingHandler] with SynchronizedSet[ConnectingHandler]
  private val _disconnectedHandlers = new HashSet[DisconnectedHandler] with SynchronizedSet[DisconnectedHandler]

  def onMessage(handler: MessageHandler) {
    _messageHandlers += handler
  }

  def connecting(handler: ConnectingHandler) {
    _connectingHandlers += handler
  }

  def disconnect(handler: DisconnectedHandler) {
    _disconnectedHandlers += handler
  }

  def result = {
    if(_messageHandlers.isEmpty) throw new RuntimeException("You need to define at least 1 message handler")
    (new ScalatraWebSocket {

      def onDisconnect() {
        _disconnectedHandlers foreach { _(this) }
      }

      override def onConnect(outbound: Outbound) = {
        super.onConnect(outbound)
        _connectingHandlers foreach { h =>
          val r = h(this)
          if(!r) throw new RuntimeException("There was a problem connecting the websocket")
        }
      }

      def onMessage(opcode: Byte, data: String) {
        _messageHandlers foreach { _(this, data) }
      }
    }).asInstanceOf[ServletWebSocket]
  }
}
trait ScalatraWebSocket extends ServletWebSocket {
  private var _outbound: Outbound = null
  def outOption = {
    Option(_outbound)
  }

  def out = outOption getOrElse (throw new RuntimeException("Not connected"))

  def onDisconnect(): Unit

  def onMessage(opcode: Byte, data: Array[Byte], offset: Int, length: Int) = {
    try {
      onMessage(opcode, new String(data, offset, length, WebSocket.encoding))
    } catch {
      case e: UnsupportedEncodingException =>
    }
  }

  def onFragment(more: Boolean, opcode: Byte, data: Array[Byte], offset: Int, length: Int) = {

  }

  def onMessage(opcode: Byte, data: String)

  def onConnect(p1: Outbound) = {
    _outbound = p1
  }

  def sendMessage(data: String) {
    outOption foreach { _.sendMessage(data) }
  }
}

trait WebSocketSupport extends Logging { self: ScalatraKernel =>

  private val wsFactory = new WebSocketFactory
  type WebSocketAction = WebSocket => Unit

  def ws(routeMatchers: RouteMatcher*)(action: WebSocket => Unit) = {
    addRoute("WS", routeMatchers, {
      try {
        doUpgrade { () =>
          val websocket = new WebSocket { }
          action(websocket)
          websocket.result
        }
        Unit
      } catch {
        case e => log error (e, "There was an error creating the websocket")
        Unit
      }
    })
  }


  private def doUpgrade(matcher: () => ServletWebSocket) = {
    val hixie = Option(request.getHeader("Sec-WebSocket-Key1")).isDefined
    val ph = request.getHeader(if(hixie) "Sec-WebSocket-Protocol" else "WebSocket-Protocol")
    val protocol = Option(ph) getOrElse request.getHeader("Sec-WebSocket-Protocol")
    val host = request.getHeader("Host")
    val origin = Option(request.getHeader("Origin")) getOrElse host

    log debug ("is hixie: %s\nprotocol: %s\nhost: %s\norigin: %s", hixie, protocol, host, origin)
//    val websocket = createWebSocket(req, (Option(protocol) getOrElse "").split(" ").headOption getOrElse null)
    val websocket = matcher()
    if(websocket == null) {
      if(hixie) response.setHeader("Connection", "close")
      response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE)
    } else {
      wsFactory.upgrade(request, response, websocket, origin, protocol)
    }
    websocket
  }
}