casualjim
10/22/2011 - 9:20 PM

CORS Support for scalatra

CORS Support for scalatra

package backchat
package web
package tests

import org.scalatra.test.specs2.ScalatraSpec
import org.scalatra.ScalatraServlet

class CORSSupportSpec extends ScalatraSpec {
  
  addServlet(new ScalatraServlet with Logging with CORSSupport {
    override protected lazy val corsConfig =
      CORSConfig(List("http://www.example.com"), List("GET", "HEAD", "POST"), "X-Requested-With,Authorization,Content-Type,Accept,Origin".split(","), true)

    get("/") {
      "OK"
    }

  }, "/*")
  
  def is =
    "The CORS support should" ^
      "augment a valid simple request" ! context.validSimpleRequest ^
      "not touch a regular request" ! context.dontTouchRegularRequest ^
      "respond to a valid preflight request" ! context.validPreflightRequest ^
      "respond to a valid preflight request with headers" ! context.validPreflightRequestWithHeaders ^ end


  object context {
    def validSimpleRequest = {
      get("/", headers = Map(CORSSupport.ORIGIN_HEADER -> "http://www.example.com")) {
        response.getHeader(CORSSupport.ACCESS_CONTROL_ALLOW_ORIGIN_HEADER) must_== "http://www.example.com"
      }
    }
    def dontTouchRegularRequest = {
      get("/") {
        response.getHeader(CORSSupport.ACCESS_CONTROL_ALLOW_ORIGIN_HEADER) must beNull
      }
    }

    def validPreflightRequest = {
      options("/", headers = Map(CORSSupport.ORIGIN_HEADER -> "http://www.example.com", CORSSupport.ACCESS_CONTROL_REQUEST_METHOD_HEADER -> "GET", "Content-Type" -> "application/json")) {
        response.getHeader(CORSSupport.ACCESS_CONTROL_ALLOW_ORIGIN_HEADER) must_== "http://www.example.com"
      }
    }

    def validPreflightRequestWithHeaders = {
      val hdrs = Map(
        CORSSupport.ORIGIN_HEADER -> "http://www.example.com",
        CORSSupport.ACCESS_CONTROL_REQUEST_METHOD_HEADER -> "GET",
        CORSSupport.ACCESS_CONTROL_REQUEST_HEADERS_HEADER -> "Origin, Authorization, Accept",
        "Content-Type" -> "application/json")
      options("/", headers = hdrs) {
        response.getHeader(CORSSupport.ACCESS_CONTROL_ALLOW_ORIGIN_HEADER) must_== "http://www.example.com"
      }
    }
  }
}
package backchat
package web

import javax.servlet.http.{ HttpServletResponse, HttpServletRequest }
import org.scalatra._
import collection.JavaConversions._

object CORSSupport {
  val ORIGIN_HEADER: String = "Origin"
  val ACCESS_CONTROL_REQUEST_METHOD_HEADER: String = "Access-Control-Request-Method"
  val ACCESS_CONTROL_REQUEST_HEADERS_HEADER: String = "Access-Control-Request-Headers"
  val ACCESS_CONTROL_ALLOW_ORIGIN_HEADER: String = "Access-Control-Allow-Origin"
  val ACCESS_CONTROL_ALLOW_METHODS_HEADER: String = "Access-Control-Allow-Methods"
  val ACCESS_CONTROL_ALLOW_HEADERS_HEADER: String = "Access-Control-Allow-Headers"
  val ACCESS_CONTROL_MAX_AGE_HEADER: String = "Access-Control-Max-Age"
  val ACCESS_CONTROL_ALLOW_CREDENTIALS_HEADER: String = "Access-Control-Allow-Credentials"
  //  private val ACCESS_CONTROL_EXPOSE_HEADERS_HEADER = "Access-Control-Expose-Headers"

