Spark读取elasticsearch踩坑

本文介绍了spark读取elasticsearch踩的一些坑,spark读取elasticsearch的正确方式案例

最近要在 Spark job 中通过 Spark SQL 的方式读取 Elasticsearch 数据,踩了一些坑,总结于此。

环境说明

Spark job 的编写语言为 Scala,scala-library 的版本为 2.11.8。

Spark 相关依赖包的版本为 2.3.2,如 spark-core、spark-sql。

Elasticsearch 数据

schema

{
  "settings": {
    "number_of_replicas": 1
  },
  "mappings": {
    "label": {
      "properties": {
        "docId": {
          "type": "keyword"
        },
        "labels": {
          "type": "nested",
          "properties": {
            "id": {
              "type": "long"
            },
            "label": {
              "type": "keyword"
            }
          }
        },
        "itemId": {
          "type": "long"
        }
      }
    }
  }
}

sample data

{
  "took" : 141,
  "timed_out" : false,
  "_shards" : {
    "total" : 5,
    "successful" : 5,
    "skipped" : 0,
    "failed" : 0
  },
  "hits" : {
    "total" : 17370929,
    "max_score" : 1.0,
    "hits" : [
      {
        "_index" : "aen-label-v1",
        "_type" : "label",
        "_id" : "123_ITEM",
        "_score" : 1.0,
        "_source" : {
          "docId" : "123_ITEM",
          "labels" : [
            {
              "id" : 7378,
              "label" : "1kg"
            }
          ],
          "itemId" : 123
        }
      },
      {
        "_index" : "aen-label-v1",
        "_type" : "label",
        "_id" : "456_ITEM",
        "_score" : 1.0,
        "_source" : {
          "docId" : "456_ITEM",
          "labels" : [
            {
              "id" : 7378,
              "label" : "2kg"
            }
          ],
          "itemId" : 456
        }
      }
    ]
  }
}

准备工作

既然要用 Spark SQL,当然少不了其对应的依赖,

dependencies {
  implementation 'org.apache.spark:spark-core_2.11:2.3.2'
  implementation 'org.apache.spark:spark-sql_2.11:2.3.2'
}

对于 ES 的相关库,如同 官网 所说,要在 Spark 中访问 ES,需要将 elasticsearch-hadoop 依赖包加入到 Spark job 运行的类路径中,具体而言就是添加到 Spark job 工程的依赖中,公司的 nexus 中当前最新的版本为 7.15.0,且目前我们是使用 gradle 管理依赖,故添加依赖的代码如下,

dependencies {
  implementation 'org.elasticsearch:elasticsearch-hadoop:7.15.0'
}

本地测试

对于 Spark,基于资源管理器的不同,可以在两种模式下运行:本地模式和集群模式,可通过 --master 参数来指定资源管理器的方式。本地模式时,不依赖额外的 Spark 集群,Spark 将在同一台机器上运行所有内容,非常方便用于本地测试,对于 Spark SQL,只需要在创建 SparkSession 时采用 local 的模式即可,

class MyUtils extends Serializable {
  def esHost() = s"es.sherlockyb.club"
  
  // local mode
  def getLocalSparkSession: SparkSession = SparkSession.builder()
    .master("local")
    .getOrCreate()
  
  // cluster mode
  def getSparkSession: SparkSession = SparkSession.builder()
    .enableHiveSupport()
    .config("spark.sql.broadcastTimeout", "3600")
    .getOrCreate()
}

测试代码

object LocalTest extends LazyLogging {
  def main(args: Array[String]): Unit = {
    new LocalTest().run()
  }
}

