WebSocket support for scalatra
class WebSocketApp extends ScalatraServlet with WebSocketSupport {
// overriden because otherwise you need a trailing slash for the root url
// prefer the freedom with or without root.
override def requestPath = {
val p = (Option(request.getPathInfo) getOrElse "").trim
if(p.isEmpty) "/" else p
}
ws("/?") { websocket =>
websocket.onMessage { (soc, msg) =>
log debug "RECV: %s".format(msg)
soc.sendMessage("ECHO: %s" format msg)
}
}
get("/?") {
log debug "executing get index"
<html>
<head><title>WebSocket connection</title></head>
<body>
<h1>Hello</h1>
<p>In a javascript console</p>
<pre>
var ws = new WebSocket("ws://localhost:8888/suske")
ws.onmessage = { "function(m) { console.log(m.data); };" }
ws.send("hello scalatra");
// Some time passes
ws.close()
</pre>
</body>
</html>
}
}
package com.mojolly.websocket
import org.scalatra._
import javax.servlet.http.{HttpServletResponse, HttpServletRequest}
import org.eclipse.jetty.websocket.{WebSocket => ServletWebSocket, WebSocketFactory}
import akka.util.Logging
import java.io.UnsupportedEncodingException
import org.eclipse.jetty.websocket.WebSocket.Outbound
import collection.mutable.{ HashSet, SynchronizedSet }
object WebSocket {
var encoding = "UTF-8"
}
trait WebSocket {
private type MessageHandler = (ScalatraWebSocket, String) => Unit
private type ConnectingHandler = (ScalatraWebSocket) => Boolean
private type DisconnectedHandler = (ScalatraWebSocket) => Unit
private val _messageHandlers = new HashSet[MessageHandler] with SynchronizedSet[MessageHandler]
private val _connectingHandlers = new HashSet[ConnectingHandler] with SynchronizedSet[ConnectingHandler]
private val _disconnectedHandlers = new HashSet[DisconnectedHandler] with SynchronizedSet[DisconnectedHandler]
def onMessage(handler: MessageHandler) {
_messageHandlers += handler
}
def connecting(handler: ConnectingHandler) {
_connectingHandlers += handler
}
def disconnect(handler: DisconnectedHandler) {
_disconnectedHandlers += handler
}
def result = {
if(_messageHandlers.isEmpty) throw new RuntimeException("You need to define at least 1 message handler")
(new ScalatraWebSocket {
def onDisconnect() {
_disconnectedHandlers foreach { _(this) }
}
override def onConnect(outbound: Outbound) = {
super.onConnect(outbound)
_connectingHandlers foreach { h =>
val r = h(this)
if(!r) throw new RuntimeException("There was a problem connecting the websocket")
}
}
def onMessage(opcode: Byte, data: String) {
_messageHandlers foreach { _(this, data) }
}
}).asInstanceOf[ServletWebSocket]
}
}
trait ScalatraWebSocket extends ServletWebSocket {
private var _outbound: Outbound = null
def outOption = {
Option(_outbound)
}
def out = outOption getOrElse (throw new RuntimeException("Not connected"))
def onDisconnect(): Unit
def onMessage(opcode: Byte, data: Array[Byte], offset: Int, length: Int) = {
try {
onMessage(opcode, new String(data, offset, length, WebSocket.encoding))
} catch {
case e: UnsupportedEncodingException =>
}
}
def onFragment(more: Boolean, opcode: Byte, data: Array[Byte], offset: Int, length: Int) = {
}
def onMessage(opcode: Byte, data: String)
def onConnect(p1: Outbound) = {
_outbound = p1
}
def sendMessage(data: String) {
outOption foreach { _.sendMessage(data) }
}
}
trait WebSocketSupport extends Logging { self: ScalatraKernel =>
private val wsFactory = new WebSocketFactory
type WebSocketAction = WebSocket => Unit
def ws(routeMatchers: RouteMatcher*)(action: WebSocket => Unit) = {
addRoute("WS", routeMatchers, {
try {
doUpgrade { () =>
val websocket = new WebSocket { }
action(websocket)
websocket.result
}
Unit
} catch {
case e => log error (e, "There was an error creating the websocket")
Unit
}
})
}
private def doUpgrade(matcher: () => ServletWebSocket) = {
val hixie = Option(request.getHeader("Sec-WebSocket-Key1")).isDefined
val ph = request.getHeader(if(hixie) "Sec-WebSocket-Protocol" else "WebSocket-Protocol")
val protocol = Option(ph) getOrElse request.getHeader("Sec-WebSocket-Protocol")
val host = request.getHeader("Host")
val origin = Option(request.getHeader("Origin")) getOrElse host
log debug ("is hixie: %s\nprotocol: %s\nhost: %s\norigin: %s", hixie, protocol, host, origin)
// val websocket = createWebSocket(req, (Option(protocol) getOrElse "").split(" ").headOption getOrElse null)
val websocket = matcher()
if(websocket == null) {
if(hixie) response.setHeader("Connection", "close")
response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE)
} else {
wsFactory.upgrade(request, response, websocket, origin, protocol)
}
websocket
}
}