Ramandeep Singh Nanda
Published

Thu 12 April 2018

←Home

Removing Projection Column Ambiguity in Spark

Column ambiguity is quite common when you join two tables. Now this poses a unnecessary hassle when you want to select all the columns from both the tables whilst discarding the duplicate columns. The aforementioned problem is difficult to handle especially, if you have wide tables, where you would want to avoid typing the column names.

There are a couple of programmatic solutions to the problem, both essentially do the same thing, but achieve the results differently.

  • Either execute sql traditionally using ss.sql(query), and then manually transform the Dataset by converting to RDD, dropping duplicate column names, or
  • Implement the query execution in a similar way as spark does, drop the duplicate column names and then create Dataset this avoid unnecessary conversion and creation of Dataset until duplicates are dropped.

Note: There are huge trade offs with one of these approaches, first one works for all scenarios, the other one uses executeCollectPublic and has high probability of giving heap errors .

Solution 2:

def sqlDropDuplicateColumns(query: String, ss: SparkSession): Dataset[Row] = {
   val logicalPlan = ss.sessionState.sqlParser.parsePlan(sqlText = query)
   val qe = ss.sessionState.executePlan(logicalPlan)
   //Assert plan is valid
   qe.assertAnalyzed()
   val ep = qe.executedPlan
   //Drop duplicate column names
   val schema = schemaToSet(ep.schema)
   val rows = ep.executeCollectPublic().map(r => {
     val lb = new ListBuffer[Any]
     schema.map(sf => lb += r.getAs(sf.name))
     Row(lb: _*)
   })
   ss.createDataFrame(ss.sparkContext.parallelize(rows), schema)
 }

Solution 1:

def sqlDropDuplicateColumnsRDD(query: String, ss: SparkSession): Dataset[Row] = {
  val df=ss.sql(query)
  //Drop duplicate column names
  val schema = schemaToSet(df.schema)
  val rdd = df.rdd.map(r => {
    val lb = new ListBuffer[Any]
    schema.map(sf => lb += r.getAs(sf.name))
    Row(lb: _*)
  })
  ss.createDataFrame(rdd, schema)
}

Helper method:

def schemaToSet(schema:StructType):StructType={
  val schemaMap=new mutable.HashMap[String,StructField]()
  for(sf<-schema){
    if(!schemaMap.contains(sf.name.toLowerCase)){
      schemaMap.put(sf.name.toLowerCase,sf)
    }
  }
  StructType(schemaMap.values.toArray)
}
Go Top
comments powered by Disqus