



Scala 3 gives you the tools to design the perfect Spark API. We proved it by creating the open source library Iskra.
Spark provides Scala programmers with more than one API for building big data pipelines. However, each of them requires some sacrifice – worse performance, additional boilerplate or lack of type safety. We propose a new Spark API for Scala 3 that solves all of these problems.
Scala is a statically typed language. However, it’s up to programmers how much type information they preserve at compile time. E.g. given an integer and a string, we could store them in a tuple of type (Int, String)
or in a more general collection like Seq[Any]
. Following this philosophy, and because of some historical reasons, Spark offers two flavours of high level APIs for Scala:
Dataset
sDataFrame
s or SQL queries. As they have many common strengths and weaknesses, we will ignore SQL and focus on DataFrame
s further on.Unfortunately, neither of these options is perfect.
Let’s say we have some data representing measurements of temperature and air pressure from weather stations. We can model a single measurement as a case class like this:
case class Measurement(
stationId: Long,
temperature: Int /* in °C */,
pressure: Int /* in hPa */,
timestamp: Long
)
We would like to find the IDs and average air pressure for all stations with the amplitude of temperature less than 20°C.
Let’s try to solve our problem using each of the approaches mentioned above.
Dataset
API: Using idiomatic ScalaWe’ll go with Dataset
s first. Our solution will look very much like operating on standard Scala collections. This API lets us use ordinary Scala functions to manipulate our data model specified by tuples or case classes.
We’ll skip the boilerplate of setting up a Spark application and assume measurements
is a Dataset[Measurement]
containing our data. We can now implement our core logic as shown below:
measurements
.groupByKey(_.stationId)
.mapGroups { (stationId, measurementss) =>
val temperatures = measurementss.map(_.temperature)
val pressures = measurementss.map(_.pressure)
(
stationId,
temperatures.min,
temperatures.max,
pressures.sum.toDouble / pressures.length
)
}
.filter(entry => entry._3 - entry._2 < 20)
.map(entry => (entry._1, entry._4))
Using tuples may seem convenient at first, but they will make our codebase hard to read once it grows bigger. Alternatively, we could replace tuples with case classes, as shown below. However, having to define a case class for every intermediate step of our computations might quickly become a burden as well.
case class AggregatedMeasurement(
stationId: Long,
minTemperature: Int,
maxTemperature: Int,
avgPressure: Double
)
/* … */
measurements
.groupByKey(_.stationId)
.mapGroups { (stationId, measurementss) =>
val temperatures = measurementss.map(_.temperature)
val pressures = measurementss.map(_.pressure)
AggregatedMeasurement(
stationId = stationId,
minTemperature = temperatures.min,
maxTemperature = temperatures.max,
avgPressure = pressures.sum.toDouble / pressures.length
)
}
.filter(aggregated => aggregated.maxTemperature - aggregated.minTemperature < 20)
.map(aggregated => (aggregated.stationId, aggregated.avgPressure))
In this approach, the compiler knows the exact type of our data model after each transformation. So it can verify the correctness of our program, at least to some extent. E.g. it will raise an error, if we try to refer to a column which doesn’t exist in our model or its type doesn’t make sense in a given context. We’ll even get code completions for the names of columns, which could help us eliminate many potential errors, even before we compile our entire codebase.
Despite these amenities, our application might still surprise us with a runtime error. This would happen e.g. if we defined our helper case class inside a method instead of an object or a package. Doing so would cause problems with serialisation that wouldn’t get detected until we run our program.
A major problem with the Dataset
API that also has to be mentioned is its performance. Because we can execute arbitrary Scala code in the bodies of our lambdas, Spark treats them as blackbox and cannot perform many of its optimizations.
DataFrame
API: Pretending to be PythonDataFrame
s, in contrast to Dataset
s, are not parameterised with the type of data they contain, so the compiler knows nothing about it. This is similar in design to the API that Spark offers for Python, where we refer to columns by their names as strings.
This might bring some flexibility in certain cases, as we’re free from the tuple vs case class dilemma. We could also compute names of columns dynamically.
However, in most other cases it’s rather annoying. What’s even more important, this is dangerous. As our data schema is only known at runtime, we typically won’t learn about many problems in our code until we deploy and run the entire application (or at least run our tests).
Our example rewritten to the DataFrame
based API could look like this:
measurements
.groupBy($"stationId")
.agg(
min($"temperature").as("minTemperature"),
max($"temperature").as("maxTemperature"),
avg($"pressure").as("avgPressure")
)
.where($"maxTemperature" - $"minTemperture" < lit(20))
.select($"stationId", $"avgPressure")
As you can see now, the names of the methods are more like SQL keywords rather than something one might be familiar with from Scala’s standard library. If you take a closer look at the snippet you might even spot that it’s actually going to crash at runtime because of the typo in minTemperture
. And even if we fix that once, something might go wrong again if at some point later on we decided to refactor our code by renaming one of the columns but forgot to do it in some places.
We gave up most type safety, but at least we got something in exchange. Because we are restricted to using only column transformations defined inside Spark, its optimization engine can heavily speed up the computations. If only our program doesn’t crash at runtime.
You could ask yourself the question: Can we design a better API for Spark that doesn’t force users to choose between type safety, convenience of use and efficiency?
Yes, we can!
Scala 3 provides all the tools required to achieve that. Let’s take the DataFrame
approach as a starting point and try to improve it.
First, let’s stop referring to columns by their stringy names:
measurements
.groupBy($.stationId)
.agg(
min($.temperature).as("minTemperature"),
max($.temperature).as("maxTemperature"),
avg($.pressure).as("avgPressure")
)
.where($.maxTemperature - $.minTemperature < lit(20))
.select($.stationId, $.avgPressure)
The only thing that changed in our code so far is that we’ve replaced all $"foo"
-like references with $.foo
. Our snippet looks now more like vanilla Scala syntax, where one refers to nested parts of data structures using a dot operator. We could make this compile without much hassle already in Scala 2 by using the Dynamic
marker-trait.
import scala.language.dynamics
import org.apache.spark.sql.functions.col
object $ extends Dynamic {
def selectDynamic(name: String) = col(name)
}
This might seem like magic, but it’s actually rather straightforward. Thanks to this trick, every expression like $.foo
gets rewritten by the compiler into $.selectDynamic("foo")
, given that $
has no statically known member called foo
.
However, more convenient column access by itself isn’t much of a game changer, since we still get feedback about errors only at runtime. But it turns out that in Scala 3 we can overcome this problem by using Selectable
instead of Dynamic
.
import org.apache.spark.sql.{ Column => UntypedColumn }
import org.apache.spark.sql.functions.col
class Column[T](val untyped: UntypedColumn) extends AnyVal
trait RowModel extends Selectable {
def selectDynamic(name: String) = Column(col(name))
}
def $: RowModel { /* ... */ } = new RowModel { /* .. */ }
Now, the type of $
is RowModel
with some type refinement. Let’s say it was RowModel { def foo: Column[Int] }
. Then, $.foo
would turn into $.selectDynamic("foo").asInstanceOf[Column[Int]]
. The desugaring contains an extra type cast, but it’s safe. The compiler took the type of foo
from the refinement. If foo
was not defined there, the compilation would fail.
The issue that we still need to solve is that the type refinement of RowModel
has to be different depending on the circumstances in which we refer to $
. These include the shape of our initial data model and the stage of the transformation pipeline we’re currently in.
Say, selecting avgPressure
should be invalid before it gets computed inside the agg
block. Similarly, we shouldn’t be allowed to refer to pressure
of a single measurement after the aggregation. So, how can we get the compiler to trace the correct type of $
at each step of our computations?
First, we need a refined type that represents the initial structure of our data frame. As Measurement
is a case class, we can use Scala 3’s metaprogramming capabilities to construct it. We won’t go deeper into the implementation details here, but what we would like to get as the result is something like:
RowModel {
def stationId: Column[Long]
def temperature: Column[Int]
def pressure: Column[Int],
def timestamp: Column[Long]
}
Later on, when we perform a transformation such as selection or aggregation, we pass on a block of code returning a column or a tuple of columns which would determine the shape of our data row in the next step of computations. Inside this block, $
needs to have the right type, which is context specific. So why don’t we use context functions, another Scala 3 feature, to achieve that?
Even if you’ve never heard of context functions before, you might have come across a Scala 2 programming pattern like:
def bar(fun: Context => Int) = ???
def baz(implicit context: Context): Int = ???
bar { implicit context =>
baz
}
In the code above, fun
is a function from Context
to Int
and the implicit
keyword before the argument of a lambda makes it available to the implicit search inside the lambda’s body. Context functions are defined almost like ordinary functions, but with ?=>
instead of =>
. Making our function contextual lets us get rid of boilerplate caused by implicit context =>
at the beginning of the closure. Effectively, our auxiliary snippet gets simplified to:
def bar(fun: Context ?=> Int) = ???
def baz(using context: Context): Int = ???
bar {
baz
}
NOTE:
If you aren’t fully familiar with Scala 3 syntax: The keyword using
is a replacement for implicit
when declaring an implicit parameter.
Let’s get back to our Spark API. We’ll treat RowModel
with its precisely refined type as our implicitly passed context. Then we’ll use the $
method to capture it.
def $(using rowModel: RowModel): rowModel.type = rowModel
Note that the return type is rowModel.type
instead of just RowModel
. This lets us preserve the precise type with the refinements. That gives us a guarantee that every reference to a column in the form of $.foo
is valid in the given context. We also know at compile time the exact types of data in each column. Going further, we could use this information to assure that operations on columns are also sensical, e.g. that the condition inside .where(...)
indeed represents a boolean, or that we don’t attempt to divide a number by a string.
Now you already know the most important concepts and syntactic patterns that you could use to implement a type safe wrapper around the loosely typed API Spark provided for DataFrame
s. So why don’t you try it yourself? This might be a good exercise, but let us cool your enthusiasm down for a moment. The actual type system used by Spark internally turns out to be not so easy to model statically. Also, the amount of available operations one can perform on data frames and columns is huge. It would require a lot of work to cover them all. But we still believe the goal is reachable, so we started a common initiative in the form of an open source library called Iskra.
The intent of the project was to provide a Spark API for Scala 3 that:
DataFrame
and SQL based APIsYou can try it out right away! Here’s the complete solution to our issue with weather stations’ measurements:
//> using scala "3.2.0"
//> using lib "org.virtuslab::iskra:0.0.2"
import org.virtuslab.iskra.api.*
case class Measurement(
stationId: Long,
temperature: Int /* in °C */,
pressure: Int /* in hPa */,
timestamp: Long
)
@main def run() =
given spark: SparkSession = SparkSession.builder()
.master("local")
.appName("weather-stations")
.getOrCreate()
val measurements = Seq(
Measurement(1, 10, 1020, 1641399107),
Measurement(2, -5, 1036, 1647015112),
Measurement(1, 19, 996, 1649175104),
Measurement(2, 25, 1015, 1657030348),
/* more data … */
).toTypedDF
import functions.{avg, min, max, lit}
measurements
.groupBy($.stationId)
.agg(
min($.temperature).as("minTemperature"),
max($.temperature).as("maxTemperature"),
avg($.pressure).as("avgPressure")
)
.where($.maxTemperature - $.minTemperature < lit(20))
.select($.stationId, $.avgPressure)
.show()
You can run the snippet using scala-cli. This is how:
SparkWeather.scala
) scala-cli --jvm temurin:11 SparkWeather.scala
from the command lineIf you use VS Code with Metals as your IDE, you can also see how code completions work. To do so:
SparkWeather.scala
in VS Codescala-cli setup-ide SparkWeather.scala
from the command lineConnect to build server
in Metals’ sidebar menuTry Iskra out and share your feelings about it with us! Contributions are welcome as well. Let’s make Spark in Scala better together.
If you want to read more about Scala 3, we encourage you to read this post: