How to Write Code Using The Spark Dataframe API: A Focus on Composability And Testing

I was recently thinking about how we should write Spark code using the Dataframe API. It turns out that there are a lot of different choices you can make and, sometimes, innocuous looking ones can bite you in the long run.

The Question Before The Question: DataFrame API Or Spark SQL?

(TL;DR: Use the DataFrame API!)

Before beginning: I assume you created the Spark context as spark and that you have a dataframe called df.

The first point to consider is that you can write code using the Dataframe API like this1:

df = ...  # IO here
df_my_column = df.select("my_column") 

Or using Spark SQL, like this

df_my_column = spark.sql("SELECT my_column FROM input")

The are advantages and disadvantages with both methods.

The Spark SQL method is very familiar to all analysts and people fluent in SQL. As a drawback, even though it returns a dataframe, it needs every dataframe to be registered as a temporary view before allowing it to be queried:

df = source_df.select('a_column')
try:
    spark.sql("select mean(a_column) from df")
except:  # a py4j exception is raised here
    print("It can't find df")

df.createOrReplaceTempView("df")
spark.sql("select mean(a_column) from df")  # now it works

The Dataframe API is much more concise:

import pyspark.sql.functions as sf

df = source_df.select('a_column')
df.select(sf.mean('a_column'))

On the other hand, it can get quite involved and "scary" 

from pyspark.sql import Window
d_types = ...
c_types = ...
df.withColumn('type',
              sf.when(sf.sum(sf.col('vehicle').isin(d_types).cast('Int'))
                        .over(Window.partitionBy('id')) > 0, 'd_type')
                .when(sf.col('vehicle').isin(c_types), 'c_type')
                .otherwise('other_type')))

(In all fairness, writing the above bit in SQL would also be quite daunting.)

But for me the real advantage comes from composing and dealing with objects in a more abstract way. The above snippet of code should, ideally, be a function:

def my_function(df, d_types, c_types):
    return df.withColumn('type',
                         sf.when(sf.sum(sf.col('vehicle').isin(d_types).cast('Int'))
                                   .over(Window.partitionBy('id')) > 0, 'd_type')
                           .when(sf.col('vehicle').isin(c_types), 'c_type')
                           .otherwise('other_type')))

If I were to rewrite that in Spark SQL, I'd have to do the following

def my_function(df, d_types, c_types):
    # do something with d_types and c_types to be able to pass them to SQL
    table_name = 'find_a_unique_table_name_not_to_clash_with_other'
    df.createOrReplaceTempView(table_name)
    return spark.sql("""
                        YOUR SQL HERE WITH %s AND MORE %s's TO INSERT c_types, d_types AND table_name
                     """ % (c_types, d_types, table_name)) 

The above function is mixing IO (the createOrReplaceTempView) with logic (the SQL execution). As a cherry on top of that, it's doing string interpolation, which is bad (like really really bad!).

Disentangling would mean rewrite them like so

def register_df_as_table(df):
    table_name = .... # generate some random unique name here
    df.createOrReplaceTempView(table_name)
    return table_name


def my_function(table_name, d_types, c_types): 
    # do something with d_types and c_types to be able to pass them to SQL 
    return spark.sql("""
                        YOUR SQL HERE WITH %s AND MORE %s's TO INSERT c_types, d_types AND table_name
                     """ % (c_types, d_types, table_name)) 

In principle you could create a decorator out of register_df_as_table and decorate my_function, but you can see that this is getting pretty involved. With the dataframe API you can compose function much more easily.

Further composing away

With that out of the door, let's see how you can further compose your functions and test them.

I won't write the code here, but let's say we have two extra functions, a_function and another_function, with a flow like this:

def load_data(..):
    pass


def my_function(df, other_args):
    pass


def a_function(df, other_args):
    pass


def another_function(df): 
    pass


def main():
    df_1 = load_data(..)
    df_2 = my_function(df_1, args_1)
    df_3 = a_function(df_2, args_2)
    df_4 = another_function(df_3)
    return df_4

The naming of those variables (df_{1..4}) is terrible, but, as you all know, there are only two hard problems in computer science: naming things, off by one errors, and overwriting variables (such as naming them all df).

A better alternative would involve piping the various functions

def pipe(data, *funcs):
    for func in funcs:
        data = func(data)
    return data


def main():
    partial_my_function = lambda df: my_function(df, args1)
    partial_a_function = lambda df: a_function(df, args2)
    return pipe(load_data(),
                partial_my_function,
                partial_a_function,
                another_function)

This makes it, to my eyes, much better. Testing such a flow would then look like

def get_test_data():
    # do something
    return data


def test_my_function():
    data = get_test_data()
    assert my_function(data, args_1) == something  # ideally this is a bit more involved


def test_a_function():
    partial_my_function = lambda df: my_function(df, args1)
    data = pipe(get_test_data, partial_my_function)
    assert a_function(data) == something


def test_another_function():
    partial_my_function = lambda df: my_function(df, args1) 
    partial_a_function = lambda df: a_function(df, args2)
    data = pipe(get_test_data, partial_my_function, partial_a_function) 
    assert another_function(data) == something

# other tests here

This way, when one of the functions breaks, all successive tests will fail2.

Ok, that was a lot of (dummy) code. As always, let me know what you think, especially if you disagree (I'm @gglanzani on Twitter if you want to reach out!).

We are hiring


  1. Technically (thank you Andrew) this syntax mixes the DataFrame and SQL API. The DataFrame way of doing that is df.select(df.my_column) or df.select(df['my_column']) or df.select(sf.col('my_column')). I still prefer df.select('my_column') as it conveys my intent better. 

  2. You'd still want to write isolated tests, not using the pipeline, in case you introduce two regressions in different part of the pipeline that cancel their errors out! 

Stay up to date on the latest insights and best-practices by registering for the GoDataDriven newsletter.