<- function(real, eps) {
dual if (!is.numeric(real)) stop("real must be numeric")
structure(list(real = real, eps = eps), class = "dual")
} <- function(x) {
var if (!is.numeric(x)) stop("x must be numeric")
dual(x, 1)
}<- function(x) {
const if (!is.numeric(x)) stop("x must be numeric")
dual(x, 0)
}
Forward Mode Automatic Differentation
Automatic differentiation can be used to calculate the exact derivative of a function at a point using applications of the chain rule. Dual numbers provide a straightforward implementation in R using S3 generic methods. A dual number has a real component and a “dual” component which can be used to exactly calculate the expression and derivative at a specific value of \(x\). Consider the quadratic form \(f(x) = 5x^2 + 3x + 10\) with derivative \(f^\prime(x) = 10x + 3\). The function and derivative can be evaluated at a value, say \(x = 5\) using the dual number \(5 + \varepsilon\), the dual component \(\varepsilon\) is considered small such that \(\varepsilon^2 = 0\) then calculating \(f(5 + \varepsilon)\):
\[\begin{align} f(5 + \varepsilon) &= 5(5 + \varepsilon)^2 + 3(5 + \varepsilon) + 10,\\ &= 5(25 + 10\varepsilon + \varepsilon^2) + 15 + 3\varepsilon + 10,\\ &= 5\varepsilon^2 + 53\varepsilon + 150. \end{align}\]
Then the coefficient of \(\varepsilon\) is the derivative and the constant is the evaluation of the function, \(f(5) = 150\) and \(f^\prime(5) = 53\).
S3 Objects
R has three systems for object oriented programming, S3, S4 and reference classes which can be learned about in the relevant chapter of Advanced R. Dual numbers can be implemented as an S3 class in R:
var
represents a variable which we want to differentiate, whereas const
represents a constant.
Next, primitive functions can be defined in terms of dual
numbers which simultaneously evaluate the function and the derivative:
<- function(x, y)
plus dual(x$real + y$real, x$eps + y$eps)
<- function(x, y)
minus dual(x$real - y$real, x$eps - y$eps)
<- function(x, y)
times dual(x$real * y$real, x$eps * y$real + y$eps * x$real)
<- function(x, y)
divide dual(
$real / y$real,
x$eps * y$real - x$real * y$eps) / (y$real * y$real)
(x )
Group generics can be used to implement the mathematics of dual numbers. Group generics included with base R include Math
which includes special functions such such as abs
and sqrt
as well as trigonometric and hyperbolic functions. Ops
which include the basic infix operations reqruired for arithmetic, +
, -
, *
, /
etc. For a full list of group generics associated with Math
and Ops
consult the R help by typing ?groupGeneric
in the R console. In order to implement a group generic for the S3 class dual
we implement Ops.dual
:
<- function(x, y) {
Ops.dual switch(
.Generic,`+` = plus(x, y),
`-` = minus(x, y),
`*` = times(x, y),
`/` = divide(x, y)
) }
switch
is used to pattern match on the generic function being called within Ops
by matching on .Generic
. Implementing dual numbers in this way allows us to define a function using the in-built infix operators in a natural way. The function \(f(x)\) can be defined in terms of dual numbers as
<- function(x)
f const(5) * x * x + const(3) * x + const(10)
Then evaluated at \(x = 5\) using the constructor var
which initialises a dual with \(\varepsilon = 1.0\).
f(var(5))
$real
[1] 150
$eps
[1] 53
attr(,"class")
[1] "dual"
The definition of f
is cumbersome since we have to explicitly create the constants using the const
constructor. The methods defined in Ops.dual
can be extended to handle cases when a double is multiplied by a dual number to convert the double to a const
and hence we can automatically differentiate any univariate function using forward mode automatic differentiation.
We can write a function which checks the arguments of plus
, minus
etc, then if the arguments aren’t explicitly dual number variables using the function var
then they are converted to a dual constant using const
. This function checks each argument (of a generic function of two arguments f
) in turn to determine if they are doubles then promotes them to constants.
<- function(f) {
lift_function function(x, y)
if (is.double(x)) {
f(const(x), y)
else if (is.double(y)) {
} f(x, const(y))
else {
} f(x, y)
} }
The ops can then be re-defined using the lift_function
:
<- function(x, y) {
Ops.dual switch(
.Generic,`+` = lift_function(plus)(x, y),
`-` = lift_function(minus)(x, y),
`*` = lift_function(times)(x, y),
`/` = lift_function(divide)(x, y)
) }
Then f
can be defined more naturally:
<- function(x)
f 5 * x * x + 3 * x + 10
And the derivative calculated:
f(var(5))
$real
[1] 150
$eps
[1] 53
attr(,"class")
[1] "dual"
Testing using Hedgehog
Hedgehog is a package which utilises testthat to implement property based testing in R. Property based testing can be used to check a wide range of inputs to a function and determine if the code outputs the expected value. In standard unit testing the state before the test is defined by programmer and typically does not change - if we were to consider a test for the derivative of the quadratic function defined above then we might write a test which evaluates the function at \(x = 5\). This verifies we are correct for \(x = 5\), but what about \(x = 0\) or another value. With property based testing, we define a random generator for the input and the test checks hundreds of potential values for failure.
The input to this property based test is a
, a number between \(-100\) and \(100\). The usual testthat
syntax is then used to evaluate the gradient using forward mode AD and comparing it to the exact derivative calculated by hand.
test_that("Derivative of 5x^2 + 3x + 10",
forall(list(a = gen.c(gen.element(
-100:100
))),function(a)
expect_equal(object = f(var(a))$eps, expected = 10 * a + 3)))
Test passed 🎊
Citation
@online{law2019,
author = {Jonny Law},
title = {Forward {Mode} {AD} in {R}},
date = {2019-08-05},
langid = {en}
}