dual <- function(real, eps) {
if (!is.numeric(real)) stop("real must be numeric")
structure(list(real = real, eps = eps), class = "dual")
}
var <- function(x) {
if (!is.numeric(x)) stop("x must be numeric")
dual(x, 1)
}
const <- function(x) {
if (!is.numeric(x)) stop("x must be numeric")
dual(x, 0)
}Forward Mode Automatic Differentation
Automatic differentiation can be used to calculate the exact derivative of a function at a point using applications of the chain rule. Dual numbers provide a straightforward implementation in R using S3 generic methods. A dual number has a real component and a “dual” component which can be used to exactly calculate the expression and derivative at a specific value of
Then the coefficient of
S3 Objects
R has three systems for object oriented programming, S3, S4 and reference classes which can be learned about in the relevant chapter of Advanced R. Dual numbers can be implemented as an S3 class in R:
var represents a variable which we want to differentiate, whereas const represents a constant.
Next, primitive functions can be defined in terms of dual numbers which simultaneously evaluate the function and the derivative:
plus <- function(x, y)
dual(x$real + y$real, x$eps + y$eps)
minus <- function(x, y)
dual(x$real - y$real, x$eps - y$eps)
times <- function(x, y)
dual(x$real * y$real, x$eps * y$real + y$eps * x$real)
divide <- function(x, y)
dual(
x$real / y$real,
(x$eps * y$real - x$real * y$eps) / (y$real * y$real)
)Group generics can be used to implement the mathematics of dual numbers. Group generics included with base R include Math which includes special functions such such as abs and sqrt as well as trigonometric and hyperbolic functions. Ops which include the basic infix operations reqruired for arithmetic, +, -, *, / etc. For a full list of group generics associated with Math and Ops consult the R help by typing ?groupGeneric in the R console. In order to implement a group generic for the S3 class dual we implement Ops.dual:
Ops.dual <- function(x, y) {
switch(
.Generic,
`+` = plus(x, y),
`-` = minus(x, y),
`*` = times(x, y),
`/` = divide(x, y)
)
}switch is used to pattern match on the generic function being called within Ops by matching on .Generic. Implementing dual numbers in this way allows us to define a function using the in-built infix operators in a natural way. The function
f <- function(x)
const(5) * x * x + const(3) * x + const(10)Then evaluated at var which initialises a dual with
f(var(5))$real
[1] 150
$eps
[1] 53
attr(,"class")
[1] "dual"
The definition of f is cumbersome since we have to explicitly create the constants using the const constructor. The methods defined in Ops.dual can be extended to handle cases when a double is multiplied by a dual number to convert the double to a const and hence we can automatically differentiate any univariate function using forward mode automatic differentiation.
We can write a function which checks the arguments of plus, minus etc, then if the arguments aren’t explicitly dual number variables using the function var then they are converted to a dual constant using const. This function checks each argument (of a generic function of two arguments f) in turn to determine if they are doubles then promotes them to constants.
lift_function <- function(f) {
function(x, y)
if (is.double(x)) {
f(const(x), y)
} else if (is.double(y)) {
f(x, const(y))
} else {
f(x, y)
}
}The ops can then be re-defined using the lift_function:
Ops.dual <- function(x, y) {
switch(
.Generic,
`+` = lift_function(plus)(x, y),
`-` = lift_function(minus)(x, y),
`*` = lift_function(times)(x, y),
`/` = lift_function(divide)(x, y)
)
}Then f can be defined more naturally:
f <- function(x)
5 * x * x + 3 * x + 10And the derivative calculated:
f(var(5))$real
[1] 150
$eps
[1] 53
attr(,"class")
[1] "dual"
Testing using Hedgehog
Hedgehog is a package which utilises testthat to implement property based testing in R. Property based testing can be used to check a wide range of inputs to a function and determine if the code outputs the expected value. In standard unit testing the state before the test is defined by programmer and typically does not change - if we were to consider a test for the derivative of the quadratic function defined above then we might write a test which evaluates the function at
The input to this property based test is a, a number between testthat syntax is then used to evaluate the gradient using forward mode AD and comparing it to the exact derivative calculated by hand.
test_that("Derivative of 5x^2 + 3x + 10",
forall(list(a = gen.c(gen.element(
-100:100
))),
function(a)
expect_equal(object = f(var(a))$eps, expected = 10 * a + 3)))Test passed 🎊