class LocalTest {
  def run(): Unit = {
    val myUtils = new MyUtils
    val spark = myUtils.getLocalSparkSession
    import spark.implicits._

    var start = System.currentTimeMillis()
    val attributeId = 7378L
    val labelNames = Array("aen-label-retail", "aen-label-seller")
    spark.read
      .format("es")
      .option("es.nodes", myUtils.esHost())
      .option("es.port", "9200")
      .option("es.nodes.wan.only", value = true)
      .option("es.resource", Joiner.on(",").join(java.util.Arrays.asList(labelNames:_*)) + "/label")
      .option("es.scroll.size", 2000)
      .load()
      .createOrReplaceTempView("temp_labels")
    
    val sqlDf = spark.sql("select itemId, labels from temp_labels where itemId in (123, 456)")
    val newDf = sqlDf
      .map(row => {
        val labels = row.getAs[Seq[Row]]("labels")
        val labelValue = labels.find(p => p.getAs[Long]("id") == attributeId).map(p => p.getAs[String]("label"))

        (row.getAs[Long]("itemId"), attributeId, labelValue.orNull)
      })
      .withColumn("final_result", lit("PASS"))
      .toDF("itemId", "attributeId", "label", "final_result")

    val finalDf = newDf.toDF("itemId", "attributeId", "label", "result")
    finalDf.printSchema()
    finalDf.show()
    
    var emptyDf = newDf
      .filter(col("label").isNotNull)
      .toDF("itemId", "attributeId", "label", "result")
    emptyDf = emptyDf.union(finalDf)
    emptyDf.printSchema()
    emptyDf.show()

    emptyDf.filter(col("itemId") === 6238081929L and col("label").notEqual(col("result")))
      .show()

    val attributeTypeIds = Array.fill(3)(100)
    val attributeTypeIdsStr = Joiner.on(",").join(java.util.Arrays.asList(attributeTypeIds:_*))
    println(attributeTypeIdsStr)


    import scala.collection.JavaConverters._
    emptyDf = emptyDf.filter(!col("itemId").isin(trainItemIds.asScala.map(Long2long).toList:_*))
    emptyDf.show(false)
  }
}

踩的坑

es.nodes.wan.only

该配置项表示连接器是否用于 WAN 上的云或受限环境如 AWS 中的 Elasticsearch 实例,默认为 false,而公司的 Elasticsearch 集群是在 AWS 上的,endpoint 只能在内网访问,因而刚开始测试时,遇到如下报错,

Exception in thread "main" org.elasticsearch.hadoop.EsHadoopIllegalArgumentException: No data nodes with HTTP-enabled available
	at org.elasticsearch.hadoop.rest.InitializationUtils.filterNonDataNodesIfNeeded(InitializationUtils.java:159)
	at org.elasticsearch.hadoop.rest.RestService.findPartitions(RestService.java:223)
	at org.elasticsearch.spark.rdd.AbstractEsRDD.esPartitions$lzycompute(AbstractEsRDD.scala:73)
	at org.elasticsearch.spark.rdd.AbstractEsRDD.esPartitions(AbstractEsRDD.scala:72)
	at org.elasticsearch.spark.rdd.AbstractEsRDD.getPartitions(AbstractEsRDD.scala:44)
	at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:253)
	at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:251)
	at scala.Option.getOrElse(Option.scala:121)
	at org.apache.spark.rdd.RDD.partitions(RDD.scala:251)
	at org.apache.spark.rdd.MapPartitionsRDD.getPartitions(MapPartitionsRDD.scala:46)
	at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:253)
	at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:251)
	at scala.Option.getOrElse(Option.scala:121)
	at org.apache.spark.rdd.RDD.partitions(RDD.scala:251)
	at org.apache.spark.rdd.MapPartitionsRDD.getPartitions(MapPartitionsRDD.scala:46)
	at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:253)
	at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:251)
	at scala.Option.getOrElse(Option.scala:121)

通过 option("es.nodes.wan.only", value = true) 将配置项设置为 true 后恢复正常。

importing spark.implicits._
在遍历 DataFrame 时遇到如下编译错误,

Unable to find encoder for type stored in a Dataset.  Primitive types (Int, String, etc) and Product types (case classes) are supported by importing spark.implicits._

在处理 DataFrame 之前需要加上 importing spark.implicits._,用于将常见的 Scala 对象转换为 DataFrame,通常在获取 SparkSession 后立马 import。

Spark SQL 读取 hive 表中 array 类型时,对于 Scala 语言,得到的类型是 WrappedArray 而不是 Array
当我们通过 createOrReplaceTempView("temp_labels") 构建一个临时表视图后,就可以通过 SQL 像操作 hive 表那样读取数据。例如读取指定的列,

val sqlDf = spark.sql("select itemId, labels from temp_labels where itemId in (123, 456)")

通过 sqlDf.printSchema() 可以看到 sqlDf 的 schema 长这样,

root
 |-- itemId: long (nullable = true)
 |-- labels: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- id: long (nullable = true)
 |    |    |-- label: string (nullable = true)

