kitaro-tn
1/10/2018 - 8:12 AM

[Spark Recommendation]行 変換 #Spark

[Spark Recommendation]行 変換 #Spark

import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.recommendation.ALS
import spark.implicits._
import org.apache.spark.sql.functions.{udf, explode}



case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)
def parseRating(str: String): Rating = {
  val fields = str.split("::")
  assert(fields.size == 4)
  Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)
}

val ratings = spark.read.textFile("data/mllib/als/sample_movielens_ratings.txt")
  .map(parseRating)
  .toDF()
val Array(training, test) = ratings.randomSplit(Array(0.8, 0.2))

// Build the recommendation model using ALS on the training data
val als = new ALS()
  .setMaxIter(5)
  .setRegParam(0.01)
  .setUserCol("userId")
  .setItemCol("movieId")
  .setRatingCol("rating")
val model = als.fit(training)

val userRecommend = model.recommendForAllUsers(10)
val zip = udf((xs: Seq[Int], ys: Seq[Float]) => xs.zip(ys))

userRecommend
	.withColumn("recommends", explode(zip($"recommendations.movieId", $"recommendations.rating")))
	.select($"userId", $"recommends._1".alias("movieId"), $"recommends._2".alias("rating"))
	.format("jdbc")
    .mode("overwrite")
    .option("driver", "com.mysql.jdbc.Driver")
    .option("url", "jdbc:mysql://127.0.0.1:3306/db?user=username&password=password&useSSL=false")
    .option("dbtable", "schema.table")
    .option("truncate", "true")