Note: This blog is inspired my the Autodiadax blog of JAX - where they implement JAX core from scratch. I found it very confusing, so I made my own thing. Hope I was able to do it justice
JAX - Why is Everyone So Excited About This Framework
The field of AI is going through an exponential surge, with new findings springing at an unprecedented rate. Accounting for Moore’s law for data, a need for a highly performant framework to do ML is the an absolute necessity, as ultimately, unlocking the machine FLOPS is probably the main goal of any framework. There have been a lot of frameworks such as Tensorflow, PyTorch, and recently JAX that have tried to unlock these machine FLOPS, and for the purpose of this blog, we’ll focus on JAX. There are a lot of this that make JAX unique, so let us jump right into it.
What is JAX?
JAX has been gaining a lot of traction in recent times, and for the right reasons. JAX allows researchers to write Python programs that are automatically scaled to leverage accelerators and supercomputers(without any additional effort). JAX was developed by Deepmind to meet a simple goal which is to balance rapid prototyping, quick iteration with the ability to deploy experiments at scale. For those aware with NumPy, think of JAX as just NumPy with Autodiff and nice distributed support. Keep these point in the back of your mind, and let us try to understand why these qualities in JAX are so important and have many people(and frontier AI labs) pivoting to it.
Core Features - JAX
Don’t take just my word for it, Francois Chollet(Keras founder) tweeted recently that almost all players in Generative AI are pivoting to JAX for the because it is fast, scales really well and there is TPU support too. A one line explanation of JAX would go something like this - JAX is basically NumPy on steroids, made for researchers to write efficient, performant and scalable workloads. JAX has a lot of unique propositions for a performant library, so let us look into what makes JAX special:
Automatic Differentiation
Autodiff keeps track of the grads and stuff, pretty important for ML workflows. We’ll cover Autodiff in a lot of detail in the coming up sections.
Just-In-Time Compilation
JAX uses a JIT compiler for speeding up entire blocks of code by exploiting any parallelism between them. Initially, we compile the function on its first use and later re-using the optimized version later, allowing efficient computations and lookup .
VMap(Auto Vectorization)
VMap(advanced vectorization) allows us to apply some function on one or more axes of a tensor. VMap vectorizes a function by adding a batch dimension to every primitive operation in the function.
PMap(SPMD programming)
JAX has built in support for Single Program Multiple Data Programming, allowing the same function to be run in parallel on it’s own XLA device.
In the introductory JAX blog post, it was mentioned that “JAX has una anima di pura programmazione funzionale”(has a soul of pure functional programming), so let us now try to understand why!
Understanding JAX on a deeper level
There is difference in understanding high level overview and low level understanding of the system, and in order to fully nail the concepts, we are actually going to implement the core JAX structure from scratch, completely in Python. A high level overview of what we are going to cover is given in the diagram below. Originally, this blog is inspired by the Autodiadax - which felt a little difficult to grasp, so I wrote it in simpler terms. Gear up, cause there is going to be a lot of code, and a lot more explanation. This is a work in process, so a lot can be added in the future version of this blog too.
For now, we are dividing this blog into 4 parts(all of which are going to covered in great detail), which are :
- Part 1: Transformations and Interpretation.
- Part 2: JaxPrs(Jax Expressions)
- Part 3: JIT(Just In Time Compilation)
- Part 4: VJP(Vectorized Jacobian Product)
We’ll implement everything from grounds up, in pure python(apart from the standard library, we are just going to use the XLA package to transfer computational workloads). Let us start tinkering.
Part 1 - Transformations and Interpretation
Let us start from the atomic unit of JAX - functions and operators. Traditionally, we apply these operations to numerical inputs, which gives us numerical answers. However, in JAX we want to override this behavior of operators and functions(which we treat as atomic units of processing - called primitives, rather than compositions) to be converted into JaxPrs(going to be covered in the next part, in great detail, but basically these are Jax Expressions, which is intermediate representation of program, and what JAX uses instead of pure Python code). Converting functions and operators into JaxPrs allows JAX to represent the function into a small, well-behaved intermediate form that is then interpreted with transformation specific interpretation rules. Transformations are basically high order functions transforming Jaxprs. Not all Python programs can be converted to Jaxprs, but for many scientific computing and ML workflows, we can do it. The examples of Transformations include:
A function to evaluate the gradient on the input function.jax.vmap():
A function to implement automatic vectorization.jax.pmap():
A function to implement data parallelism across processing units.
Let us try to define some primitives so we can understand their application:
import numpy as np
from typing import NamedTuple, Any
# An object with a name, to which we attach interpretation rules.
class Primitive(NamedTuple):
name: str
# Define the primitives - we'll start with the basic ones.
add_p = Primitive("add")
mul_p = Primitive("mul")
neg_p = Primitive("neg")
sin_p = Primitive("sin")
cos_p = Primitive("cos")
reduce_sum_p = Primitive("reduce_sum")
greater_p = Primitive("greater")
less_p = Primitive("less")
transpose_p = Primitive("transpose")
broadcast_p = Primitive("broadcast")
# Bind is the interception point.
def bind1(prim, *args, **kwargs):
out, = bind(prim, *args, **kwargs)
return out
# Values as positional args, Metadata as kwargs.
def add(x,y): return bind1(add_p, x, y)
def mul(x,y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x) : return bind1(sin_p, x)
def cos(x) : return bind1(cos_p, x)
def greater(x,y): return bind1(greater_p, x, y)
def less(x,y): return bind1(less_p, x, y)
def transpose(x,perm) : return bind1(transpose_p, x, perm=perm)
def broadcast(x,shape,axes) : return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x,axis) : return bind1(reduce_sum_p, x, axis=axis)
We’ll attach our interpolation rules to the Primitive
object - one for each transformation. The interception point is bind
, which will figure out which transformations to apply(based on certain rules which we are going to cover later)
All the pure python function arguments are wrapped around Tracer
objects - which records all the operations performed on it, creating a Jaxpr. The tracer object contains information about the shape, dtype of the initial arguments(not their value), which allows JAX to use the cached compiled program directly. Any change in shape/dtype triggers tracing, but not the value. This is the reason why only “functionally pure” functions(functions without side effects and which do not rely on values outside their arguments) should be used with JAX.
In the below code, MainTrace
is basically a interpreter, and we are representing the active interpreters as a stack. When we are about to apply any transformation, we’ll push another interpreter into the stack using the new_main
. At the bottom of the stack, there is a evaluation interpreter(or EvalTrace
- which we are going to see later in this section)
from contextlib import contextmanager
from typing import Any
class MainTrace(NamedTuple):
level: int
trace_type: type['Trace']
global_data: Any | None
trace_stack: list[MainTrace] = []
dynamic_trace: MainTrace | None = None # Later
def new_main(trace_type: type['Trace'], global_data=None):
level = len(trace_stack)
main = MainTrace(level, trace_type, global_data)
yield main
We’ll implement the Trace
and Tracer
base classes. A Tracer
is basically an object that flows through the Python program that we are transforming. It represents a boxed up value with data to be used by the interpreter(or MainTrace
in this case). Trace
on the other hand, boxes up Tracer
objects and also handles primitive applications.
class Trace:
main: MainTrace
def __init__(self, main) -> None:
self.main = main
def pure(self, val):
raise NotImplementedError()
def lift(self, val):
raise NotImplementedError()
def process_primitive(self, primitive, tracers, params):
raise NotImplementedError()
# One tracer per transformation
class Tracer:
# stores Trace.
_trace: Trace
__array_priority__ = 69
def aval(self):
# Tracer carries an abstract value. One abstract value per base type.
raise NotImplementedError()
def full_lower(self):
return self
def __neg__(self): return self.aval._neg(self)
def __add__(self, other): return self.aval._add(self, other)
def __radd__(self, other): return self.aval._radd(self, other)
def __mul__(self, other): return self.aval._mul(self, other)
def __rmul__(self, other): return self.aval._rmul(self, other)
def __gt__(self, other): return self.aval._gt(self, other)
def __lt__(self, other): return self.aval._lt(self, other)
def __bool__(self): return self.aval._bool(self)
def __nonzero__(self): return self.aval._nonzero(self)
def __getattr__(self, name: str) -> Any:
return getattr(self.aval, name)
except AttributeError:
raise AttributeError(f"No attribute exists : {name}")
def swap(f): return lambda x, y : f(y, x)
For our use case, we are going to focus on abstract values that wrap arrays divided into two classes based on different levels of abstraction. These are:
: This class represents the set of all possible arrays with a given shape and datatype.ConcreteArray
: This class represents a singleton set consisting of a single array value.
class ShapedArray:
""" Set of all possible arrays with a given shape and dtype. """
array_abstraction_level = 1
shape: tuple[int, ...]
dtype: np.dtype
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
def ndim(self):
return len(self.shape)
_neg = staticmethod(neg)
_add = staticmethod(add)
_radd = staticmethod(swap(add))
_mul = staticmethod(mul)
_rmul = staticmethod(swap(mul))
_gt = staticmethod(greater)
_lt = staticmethod(less)
def _bool(tracer):
raise Exception("Can't convert to bool")
def _nonzero(tracer):
raise Exception("Can't convert to bool")
def str_short(self):
return f'{}[{",".join(str(d) for d in self.shape)}]'
def __eq__(self, other):
return (type(self) == type(other) and self.shape == other.shape and self.dtype == other.dtype)
class ConcreteArray(ShapedArray):
""" Singleton set consisting of a single array value. """
array_abstraction_level = 2
val: np.ndarray
def __init__(self, val):
self.val = val
self.shape = val.shape
self.dtype = val.dtype
def _bool(tracer):
return bool(tracer.aval.val)
def _nonzero(tracer):
return bool(tracer.aval.val)
def get_aval(x):
if isinstance(x, Tracer):
return x.aval
elif type(x) in jax_types:
return ConcreteArray(np.asarray(x))
raise TypeError(x)
jax_types = {bool, int, float, np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}
After setting up the interpreter stack, the base classes for Trace
, and base classes for abstract values, we should come back and implement the bind
function - which, if you remember is our interception point to figure out which transformation rules to apply.
The steps performed by the bind
function are:
- Find Top Trace ; figure out which interpreter should handle the primitive application.
- Call the top trace’s process primitive so the trace can apply interpretation rule
- Full raise ensures that inputs are boxed in the tracer instances.
- Full lower for optional optimization, so that we unbox values out of Tracer as much as possible.
The main action is that we figure out which interpreter should handle this primitive application. We then call the top trace’s process_primitive
so that the trace can apply it’s interpretation rules. The calls to full_raise
just ensure that inputs are boxed in the top trace’s Tracer
is for optional optimization so that we unbox values out of Tracer
as much as possible).
def bind(prim, *args, **params):
top_trace = find_top_trace(args)
tracers = [full_raise(top_trace, arg) for arg in args]
outs = top_trace.process_primitive(prim, tracers, params)
return [full_lower(out) for out in outs]
import operator as op
# Returns the highest level interpreter associated with Tracer, otherwise returns the EvalTrace.
def find_top_trace(xs) -> Trace:
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),default=trace_stack[0], key=op.attrgetter('level'))
if dynamic_trace and dynamic_trace.level > top_main.level:
top_main = dynamic_trace
return top_main.trace_type(top_main)
def full_lower(val):
if isinstance(val, Tracer):
return val.full_lower()
return val
# Boxing up values into Tracer's for a particular Trace.
# Trace.pure is called for non-tracer constants, and Trace.lift called for values that are already Tracer's from a lower level interpreter.
def full_raise(trace, val):
if not isinstance(val, Tracer):
assert type(val) in jax_types
return trace.pure(val)
level = trace.main.level
if val._trace.main is trace.main:
return val
elif val._trace.min.level < level :
return trace.lift(val)
elif val._trace.min.level > level :
raise Exception("Cannot lift level")
raise Exception("Different traces at same level.")
Evaluation Interpreter
As explained earlier, the Evaluation Interpreter will sit at the bottom of the interpreter stack. Since this is the easiest to implement, we’ll start with this.
extends from the Trace
base class, and implements the process_primitive
function, which basically applies the implementation rule of the primitive.
As mentioned, the trace_stack(which is basically a list) has the EvalTrace
at the bottom. After that, we implement all the primitive functions(remember, we are just doing the vector operation, only we have a interception point which will figure out which transformations to apply)
class EvalTrace(Trace):
pure = lift = lambda self, x: x
def process_primitive(self, primitive, tracers, params):
return impl_rules[primitive](*tracers, **params)
trace_stack.append(MainTrace(0, EvalTrace, None))
impl_rules = {}
impl_rules[add_p] = lambda x, y: [np.add(x, y)]
impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]
impl_rules[neg_p] = lambda x: [np.negative(x)]
impl_rules[sin_p] = lambda x: [np.sin(x)]
impl_rules[cos_p] = lambda x: [np.cos(x)]
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
impl_rules[greater_p] = lambda x, y: [np.greater(x,y)]
impl_rules[less_p] = lambda x, y: [np.less(x,y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
def broadcast_impl(x, *, shape, axes):
for axis in sorted(axes):
x = np.expand_dims(x, axis)
return [np.broadcast_to(x, shape)] # read broadcasting rules!
impl_rules[broadcast_p] = broadcast_impl
We mentioned earlier that JAX is well suited for ML, and that means JAX has a good(and general) support for automatic differentiation(AD). AD can obtain gradients of numerical programs very efficiently, which we generally use to calculate loss and backpropagate it to minimize the calculated loss. Automatic differentiation basically applies a set of elementary operations on a function, and automatically computes the gradients by application of the chain rule. It makes a set of equations that include intermediate variables to create a computational graph, and then computes the gradients.
To accommodate generality in it’s AD system, JAX implements both forward and reverse mode automatic differentiation. The ever-so used grad
function is built on reverse mode AD, while for forward mode, JAX uses JVP(Jacobian Vector Product). JVPs are evaluated on the fly, so they are memory efficient, but in ML, we don’t see forward mode differentiation.1
Let us implement a JVP based Tracer that calculates both primals(basically the value of the function at any point) and the tangent(the forward mode gradient value associated with the function at that particular point). Before that, let us define some helper function we are going to use.
import builtins
# Get a vector full of zeros like the abstract value of array
def zeros_like(val):
aval = get_aval(val)
return np.zeros(aval.shape, aval.dtype)
# Given a pair of values, unpack them into two lists.
def unzip2(pairs):
lst1, lst2 = [], []
for x1, x2 in pairs:
return lst1, lst2
# Map values and wrap them in a list.
def map(f, *xs):
return list(, *xs))
# Returns a list of pairs of values.
def zip(*args):
fst, *rest = args = map(list, args)
n = len(fst)
for arg in rest:
assert len(arg) == n
return list(*args))
For forward mode differentiation, the JVPTracer
carries the boxed up primal-tangent pair of values, while the JVPTrace
applies the JVP rules. For initialization, we want to “package” the pure
and lift
values with zero tangent. After doing that, let us add some JVP rules for the primitives2. In the end, let us add a “Transformation API”(jvp_v1
) which pushes another interpreter into the stack using the new_main
and gives us the primal and the tangent associated with the primitive.
class JVPTracer(Tracer):
def __init__(self, trace, primal, tangent):
self._trace = trace
self.primal = primal
self.tangent = tangent
def aval(self):
return get_aval(self.primal)
class JVPTrace(Trace):
pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))
def process_primitive(self, primitive, tracers, params):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
jvp_rule = jvp_rules[primitive]
primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)
return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]
jvp_rules = {}
# Add some JVP rules
def add_jvp(primals, tangents):
(x,y), (x_dot, y_dot) = primals, tangents
return [x + y], [x_dot + y_dot]
jvp_rules[add_p] = add_jvp
def mul_jvp(primals, tangents):
(x,y), (x_dot, y_dot) = primals, tangents
return [x * y], [x_dot*y + x*y_dot]
jvp_rules[mul_p] = mul_jvp
def sin_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [sin(x)], [cos(x) * x_dot]
jvp_rules[sin_p] = sin_jvp
def cos_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [cos(x)], [-sin(x) * x_dot]
jvp_rules[cos_p] = cos_jvp
def neg_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [neg(x)], [neg(x_dot)]
jvp_rules[neg_p] = neg_jvp
def reduce_sum_jvp(primals, tangents, *, axis):
(x,), (x_dot, ) = primals, tangents
return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]
jvp_rules[reduce_sum_p] = reduce_sum_jvp
def greater_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = greater(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[greater_p] = greater_jvp
def less_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = less(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[less_p] = less_jvp
# Transformation API.
def jvp_v1(f, primals, tangents):
with new_main(JVPTrace) as main:
trace = JVPTrace(main)
tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
out = f(*tracers_in)
tracer_out = full_raise(trace, out)
primal_out, tangent_out = tracer_out.primal, tracer_out.tangent
return primal_out, tangent_out
With all these in place, we can now differentiate function in JAX. Here’s an example code:
#Example to demonstrate JVP_V1
x = 3.0
y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))
There’s a reason we named our Transformation API as jvp_v1
. A limitation in this is that it accepts arrays as input and gives a single array as output. In this next iteration of jvp
, we will deal with nested inputs and nested outputs. This could mean that at each layer of the stack, we might have to deal with nested inputs.
In order to deal with this, we are going to wrap the user function so that it accepts arrays as input, and returns a flat list of arrays as output. Since we are accepting user functions that have arbitrary “containers” in the inputs and outputs, just flattening a list inside of a list isn’t going to be very general, and we are actually going to need reference to a tree data structure. Let us try to understand why do we need a tree for this.
In our use case, we mentioned that the inputs and outputs containers can be of arbitrary depth. If we represent the topmost level as the parent node(where each node can have multiple children, and can represent both lists(internal nodes) and the value(leaf node)). In any scenario where the data is arranged in a hierarchal manner, a tree is the optimal data structure to use. A real life example where a hierarchal data is represented by a tree would be a folder/directory structure. In that case as well:
- Each node can have multiple children(either subfolders or files)
- Nodes can represent both folders and files.
def jvp_flat(f, primals, tangents):
with new_main(JVPTrace) as main:
trace = JVPTrace(main)
tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)
return primals_out, tangents_out
def jvp(f, primals, tangents):
primals_flat, in_tree = tree_flatten(primals)
tangents_flat, in_tree2 = tree_flatten(tangents)
if in_tree != in_tree2: raise TypeError
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
tangents_out = tree_unflatten(out_tree(), tangents_out_flat)
return primals_out, tangents_out
Now that we have understood why we are going to need a need a tree flatten/unflatten function, let us implement those.
Notice in the flatten_fun
function that the function might actually require the unflatten version of the arguments, we are going to provide it in the unflattened version only(this information isn’t available until we run the function, that is also the reason why we return the reference to flat_fun
# Notice that we need to provide the unflattened version of the args.
def flatten_fun(f, in_tree):
store = Store()
def flat_fun(*args_flat):
pytree_args = tree_unflatten(in_tree, args_flat)
out = f(*pytree_args)
out_flat, out_tree = tree_flatten(out)
return out_flat
return flat_fun, store
# Helper classes.
class Empty: pass
empty = Empty()
class Store:
val = empty
def set_value(self, val):
assert self.val is empty
self.val = val
def __call__(self):
return self.val
# PyTree handling of JVP implementation.
from import Hashable
import itertools as it
from import Callable
# Represents the node of the data container.
class NodeType(NamedTuple):
to_iterable: Callable
from_iterable: Callable
# Basically a typecheck function to register only lists, tuples and dicts as valid tree DS.
def register_pytree_node(ty, to_iter, from_iter):
node_types[ty] = NodeType(str(ty), to_iter, from_iter)
node_types = {}
# Only dict, tuple and list can represent PyTree. This also acts as a typecheck.
register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs))
register_pytree_node(list, lambda l: (None, l), lambda _, xs: list(xs))
register_pytree_node(dict, lambda d: map(tuple, unzip2(sorted(d.items()))), lambda keys, vals: dict(zip(keys, vals)))
class PyTreeDef(NamedTuple):
node_type: NodeType
child_treedefs: tuple['PyTreeDef', ...]
class Leaf: pass
leaf = Leaf()
# Flatten the tree.
def tree_flatten(x):
children_iter, tree_def = _tree_flatten(x)
return list(children_iter), tree_def
def _tree_flatten(x):
node_type = node_types.get(type(x))
if node_type:
node_metadata, children = node_type.to_iterable(x)
children_flat, child_trees = unzip2(map(_tree_flatten, children))
flattened = it.chain.from_iterable(children_flat)
return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees))
return [x], leaf
# Unflatten the tree.
def tree_unflatten(tree_def, xs):
return _tree_unflatten(tree_def, iter(xs))
def _tree_unflatten(tree_def, xs):
if tree_def is leaf:
return next(xs)
children = (_tree_unflatten(t, xs) for t in tree_def.child_treedefs)
return tree_def.node_type.from_iterable(tree_def.node_metadata, children)
We have successfully implemented arbitrary depth input/output containers. These will be helpful with future transformations as well. Here’s an example code to understand how it works:
# Define some arbitrary Pythonic function
def f(x):
y = 3.0 * sin(x) * cos(x)
z = x*x + y*y
# We can handle arbitrary depth inputs.
return {'Rick': z, 'Astley': [x, y]}
# Evaluate the functions at specific values using JVP.
x, xdot = 1.0, 1.5
y, ydot = jvp(f, (x,), (xdot,))
# Print the results
With the improved JVP implementation now in place, we can move to a another important bottleneck that JAX tries to resolve, especially when it comes to machine learning workflows - and that is vectorized batching. Traditionally, when batching data for ML workflows,(a reason to batch data is to parallelly process a chunk of data, and not process element by element) loop can be slow and can become a bottleneck for performance critical code. JAX vmap
addresses these issues, and provides the following functionalities:
- Applying operations to the entire array at once, not individual elements.
- Processing multiple inputs simultaneously.
- Seamless integration with other JAX functions, including
Let us implement vectorized batching with vmap
, but before that let us implement some helper functions:
: It produces mapped abstract values from the unmapped ones, by removing an axis from it.move_batch_axis
: It is used to move the batch dimensions around(by basically moving the axis).
# Produce mapped values from the unmapped ones.
def mapped_aval(batch_dim, aval):
shape = list(aval.shape)
del shape[batch_dim]
return ShapedArray(tuple(shape), aval.dtype)
# Move the batch axis by mving the axis.
def move_batch_axis(axis_size, src, dst, x):
if src is not_mapped:
target_shape = list(np.shape(x))
target_shape.insert(dst, axis_size)
return broadcast(x, target_shape, [dst])
elif src == dst:
return x
return move_axis(x, src, dst)
# Move the axis from src to dst.
def move_axis(x, src, dst):
perm = [i for i in range(np.dim(x)) if i != src]
perm.insert(dst, src)
return transpose(x, perm)
With the helper function implementation in place, let us shift our focus to implementing a BatchTracer
for vectorized batching. The tracer carries a batched value and an optional integer indicating which axis is the batch axis. Similar to other trace classes, BatchTrace
implements the pure
and lift
methods containing the boxed up values in a BatchTracer
instance. We use the MainTrace
’s global data field to store the batch axis size.
from typing import Union
# Wrapper class.
class NotMapped: pass
not_mapped = NotMapped()
# Wrapper class. Apart from BatchAxis, all are mapped.
BatchAxis = Union[NotMapped, int]
# Tracer to accomadate batching of data.
class BatchTracer(Tracer):
def __init__(self, trace, val, batch_dim):
self._trace = trace
self.val = val
self.batch_dim = batch_dim
def aval(self):
if self.batch_dim is not_mapped:
return get_aval(self.val)
return mapped_aval(self.batch_dim, get_aval(self.val))
def full_lower(self):
if self.batch_dim is not_mapped:
return full_lower(self.val)
return self
# Tracer
class BatchTrace(Trace):
pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)
def process_primitive(self, primitive, tracers, params):
vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)
vmap_rule = vmap_rules[primitive]
val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params)
return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)]
def axis_size(self):
return self.main.global_data
vmap_rules = {}
The next step is to implement the batched interpreter rules for each primitive. The implementation is divided into three classes of implementations, one for binary operators(addition and multiplication primitives), one for unary operator(such as sin, cos and negation primitives), and a separate one for reduce sum. With all these in place, we add a transformation API to start the trace(see the vmap_flat
and vmap
from functools import partial
# primitive rules - addition, multiplication
def binop_batching_rule(op, axis_size, vals_in, dims_in):
(x, y), (x_bdim, y_bdim) = vals_in, dims_in
if x_bdim != y_bdim:
if x_bdim is not_mapped:
x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
x_bdim = y_bdim
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
return [op(x,y)], [x_bdim]
vmap_rules[add_p] = partial(binop_batching_rule, add)
vmap_rules[mul_p] = partial(binop_batching_rule, mul)
# primitive rules - sin, cos, negation
def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
(x,), (x_bdim,) = vals_in, dims_in
return [op(x)], [x_bdim]
vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)
vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)
vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)
# primitive rules - reduce sum
def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):
(x,), (x_bdim,) = vals_in, dims_in
new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)
out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)
return [reduce_sum(x, new_axis)], [out_bdim]
vmap_rules[reduce_sum_p] = reduce_sum_batching_rule
# Transformation API.
def vmap_flat(f, in_axes, *args):
axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes) if ax is not not_mapped}
with new_main(BatchTrace, axis_size) as main:
trace = BatchTrace(main)
tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x for x, ax in zip(args, in_axes)]
outs = f(*tracers_in)
tracers_out = [full_raise(trace,out) for out in outs]
vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)
outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out) for val_out, bdim in zip(vals_out, bdims_out)]
return outs_transposed
def vmap(f, in_axes):
def batched_f(*args):
args_flat, in_tree = tree_flatten(args)
in_axes_flat, in_tree2 = tree_flatten(in_axes)
if in_tree != in_tree2 : raise TypeError
f_flat, out_tree = flatten_fun(f, in_tree)
outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)
return tree_unflatten(out_tree(), outs_flat)
return batched_f
Let’s see our implementation in action!
# Pythonic function
def add_to_a_scalar(scalar):
assert np.ndim(scalar) == 0
return 69 + scalar
# Vectorized operation using VMAP
vector_in = np.arange(420.0)
vector_out = vmap(add_to_a_scalar, (0,))(vector_in)
# Output
With the implementations of VMap and JVP(basically, Autodiff) in place, the next transformations in place are JIT and VJP(for reverse mode autodiff). The implemented transformations only needed each Tracer to carry an extra bit of context, but for JIT and VJP, we need much richer context(the next few sections are going to explain how), and for that, we need to represent Pythonic programs as JaxPrs.
Part 2 - JaxPrs
JaxPrs are JAX’s internal representation of programs. For JIT implementation, we need JaxPrs because JIT need to stage out any computation out of Python(mostly to XLA), and therefore to represent the data using JaxPrs helps in tracing the python function back up. In the case of VJP, JaxPrs provide a way to represent the computation for the backward pass of the reverse mode autodiff. To represent JaxPrs as Python data structure, we re-use the ShapedArray
class defined before(for types) and can represent the term syntax with a few Python structs.
# Class to hold abstract value as ShapedArray.
class Var:
aval: ShapedArray
def __init__(self, aval): self.aval = aval
# Class for holding value(both normal and abstract)
class Lit:
val: Any
aval: ShapedArray
def __init__(self, val):
self.aval = aval = raise_to_shaped(get_aval(val))
self.val = np.array(val, aval.dtype)
# Atom is the building block for JaxPrs.
Atom = Union[Var, Lit]
# A JaxprEqn is basically a class holding the primitive, and the inputs and outputs associated with it.
class JaxprEqn(NamedTuple):
primitive: Primitive
inputs: list[Atom]
params: dict[str, Any]
out_binders: list[Var]
# A JaxPr can hold multiple JaxprEqn
class Jaxpr(NamedTuple):
in_binders: list[Var]
eqns: list[JaxprEqn]
outs: list[Atom]
def __hash__(self): return id(self)
__eq__ = op.is_
def raise_to_shaped(aval):
return ShapedArray(aval.shape, aval.dtype)
Type checking is very strict in JAX, which is crucial for speeding up computational workflows. Strict type checking allows JAX to perform type specialization, and optimize code for specific data types. For JaxPrs, type checking involves checking whether there are no unbound variables, and that variables are bound only once, and the equation of the type of primitive matches the type of output binders. JaxPrs are platform-agnostic, so type checking ensures consistency across platforms.
class JaxprsType(NamedTuple):
in_types: list[ShapedArray]
out_types: list[ShapedArray]
def __repr__(self):
in_types = ", ".join(aval.str_short() for aval in self.in_types)
out_types = ", ".join(aval.str_short() for aval in self.out_types)
return f'({in_types}) -> ({out_types})'
# Typechcek for reasons mentioned above.
def typecheck_jaxpr(jaxpr):
env: set[Var] = set()
for v in jaxpr.in_binders:
if v in env: raise TypeError
for eqn in jaxpr.eqns:
in_types = [typecheck_atom(env, x) for x in eqn.inputs]
out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)
for out_binder, out_type in zip(eqn.out_binders, out_types):
if not out_type == out_binder.aval: raise TypeError
for out_binder in eqn.out_binders:
if out_binder in env: raise TypeError
in_types = [v.aval for v in jaxpr.in_binders]
out_types = [typecheck_atom(env, x) for x in jaxpr.outs]
return JaxprsType(in_types, out_types)
def typecheck_atom(env, x):
if isinstance(x, Var):
if x not in env: raise TypeError("Unbound Variable")
return x.aval
elif isinstance(x, Lit):
return raise_to_shaped(get_aval(x.val))
assert False
# This is a simple JaxPr interpreter, with type checking.
def eval_jaxpr(jaxpr, args):
env = {}
def read(x):
return env[x] if type(x) is Var else x.val
def write(v, val):
assert v not in env
env[v] = val
map(write, jaxpr.in_binders, args)
for eqn in jaxpr.eqns:
in_vals = map(read, eqn.inputs)
outs = bind(eqn.primitive, *in_vals, **eqn.params) # Using bind makes this interpreter traceable too.
map(write, eqn.out_binders, outs)
return map(read, jaxpr.outs)
def jaxpr_as_fun(jaxpr):
return lambda *args: eval_jaxpr(jaxpr, args)
Similarly to what we did with other interpreters, we’ll now enable tracing for JaxPrs. There are two ways in which we can do this, and we’ll start with what jit
uses. But first, let us define some helper functions and then build up the JaxPrTrace
and JaxPrTracer
The JaxPrTrace
class implements a new_arg
function to return a Tracer
instance after adding it to the builder. The get_or_make_tracer
method add a tracer to the builder, or if it doesn’t exists(checked using the id
of the Tracer instance). The pure
and lift
variables of the Tracer return the reference to this function. The process_primitive
function is similar to the ones described before, with the only difference being the use of JaxPrs.
# Helper functions: Jaxprs with Tracing.
# Split a list and outputs the partitions.
def split_list(lst, n):
assert 0 <= n <= len(lst)
return lst[:n], lst[n:]
# Partition a list and return the components.
def partition_list(bs, l):
assert len(bs) == len(l)
lists = lst1, lst2 = [], []
for b, x in zip(bs, l):
return lst1, lst2
# Tracer, as mentioned contains the boxed-up values.
class JaxprTracer(Tracer):
__slots__ = ['aval']
aval: ShapedArray
def __init__(self, trace, aval):
self._trace = trace
self.aval = aval
# Main JaxPrTrace class.
class JaxprTrace(Trace):
def new_arg(self, aval):
aval = raise_to_shaped(aval)
tracer = self.builder.new_tracer(self, aval)
self.builder.tracer_to_var[id(tracer)] = Var(aval)
return tracer
def get_or_make_const_tracer(self, val):
tracer = self.builder.const_tracers.get(id(val))
if tracer is None:
tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val)))
self.builder.add_const(tracer, val)
return tracer
pure = lift = get_or_make_const_tracer
def process_primitive(self, primitive, tracers, params):
avals_in = [t.aval for t in tracers]
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
out_tracers = [self.builder.new_tracer(self, a) for a in avals_out]
inputs = [self.builder.getvar(t) for t in tracers]
outvars = [self.builder.add_var(t) for t in out_tracers]
self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars))
return out_tracers
def builder(self):
return self.main.global_data
abstract_eval_rules = {}
is the container we use to keeps track of the variables, constants and equations, - as the interpreter global data and will be referenced later as we build up the JaxPrs. The implementation is followed.
# Container class to hold up data.
class JaxprBuilder:
eqns: list[JaxprEqn]
tracer_to_var: dict[int, Var]
const_tracers: dict[int, JaxprTracer]
constvals: dict[Var, Any]
tracers: list[JaxprTracer]
def __init__(self):
self.eqns = []
self.tracer_to_var = {}
self.const_tracers = {}
self.constvals = {}
self.tracers = []
# Add a new tracer with a given aval
def new_tracer(self, trace, aval):
tracer = JaxprTracer(trace, aval)
return tracer
## Other getter and setters method for the class, self explanatory.
def add_eqn(self, eqn):
def add_var(self, tracer):
assert id(tracer) not in self.tracer_to_var
var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)
return var
def getvar(self, tracer):
var = self.tracer_to_var.get(id(tracer))
assert var is not None
return var
def add_const(self, tracer, val):
var = self.add_var(tracer)
self.const_tracers[id(val)] = tracer
self.constvals[var] = val
return var
def build(self, in_tracers, out_tracers):
constvars, constvals = unzip2(self.constvals.items())
t2v = lambda t: self.tracer_to_var[id(t)]
in_binders = constvars + [t2v(t) for t in in_tracers]
out_vars = [t2v(t) for t in out_tracers]
jaxpr = Jaxpr(in_binders, self.eqns, out_vars)
typecheck_jaxpr(jaxpr) # important step!
jaxpr, constvals = _inline_literals(jaxpr, constvals)
return jaxpr, constvals
def _inline_literals(jaxpr, consts):
const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))
scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]
new_const_binders, lit_binders = partition_list(scalars, const_binders)
new_consts, lit_vals = partition_list(scalars, consts)
literals = dict(zip(lit_binders, map(Lit, lit_vals)))
new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs], eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]
new_outs = [literals.get(x, x) for x in jaxpr.outs]
new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)
return new_jaxpr, new_consts
With the Tracer, Trace implementations in place, let us implement the eval_rules
as we did for other cases as well. Most of these are very general, with the intention that these abstraction will be reused for other JaxPr-producing trace methods.
# Binop for Add, Multiply.
def binop_abstract_eval(x, y):
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError
return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[add_p] = binop_abstract_eval
abstract_eval_rules[mul_p] = binop_abstract_eval
# Compare for less than, greater than.
def compare_abstract_eval(x, y):
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if x.shape != y.shape: raise TypeError
return [ShapedArray(x.shape, np.dtype('bool'))]
abstract_eval_rules[greater_p] = compare_abstract_eval
abstract_eval_rules[less_p] = compare_abstract_eval
# Vectorized Op for Sin, Cosine and Negation.
def vectorized_unop_abstract_eval(x):
return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval
# Different eval for reduce_sum.
def reduce_sum_abstract_eval(x, *, axis):
axis_ = set(axis)
new_shape = [d for i,d in enumerate(x.shape) if i not in axis_]
return [ShapedArray(tuple(new_shape), x.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval
# One for broadcast as well.
def broadcast_abstract_eval(x, *, shape, axes):
return [ShapedArray(tuple(shape), x.dtype)]
abstract_eval_rules[broadcast_p] = broadcast_abstract_eval
With all the things in place, we can kick off our Transformation API. There is however a really fundamental flaw in make_jaxpr_v1
, which maybe deserves a blog post on it’s own3 . In short, the input which were not boxed up in JaxprTracer
instances ended up wasting memory, time dispatching and maybe even fragmenting memory.
This “omnistagging” issue ensures that JaxprTrace started by make_jaxpr
is always applied. Conceptually, the dynamic trace is identical to stashing the current interpreter stack and starting a new one with the JaxprTrace at the bottom. The new transformation API(make_jaxpr
) uses the dynamic_trace
global(see Part 1) for this reason.
from functools import lru_cache
def make_jaxpr_v1(f, *avals_in):
avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree)
builder = JaxprBuilder()
with new_main(JaxprTrace, builder) as main:
trace = JaxprTrace(main)
tracers_in = [trace.new_arg(aval) for aval in avals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
jaxpr, consts =, tracers_out)
return jaxpr, consts, out_tree()
# There is a limitations tho. This version can't stage out all the primitve opeations performed by the Python Callable.
def new_dynamic(main):
global dynamic_trace
prev_dynamic_trace, dynamic_trace = dynamic_trace, main
dynamic_trace = prev_dynamic_trace
def make_jaxpr(f, *avals_in):
avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree)
builder = JaxprBuilder()
with new_main(JaxprTrace, builder) as main:
with new_dynamic(main):
trace = JaxprTrace(main)
tracers_in = [trace.new_arg(aval) for aval in avals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
jaxpr, consts =, tracers_out)
return jaxpr, consts, out_tree()
Part 3 - JIT
After converting the Pythonic functions into JaxPrs, the next step is taken up by JIT compiler. JIT, or Just-In-Time compiler analyzes the JaxPrs and identifies the specific operations needed, then the optimized machine code is generated for those operations. JIT only compiles the necessary code(and caches the compiled machine code for future use) - giving significant speedups for computationally expensive workflows.
Similar to JaxPrs and JVP, even JIT has a transformation like API transforming a Python function, but conceptually, JIT under the hood is a high-order primitive(basically, a high order primitive is parameterized by a function) rather than a transformation. Similar to a primitive, JIT take JaxPrs as input, returns a “transformed” function(in this case, an optimized version of that function) as output, and operates on functional level, transforming it’s execution.
In order to handle high order primitives, JIT uses a staged processing approach, where we can just use make_jaxpr
in the primitive wrapper to form JaxPrs up-front and skip the python function entirely4 - which is what we need to stage these computation to XLA(Accelerated Linear Algebra) for ML workflows.
Since JIT is a high-level primitive, we need to give it transformation rules. When we evaluate any xla_primitive
application, we stage out the computation to XLA by translating the JaxPrs into an XLA HLO program ; including transferring the argument values to the XLA device, executing the XLA program(and cache the results as per the shape and dtype signature), and transferring back the results.
# This is the staged JIT wrapper, with computation done by XLA.
def jit(f):
def f_jitted(*args):
avals_in = [raise_to_shaped(get_aval(x)) for x in args]
jaxpr, consts, out_tree = make_jaxpr(f, *avals_in)
outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts = len(consts))
return tree_unflatten(out_tree, outs)
return f_jitted
xla_call_p = Primitive('xla_call')
# Utility for XLA call.
class IDhashable:
val: Any
def __init__(self, val):
self.val = val
def __hash__(self):
return id(self.val)
def __eq__(self, other):
return type(other) is IDhashable and id(self.val) == id(other.val)
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
xe = xc._xla
xops = xc._xla.ops
def xla_call_impl(*args, jaxpr, num_consts):
consts, args = args[:num_consts], args[num_consts:]
hashable_consts = tuple(map(IDhashable, consts))
execute = xla_callable(IDhashable(jaxpr), hashable_consts)
return execute(*args)
impl_rules[xla_call_p] = xla_call_impl
def xla_callable(hashable_jaxpr, hashable_consts):
jaxpr = hashable_jaxpr.val
consts = [x.val for x in hashable_consts]
in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
c = xc.XlaBuilder('xla_call')
xla_consts = _xla_consts(c, consts)
xla_params = _xla_params(c, in_avals)
outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)
out = xops.Tuple(c, outs)
compiled = xb.get_backend(None).compile(
return partial(execure_compiled, compiled, [v.aval for v in jaxpr.outs])
def _xla_consts(c, consts):
unique_consts = {id(cnst): cnst for cnst in consts}
xla_consts = {
id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()
return [xla_consts[id(cnst)] for cnst in consts]
def _xla_params(c, avals_in):
return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]
def _xla_shape(aval):
return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)
Let us now define the transformations for xla_call_p
, other than its evaluation rule.
# JVP rule for XLA call.
def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
del num_consts
new_jaxpr, new_consts = jvp_jaxpr(jaxpr)
outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,
n = len(outs) // 2
primals_out, tangents_out = outs[:n], outs[n:]
return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule
# JVP for the JaxPrs.
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
def jvp_traceable(*primals_and_tangents):
n = len(primals_and_tangents) // 2
primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]
return jvp(jaxpr_as_fun(jaxpr), primals, tangents)
in_avals = [v.aval for v in jaxpr.in_binders]
new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals)
return new_jaxpr, new_consts
# VMAP rule for XLA call.
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
del num_consts # Unused
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
) -> tuple[Jaxpr, list[Any]]:
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
in_avals = [unmapped_aval(axis_size, d, v.aval)
for v, d in zip(jaxpr.in_binders, bdims_in)]
new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals)
return new_jaxpr, new_consts
def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
) -> ShapedArray:
if batch_dim is not_mapped:
return aval
shape = list(aval.shape)
shape.insert(batch_dim, axis_size)
return ShapedArray(tuple(shape), aval.dtype)
# Abstract Eval XLA call rule.
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):
del num_consts # Unused
jaxpr_type = typecheck_jaxpr(jaxpr)
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
raise TypeError
return jaxpr_type.out_types
abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule
def xla_call_translation(c, in_avals, out_avals, in_vals, *, jaxpr, num_consts):
del num_consts, out_avals
# Calling jaxpr_subcomp directly would inline.
with ir.InsertionPoint(c.module.body):
@func.func(*(aval_to_ir_type(aval) for aval in in_avals))
def inner_xla_call(*params):
return jaxpr_subcomp(c, jaxpr, params)
name = c.symbol_table.insert(inner_xla_call.func_op)
return func.CallOp(inner_xla_call.func_op, in_vals).results
hlo_translations[xla_call_p] = xla_call_translation
With all the rules in place, we turn our attention to an important issue - memory persistence for arrays. After the XLA operation is done, we transferred the results back to CPU memory as np.array
, but most of the time, we need to transfer these results back for the next operation. For that, we’ll introduce an Array class to wrap up XLA buffers.
def handle_result(aval: ShapedArray, buf): # noqa: F811
return Array(aval, buf)
class Array:
buf: Any
aval: ShapedArray
def __init__(self, aval, buf):
self.aval = aval
self.buf = buf
dtype = property(lambda self: self.aval.dtype)
shape = property(lambda self: self.aval.shape)
ndim = property(lambda self: self.aval.ndim)
def __array__(self): return np.asarray(self.buf)
def __repr__(self): return repr(np.asarray(self.buf))
def __str__(self): return str(np.asarray(self.buf))
_neg = staticmethod(neg)
_add = staticmethod(add)
_radd = staticmethod(add)
_mul = staticmethod(mul)
_rmul = staticmethod(mul)
_gt = staticmethod(greater)
_lt = staticmethod(less)
input_handlers[Array] = lambda x: x.buf
With that, we implemented another core feature of JAX. Let’s move on to the next part, where we implement some special autodiff functions - linearize
and vjp
, which have some caveats.
Part 4 - linearize
and vjp
Here’s an diagram summarizing key points of linearize and VJP in JAX. Let us implement them!!
Implementing linearize
Linearize computes the linear approximation of a function, and it operates in the tangent space. Therefore, we want to stage out the linear part of JVP computation to build a JaxPr from a JVP. To do this, we need to perform partial evaluation - to evaluate all the primal values as a tarce, but stage the tangent computations into a Jaxpr. Unlike the previous make_jaxpr
functions, this approach stages out only those primitive binds with a dependence on tangent inputs.
def split_half(lst: list[Any]) -> tuple[list[Any], list[Any]]:
assert not len(lst) % 2
return split_list(lst, len(lst) // 2)
def merge_lists(which: list[bool], l1: list[Any], l2: list[Any]) -> list[Any]:
l1, l2 = iter(l1), iter(l2)
out = [next(l2) if b else next(l1) for b in which]
assert next(l1, None) is next(l2, None) is None
return out
def linearize_flat(f, *primals_in):
pvals_in = ([PartialVal.known(x) for x in primals_in] +
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
def f_jvp(*primals_tangents_in):
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
return [*primals_out, *tangents_out]
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)
primal_pvals, _ = split_half(pvals_out)
assert all(pval.is_known for pval in primal_pvals)
primals_out = [pval.const for pval in primal_pvals]
f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])
return primals_out, f_lin
# This linearize function has JVp and partial evaluation combined.
def linearize(f, *primals_in):
primals_in_flat, in_tree = tree_flatten(primals_in)
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
def f_lin(*tangents_in):
tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
if in_tree != in_tree2: raise TypeError
tangents_out_flat = f_lin_flat(*tangents_in_flat)
return tree_unflatten(out_tree(), tangents_out_flat)
return primals_out, f_lin
def vspace(aval: ShapedArray) -> ShapedArray:
return raise_to_shaped(aval) # TODO handle integers?
As mentioned, in the linearize
, there is JVP + general partial information transformation. The workflow is simple, turn a Python callable into outputs of two types - one where all the outputs can be computed from the known outputs, and a partial JaxPr which can only be performed after its required inputs are known5. Think of partial evaluation as “unzipping” one computation into two(one for primal, and one for tangent jaxpr). We kind of only want to form a JaxPr for those operations whose operations must be delayed due to dependence on unknown inputs, which reduces unnecessary evaluations.
For the reasons mentioned above, let us start our implementation by creating a PartialVal
class. Partial evaluation will take a list of PartialVal
representing inputs, and return a list of PartialVal
outputs along with a jaxpr representing the delayed computation:
class PartialVal(NamedTuple):
aval: ShapedArray
const: Any | None
def known(cls, val: Any):
return PartialVal(get_aval(val), val)
def unknown(cls, aval: ShapedArray):
return PartialVal(aval, None)
# To check whether the inputs are known or unknown, so we distribute accordingly.
is_known = property(lambda self: self.const is not None)
is_unknown = property(lambda self: self.const is None)
# The transformation API.
def partial_eval_flat(f: Callable, pvals_in: list[PartialVal]
) -> tuple[Jaxpr, list[PartialVal], list[Any]]:
with new_main(PartialEvalTrace) as main:
trace = PartialEvalTrace(main)
tracers_in = [trace.new_arg(pval) for pval in pvals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
pvals_out = [t.pval for t in tracers_out]
unk_tracers_in = [t for t in tracers_in if t.pval.is_unknown]
unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown]
jaxpr, consts = tracers_to_jaxpr(unk_tracers_in, unk_tracers_out)
return jaxpr, pvals_out, consts
Now, we’ll implement the PartialEvalTrace
and PartialEvalTracer
. The difference with the previous versions of the Trace
, Tracer
classes is that the interpreter will build JaxPrs on the fly and will keep track of data dependencies. In order to do so, they implement a Bipartite DAG(Directed Acyclic Graph) between the nodes of the PartialEvalTracer
(representing staged out values) and JaxprRecipe
nodes (representing formulas for how to compute some values from others). The reason to choose Bipartite graph structure is to simplify dependency management(modular updates and extensions)
These recipe’s can be of several types - JaxprEqnRecipe
(corresponding to JaxPrEqn
’s primitive application), and constants, lambda binders.
from weakref import ref, ReferenceType
# Wrapper for a lambda recipe.
class LambdaBindingRecipe(NamedTuple):
# Wrapper for a const recipe.
class ConstRecipe(NamedTuple):
val: Any
# Wrapper for Jaxpr recipe
class JaxprEqnRecipe(NamedTuple):
prim: Primitive
tracers_in: list['PartialEvalTracer']
params: dict[str, Any]
avals_out: list[ShapedArray]
tracer_refs_out: list['ReferenceType[PartialEvalTracer]']
JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
# Partial Eval Tracer - contains boxedup values.
class PartialEvalTracer(Tracer):
pval: PartialVal
recipe: JaxprRecipe | None
def __init__(self, trace, pval, recipe):
self._trace = trace
self.pval = pval
self.recipe = recipe
aval = property(lambda self: self.pval.aval)
def full_lower(self):
if self.pval.is_known:
return full_lower(self.pval.const)
return self
With these implementations in place, let us now implement PartialEvalTrace
. Each argument in it corresponds to a LambdaBindingRecipe
leaf node, and each constant is a ConstRecipe
leaf node holding a reference to the constant.
The implementation for process_primitive
is also straightforward. If all inputs are known then we can bind the primitive to the known values(basically evaluate it in Python). On the other hand, if any inputs is unknown, then we stage out into a JaxprEqnRecipe
representing the primitive application. All but the call to xla_call_primitive
works on this logic(in the XLA primitive, we require recursive treatment)
class PartialEvalTrace(Trace):
def new_arg(self, pval: PartialVal) -> Any:
return PartialEvalTracer(self, pval, LambdaBindingRecipe())
def lift(self, val: Any) -> PartialEvalTracer:
return PartialEvalTracer(self, PartialVal.known(val), None)
pure = lift
def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:
if tracer.pval.is_unknown:
return tracer
pval = PartialVal.unknown(raise_to_shaped(tracer.aval))
return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))
def process_primitive(self, primitive, tracers, params):
if all(t.pval.is_known for t in tracers):
return bind(primitive, *map(full_lower, tracers), **params)
rule = partial_eval_rules.get(primitive)
if rule: return rule(self, tracers, **params)
tracers_in = [self.instantiate_const(t) for t in tracers]
avals_in = [t.aval for t in tracers_in]
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)
for aval in avals_out]
eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,
map(ref, tracers_out))
for t in tracers_out: t.recipe = eqn
return tracers_out
partial_eval_rules = {}
Now, we can build graph representations of JaxPrs with PartialEvalTrace
- we just need a mechanism to convert graph representation to standard JaxPr(corresponds to a topological sort of the graph).
def tracers_to_jaxpr(tracers_in: list[PartialEvalTracer],
tracers_out: list[PartialEvalTracer]):
tracer_to_var: dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))
for t in tracers_in}
constvar_to_val: dict[int, Any] = {}
constid_to_var: dict[int, Var] = {}
processed_eqns: set[int] = set()
eqns: list[JaxprEqn] = []
for t in toposort(tracers_out, tracer_parents):
if isinstance(t.recipe, LambdaBindingRecipe):
assert id(t) in set(map(id, tracers_in))
elif isinstance(t.recipe, ConstRecipe):
val = t.recipe.val
var = constid_to_var.get(id(val))
if var is None:
aval = raise_to_shaped(get_aval(val))
var = constid_to_var[id(val)] = Var(aval)
constvar_to_val[var] = val
tracer_to_var[id(t)] = var
elif isinstance(t.recipe, JaxprEqnRecipe):
if id(t.recipe) not in processed_eqns:
eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
raise TypeError(t.recipe)
constvars, constvals = unzip2(constvar_to_val.items())
in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
out_vars = [tracer_to_var[id(t)] for t in tracers_out]
jaxpr = Jaxpr(in_binders, eqns, out_vars)
return jaxpr, constvals
def recipe_to_eqn(tracer_to_var: dict[int, Var], recipe: JaxprEqnRecipe
) -> JaxprEqn:
inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]
out_binders = [Var(aval) for aval in recipe.avals_out]
for t_ref, var in zip(recipe.tracer_refs_out, out_binders):
if t_ref() is not None: tracer_to_var[id(t_ref())] = var
return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)
def tracer_parents(t: PartialEvalTracer) -> list[PartialEvalTracer]:
return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []
## Toposort and stuff
def toposort(out_nodes: list[Any], parents: Callable[[Any], list[Any]]):
if not out_nodes: return []
out_nodes = remove_duplicates(out_nodes)
child_counts = {}
stack = list(out_nodes)
while stack:
node = stack.pop()
if id(node) in child_counts:
child_counts[id(node)] += 1
child_counts[id(node)] = 1
for node in out_nodes:
child_counts[id(node)] -= 1
sorted_nodes = []
childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
while childless_nodes:
node = childless_nodes.pop()
for parent in parents(node):
if child_counts[id(parent)] == 1:
child_counts[id(parent)] -= 1
sorted_nodes = sorted_nodes[::-1]
check_toposort(sorted_nodes, parents)
return sorted_nodes
def remove_duplicates(lst):
seen = set()
return [x for x in lst if id(x) not in seen and not seen.add(id(x))]
def check_toposort(nodes: list[Any], parents: Callable[[Any], list[Any]]):
seen = set()
for node in nodes:
assert all(id(parent) in seen for parent in parents(node))
Let us test it in action. We also need to implement the partial evaluation rule xla_call_p
(to handle JIT and related functions). There are two rules to write, one for trace-time partial evaluation(xla_call_partial_eval
), and one for partial evaluation of Jaxprs(xla_call_peval_eqn
# Example usage.
y, sin_lin = linearize(sin, 3.)
print(y, sin(3.))
print(sin_lin(1.), cos(3.))
## Rules for XLA Primitive(trace time and Partial Eval)
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
del num_consts # Unused
in_unknowns = [not t.pval.is_known for t in tracers]
jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
known_vals = [t.pval.const for t in known_tracers]
outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0)
outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res)
res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
for v in jaxpr2.outs]
eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers,
dict(jaxpr=jaxpr2, num_consts=0),
[v.aval for v in jaxpr2.outs], map(ref, outs2))
for t in outs2: t.recipe = eqn
return merge_lists(out_unknowns, outs1, outs2)
partial_eval_rules[xla_call_p] = xla_call_partial_eval
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
instantiate: list[bool] | None = None,
) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
env: dict[Var, bool] = {}
residuals: set[Var] = set()
def read(x: Atom) -> bool:
return type(x) is Var and env[x]
def write(unk: bool, v: Var) -> None:
env[v] = unk
def new_res(x: Atom) -> Atom:
if type(x) is Var: residuals.add(x)
return x
eqns1, eqns2 = [], []
map(write, in_unknowns, jaxpr.in_binders)
for eqn in jaxpr.eqns:
unks_in = map(read, eqn.inputs)
rule = partial_eval_jaxpr_rules.get(eqn.primitive)
if rule:
eqn1, eqn2, unks_out, res = rule(unks_in, eqn)
eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res)
map(write, unks_out, eqn.out_binders)
elif any(unks_in):
inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)]
eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders))
map(partial(write, True), eqn.out_binders)
map(partial(write, False), eqn.out_binders)
out_unknowns = map(read, jaxpr.outs)
if instantiate is not None:
for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):
if inst and not uk: new_res(v)
out_unknowns = map(op.or_, out_unknowns, instantiate)
residuals, num_res = list(residuals), len(residuals)
assert all(type(v) is Var for v in residuals), residuals
ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)
jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals)
jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2)
typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2)
return jaxpr1, jaxpr2, out_unknowns, num_res
def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
jaxprty = typecheck_jaxpr(jaxpr) # (a1, a2) -> (b1, b2 )
jaxpr1ty = typecheck_jaxpr(jaxpr1) # a1 -> (b1, res)
jaxpr2ty = typecheck_jaxpr(jaxpr2) # (res, a2) -> b2
a1, a2 = partition_list(unks_in, jaxprty.in_types)
b1, b2 = partition_list(unks_out, jaxprty.out_types)
b1_, res = split_list(jaxpr1ty.out_types, len(b1))
res_, a2_ = split_list(jaxpr2ty.in_types, len(res))
b2_ = jaxpr2ty.out_types
if jaxpr1ty.in_types != a1: raise TypeError
if jaxpr2ty.out_types != b2: raise TypeError
if b1 != b1_: raise TypeError
if res != res_: raise TypeError
if a2 != a2_: raise TypeError
if b2 != b2_: raise TypeError
partial_eval_jaxpr_rules = {}
def xla_call_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Var]]:
jaxpr = eqn.params['jaxpr']
jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
ins1, ins2 = partition_list(unks_in, eqn.inputs)
out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)
residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]
eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),
out_binders1 + residuals)
eqn2 = JaxprEqn(xla_call_p, residuals + ins2,
dict(jaxpr=jaxpr2, num_consts=0), out_binders2)
return eqn1, eqn2, unks_out, residuals
partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn
Now, we can compose linearize
and jit
#Example usage.
def f(x):
y = sin(x) * 2.
z = - y + x
return z
y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
Implementing vjp
and grad
The vjp
transformation is very similar to linearize, with the only difference being the fact that we transpose the linear part of the computation before returning it, so our implementation is pretty straightforward. Also, since we have the linear computation as a JaxPr, we can implement the transpose transformation as a JaxPr interpreter. The use of UndefPrimal
instance is to indicate which arguments we want to transpose(and with respect to what). We also register this as a pytree node as that gives us a handy way to prune these placeholders out of argument lists.
# VJP implementation
def vjp(f, x):
y, f_lin = linearize(f, x)
f_vjp = lambda y_bar: transpose(f_lin)(y_bar)
return y, f_vjp
# Transpose transformation
def vjp_flat(f, *primals_in):
pvals_in = ([PartialVal.known(x) for x in primals_in] +
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
primal_pvals_in, tangent_pvals_in = split_half(pvals_in)
def f_jvp(*primals_tangents_in):
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
return [*primals_out, *tangents_out]
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in) # linearize
primal_pvals, _ = split_half(pvals_out)
assert all(pval.is_known for pval in primal_pvals)
primals_out = [pval.const for pval in primal_pvals]
transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in]
f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts)
return primals_out, f_vjp
def vjp(f, *primals_in):
primals_in_flat, in_tree = tree_flatten(primals_in)
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
def f_vjp(*cotangents_out):
cotangents_out_flat, _ = tree_flatten(cotangents_out)
cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)
return tree_unflatten(in_tree, cotangents_in_flat)
return primals_out, f_vjp
class UndefPrimal(NamedTuple):
aval: ShapedArray
lambda u: (u.aval, ()),
lambda aval, _: UndefPrimal(aval))
Next, we can write eval_jaxpr_transposed
, along with the transpose rules for all the primitives(which can be linear in at least one argument).
# NB: the analogous function in JAX is called 'backward_pass'
def eval_jaxpr_transposed(jaxpr: Jaxpr, args: list[Any], cotangents: list[Any]
) -> list[Any]:
primal_env: dict[Var, Any] = {}
ct_env: dict[Var, Any] = {}
def read_primal(x: Atom) -> Any:
return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val
def write_primal(v: Var, val: Any) -> None:
if type(val) is not UndefPrimal:
primal_env[v] = val
def read_cotangent(v: Var) -> Any:
return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype))
def write_cotangent(x: Atom, val: Any):
if type(x) is Var and val is not None:
ct_env[x] = add(ct_env[x], val) if x in ct_env else val
map(write_primal, jaxpr.in_binders, args)
map(write_cotangent, jaxpr.outs, cotangents)
for eqn in jaxpr.eqns[::-1]:
primals_in = map(read_primal, eqn.inputs)
cts_in = map(read_cotangent, eqn.out_binders)
rule = transpose_rules[eqn.primitive]
cts_out = rule(cts_in, *primals_in, **eqn.params)
map(write_cotangent, eqn.inputs, cts_out)
return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args)
if type(x) is UndefPrimal]
transpose_rules = {}
# Rules
def mul_transpose_rule(cts, x, y):
z_bar, = cts
assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal)
return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)]
transpose_rules[mul_p] = mul_transpose_rule
def neg_transpose_rule(cts, x):
ybar, = cts
assert type(x) is UndefPrimal
return [neg(ybar)]
transpose_rules[neg_p] = neg_transpose_rule
def add_transpose_rule(cts, x, y):
z_bar, = cts
return [z_bar, z_bar]
transpose_rules[add_p] = add_transpose_rule
def reduce_sum_transpose_rule(cts, x, *, axis):
y_bar, = cts
return [broadcast(y_bar, x.aval.shape, axis)]
transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule
def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
del num_consts # Unused
undef_primals = [type(x) is UndefPrimal for x in invals]
transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))
residuals, _ = partition_list(undef_primals, invals)
outs = bind(xla_call_p, *new_consts, *residuals, *cts,
jaxpr=transposed_jaxpr, num_consts=len(new_consts))
outs = iter(outs)
return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
) -> tuple[Jaxpr, list[Any]]:
avals_in, avals_out = typecheck_jaxpr(jaxpr)
traceable = partial(eval_jaxpr_transposed, jaxpr)
args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]
trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))
return trans_jaxpr, consts
Now that we can linearize and transpose, we can finally write grad
# Grad implementation.
def grad(f):
def gradfun(x, *xs):
y, f_vjp = vjp(f, x, *xs)
if np.shape(y) != (): raise TypeError
x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y)))
return x_bar
return gradfun
# Example usage.
def f(x):
y = sin(x) * 2.
z = - y + x
return z
Finally, we are done with our implementation!!! Give yourself a pat on the back, you now have your version of JAX, in Python, spelled out completely. I’ll be maintaining a repository for this particular blog, and will update things as I learn more about this amazing library. Would like to thank Lucas Beyer for replying to my tweet to motivate me to understand more about this framework.
- Autodiadax - Implementing the JAX Core from scratch.
- PyTorch is Dead, Long Live JAX - Neel Gupta.
- Is JAX better than PyTorch - Reddit Discussion.
- JAX - The Sharp Bits - JAX Docs
A better and more detailed explanation(without abstracting any math) goes something like this(according to JAX Docs):
To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with “tall” Jacobians, but inefficient for “wide” Jacobians.
If you’re doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\mathbb{R}^N$ to a scalar loss value in $\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\partial{f(x)} \in \mathbb{R}^{1 \times n}$, which we often identify with the Gradient vector $\nabla f(x) \in \mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluate the original function, sure seems inefficient! In particular, for training neural networks, where f is a training loss function and n can be in the millions or billions, this approach just won’t scale. To do better for functions like this, we just need to use reverse-mode autodiff.
For a refresher on autodiff, refer this Hackernews post with some really awesome explanations by really awesome people(and possible pitfalls). ↩
This is often referred to as the “omnistaging” issue in the JAX-ML repo. Even if I try, I can’t explain in the detail this PR is described. Highly recommend to read it. ↩
As per the official documentation, this is what happens in staged processing.
In this case,
puts itsJaxprTrace
at the top of the interpreter stack, and no transformations lower in the stack, which might enter via closed-over Tracers, are applied to the Python callable as we trace it. (Transformations applied within the Python callable are applied as usual, being added to the stack above the JaxprTrace.) Instead, the transformations lower in the stack are later applied to the call primitive, and the call primitive’s rules must then transform the jaxpr itself. Because we trace to a jaxpr up-front, this approach can’t support data-dependent Python control flow, but it is more straightforward to implement. We refer to this kind of higher-order primitive as an “initial-style higher-order primitive”, and say that its jaxpr-processing transformation rules are “initial-style transformation rules.” -
This transformation is tricky to summarize in a type signature. If we assume the input function’s type signature is
(a1, a2) -> (b1, b2)
, wherea1
represent the known and unknown inputs, respectively, and whereb1
only has a data dependency ona1
has some data dependency ona2
, then we might write:partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)