  private val ANY_ORIGIN: String = "*"
  private val SIMPLE_HEADERS = List(ORIGIN_HEADER.toUpperCase(ENGLISH), "ACCEPT", "ACCEPT-LANGUAGE", "CONTENT-LANGUAGE")
  private val SIMPLE_CONTENT_TYPES = List("APPLICATION/X-WWW-FORM-URLENCODED", "MULTIPART/FORM-DATA", "TEXT/PLAIN")
  val CORS_HEADERS = List(
    ORIGIN_HEADER,
    ACCESS_CONTROL_ALLOW_CREDENTIALS_HEADER,
    ACCESS_CONTROL_ALLOW_HEADERS_HEADER,
    ACCESS_CONTROL_ALLOW_METHODS_HEADER,
    ACCESS_CONTROL_ALLOW_ORIGIN_HEADER,
    ACCESS_CONTROL_MAX_AGE_HEADER,
    ACCESS_CONTROL_REQUEST_HEADERS_HEADER,
    ACCESS_CONTROL_REQUEST_METHOD_HEADER)
  //  private val SIMPLE_RESPONSE_HEADERS = List("CACHE-CONTROL", "CONTENT-LANGUAGE", "EXPIRES", "LAST-MODIFIED", "PRAGMA", "CONTENT-TYPE")
}
trait CORSSupport extends Handler { self: ScalatraKernel with Logging ⇒

  import CORSSupport._

  protected def corsConfig = Config.CORS
  private val anyOriginAllowed: Boolean = corsConfig.allowedOrigins.contains(ANY_ORIGIN)
  private val allowedOrigins = corsConfig.allowedOrigins
  private val allowedMethods = corsConfig.allowedMethods
  private val allowedHeaders = corsConfig.allowedHeaders
  private val preflightMaxAge: Int = corsConfig.preflightMaxAge
  private val allowCredentials: Boolean = corsConfig.allowCredentials

  logger debug "Enabled CORS Support with:\nallowedOrigins: %s\nallowedMethods: %s\nallowedHeaders: %s".format(
    allowedOrigins mkString ", ",
    allowedMethods mkString ", ",
    allowedHeaders mkString ", ")

  protected def handlePreflightRequest() {
    logger trace "handling preflight request"
    // 5.2.7
    augmentSimpleRequest()
    // 5.2.8
    if (preflightMaxAge > 0) response.setHeader(ACCESS_CONTROL_MAX_AGE_HEADER, preflightMaxAge.toString)
    // 5.2.9
    response.setHeader(ACCESS_CONTROL_ALLOW_METHODS_HEADER, allowedMethods mkString ",")
    // 5.2.10
    response.setHeader(ACCESS_CONTROL_ALLOW_HEADERS_HEADER, allowedHeaders mkString ",")
    response.flushBuffer()
    response.getOutputStream.flush()
  }

