Forward Mode AD in R

R
Author

Jonny Law

Published

August 5, 2019

Forward Mode Automatic Differentation

Automatic differentiation can be used to calculate the exact derivative of a function at a point using applications of the chain rule. Dual numbers provide a straightforward implementation in R using S3 generic methods. A dual number has a real component and a “dual” component which can be used to exactly calculate the expression and derivative at a specific value of \(x\). Consider the quadratic form \(f(x) = 5x^2 + 3x + 10\) with derivative \(f^\prime(x) = 10x + 3\). The function and derivative can be evaluated at a value, say \(x = 5\) using the dual number \(5 + \varepsilon\), the dual component \(\varepsilon\) is considered small such that \(\varepsilon^2 = 0\) then calculating \(f(5 + \varepsilon)\):

\[\begin{align} f(5 + \varepsilon) &= 5(5 + \varepsilon)^2 + 3(5 + \varepsilon) + 10,\\ &= 5(25 + 10\varepsilon + \varepsilon^2) + 15 + 3\varepsilon + 10,\\ &= 5\varepsilon^2 + 53\varepsilon + 150. \end{align}\]

Then the coefficient of \(\varepsilon\) is the derivative and the constant is the evaluation of the function, \(f(5) = 150\) and \(f^\prime(5) = 53\).

S3 Objects

R has three systems for object oriented programming, S3, S4 and reference classes which can be learned about in the relevant chapter of Advanced R. Dual numbers can be implemented as an S3 class in R:

dual <- function(real, eps) {
  if (!is.numeric(real)) stop("real must be numeric")
  structure(list(real = real, eps = eps), class = "dual")
} 
var <- function(x) {
  if (!is.numeric(x)) stop("x must be numeric")
  dual(x, 1)
}
const <- function(x) {
  if (!is.numeric(x)) stop("x must be numeric")
  dual(x, 0)
}

var represents a variable which we want to differentiate, whereas const represents a constant.

Next, primitive functions can be defined in terms of dual numbers which simultaneously evaluate the function and the derivative:

plus <- function(x, y) 
  dual(x$real + y$real, x$eps + y$eps)
minus <- function(x, y) 
  dual(x$real - y$real, x$eps - y$eps)
times <- function(x, y) 
  dual(x$real * y$real, x$eps * y$real + y$eps * x$real)
divide <- function(x, y) 
  dual(
      x$real / y$real,
      (x$eps * y$real - x$real * y$eps) / (y$real * y$real)
    )

Group generics can be used to implement the mathematics of dual numbers. Group generics included with base R include Math which includes special functions such such as abs and sqrt as well as trigonometric and hyperbolic functions. Ops which include the basic infix operations reqruired for arithmetic, +, -, *, / etc. For a full list of group generics associated with Math and Ops consult the R help by typing ?groupGeneric in the R console. In order to implement a group generic for the S3 class dual we implement Ops.dual:

Ops.dual <- function(x, y) {
  switch(
    .Generic,
    `+` = plus(x, y),
    `-` = minus(x, y),
    `*` = times(x, y),
    `/` = divide(x, y)
  )
}

switch is used to pattern match on the generic function being called within Ops by matching on .Generic. Implementing dual numbers in this way allows us to define a function using the in-built infix operators in a natural way. The function \(f(x)\) can be defined in terms of dual numbers as

f <- function(x) 
  const(5) * x * x + const(3) * x + const(10)

Then evaluated at \(x = 5\) using the constructor var which initialises a dual with \(\varepsilon = 1.0\).

f(var(5))
$real
[1] 150

$eps
[1] 53

attr(,"class")
[1] "dual"

The definition of f is cumbersome since we have to explicitly create the constants using the const constructor. The methods defined in Ops.dual can be extended to handle cases when a double is multiplied by a dual number to convert the double to a const and hence we can automatically differentiate any univariate function using forward mode automatic differentiation.

We can write a function which checks the arguments of plus, minus etc, then if the arguments aren’t explicitly dual number variables using the function var then they are converted to a dual constant using const. This function checks each argument (of a generic function of two arguments f) in turn to determine if they are doubles then promotes them to constants.

lift_function <- function(f) {
  function(x, y)
    if (is.double(x)) {
      f(const(x), y)
    } else if (is.double(y)) {
      f(x, const(y))
    } else {
      f(x, y)
    }
}

The ops can then be re-defined using the lift_function:

Ops.dual <- function(x, y) {
  switch(
    .Generic,
    `+` = lift_function(plus)(x, y),
    `-` = lift_function(minus)(x, y),
    `*` = lift_function(times)(x, y),
    `/` = lift_function(divide)(x, y)
  )
}

Then f can be defined more naturally:

f <- function(x) 
  5 * x * x + 3 * x + 10

And the derivative calculated:

f(var(5))
$real
[1] 150

$eps
[1] 53

attr(,"class")
[1] "dual"

Testing using Hedgehog

Hedgehog is a package which utilises testthat to implement property based testing in R. Property based testing can be used to check a wide range of inputs to a function and determine if the code outputs the expected value. In standard unit testing the state before the test is defined by programmer and typically does not change - if we were to consider a test for the derivative of the quadratic function defined above then we might write a test which evaluates the function at \(x = 5\). This verifies we are correct for \(x = 5\), but what about \(x = 0\) or another value. With property based testing, we define a random generator for the input and the test checks hundreds of potential values for failure.

The input to this property based test is a, a number between \(-100\) and \(100\). The usual testthat syntax is then used to evaluate the gradient using forward mode AD and comparing it to the exact derivative calculated by hand.

test_that("Derivative of 5x^2 + 3x + 10",
          forall(list(a = gen.c(gen.element(
            -100:100
          ))),
          function(a)
            expect_equal(object = f(var(a))$eps, expected = 10 * a + 3)))
Test passed 🎊

Citation

BibTeX citation:
@online{law2019,
  author = {Jonny Law},
  title = {Forward {Mode} {AD} in {R}},
  date = {2019-08-05},
  langid = {en}
}
For attribution, please cite this work as:
Jonny Law. 2019. “Forward Mode AD in R.” August 5, 2019.