labels 是包含 struct 的数组,于是从 row 中将 labels 列读出时想尝试转换为 Array,

val newDf = sqlDf.map(
  row => {
    val labels = row.getAs[Array[Row]]("labels")
    val labelValue = labels.find(p => p.getAs[Long]("id") == attributeId).map(p => p.getAs[String]("label"))

    (row.getAs[Long]("itemId"), attributeId, labelValue.orNull)
  }
)

结果报错如下,

java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to [Lorg.apache.spark.sql.Row;

可以看到 Spark SQL 在读取表中数组列时,是用的 scala.collection.mutable.WrappedArray 来存储结果的,看其类定义可知,它是间接实现 Seq 接口的,所以也可用 row.getAsSeq[Row] 来读取。这里需要注意的是,Array[T] 虽然在 Scala 源码定义中是 class,但其对标的 Java 类型是原生数组 T[]。

判断 Column 是否为 null 时,需要用 is null 或 is not null,而不是 === 或 !==
对于错误的用法,filter 并不会生效,就像下面这样

newDf.filter(col("label") !== null)

这一点和 hive 表以及 MySQL 表判断字段是否为 null,是保持一致的,应该像下面这样,

newDf.filter(col("label").isNotNull)

最终代码

import com.google.common.base.Joiner
import com.typesafe.scalalogging.LazyLogging
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}

object TestMain extends LazyLogging {
  def main(args: Array[String]): Unit = {
    val myUtils = new MyUtils
    new TestApp(myUtils).run()
  }
}

class TestApp(myUtils: MyUtils) extends Serializable with LazyLogging {  
  def esDf(spark: SparkSession, indices: Array[String]): DataFrame = {
    spark.read
      .format("es")
      .option("es.nodes", myUtils.esHost())
      .option("es.port", "9200")
      .option("es.nodes.wan.only", value = true)
      .option("es.resource", Joiner.on(",").join(java.util.Arrays.asList(indices:_*)) + "/label")
      .option("es.scroll.size", 2000)
      .load()
  }
  
  def run(): Unit = {
    val spark = myUtils.getSparkSession
    import spark.implicits._
    
    val esTempView = "es_label"
    val labelNames = Array("aen-label-retail", "aen-label-seller")
    esDf(spark, labelNames).createOrReplaceTempView(esTempView)
    
    val labelDf = getLabelDf(spark, itemIdsStr, attributeTypeIds, esTempView)
    println("debug log")
    labelDf.printSchema()
    labelDf.show()
    labelDf.createOrReplaceTempView("final_labels")
    
    val data = spark.sql(
      s"""
      |select cc.*, pp.final_result, pp.label, null as remark
      |from temp_request cc
      |left join final_labels pp
      |on cc.itemid = pp.itemId
      |and cc.attributetypeid = pp.attributeId
      |where cc.profile = '$jobId'
      |""".stripMargin)

    data.distinct().write.mode(SaveMode.Overwrite)
    .option("compression", "gzip")
    .json(s"s3://sherlockyb-test/check-precision/job_id=$jobId")
  }
  
  def getLabelDf(spark: SparkSession, itemIdsStr: String, attributeTypeIds: Array[String], esTempView: String): DataFrame = {
    import spark.implicits._

    val sqlDf = spark.sql(s"select itemId, labels from $esTempView where itemId in ($itemIdsStr)")
    val emptyDf = spark.emptyDataFrame
    var labelDf = emptyDf
    attributeTypeIds.foreach(attributeTypeId => {
      val attributeDf = sqlDf
        .map(row => {
          val labels = row.getAs[Seq[Row]]("labels")
          val labelValue = labels.find(p => p.getAs[Long]("id") == attributeTypeId.toLong).map(p => p.getAs[String]("label"))

          (row.getAs[Long]("itemId"), attributeTypeId.toLong, labelValue.orNull)
        })
        .withColumn("final_result", lit("PASS"))
        .toDF("itemId", "attributeId", "label", "final_result")
        .filter(col("label").isNotNull)
      if (labelDf == emptyDf) {
        labelDf = attributeDf
      } else {
        labelDf = labelDf.union(attributeDf)
      }
    })

    labelDf.dropDuplicates("itemId","attributeId")
  }
}
上一篇 下一篇


推荐文章

评论
说点什么吧?

发表评论

取消回复
  最新文章