  protected def augmentSimpleRequest() {
    val hdr = if (anyOriginAllowed && !allowCredentials) ANY_ORIGIN else request.getHeader(ORIGIN_HEADER)
    response.setHeader(ACCESS_CONTROL_ALLOW_ORIGIN_HEADER, hdr)
    if (allowCredentials) response.setHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS_HEADER, "true")
    /*
    if (allowedHeaders.nonEmpty) {
      val hdrs = allowedHeaders.filterNot(hn => SIMPLE_RESPONSE_HEADERS.contains(hn.toUpperCase(ENGLISH))).mkString(",")
      response.addHeader(ACCESS_CONTROL_ALLOW_HEADERS_HEADER, hdrs)
    }
*/
  }

  private def originMatches = // 6.2.2
    anyOriginAllowed || (allowedOrigins contains request.getHeader(ORIGIN_HEADER))

  private def isEnabled =
    !("Upgrade".equalsIgnoreCase(request.getHeader("Connection")) &&
      "WebSocket".equalsIgnoreCase(request.getHeader("Upgrade"))) &&
      !requestPath.contains("eb_ping") // don't do anything for the ping endpoint

  private def isValidRoute: Boolean = routes.matchingMethods.nonEmpty
  private def isPreflightRequest = {
    val isCors = isCORSRequest
    val validRoute = isValidRoute
    val isPreflight = request.getHeader(ACCESS_CONTROL_REQUEST_METHOD_HEADER).isNotBlank
    val enabled = isEnabled
    val matchesOrigin = originMatches
    val methodAllowd = allowsMethod
    val allowsHeaders = headersAreAllowed
    val result = isCors && validRoute && isPreflight && enabled && matchesOrigin && methodAllowd && allowsHeaders
    logger trace "This is a preflight validation check. valid? %s".format(result)
    logger trace "cors? %s, route? %s, preflight? %s, enabled? %s, origin? %s, method? %s, header? %s".format(
      isCors, validRoute, isPreflight, enabled, matchesOrigin, methodAllowd, allowsHeaders)
    result
  }

  private def isCORSRequest = { // 6.x.1
    val h = request.getHeader(ORIGIN_HEADER)
    val result = h.isNotBlank
    if (!result) logger trace ("No origin found in the request")
    else logger trace ("We found the origin: %s".format(h))
    result
  }

  private def isSimpleHeader(header: String) = {
    val ho = header.toOption
    ho.isDefined && (ho forall { h ⇒
      val hu = h.toUpperCase(ENGLISH)
      SIMPLE_HEADERS.contains(hu) || (hu == "CONTENT-TYPE" &&
        SIMPLE_CONTENT_TYPES.exists(request.getContentType.toUpperCase(ENGLISH).startsWith))
    })
  }

  private def allOriginsMatch = { // 6.1.2
    val h = request.getHeader(ORIGIN_HEADER).toOption
    h.isDefined && h.get.split(" ").nonEmpty && h.get.split(" ").forall(allowedOrigins.contains)
  }

  private def isSimpleRequest = {
    val isCors = isCORSRequest
    val enabled = isEnabled
    val allOrigins = allOriginsMatch
    val res = isCors && enabled && allOrigins && request.getHeaderNames.forall(isSimpleHeader)
    logger trace "This is a simple request: %s, because: %s, %s, %s".format(res, isCors, enabled, allOrigins)
    res
  }

  private def allowsMethod = { // 5.2.3 and 5.2.5
    val accessControlRequestMethod = request.getHeader(ACCESS_CONTROL_REQUEST_METHOD_HEADER)
    logger.trace("%s is %s" format (ACCESS_CONTROL_REQUEST_METHOD_HEADER, accessControlRequestMethod))
    val result = accessControlRequestMethod.isNotBlank && allowedMethods.contains(accessControlRequestMethod.toUpperCase(ENGLISH))
    logger.trace("Method %s is %s among allowed methods %s".format(accessControlRequestMethod, if (result) "" else " not", allowedMethods))
    result
  }

  private def headersAreAllowed = { // 5.2.4 and 5.2.6
    val accessControlRequestHeaders = request.getHeader(ACCESS_CONTROL_REQUEST_HEADERS_HEADER).toOption
    logger.trace("%s is %s".format(ACCESS_CONTROL_REQUEST_HEADERS_HEADER, accessControlRequestHeaders))
    val ah = (allowedHeaders ++ CORS_HEADERS).map(_.trim.toUpperCase(ENGLISH))
    val result = accessControlRequestHeaders forall { hdr ⇒
      val hdrs = hdr.split(",").map(_.trim.toUpperCase(ENGLISH))
      logger.debug("Headers [%s]".format(hdrs))
      (hdrs.nonEmpty && hdrs.forall { h ⇒ ah.contains(h) }) || isSimpleHeader(hdr)
    }
    logger.trace("Headers [%s] are %s among allowed headers %s".format(
      accessControlRequestHeaders getOrElse "No headers", if (result) "" else " not", ah))
    result
  }

  abstract override def handle(req: HttpServletRequest, res: HttpServletResponse) {
    _request.withValue(req) {
      logger trace "the headers are: %s".format(req.getHeaderNames.mkString(", "))
      _response.withValue(res) {
        request.method match {
          case Options if isPreflightRequest ⇒ {
            handlePreflightRequest()
          }
          case Get | Post | Head if isSimpleRequest ⇒ {
            augmentSimpleRequest()
            super.handle(req, res)
          }
          case _ if isCORSRequest ⇒ {
            augmentSimpleRequest()
            super.handle(req, res)
          }
          case _ ⇒ super.handle(req, res)
        }
      }
    }
  }

}