Ramandeep Singh Nanda

Wed 17 January 2018


Testing Spark Dataframes

Testing Spark Dataframe transforms is essential and can be accomplished in a more reusable manner. The way, I generally accomplish that is to

  • Read the expected and test Dataframe, and
  • Invoke the desired transform, and
  • Calculate the difference between dataframes. The only caveat in calculating the difference is that in built except function is not sufficient for columns with decimal column types and that requires a bit of work.

To accomplish generic dataframe comparison:

  • We need to look at the type of the column and when its numeric,
  • Convert it to the corresponding java type and then do decimal comparisons , while allowing for custom precision mismatches. Otherwise,
  • Just use the except clause for other column comparisons.

Comparison Code

def compareDF(result: Dataset[Row], expected: Dataset[Row]): Unit = {
val expectedSchemaMap = expected.schema.map(sf => (sf.name, sf.dataType)).toMap[String, DataType]
val resSchemaMap = result.schema.map(sf => (sf.name, sf.dataType)).toMap[String, DataType]
_ match {
  case (name: String, dType: NumericType) =>
    assert(compareNumericTypes(result, expected, resSchemaMap(name), dType, name, s"$name column was not equal")
  case kv: Map[_, _] =>
    assert(result.select(kv._1).except(result.select(kv._1)).count() == 0, s"${kv._1} column was not equal")

def compareNumericTypes(result: Dataset[Row], expected: Dataset[Row], resType: DataType, expType: DataType, colName: String, precision: Double = 0.01): Boolean = {
  //collect Results
  val res = extractAndSortNumericRow(result, colName, resType)
  val exp = extractAndSortNumericRow(expected, colName, expType)
  //compare lengths first
  if (res.length != exp.length) return false
  res match {
    case Seq(_: java.lang.Integer, _*) | Seq(_: java.lang.Long, _*) =>
      !res.zip(exp).exists(zipped => (safelyGet(zipped._1).longValue() - safelyGet(zipped._2).longValue()) != 0L)
    case Seq(_: java.lang.Float, _*) | Seq(_: java.lang.Double, _*) =>
      !res.zip(exp).exists(zipped => (safelyGet(zipped._1).doubleValue() - safelyGet(zipped._2).doubleValue()).abs >= precision)
//upcast types
def safelyGet[T >: Number](v: T): T = {
  v match {
    case _: java.lang.Long | _: java.lang.Integer => java.lang.Long.parseLong(v.toString)
    case _: java.lang.Float | _: java.lang.Double =>
    case _ => v

//map internal spark types to java types.

def extractAndSortNumericRow[B >: Number](df: Dataset[Row], colName: String, dt: DataType): Seq[B] = {
  import ss.implicits._
  dt match {
    case _: LongType => df.select(colName).map(row => row.getAs[java.lang.Long](0)).sort().collect().toSeq
    case _: IntegerType => df.select(colName).map(row => row.getAs[java.lang.Integer](0)).sort().collect().toSeq
    case _: DoubleType => df.select(colName).map(row => row.getAs[java.lang.Double](0)).sort().collect().toSeq
    case _: FloatType => df.select(colName).map(row => row.getAs[java.lang.Float](0)).sort().collect().toSeq
    case _: DecimalType => df.select(colName).map(row => row.getAs[java.lang.BigDecimal](0)).sort().collect().toSeq

The code above does the heavylifting for doing comparisons for dataframes. Now all we need is a simple function that invokes the transforms and some simple scalatest testing code showing all this in action.

Function that invokes the transform and does comparison:

def invokeAndCompare(testFileName: String, expectedFileName: String, func: Dataset[Row] => Dataset[Row]): Unit = {
  val df = readJsonDF(testFileName)
  val expected = readJsonDF(expectedFileName)
  val transformResult = func(df)
  compareDF(transformResult, expected)

def readJsonDF(fileName: String): Dataset[Row] = {

Testing Code

Just utilize ScalaTest. Here is how a test looks like for your transforms.

class RandomTransformsTest extends FlatSpec with Matchers with BeforeAndAfter {
  after {
    //close spark session
  before {
    val ss = SparkSession.builder().master("local[*]").getOrCreate()

  "testRandomTransform" should "give correct output for input dataframe" in {
    val testFileLoc = ""
    val expectedFileLoc = ""
    //just get the function definition, it will be invoked by invokeAndCompare with the dataframe later on.
    val func = RandomTransforms.someRandomFunc() _
    SomeObject.invokeAndCompare(testFileLoc, expectedFileLoc, func)


Wrap Up:

So, there we go, testing made easy for Spark dataframes. It requires some tedious mapping for decimal numbers, but once developed, tests are easy to write for all your dataframe transforms.

