|
Scala generics for automatic differentiation: msg#00247lang.scala
I've been reading this interesting paper by Dan Piponi: Automatic Differentiation, C++ Templates and Photogrammetry http://sigfpe.blogspot.com/2005/07/automatic-differentiation.html (PDF here: http://homepage.mac.com/sigfpe/paper.pdf) In a nutshell, automatic differentiation makes use of dual numbers of the form a + bd (similar to complex numbers but instead of an imaginary part you have an infinitesimal part (the d coefficient), such that d^2 = 0 rather than i^2 = -1. Given this, calculating the first-order derivative of a function f at a given point x is a matter of calculating f(x+d) and reading off the imaginary part. Calculating higher-order derivatives is when C++ makes it look simple, and I've been stuck trying to write the equivalent code in Scala -- I'm sure it is possible and I'm just not familiar enough with Scala though, so I'm turning here for ideas. Thanks in advance! Here's the Scala code for first-order differentiation: case class Dual(a: Double, b: Double) { def d() : Dual = Dual(0,1) def +(y: Dual): Dual = { Dual(this.a + y.a, this.b + y.b); } def *(y: Dual): Dual = { Dual(this.a * y.a, this.a * y.b + this.b * y.a); } } object test { def f(x: Dual) = ((x + Dual(2,0)) * (x + Dual(1,0))) implicit def int2dual(x: Int): Dual = Dual(x, 0) implicit def dbl2dual(x: Double): Dual = Dual(x, 0) def diff(f: Dual=>Dual)(n: Dual) = f(n + n.d()).b def main(args: Array[String]): Unit = { Console.println(diff(f)(20)); } } This is the C++ code: using namespace std; template<class X> class Dual { public: X a; X b; Dual<X>(X a0, X b0 = 0) : a(a0), b(b0) {} static Dual<X> d() { return Dual<X>(X(0),X(1)); } }; template<class X> Dual<X> operator+(const Dual<X> &x, const Dual<X> &y) { return Dual<X>(x.a+y.a, x.b+y.b); } template<class X> Dual<X> operator*(const Dual<X> &x, const Dual<X> &y) { return Dual<X>(x.a*y.a, x.a*y.b+x.b*y.a); } template<class X> X f(X x) { return (x+X(2))*(x+X(1)); } template<class X> X g(X x) { return f(Dual<X>(x)+Dual<X>::d()).b; } float h(float x) { return g(Dual<float>(x) + Dual<float>::d()).b; } int main() { cout << h(3) << '\n'; } -- Michel Salim http://hircus.wordpress.com/ My theology, briefly, is that the universe was dictated but not signed. -- Christopher Morley |
|
| <Prev in Thread] | Current Thread | [Next in Thread> |
|---|---|---|
| Previous by Date: | Multi-parameter case classes: 00247, Tony Morris |
|---|---|
| Next by Date: | Re: Tail recursion -> while (?): 00247, Insitu |
| Previous by Thread: | Multi-parameter case classesi: 00247, Tony Morris |
| Next by Thread: | Re: Scala generics for automatic differentiation: 00247, Eric Willigers |
| Indexes: | [Date] [Thread] [Top] [All Lists] |
| News | Mail Home | sitemap | FAQ | advertise |