• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

breandan/kotlingrad:

原作者: [db:作者] 来自: 网络 收藏 邀请

开源软件名称(OpenSource Name):

breandan/kotlingrad

开源软件地址(OpenSource Url):

https://github.com/breandan/kotlingrad

开源编程语言(OpenSource Language):

Kotlin 100.0%

开源软件介绍(OpenSource Introduction):

Kotlin∇: Type-safe Symbolic Differentiation for the JVM

Kotlin 1.6.20 Maven Central CI DOI

Kotlin∇ is a type-safe automatic differentiation framework written in Kotlin. It allows users to express differentiable programs with higher-dimensional data structures and operators. We attempt to restrict syntactically valid constructions to those which are algebraically valid and can be checked at compile-time. By enforcing these constraints in the type system, it eliminates certain classes of runtime errors that may occur during the execution of a differentiable program. Due to type-inference, most type declarations may be safely omitted by the end-user. Kotlin∇ strives to be expressive, safe, and notationally similar to mathematics.

Table of contents

Introduction

Inspired by Stalin∇, Autograd, DiffSharp, Myia, Nexus, Tangent, Lantern et al., Kotlin∇ attempts to port recent advancements in automatic differentiation (AD) to the Kotlin language. AD is useful for gradient descent and has a variety of applications in numerical optimization and machine learning. Our implementation adds a number of experimental ideas, including compile-time shape-safety, algebraic simplification and numerical stability checking with property-based testing. We aim to provide an algebraically-grounded implementation of AD for shape-safe tensor operations. Tensors in Kotlin∇ are represented as multidimensional arrays.

Features

Kotlin∇ currently supports the following features:

  • Arithmetical operations on scalars, vectors and matrices
  • Shape-safe vector and matrix algebra
  • Partial and higher-order differentiation on scalars
  • Property-based testing for numerical gradient checking
  • Recovery of symbolic derivatives from AD

Additionally, it aims to support:

All of these features are implemented without access to bytecode or special compiler tricks - just using higher-order functions and lambdas as shown in Lambda the Ultimate Backpropogator, embedded DSLs a la Lightweight Modular Staging, and ordinary generics. Please see below for a more detailed feature comparison.

Usage

Installation

Kotlin∇ is hosted on Maven Central. An example project is provided here.

Gradle

dependencies {
  implementation("ai.hypergraph:kotlingrad:0.4.7")
}

Maven

<dependency>
  <groupId>ai.hypergraph</groupId>
  <artifactId>kotlingrad</artifactId>
  <version>0.4.7</version>
</dependency>

Jupyter Notebook

To access Kotlin∇'s notebook support, use the following line magic:

@file:DependsOn("ai.hypergraph:kotlingrad:0.4.7")

For more information, explore the tutorial.

Notation

Kotlin∇ operators are higher-order functions, which take at most two inputs and return a single output, all of which are functions with the same numerical type, and whose shape is denoted using superscript in the rightmost column below.

Math Infix Prefix Postfix Operator Type Signature

a(b)
a of b
(a: ℝτ→ℝπ, b: ℝλ → ℝτ) → (ℝλ→ℝπ)
a + b
a - b
plus(a, b)
minus(a, b)
(a: ℝτ→ℝπ, b: ℝλ → ℝπ) → (ℝ?→ℝπ)
a * b
a.times(b)
times(a, b) (a: ℝτ→ℝm×n, b: ℝλ→ℝn×p) → (ℝ?→ℝm×p)

a / b
a.div(b)
div(a, b) (a: ℝτ→ℝm×n, b: ℝλ→ℝp×n) → (ℝ?→ℝm×p)
-a
+a
a.unaryMinus()
a.unaryPlus()
(a: ℝτ→ℝπ) → (ℝτ→ℝπ)


sin(a)
cos(a)
tan(a)
a.sin()
a.cos()
a.tan()
(a: ℝ→ℝ) → (ℝ→ℝ)
ln(a)
log(a)
a.ln()
a.log()
(a: ℝτ→ℝm×m) → (ℝτ→ℝm×m)
a.log(b) log(a, b) (a: ℝτ→ℝm×m, b: ℝλ→ℝm×m) → (ℝ?→ℝ)
a.pow(b) pow(a, b) (a: ℝτ→ℝm×m, b: ℝλ→ℝ) → (ℝ?→ℝm×m)

a.pow(1.0/2)
a.root(3)
sqrt(a)
cbrt(a)
a.sqrt()
a.cbrt()
(a: ℝτ→ℝm×m) → (ℝτ→ℝm×m)

a.d(b)
d(a) / d(b)
grad(a)[b] (a: C(ℝτ→ℝ)*, b: C(ℝλ→ℝ)) → (ℝ?→ℝ)
grad(a) a.grad() (a: C(ℝτ→ℝ)) → (ℝτ→ℝτ)
a.d(b)
a.grad(b)
grad(a, b)
grad(a)[b]
(a: C(ℝτ→ℝπ), b: C(ℝλ→ℝω)) → (ℝ?→ℝπ×ω)
divg(a) a.divg() (a: C(ℝτ→ℝm)) → (ℝτ→ℝ)
curl(a) a.curl() (a: C(ℝ3→ℝ3)) → (ℝ3→ℝ3)
grad(a) a.grad() (a: C(ℝτ→ℝm)) → (ℝτ→ℝm×τ)
hess(a) a.hess() (a: C(ℝτ→ℝ)) → (ℝτ→ℝτ×τ)
lapl(a) a.lapl() (a: C(ℝτ→ℝ)) → (ℝτ→ℝτ)

