Published: Jul 22, 2022|10 min read10 minutes read
Spark provides Scala programmers with more than one API for building big data pipelines. However, each of them requires some sacrifice – worse performance, additional boilerplate, or lack of type safety. We propose a new Spark API for Scala 3 that solves all of these problems.
Scala is a statically typed language. However, it’s up to programmers how much type information they preserve at compile time. For example, given an integer and a string, we could store them in a tuple of type (Int, String) or in a more general collection like Seq[Any]. Following this philosophy, and because of some historical reasons, Spark offers two flavors of high-level APIs for Scala:
a precisely typed one based on Datasets
a loosely typed one based on DataFrames or SQL queries. As they have many common strengths and weaknesses, we will ignore SQL and focus on DataFrames further on.
Unfortunately, neither of these options is perfect.
Let’s say we have some data representing measurements of temperature and air pressure from weather stations. We can model a single measurement as a case class like this:
1case class Measurement(
2 stationId: Long,
3 temperature: Int /* in °C */,
4 pressure: Int /* in hPa */,
5 timestamp: Long
6)
We would like to find the IDs and average air pressure for all stations with the amplitude of temperature less than 20°C.
Let’s try to solve our problem using each of the approaches mentioned above.
Want to know more about Spark APIs for Scala? Talk to one of our experts and see how we can help you to propel your project.
We’ll go with Datasets first. Our solution will look very much like operating on standard Scala collections. This API lets us use ordinary Scala functions to manipulate our data model specified by tuples or case classes.
We’ll skip the boilerplate of setting up a Spark application and assume measurements are a Dataset[Measurement] containing our data. We can now implement our core logic as shown below:
1measurements
2 .groupByKey(_.stationId)
3 .mapGroups { (stationId, measurementss) =>
4 val temperatures = measurementss.map(_.temperature)
5 val pressures = measurementss.map(_.pressure)
6 (
7 stationId,
8 temperatures.min,
9 temperatures.max,
10 pressures.sum.toDouble / pressures.length
11 )
12 }
13 .filter(entry => entry._3 - entry._2 < 20)
14 .map(entry => (entry._1, entry._4))
Using tuples may seem convenient at first, but they will make our codebase hard to read once it grows bigger. Alternatively, we could replace tuples with case classes, as shown below. However, having to define a case class for every intermediate step of our computations might quickly become a burden as well.
1case class AggregatedMeasurement(
2 stationId: Long,
3 minTemperature: Int,
4 maxTemperature: Int,
5 avgPressure: Double
6)
7
8/* … */
9
10 measurements
11 .groupByKey(_.stationId)
12 .mapGroups { (stationId, measurementss) =>
13 val temperatures = measurementss.map(_.temperature)
In this approach, the compiler knows the exact type of our data model after each transformation. So it can verify the correctness of our program, at least to some extent. E.g. it will raise an error, if we try to refer to a column that doesn’t exist in our model or its type doesn’t make sense in a given context. We’ll even get code completions for the names of columns, which could help us eliminate many potential errors, even before we compile our entire codebase.
Despite these amenities, our application might still surprise us with a runtime error. This would happen e.g. if we defined our helper case class inside a method instead of an object or a package. Doing so would cause problems with serialization that wouldn’t get detected until we run our program.
A major problem with the Dataset API that also has to be mentioned is its performance. Because we can execute arbitrary Scala code in the bodies of our lambdas, Spark treats them as black box and cannot perform many of its optimizations.
DataFrames, in contrast to Datasets, are not parameterized with the type of data they contain, so the compiler knows nothing about it. This is similar in design to the API that Spark offers for Python, where we refer to columns by their names as strings.
This might bring some flexibility in certain cases, as we’re free from the tuple vs case class dilemma. We could also compute names of columns dynamically.
However, in most other cases it’s rather annoying. What’s even more important, this is dangerous. As our data schema is only known at runtime, we typically won’t learn about many problems in our code until we deploy and run the entire application (or at least run our tests).
Our example rewritten to the DataFrame based API could look like this:
As you can see now, the names of the methods are more like SQL keywords rather than something one might be familiar with from Scala’s standard library. If you take a closer look at the snippet you might even spot that it’s actually going to crash at runtime because of the typo in minTemperture. And even if we fix that once, something might go wrong again if at some point later on we decided to refactor our code by renaming one of the columns but forgot to do it in some places.
We gave up most types of safety, but at least we got something in exchange. Because we are restricted to using only column transformations defined inside Spark, its optimization engine can heavily speed up the computations. If only our program doesn’t crash at runtime.
You could ask yourself the question: Can we design a better API for Spark that doesn’t force users to choose between type safety, convenience of use, and efficiency?
Yes, we can!
Scala 3 provides all the tools required to achieve that. Let’s take the DataFrame approach as a starting point and try to improve it.
First, let’s stop referring to columns by their stringy names:
The only thing that changed in our code so far is that we’ve replaced all $"foo"-like references with $.foo. Our snippet looks now more like vanilla Scala syntax, where one refers to nested parts of data structures using a dot operator. We could make this compile without much hassle already in Scala 2 by using the Dynamic marker-trait.
1import scala.language.dynamics
2import org.apache.spark.sql.functions.col
3
4object $ extends Dynamic {
5 def selectDynamic(name: String) = col(name)
6}
This might seem like magic, but it’s actually rather straightforward. Thanks to this trick, every expression like $.foo gets rewritten by the compiler into $.selectDynamic("foo"), given that $ has no statically known member called foo.
However, more convenient column access by itself isn’t much of a game changer, since we still get feedback about errors only at runtime. But it turns out that in Scala 3 we can overcome this problem by using Selectable instead of Dynamic.
Now, the type of $ is RowModel with some type of refinement. Let’s say it was RowModel { def foo: Column[Int] }. Then, $.foo would turn into $.selectDynamic("foo").asInstanceOf[Column[Int]]. The desugaring contains an extra type cast, but it’s safe. The compiler took the type of foo from the refinement. If foo was not defined there, the compilation would fail.
The issue that we still need to solve is that the type refinement of RowModel has to be different depending on the circumstances in which we refer to $. These include the shape of our initial data model and the stage of the transformation pipeline we’re currently in.
Say, selecting avgPressure should be invalid before it gets computed inside the agg block. Similarly, we shouldn’t be allowed to refer to the pressure of a single measurement after the aggregation. So, how can we get the compiler to trace the correct type of $ at each step of our computations?
First, we need a refined type that represents the initial structure of our data frame. As Measurement is a case class, we can use Scala 3’s metaprogramming capabilities to construct it. We won’t go deeper into the implementation details here, but what we would like to get as the result is something like:
1RowModel {
2 def stationId: Column[Long]
3 def temperature: Column[Int]
4 def pressure: Column[Int],
5 def timestamp: Column[Long]
6}
Later on, when we perform a transformation such as selection or aggregation, we pass on a block of code returning a column or a tuple of columns which would determine the shape of our data row in the next step of computations. Inside this block, $ needs to have the right type, which is context-specific. So why don’t we use context functions, another Scala 3 feature, to achieve that?
Even if you’ve never heard of context functions before, you might have come across a Scala 2 programming pattern like:
1def bar(fun: Context => Int) = ???
2def baz(implicit context: Context): Int = ???
3
4bar { implicit context =>
5 baz
6}
In the code above, fun is a function from Context to Int and the implicit keyword before the argument of a lambda makes it available to the implicit search inside the lambda’s body. Context functions are defined almost like ordinary functions, but with ?=> instead of =>. Making our function contextual lets us get rid of the boilerplate caused by implicit context => at the beginning of the closure. Effectively, our auxiliary snippet gets simplified to:
1def bar(fun: Context ?=> Int) = ???
2def baz(using context: Context): Int = ???
3
4bar {
5 baz
6}
NOTE: If you aren’t fully familiar with Scala 3 syntax: The keyword used is a replacement for implicit when declaring an implicit parameter.
Let’s get back to our Spark API. We’ll treat RowModel with its precisely refined type as our implicitly passed context. Then we’ll use the $ method to capture it.
Note that the return type is rowModel.type instead of just RowModel. This lets us preserve the precise type with the refinements. That gives us a guarantee that every reference to a column in the form of $.foo is valid in the given context. We also know at compile time the exact types of data in each column. Going further, we could use this information to assure that operations on columns are also sensical, e.g. that the condition inside .where(...) indeed represents a boolean, or that we don’t attempt to divide a number by a string.
Now you already know the most important concepts and syntactic patterns that you could use to implement a type safe wrapper around the loosely typed API Spark provided for DataFrames. So why don’t you try it yourself? This might be a good exercise, but let us cool your enthusiasm down for a moment. The actual type system used by Spark internally turns out to be not so easy to model statically. Also, the amount of available operations one can perform on data frames and columns is huge. It would require a lot of work to cover them all. But we still believe the goal is reachable, so we started a common initiative in the form of an open-source library called Iskra.
The intent of the project was to provide a Spark API for Scala 3 that:
is type safe, providing meaningful compilation errors
avoids boilerplate
is intuitive to use for people already familiar with Spark
works well with IDEs, e.g. providing code completions for methods and names of columns
is efficient, taking advantage of all optimizations Spark offers for DataFrame and SQL-based APIs
is extensible, giving library users the possibility to easily define their own typed wrappers for methods from the API not yet covered by the library
You can try it out right away! Here’s the complete solution to our issue with weather stations’ measurements:
1//> using scala "3.2.0"
2//> using lib "org.virtuslab::iskra:0.0.2"
3
4import org.virtuslab.iskra.api.*
5
6case class Measurement(
7 stationId: Long,
8 temperature: Int /* in °C */,
9 pressure: Int /* in hPa */,
10 timestamp: Long
11)
12
13@main def run() =
14 given spark: SparkSession = SparkSession.builder()