class: center, middle # Advanced Scala ## Type Classes Neville Li @sinisa_lyh Nov 2015 --- class: center, middle # Polymorphism .footnote[http://eed3si9n.com/learning-scalaz/polymorphism.html] --- # Parametric polymorphism ```scala def head[T](xs: List[T]): T = xs(0) ``` -- # `T` can be any type ```scala head(List(1, 2, 3)) // Int = 1 head(List("a", "b", "c")) // String = "a" case class Car(make: String) head(List(Car("Toyota"), Car("Volkswagen"), Car("BMW"))) // Car = Car(Toyota) ``` --- # Not all types have `.make` ```scala def getMake[T](x: T): String = x.make // error: value make is not a member of type parameter T ``` -- # But sub-types of `Car` do ```scala def getMake[T <: Car](x: T): String = x.make ``` --- # Subtype polymorphism ```scala trait Num[T] { def plus(that: T): T } def sum[T <: Num[T]](xs: T*): T = xs.reduce(_ plus _) ``` -- # Let's try that for `Complex` ```scala case class Complex(r: Double, i: Double) extends Num[Complex] { def plus(that: Complex): Complex = Complex(this.r + that.r, this.i + that.i) } ``` --- # Why is `Num[T]` parametric? ```scala case class Vector(v: Double*) extends Num[Vector] { def plus(that: Vector): Vector = Vector(this.v.zip(that.v).map(p => p._1 + p._2): _*) } ``` -- # `sum[T]` works for both `Complex` and `Vector` ```scala // Num[Complex]#plus(that: Complex): Complex sum(Complex(1.5, 0.5), Complex(10.1, 5.2)) ``` - T = Complex ```scala // Num[Vector]#plus(that: Vector): Vector sum(Vector(1, 2, 3), Vector(20, 30, 40)) ``` - T = Vector --- # How about existing types ```scala def plus[T <: Num[T]](xs: T*): T = xs.reduce(_ plus _) ``` -- - ### `scala.Int`, `scala.Long`,etc. extends `AnyVal` extends `Any` - ### `java.lang.String` extends `java.lang.Object` - ### Most classes out there do not extend `Num[T]` - ### We can't change them and add it back --- # Ad-hoc polymorphism ```scala trait Num[T] { def plus(x: T, y: T): T } def sum[T](xs: T*)(ev: Num[T]): T = xs.reduce(ev.plus) ``` - `ev`: evidence that `T` conforms to the behaviors of `Num[T]` --- # Add `Num[T]` behavior to `Complex` ```scala case class Complex(r: Double, i: Double) val complexNum = new Num[Complex] { override def plus(x: Complex, y: Complex) = Complex(x.r + y.r, x.i + y.i) } ``` -- # Pass in evidence for `sum[T]` ```scala sum(Complex(1.5, 0.5), Complex(10.1, 5.2))(complexNum) ``` --- # Works for `String` too ```scala val stringNum = new Num[String] { override def plus(x: String, y: String) = (x.toDouble + y.toDouble).toString } sum("1.2", "3.4")(stringNum) ``` -- - ### Now we can make any type `T` summable - ### By implementing `Num[T]` - ### And pass it to `sum(...)(ev: Num[T])` --- # Type class ### In computer science, a type class is a type system construct that supports ad hoc polymorphism. This is achieved by adding constraints to type variables in parametrically polymorphic types. Such a constraint typically involves a type class `T` and a type variable `a`, and means that `a` can only be instantiated to a type whose members support the overloaded operations associated with `T`. .footnote[https://en.wikipedia.org/wiki/Type_class] -- - ### Type class `T` and type variable `a` → `T[a]` in Haskell-land - ### `M[T]` in Scala world - ### `M=Num` and `T={Int, Double, Complex, ...}` --- # Not a class in OOP sense ### `M[T]` → type `T` exhibits behavior of `M` -- - ### `Ordering[T]` comparable with `def compare(x: T, y: T): Int` -- - ### `Numeric[T]` extends `Ordering[T]` with `plus`, `minus`, `times`, `negate`, etc. -- - ### `Fractional[T]` extends `Numeric[T]` with `div`, etc. -- - ### `Integral[T]` extends `Numeric[T]` with `quot`, `rem`, etc. --- # Type class doesn't mean typing a lot ## Remember `PType<T>`? ```java <T> PCollection<T> parallelDo(DoFn<S,T> doFn, PType<T> type) ``` ```java <K,V> PTable<K,V> parallelDo(DoFn<S,Pair<K,V>> doFn, PTableType<K,V> type) ``` .footnote[https://crunch.apache.org/apidocs/0.11.0/org/apache/crunch/PCollection.html] -- - ### No currying, i.e. `def fn(x: Int)(y: Int)` - ### No implicits, i.e. `def sort(xs: List[T])(implicit ev: Ordering[T])` --- # Implicit to the rescue ```scala implicit val stringNum = new Num[String] { override def plus(x: String, y: String) = (x.toDouble + y.toDouble).toString } def sum[T](xs: T*)(implicit ev: Num[T]): T = xs.reduce(ev.plus) ``` -- # Magic! ```scala sum("1.2", "3.4") ``` --- # Context bound ```scala def sum[T](xs: T*)(implicit ev: Num[T]): T = xs.reduce(ev.plus) ``` -- # Is equivalent to ```scala def sum[T : Num](xs: T*): T = { val ev = implicitly[Num[T]] xs.reduce(ev.plus) } ``` - `T : Num` → there exists an instance of `Num[T]` -- ```scala @inline def implicitly[T](implicit e: T) = e // for summoning implicit values from the nether world ``` - Delaying implicit resolution with `implicitly[T]` --- # Implicit propagation ```scala def cusum[T](xs: T*): Seq[T] = (1 to xs.size).map(n => sum(xs.take(n): _*)) // could not find implicit value for evidence parameter of type Num[T] ``` - `sum` requires `Num[T]` but `cusum` does not -- # This works ```scala def cusum[T: Num](xs: T*): Seq[T] = (1 to xs.size).map(n => sum(xs.take(n): _*)) ``` - `cusum` now requires `Num[T]` but doesn't need explicit reference to `ev: Num[T]` -- # We could also use `scanLeft` ```scala def cusum[T](xs: T*)(implicit ev: Num[T]): Seq[T] = xs.tail.scanLeft(xs.head)(ev.plus) ``` --- class: center, middle # Built-in Type Classes --- # `Ordering[T]` simplified ```scala trait Ordering[T] { outer => def compare(x: T, y: T): Int // Need to override this! def lteq(x: T, y: T): Boolean = compare(x, y) <= 0 def gteq(x: T, y: T): Boolean = compare(x, y) >= 0 def lt(x: T, y: T): Boolean = compare(x, y) < 0 def gt(x: T, y: T): Boolean = compare(x, y) > 0 def equiv(x: T, y: T): Boolean = compare(x, y) == 0 def max(x: T, y: T): T = if (gteq(x, y)) x else y def min(x: T, y: T): T = if (lteq(x, y)) x else y def reverse: Ordering[T] = new Ordering[T] { override def reverse = outer def compare(x: T, y: T) = outer.compare(y, x) } } ``` -- ```scala object Ordering { def by[T, S](f: T => S)(implicit ord: Ordering[S]): Ordering[T] = new Ordering[T] { override def compare(x: T, y: T) = ord.compare(f(x), f(y)) } } ``` --- # Custom ordering ```scala case class Band(name: String, members: Int, founded: Int) val bands = List( Band("Behemoth", 3, 1991), Band("Carcass", 4, 1985), Band("Rammstein", 6, 1994)) ``` -- ```scala bands.sorted // No implicit Ordering defined for Band. ``` -- ```scala implicit val bandNameOrd = new Ordering[Band] { override def compare(x: Band, y: Band) = x.name.compare(y.name) } bands.sorted ``` --- # Ordering "factory" ```scala bands.sorted(Ordering.by { b: Band => b.founded }) bands.sorted(Ordering.by { b: Band => b.founded }.reverse) bands.sorted(Ordering.by { b: Band => -b.founded }) ``` -- # Why does this not work? ```scala bands.sorted(Ordering.by(_.founded)) // error: missing parameter type for expanded function ((x$1) => x$1.founded) ``` -- - `Ordering.by[T, S]` requires `T => S` but `T` is unknown - Even though `List#sorted` requires `Ordering[T]` - Local type inference vs. Hindley-Milner .footnote[http://stackoverflow.com/questions/7234095/why-is-scalas-type-inference-not-as-powerful-as-haskells] --- # `Numeric[T]` simplified ```scala trait Numeric[T] extends Ordering[T] { def plus(x: T, y: T): T def minus(x: T, y: T): T def times(x: T, y: T): T def negate(x: T): T def fromInt(x: Int): T class Ops(lhs: T) { def +(rhs: T) = plus(lhs, rhs) def -(rhs: T) = minus(lhs, rhs) def *(rhs: T) = times(lhs, rhs) def unary_-() = negate(lhs) } implicit def mkNumericOps(lhs: T): Ops = new Ops(lhs) } ``` -- - `Ops` and `implicit def mkNumericOps` → _pimp my library_ --- # We also need `div` from `Fractional[T]` ```scala trait Fractional[T] extends Numeric[T] { def div(x: T, y: T): T class FractionalOps(lhs: T) extends Ops(lhs) { def /(rhs: T) = div(lhs, rhs) } override implicit def mkNumericOps(lhs: T): FractionalOps = new FractionalOps(lhs) } ``` --- # Super duper generic `mean[T]` ```scala def mean[T: Fractional](xs: T*): T = { val ev = implicitly[Fractional[T]] ev.div(xs.reduce(ev.plus), ev.fromInt(xs.size)) } ``` -- With `Seq[T]#sum(implicit num: Numeric[T]): T` ```scala def mean[T: Fractional](xs: T*): T = { val ev = implicitly[Fractional[T]] ev.div(xs.sum, ev.fromInt(xs.size)) } ``` -- With `mkNumericOps` ```scala def mean[T: Fractional](xs: T*): T = { val ev = implicitly[Fractional[T]] import ev.mkNumericOps xs.sum / ev.fromInt(xs.size) } ``` --- class: center, middle # Exercise [RationalType.scala] --- class: center, middle # Abstract Algebra with Type Classes --- # Semigroup Given a set `\(S\)` and an operation `\(*\)`, we say that `\((S, *)\)` is a _semigroup_ if it satisfies the following properties for any `\(x, y, z \in S\)`: - _Closure_: `\(x * y \in S\)` - _Associativity_: `\((x * y) * z = x * (y * z)\)` We also say that `\(S\)` _forms a semigroup under_ `\(*\)`. # Examples - Strings under concatenation (not commutative) - Integers under plus (commutative) --- # Look familiar? ```scala trait Semigroup[T] { def plus(x: T, y: T): T } ``` -- ```scala implicit val intSemigroup = new Semigroup[Int] { override def plus(x: Int, y: Int): Int = x + y } // def instead of val because a new instance must be created for every item type T implicit def setSemigroup[T] = new Semigroup[Set[T]] { override def plus(x: Set[T], y: Set[T]): Set[T] = x ++ y } ``` --- # Monoid A monoid is a semigroup with an identity element. More formally, given a set `\(M\)` and an operation `\(*\)`, we say that `\((M, *)\)` is a _monoid_ if it satisfies the following properties for any `\(x, y, z \in M\)`: - _Closure_: `\(x * y \in M\)` - _Associativity_: `\((x * y) * z = x * (y * z)\)` - _Identity_: There exists an `\(e \in M\)` such that `\(e * x = x * e = x\)` We also say that `\(M\)` _is a monoid under_ `\(*\)`. # Examples - The natural numbers _N_ are monoids under addition with `\(e = 0\)` (commutative) - _N_ is a monoid under multiplication with `\(e = 1\)` (commutative) - Strings form a monoid under concatenation with `\(e = \)` `""` (not commutative) --- # A `Monoid[T]` is a `Semigroup[T]` ```scala trait Monoid[T] extends Semigroup[T] { def zero: T } ``` -- ```scala implicit val intMonoid = new Monoid[Int] { override def plus(x: Int, y: Int): Int = x + y override def zero = 0 } implicit def setMonoid[T] = new Monoid[Set[T]] { override def plus(x: Set[T], y: Set[T]): Set[T] = x ++ y override def zero = Set() } ``` --- # Groups We say that `\((G, *)\)` is a _group_ if it is a Monoid that also satisfies the following property: - _Invertibility_: For every `\(x \in G\)` there is an `\(xinv\)` such that `\(x * xinv = xinv * x = e\)` Moreover, it is an _abelian group_ if it satisfies the property: - _Commutative_: `\(x * y = y * x\)` for all _x_ and _y_ `\(\in G\)` # Examples - Integers `\(Z\)` are an abelian group under addition - Natural numbers are _not_ a group under addition (given a number `\(x\)` in `\(N\)`, `\(-x\)` is not in `\(N\)`) - Neither integers nor natural numbers are a group under multiplication, but the set of nonzero rational numbers (`\(n/d\)` for any `\(n, d \in N, n \neq 0, d \neq 0\)`) is an (abelian) group under multiplication. --- # A `Group[T]` is a `Monoid[T]` ```scala trait Group[T] extends Monoid[T] { // must override negate or minus (or both) def negate(v: T): T = minus(zero, v) def minus(x: T, y: T): T = plus(x, negate(y)) } ``` -- ```scala implicit val intGroup = new Group[Int] { override def plus(x: Int, y: Int): Int = x + y override def zero = 0 override def negate(v: Int) = -v } ``` --- # Ring Whereas a group is defined by a set and a single operation, a ring is defined by a set and two operations. Given a set `\(R\)` and operations `\(*\)` and `\(+\)`, we say that `\((R, +, *)\)` is a ring if it satisfies the following properties: - `\((R, +)\)` is an abelian group - For any _x_ and _y_ `\(\in R\)`, `\((x * y) * z = x * (y * z)\)` - For any _x_ and _y_ `\(\in R\)`, `\(x * (y + z) = x * y + x * z\)` and `\((x + y) * z = x * z + y * z\)` # Examples - `\((Z, +, *)\)` - The set of square matrices of a given size is a ring --- # Field A field is a commutative ring in which every non-zero element contains a multiplicative inverse. Equivalently, `\((F, +, *)\)` is a field if `\((F', *)\)` is an abelian group where `\(F'\)` is the set of non-zero elements in `\(F\)`. # Examples - Rational numbers - Real numbers - Complex numbers --- class: center, middle # Examples [Algebird.scala] --- # Tuples ```scala class Tuple2Semigroup[A, B](implicit asemigroup: Semigroup[A], bsemigroup: Semigroup[B]) extends Semigroup[(A, B)] { override def plus(l: (A, B), r: (A, B)) = (asemigroup.plus(l._1, r._1), bsemigroup.plus(l._2, r._2)) override def sumOption(to: TraversableOnce[(A, B)]) = { val buf = new ArrayBufferedOperation[(A, B), (A, B)](1000) with BufferedReduce[(A, B)] { def operate(items: Seq[(A, B)]) = (asemigroup.sumOption(items.iterator.map(_._1)).get, bsemigroup.sumOption(items.iterator.map(_._2)).get) } to.foreach(buf.put(_)) buf.flush } } // class Tuple3Semigroup[A, B, C] // class Tuple4Semigroup[A, B, C, D] // ... trait GeneratedSemigroupImplicits { implicit def semigroup2[A, B](implicit asemigroup: Semigroup[A], bsemigroup: Semigroup[B]): Semigroup[(A, B)] = { new Tuple2Semigroup[A, B]()(asemigroup, bsemigroup) } // implicit def semigroup3[A, B, C] // implicit def semigroup4[A, B, C, D] // ... } ``` --- # Case classes ```scala class Product2Semigroup[X, A, B](apply: (A, B) => X, unapply: X => Option[(A, B)]) (implicit asemigroup: Semigroup[A], bsemigroup: Semigroup[B]) extends Semigroup[X] { override def plus(l: X, r: X) = { val lTuple = unapply(l).get val rTuple = unapply(r).get apply(asemigroup.plus(lTuple._1, rTuple._1), bsemigroup.plus(lTuple._2, rTuple._2)) } } // class Product3Semigroup[X, A, B, C] // class Product4Semigroup[X, A, B, C] // ... trait ProductSemigroups { def apply[X, A, B](applyX: (A, B) => X, unapplyX: (X) => Option[(A, B)]) (implicit asemigroup: Semigroup[A], bsemigroup: Semigroup[B]): Semigroup[X] = { new Product2Semigroup[X, A, B](applyX, unapplyX)(asemigroup, bsemigroup) } // def apply[X, A, B, C] // def apply[X, A, B, D] // ... } ``` --- # Tuples ```scala // (Int, Double, (Set[Int], String)) val xs = (1 to 10).map { x => (x, x * 0.1, (Set(x % 5), x.toString))} sum(xs) ``` # Case classes ```scala case class P(x: Int, y: Double) // P.apply: (Int, Double) => P // P.unapply: (P) => Option[(A, B)] // Semigroup.apply[P, Int, Double] val sg = Semigroup(P.apply _, P.unapply _) sum(Seq(P(1, 2.0), P(2, 10.0)))(sg) ``` --- # Further Reading - ## Scala Standard Library - ## Twitter Algebird - ## Simulacrum - ## Haskell --- class: center, middle # The End ## Happy Typing