ℝ can be a Double, Float or BigDecimal. Specialized operators are defined for subsets of ℝ, e.g., Int, Short or BigInteger for subsets of ℤ, however differentiation is only defined for continuously differentiable functions on ℝ.

a and b are higher-order functions. These may be constants (e.g., 0, 1.0), variables (e.g., Var()) or expressions (e.g., x + 1, 2 * x + y).

For infix notation, . is optional. Parentheses are also optional depending on precedence.

§ Matrix division is defined iff B is invertible, although it could be possible to redefine this operator using the Moore-Penrose inverse.

Where C(ℝm) is the space of all continuous functions over ℝ. If the function is not over ℝ, it will fail at compile-time. If the function is over ℝ but not continuous differentiable at the point under consideration, it will fail at runtime.

? The input shape is tracked at runtime, but not at the type level. While it would be nice to infer a union type bound over the inputs of binary functions, it is likely impossible using the Kotlin type system without great effort. If the user desires type checking when invoking higher order functions with literal values, they will need to specify the combined input type explicitly or do so at runtime.

τ, λ, π, ω Arbitrary products.

Higher-Rank Derivatives

Kotlin∇ supports derivatives between tensors of up to rank 2. The shape of a tensor derivative depends on (1) the shape of the function under differentiation and (2) the shape of the variable with respect to which we are differentiating.

I/O Shape ?→ℝ ?→ℝm ?→ℝj×k
?→ℝ ?→ℝ ?→ℝm ?→ℝj×k
?→ℝn ?→ℝn ?→ℝm×n
?→ℝh×i ?→ℝh×i

Matrix-by-vector, vector-by-matrix, and matrix-by-matrix derivatives require rank 3+ tensors and are currently unsupported.

Higher-order derivatives

Kotlin∇ supports arbitrary order derivatives on scalar functions, and up to 2nd order derivatives on vector functions. Higher-order derivatives on matrix functions are unsupported.

Shape safety

Shape safety is an important concept in Kotlin∇. There are three broad strategies for handling shape errors:

  • Hide the error somehow by implicitly reshaping or broadcasting arrays
  • Announce the error at runtime, with a relevant message, e.g., InvalidArgumentError
  • Do not allow programs which can result in a shape error to compile

In Kotlin∇, we use the last strategy to check the shape of tensor operations. Consider the following program:

// Inferred type: Vec<Double, D2>
val a = Vec(1.0, 2.0)
// Inferred type: Vec<Double, D3>
val b = Vec(1.0, 2.0, 3.0)

val c = b + b

// Does not compile, shape mismatch
// a + b

Attempting to sum two vectors whose shapes do not match will fail to compile, and they must be explicitly resized.

// Inferred type: Mat<Double, D1, D4>
val a = Mat1x4(1.0, 2.0, 3.0, 4.0)
// Inferred type: Mat<Double, D4, D1>
val b = Mat4x1(1.0, 2.0, 3.0, 4.0)

val c = a * b

// Does not compile, inner dimension mismatch
// a * a
// b * b

Similarly, attempting to multiply two matrices whose inner dimensions do not match will fail to compile.

val a = Mat2x4( 
  1.0, 2.0, 3.0, 4.0,
  5.0, 6.0, 7.0, 8.0
)

val b = Mat4x2( 
  1.0, 2.0,
  3.0, 4.0,
  5.0, 6.0,
  7.0, 8.0
)

// Types are optional, but encouraged
val c: Mat<Double, D2, D2> = a * b 

val d = Mat2x1(1.0, 2.0)

val e = c * d

val f = Mat3x1(1.0, 2.0, 3.0)

// Does not compile, inner dimension mismatch
// e * f

Explicit types are optional but encouraged. Type inference helps preserve shape information over long programs.

fun someMatFun(m: Mat<Double, D3, D1>): Mat<Double, D3, D3> = ...
fun someMatFun(m: Mat<Double, D2, D2>) = ...

When writing a function, it is mandatory to declare the input type(s), but the return type may be omitted. Shape-safety is currently supported up to rank-2 tensors, i.e. matrices.

Example

The following example shows how to derive higher-order partials of a function z of type ℝ²→ℝ:

val z = x * (-sin(x * y) + y) * 4  // Infix notation
val `∂z∕∂x` = d(z) / d(x)          // Leibniz notation [Christianson, 2012]
val `∂z∕∂y` = d(z) / d(y)          // Partial derivatives
val `∂²z∕∂x²` = d(`∂z∕∂x`) / d(x)  // Higher-order derivatives
val `∂²z∕∂x∂y` = d(`∂z∕∂x`) / d(y) // Higher-order partials
val `∇z` = z.grad()                // Gradient operator

