Partitioning by constructor

It’s not unusual in Scala to want to take a collection with items of some algebraic data type and partition its elements by their constructors. In this Stack Overflow question, for example, we’re given a type for fruits:

sealed trait Fruit

case class Apple(id: Int, sweetness: Int) extends Fruit
case class Pear(id: Int, color: String) extends Fruit

The goal is to be able to take a collection of fruits and split it into two collections—one of apples and one of pairs.

def partitionFruits(fruits: List[Fruit]): (List[Apple], List[Pear]) = ???

It’s pretty easy to use collect to solve this problem for particular cases. It’s a little trickier when we start thinking about what a more generic version of such a method would look like—we want to take a collection of items of some arbitrary algebraic data type and return a n-tuple whose elements are collections of items of each of that ADT’s constructors (and let’s require them to be typed as specifically as possible, since this is Scala). It’s not too hard to imagine how you could write a macro that would perform this operation, but it’d be messy and would probably end up feeling kind of ad-hoc (at least without a lot of additional work and thinking).

Fortunately we’ve got Shapeless 2.0, where Miles Sabin and co. have written the macros for us so we can keep our hands clean.

The key here is the Generic type class, which makes the coproduct-iness of the ADT something we can work with explicitly:

import shapeless._

trait Partitioner[C <: Coproduct] extends DepFn1[List[C]] { type Out <: HList }

object Partitioner {
  type Aux[C <: Coproduct, Out0 <: HList] = Partitioner[C] { type Out = Out0 }

  implicit def cnilPartitioner: Aux[CNil, HNil] = new Partitioner[CNil] {
    type Out = HNil

    def apply(c: List[CNil]): Out = HNil
  }

  implicit def cpPartitioner[H, T <: Coproduct, OutT <: HList](implicit
    cp: Aux[T, OutT]
  ): Aux[H :+: T, List[H] :: OutT] = new Partitioner[H :+: T] {
    type Out = List[H] :: OutT

    def apply(c: List[H :+: T]): Out =
      c.collect { case Inl(h) => h } :: cp(c.collect { case Inr(t) => t })
  }
}

def partition[A, C <: Coproduct, Out <: HList](as: List[A])(implicit
  gen: Generic.Aux[A, C],
  partitioner: Partitioner.Aux[C, Out],
  tupler: ops.hlist.Tupler[Out]
) = tupler(partitioner(as.map(gen.to)))

And now if we’ve got some fruits:

val fruits: List[Fruit] = List(
  Apple(1, 10),
  Pear(2, "red"),
  Pear(3, "green"),
  Apple(4, 6),
  Pear(5, "purple")
)

We can write the following:

scala> val (apples, pears) = partition(fruits)
apples: List[Apple] = List(Apple(1,10), Apple(4,6))
pears: List[Pear] = List(Pear(2,red), Pear(3,green), Pear(5,purple))

This is pretty neat—in less than thirty lines of code we’ve written a completely generic partitioning method that’ll take a collection of stuff of any algebraic data type and split it up by constructor.

It’s a little annoying that it’s not immediately obvious how the tuple ended up ordered that way, though. Do the apples come first because Apple was defined first (my preference), or because it comes before Pear in the dictionary? If we were to go and dig around in the Shapeless source code we’d learn that the constructors are sorted by name, but we don’t particularly want to have to remember that fact.

Fortunately Shapeless’s records make a nicer syntax super easy:

import shapeless._, labelled.{ field, FieldType }

trait Partitioner[C <: Coproduct] extends DepFn1[List[C]] { type Out <: HList }

object Partitioner {
  type Aux[C <: Coproduct, Out0 <: HList] = Partitioner[C] { type Out = Out0 }

  implicit def cnilPartitioner: Aux[CNil, HNil] = new Partitioner[CNil] {
    type Out = HNil

    def apply(c: List[CNil]): Out = HNil
  }

  implicit def cpPartitioner[K, H, T <: Coproduct, OutT <: HList](implicit
    cp: Aux[T, OutT]
  ): Aux[FieldType[K, H] :+: T, FieldType[K, List[H]] :: OutT] =
    new Partitioner[FieldType[K, H] :+: T] {
      type Out = FieldType[K, List[H]] :: OutT

      def apply(c: List[FieldType[K, H] :+: T]): Out =
        field[K](c.collect { case Inl(h) => (h: H) }) ::
        cp(c.collect { case Inr(t) => t })
  }
}

def partition[A, C <: Coproduct, Out <: HList](as: List[A])(implicit
  gen: LabelledGeneric.Aux[A, C],
  partitioner: Partitioner.Aux[C, Out]
) = partitioner(as.map(gen.to))

Note that apart from the imports, the first half of this block of code is exactly the same as the one above. Instead of returning a tuple, though, we return a record, which is literally just an HList whose items are tagged with labels. This gives us the following syntax:

scala> val baskets = partition(fruits)
partitioned: shapeless.:: ...

scala> baskets('Apple)
res0: List[Apple] = List(Apple(1,10), Apple(4,6))

scala> baskets('Pear)
res1: List[Pear] = List(Pear(2,red), Pear(3,green), Pear(5,purple))

This is all completely type-safe—if we ask for a baskets('Burger) we’ll get a nice compile-time error. It’s also generic enough that if we change our code to add a banana constructor to our fruit type, we’ll be able to write baskets('Banana) here without touching the definition of partition.