val values = arrayOf(x to 0, y to 1)

println("z(x, y) \t= $z\n" +
  "z(${values.map { it.second }.joinToString()}) \t\t= ${z(*values)}\n" +
  "∂z/∂x \t\t= $`∂z∕∂x` \n\t\t= " + `∂z∕∂x`(*values) + "\n" +
  "∂z/∂y \t\t= $`∂z∕∂y` \n\t\t= " + `∂z∕∂y`(*values) + "\n" +
  "∂²z/∂x² \t= $`∂z∕∂y` \n\t\t= " + `∂²z∕∂x²`(*values) + "\n" +
  "∂²z/∂x∂y \t= $`∂²z∕∂x∂y` \n\t\t= " + `∂²z∕∂x∂y`(*values) + "\n" +
  "∇z \t\t= $`∇z` \n\t\t= [${`∇z`[x]!!(*values)}, ${`∇z`[y]!!(*values)}]ᵀ")

Any backticks and unicode characters above are simply for readability and have no effect on the behavior. Running this program via ./gradlew HelloKotlingrad should produce the following output:

z(x, y)         = ((x) * ((- (sin((x) * (y)))) + (y))) * (4.0)
z(0, 1)         = 0.0
∂z/∂x           = d(((x) * ((- (sin((x) * (y)))) + (y))) * (4.0)) / d(x) 
                = 4.0
∂z/∂y           = d(((x) * ((- (sin((x) * (y)))) + (y))) * (4.0)) / d(y) 
                = 0.0
∂²z/∂x²         = d(((x) * ((- (sin((x) * (y)))) + (y))) * (4.0)) / d(y) 
                = 4.0
∂²z/∂x∂y        = d(d(((x) * ((- (sin((x) * (y)))) + (y))) * (4.0)) / d(x)) / d(y) 
                = 4.0
∇z              = {y=d(((x) * ((- (sin((x) * (y)))) + (y))) * (4.0)) / d(y), x=d(((x) * ((- (sin((x) * (y)))) + (y))) * (4.0)) / d(x)} 
                = [4.0, 0.0]ᵀ

Variable capture

Not only does Kotlin∇'s type system encode output shape, it is also capable of tracking free and bound variables, for order-independent name binding and partial application. Expressions inhabited by free variables are typed as functions until fully bound, at which time they return a concrete value. Consider the following example:

val q = X + Y * Z + Y + 0.0
val p0 = q(X to 1.0, Y to 2.0, Z to 3.0) // Name binding
val p1 = q(X to 1.0, Y to 1.0)(Z to 1.0) // Variadic currying
val p3 = q(Z to 1.0)(X to 1.0, Y to 1.0) // Any order is possible
val p4 = q(Z to 1.0)(X to 1.0)(Y to 1.0) // Proper currying
val p5 = q(Z to 1.0)(X to 1.0) // Returns a partially applied function
val p6 = (X + Z + 0)(Y to 1.0) // Does not compile

This feature is made possible by encoding a type-level Hasse diagram over a small set of predefined variable names, with skip-connections for variadic combination and partial application. Curious readers may glean further details by referring to the implementation and usage example.

Visualization tools

Kotlin∇ provides various graphical tools that can be used for visual debugging.

Dataflow graphs

Kotlin∇ functions are a type of directed acyclic graph, called dataflow graphs (DFGs). For example, running the expression ((1 + x * 2 - 3 + y + z / y).d(y).d(x) + z / y * 3 - 2).render() will display the following DFG:

Red and blue edges indicate the right and left inputs to a binary operator, respectively. Consider the DFG for a batch of stochastic gradients on linear regression, which can be written in matrix form as :

Thetas represent the hidden parameters under differentiation and the constants are the batch inputs (X) and targets (Y). When all the free variables are bound to numerical values, the graph collapses into a single node, which can be unwrapped into a Kotlin Number.

Plotting

To generate the sample 2D plots below, run ./gradlew Plot2D.

Plotting is also possible in higher dimensions, for example in 3D via ./gradlew Plot3D:

Loss curves

Gradient descent is one application for Kotlin∇. Below, is a typical loss curve of SGD on a multilayer perceptron:

To train the model, execute ./gradlew MLP from within the parent directory.

Testing

To run the tests, execute ../gradlew allTests from the core directory.

Kotlin∇ claims to eliminate certain runtime errors, but how do we know the proposed implementation is not incorrect? One method, borrowed from the Haskell community, is called property-based testing (PBT), closely related to metamorphic testing. Notable implementations include QuickCheck, Hypothesis and ScalaTest (ported to Kotlin in Kotest). PBT uses algebraic properties to verify the result of an operation by constructing semantically equivalent but syntactically distinct expressions, which should produce the same answer. Kotlin∇ uses two such equivalences to validate its AD implementation:

For example, consider the following test, which checks whether the analytical derivative and the automatic derivative, when evaluated at a given point, are equal to each other within the limits of numerical precision:

val x by Var()
val y by Var()


                      

鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap