Note: This blog is inspired my the Autodiadax blog of JAX - where they implement JAX core from scratch. I found it very confusing, so I made my own thing. Hope I was able to do it justice

JAX - Why is Everyone So Excited About This Framework

Introduction

The field of AI is going through an exponential surge, with new findings springing at an unprecedented rate. Accounting for Moore’s law for data, a need for a highly performant framework to do ML is the an absolute necessity, as ultimately, unlocking the machine FLOPS is probably the main goal of any framework. There have been a lot of frameworks such as Tensorflow, PyTorch, and recently JAX that have tried to unlock these machine FLOPS, and for the purpose of this blog, we’ll focus on JAX. There are a lot of this that make JAX unique, so let us jump right into it.

What is JAX?

JAX has been gaining a lot of traction in recent times, and for the right reasons. JAX allows researchers to write Python programs that are automatically scaled to leverage accelerators and supercomputers(without any additional effort). JAX was developed by Deepmind to meet a simple goal which is to balance rapid prototyping, quick iteration with the ability to deploy experiments at scale. For those aware with NumPy, think of JAX as just NumPy with Autodiff and nice distributed support. Keep these point in the back of your mind, and let us try to understand why these qualities in JAX are so important and have many people(and frontier AI labs) pivoting to it.

Core Features - JAX

Don’t take just my word for it, Francois Chollet(Keras founder) tweeted recently that almost all players in Generative AI are pivoting to JAX for the because it is fast, scales really well and there is TPU support too. A one line explanation of JAX would go something like this - JAX is basically NumPy on steroids, made for researchers to write efficient, performant and scalable workloads. JAX has a lot of unique propositions for a performant library, so let us look into what makes JAX special:

Automatic Differentiation

Autodiff keeps track of the grads and stuff, pretty important for ML workflows. We’ll cover Autodiff in a lot of detail in the coming up sections.

Just-In-Time Compilation

JAX uses a JIT compiler for speeding up entire blocks of code by exploiting any parallelism between them. Initially, we compile the function on its first use and later re-using the optimized version later, allowing efficient computations and lookup .

VMap(Auto Vectorization)

VMap(advanced vectorization) allows us to apply some function on one or more axes of a tensor. VMap vectorizes a function by adding a batch dimension to every primitive operation in the function.

PMap(SPMD programming)

JAX has built in support for Single Program Multiple Data Programming, allowing the same function to be run in parallel on it’s own XLA device.

In the introductory JAX blog post, it was mentioned that “JAX has una anima di pura programmazione funzionale”(has a soul of pure functional programming), so let us now try to understand why!

Understanding JAX on a deeper level

There is difference in understanding high level overview and low level understanding of the system, and in order to fully nail the concepts, we are actually going to implement the core JAX structure from scratch, completely in Python. A high level overview of what we are going to cover is given in the diagram below. Originally, this blog is inspired by the Autodiadax - which felt a little difficult to grasp, so I wrote it in simpler terms. Gear up, cause there is going to be a lot of code, and a lot more explanation. This is a work in process, so a lot can be added in the future version of this blog too.

Pasted image 20241017131715

For now, we are dividing this blog into 4 parts(all of which are going to covered in great detail), which are :

  1. Part 1: Transformations and Interpretation.
  2. Part 2: JaxPrs(Jax Expressions)
  3. Part 3: JIT(Just In Time Compilation)
  4. Part 4: VJP(Vectorized Jacobian Product)

We’ll implement everything from grounds up, in pure python(apart from the standard library, we are just going to use the XLA package to transfer computational workloads). Let us start tinkering.

Part 1 - Transformations and Interpretation

Let us start from the atomic unit of JAX - functions and operators. Traditionally, we apply these operations to numerical inputs, which gives us numerical answers. However, in JAX we want to override this behavior of operators and functions(which we treat as atomic units of processing - called primitives, rather than compositions) to be converted into JaxPrs(going to be covered in the next part, in great detail, but basically these are Jax Expressions, which is intermediate representation of program, and what JAX uses instead of pure Python code). Converting functions and operators into JaxPrs allows JAX to represent the function into a small, well-behaved intermediate form that is then interpreted with transformation specific interpretation rules. Transformations are basically high order functions transforming Jaxprs. Not all Python programs can be converted to Jaxprs, but for many scientific computing and ML workflows, we can do it. The examples of Transformations include:

  • jax.grad():A function to evaluate the gradient on the input function.
  • jax.vmap():A function to implement automatic vectorization.
  • jax.pmap():A function to implement data parallelism across processing units.

Let us try to define some primitives so we can understand their application:


import numpy as np

from typing import NamedTuple, Any

# An object with a name, to which we attach interpretation rules.
class Primitive(NamedTuple):
    name: str

# Define the primitives - we'll start with the basic ones.
add_p = Primitive("add")
mul_p = Primitive("mul")
neg_p = Primitive("neg")
sin_p = Primitive("sin")
cos_p = Primitive("cos")
reduce_sum_p = Primitive("reduce_sum")
greater_p = Primitive("greater")
less_p = Primitive("less")
transpose_p = Primitive("transpose")
broadcast_p = Primitive("broadcast")

#  Bind is the interception point.
def bind1(prim, *args, **kwargs):
    out, = bind(prim, *args, **kwargs)
    return out

# Values as positional args, Metadata as kwargs.
def add(x,y): return bind1(add_p, x, y)
def mul(x,y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x) : return bind1(sin_p, x)
def cos(x) : return bind1(cos_p, x)
def greater(x,y): return bind1(greater_p, x, y)
def less(x,y): return bind1(less_p, x, y)
def transpose(x,perm) : return bind1(transpose_p, x, perm=perm)
def broadcast(x,shape,axes) : return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x,axis) : return bind1(reduce_sum_p, x, axis=axis)

We’ll attach our interpolation rules to the Primitive object - one for each transformation. The interception point is bind, which will figure out which transformations to apply(based on certain rules which we are going to cover later)

All the pure python function arguments are wrapped around Tracer objects - which records all the operations performed on it, creating a Jaxpr. The tracer object contains information about the shape, dtype of the initial arguments(not their value), which allows JAX to use the cached compiled program directly. Any change in shape/dtype triggers tracing, but not the value. This is the reason why only “functionally pure” functions(functions without side effects and which do not rely on values outside their arguments) should be used with JAX.

In the below code, MainTrace is basically a interpreter, and we are representing the active interpreters as a stack. When we are about to apply any transformation, we’ll push another interpreter into the stack using the new_main. At the bottom of the stack, there is a evaluation interpreter(or EvalTrace - which we are going to see later in this section)


from contextlib import contextmanager
from typing import Any

class MainTrace(NamedTuple):
    level: int
    trace_type: type['Trace']
    global_data: Any | None

trace_stack: list[MainTrace] = []
dynamic_trace: MainTrace | None = None  # Later  

@contextmanager
def new_main(trace_type: type['Trace'], global_data=None):

    level = len(trace_stack)
    main = MainTrace(level, trace_type, global_data)
    trace_stack.append(main)

    try:
        yield main
    finally:
        trace_stack.pop()

We’ll implement the Trace and Tracer base classes. A Tracer is basically an object that flows through the Python program that we are transforming. It represents a boxed up value with data to be used by the interpreter(or MainTrace in this case). Trace on the other hand, boxes up Tracer objects and also handles primitive applications.


class Trace:
    main: MainTrace

    def __init__(self, main) -> None:
        self.main = main

    def pure(self, val):
        raise NotImplementedError()

    def lift(self, val):
        raise NotImplementedError()

    def process_primitive(self, primitive, tracers, params):
        raise NotImplementedError()

# One tracer per transformation
class Tracer:
	# stores Trace.
    _trace: Trace
    __array_priority__ = 69

    @property
    def aval(self):
	    # Tracer carries an abstract value. One abstract value per base type.
        raise NotImplementedError()

    def full_lower(self):
        return self

    def __neg__(self): return self.aval._neg(self)

    def __add__(self, other): return self.aval._add(self, other)

    def __radd__(self, other): return self.aval._radd(self, other)

    def __mul__(self, other): return self.aval._mul(self, other)

    def __rmul__(self, other): return self.aval._rmul(self, other)

    def __gt__(self, other): return self.aval._gt(self, other)

    def __lt__(self, other): return self.aval._lt(self, other)

    def __bool__(self): return self.aval._bool(self)

    def __nonzero__(self): return self.aval._nonzero(self)

    def __getattr__(self, name: str) -> Any:
        try:
            return getattr(self.aval, name)
        except AttributeError:
            raise AttributeError(f"No attribute exists : {name}")

def swap(f): return lambda x, y : f(y, x)

For our use case, we are going to focus on abstract values that wrap arrays divided into two classes based on different levels of abstraction. These are:

  • ShapedArray: This class represents the set of all possible arrays with a given shape and datatype.
  • ConcreteArray: This class represents a singleton set consisting of a single array value.

class ShapedArray:

    """ Set of all possible arrays with a given shape and dtype. """

    array_abstraction_level = 1
    shape: tuple[int, ...]
    dtype: np.dtype

    def __init__(self, shape, dtype):
        self.shape = shape
        self.dtype = dtype

    @property
    def ndim(self):
        return len(self.shape)

    _neg = staticmethod(neg)
    _add = staticmethod(add)
    _radd = staticmethod(swap(add))
    _mul = staticmethod(mul)
    _rmul = staticmethod(swap(mul))
    _gt = staticmethod(greater)
    _lt = staticmethod(less)

    @staticmethod
    def _bool(tracer):
        raise Exception("Can't convert to bool")

    @staticmethod
    def _nonzero(tracer):
        raise Exception("Can't convert to bool")

    def str_short(self):
        return f'{self.dtype.name}[{",".join(str(d) for d in self.shape)}]'

    def __eq__(self, other):
        return (type(self) == type(other) and self.shape == other.shape and self.dtype == other.dtype)

class ConcreteArray(ShapedArray):
    """ Singleton set consisting of a single array value. """

    array_abstraction_level = 2
    val: np.ndarray

    def __init__(self, val):
        self.val = val
        self.shape = val.shape
        self.dtype = val.dtype

    @staticmethod
    def _bool(tracer):
        return bool(tracer.aval.val)

    @staticmethod
    def _nonzero(tracer):
        return bool(tracer.aval.val)

def get_aval(x):
    if isinstance(x, Tracer):
        return x.aval

    elif type(x) in jax_types:
        return ConcreteArray(np.asarray(x))

    else:
        raise TypeError(x)

jax_types = {bool, int, float, np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}

After setting up the interpreter stack, the base classes for Trace/Tracer, and base classes for abstract values, we should come back and implement the bind function - which, if you remember is our interception point to figure out which transformation rules to apply.

The steps performed by the bind function are:

  1. Find Top Trace ; figure out which interpreter should handle the primitive application.
  2. Call the top trace’s process primitive so the trace can apply interpretation rule
  3. Full raise ensures that inputs are boxed in the tracer instances.
  4. Full lower for optional optimization, so that we unbox values out of Tracer as much as possible.

The main action is that we figure out which interpreter should handle this primitive application. We then call the top trace’s process_primitive so that the trace can apply it’s interpretation rules. The calls to full_raise just ensure that inputs are boxed in the top trace’s Tracer instances(full_lower is for optional optimization so that we unbox values out of Tracer as much as possible).

def bind(prim, *args, **params):
    top_trace = find_top_trace(args)
    tracers = [full_raise(top_trace, arg) for arg in args]
    outs = top_trace.process_primitive(prim, tracers, params)
    return [full_lower(out) for out in outs]

import operator as op

# Returns the highest level interpreter associated with Tracer, otherwise returns the EvalTrace.
def find_top_trace(xs) -> Trace:
  top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),default=trace_stack[0], key=op.attrgetter('level'))

  if dynamic_trace and dynamic_trace.level > top_main.level:
    top_main = dynamic_trace

  return top_main.trace_type(top_main)

  
def full_lower(val):
    if isinstance(val, Tracer):
        return val.full_lower()
    return val
  
# Boxing up values into Tracer's for a particular Trace.
# Trace.pure is called for non-tracer constants, and Trace.lift called for values that are already Tracer's from a lower level interpreter.
def full_raise(trace, val):
    if not isinstance(val, Tracer):
        assert type(val) in jax_types
        return trace.pure(val)

    level = trace.main.level

    if val._trace.main is trace.main:
        return val

    elif val._trace.min.level < level :
        return trace.lift(val)  

    elif val._trace.min.level > level :
        raise Exception("Cannot lift level")

    else:
        raise Exception("Different traces at same level.")

Evaluation Interpreter

As explained earlier, the Evaluation Interpreter will sit at the bottom of the interpreter stack. Since this is the easiest to implement, we’ll start with this.

EvalTrace extends from the Trace base class, and implements the process_primitive function, which basically applies the implementation rule of the primitive.

As mentioned, the trace_stack(which is basically a list) has the EvalTrace at the bottom. After that, we implement all the primitive functions(remember, we are just doing the vector operation, only we have a interception point which will figure out which transformations to apply)

class EvalTrace(Trace):
    pure = lift = lambda self, x: x

    def process_primitive(self, primitive, tracers, params):
        return impl_rules[primitive](*tracers, **params)

trace_stack.append(MainTrace(0, EvalTrace, None))

impl_rules = {}

impl_rules[add_p] = lambda x, y: [np.add(x, y)]
impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]
impl_rules[neg_p] = lambda x: [np.negative(x)]
impl_rules[sin_p] = lambda x: [np.sin(x)]
impl_rules[cos_p] = lambda x: [np.cos(x)]
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
impl_rules[greater_p] = lambda x, y: [np.greater(x,y)]
impl_rules[less_p] = lambda x, y: [np.less(x,y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]

def broadcast_impl(x, *, shape, axes):
    for axis in sorted(axes):
        x = np.expand_dims(x, axis)
    return [np.broadcast_to(x, shape)] # read broadcasting rules!

impl_rules[broadcast_p] = broadcast_impl

We mentioned earlier that JAX is well suited for ML, and that means JAX has a good(and general) support for automatic differentiation(AD). AD can obtain gradients of numerical programs very efficiently, which we generally use to calculate loss and backpropagate it to minimize the calculated loss. Automatic differentiation basically applies a set of elementary operations on a function, and automatically computes the gradients by application of the chain rule. It makes a set of equations that include intermediate variables to create a computational graph, and then computes the gradients.

To accommodate generality in it’s AD system, JAX implements both forward and reverse mode automatic differentiation. The ever-so used grad function is built on reverse mode AD, while for forward mode, JAX uses JVP(Jacobian Vector Product). JVPs are evaluated on the fly, so they are memory efficient, but in ML, we don’t see forward mode differentiation.1

Let us implement a JVP based Tracer that calculates both primals(basically the value of the function at any point) and the tangent(the forward mode gradient value associated with the function at that particular point). Before that, let us define some helper function we are going to use.

import builtins

# Get a vector full of zeros like the abstract value of array
def zeros_like(val):
    aval = get_aval(val)
    return np.zeros(aval.shape, aval.dtype)

# Given a pair of values, unpack them into two lists.
def unzip2(pairs):
    lst1, lst2 = [], []
    for x1, x2 in pairs:
        lst1.append(x1)
        lst2.append(x2)
    return lst1, lst2

# Map values and wrap them in a list.
def map(f, *xs):
    return list(builtins.map(f, *xs))

# Returns a list of pairs of values.
def zip(*args):
    fst, *rest = args = map(list, args)
    n = len(fst)
    for arg in rest:
        assert len(arg) == n
    return list(builtins.zip(*args))

For forward mode differentiation, the JVPTracer carries the boxed up primal-tangent pair of values, while the JVPTrace applies the JVP rules. For initialization, we want to “package” the pure and lift values with zero tangent. After doing that, let us add some JVP rules for the primitives2. In the end, let us add a “Transformation API”(jvp_v1) which pushes another interpreter into the stack using the new_main and gives us the primal and the tangent associated with the primitive.

class JVPTracer(Tracer):
    def __init__(self, trace, primal, tangent):
        self._trace = trace
        self.primal = primal
        self.tangent = tangent

    @property
    def aval(self):
        return get_aval(self.primal)

class JVPTrace(Trace):
    pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))

    def process_primitive(self, primitive, tracers, params):
        primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
        jvp_rule = jvp_rules[primitive]
        primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)
        return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]

jvp_rules = {}

# Add some JVP rules
def add_jvp(primals, tangents):
    (x,y), (x_dot, y_dot) = primals, tangents
    return [x + y], [x_dot + y_dot]

jvp_rules[add_p] = add_jvp

def mul_jvp(primals, tangents):
    (x,y), (x_dot, y_dot) = primals, tangents
    return [x * y], [x_dot*y + x*y_dot]

jvp_rules[mul_p] = mul_jvp

def sin_jvp(primals, tangents):
    (x,), (x_dot,) = primals, tangents
    return [sin(x)], [cos(x) * x_dot]
  
jvp_rules[sin_p] = sin_jvp

def cos_jvp(primals, tangents):
    (x,), (x_dot,) = primals, tangents
    return [cos(x)], [-sin(x) * x_dot]

jvp_rules[cos_p] = cos_jvp

def neg_jvp(primals, tangents):
    (x,), (x_dot,) = primals, tangents
    return [neg(x)], [neg(x_dot)]

jvp_rules[neg_p] = neg_jvp

def reduce_sum_jvp(primals, tangents, *, axis):
    (x,), (x_dot, ) = primals, tangents
    return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]

jvp_rules[reduce_sum_p] = reduce_sum_jvp

def greater_jvp(primals, tangents):
    (x, y), _ = primals, tangents
    out_primal = greater(x, y)
    return [out_primal], [zeros_like(out_primal)]

jvp_rules[greater_p] = greater_jvp

def less_jvp(primals, tangents):
    (x, y), _ = primals, tangents
    out_primal = less(x, y)
    return [out_primal], [zeros_like(out_primal)]

jvp_rules[less_p] = less_jvp

# Transformation API.
def jvp_v1(f, primals, tangents):
    with new_main(JVPTrace) as main:
        trace = JVPTrace(main)
        tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
        out = f(*tracers_in)
        tracer_out = full_raise(trace, out)
        primal_out, tangent_out = tracer_out.primal, tracer_out.tangent

    return primal_out, tangent_out

With all these in place, we can now differentiate function in JAX. Here’s an example code:

#Example to demonstrate JVP_V1
x = 3.0
y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))
print(sin_deriv_at_3)
print(cos(3.0))

There’s a reason we named our Transformation API as jvp_v1. A limitation in this is that it accepts arrays as input and gives a single array as output. In this next iteration of jvp, we will deal with nested inputs and nested outputs. This could mean that at each layer of the stack, we might have to deal with nested inputs.

In order to deal with this, we are going to wrap the user function so that it accepts arrays as input, and returns a flat list of arrays as output. Since we are accepting user functions that have arbitrary “containers” in the inputs and outputs, just flattening a list inside of a list isn’t going to be very general, and we are actually going to need reference to a tree data structure. Let us try to understand why do we need a tree for this.

In our use case, we mentioned that the inputs and outputs containers can be of arbitrary depth. If we represent the topmost level as the parent node(where each node can have multiple children, and can represent both lists(internal nodes) and the value(leaf node)). In any scenario where the data is arranged in a hierarchal manner, a tree is the optimal data structure to use. A real life example where a hierarchal data is represented by a tree would be a folder/directory structure. In that case as well:

  • Each node can have multiple children(either subfolders or files)
  • Nodes can represent both folders and files.
def jvp_flat(f, primals, tangents):
    with new_main(JVPTrace) as main:
        trace = JVPTrace(main)
        tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
        outs = f(*tracers_in)
        tracers_out = [full_raise(trace, out) for out in outs]
        primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)    
    return primals_out, tangents_out


def jvp(f, primals, tangents):
    primals_flat, in_tree = tree_flatten(primals)
    tangents_flat, in_tree2 = tree_flatten(tangents)
    if in_tree != in_tree2: raise TypeError
    f, out_tree = flatten_fun(f, in_tree)
    primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat)
    primals_out = tree_unflatten(out_tree(), primals_out_flat)
    tangents_out = tree_unflatten(out_tree(), tangents_out_flat)

    return primals_out, tangents_out

  

Now that we have understood why we are going to need a need a tree flatten/unflatten function, let us implement those.

Notice in the flatten_fun function that the function might actually require the unflatten version of the arguments, we are going to provide it in the unflattened version only(this information isn’t available until we run the function, that is also the reason why we return the reference to flat_fun)

# Notice that we need to provide the unflattened version of the args.
def flatten_fun(f, in_tree):
    store = Store()
    
    def flat_fun(*args_flat):
        pytree_args = tree_unflatten(in_tree, args_flat)
        out = f(*pytree_args)
        out_flat, out_tree = tree_flatten(out)
        store.set_value(out_tree)
        return out_flat

    return flat_fun, store

# Helper classes.

class Empty: pass
empty = Empty()
  
class Store:
    val = empty
    def set_value(self, val):
        assert self.val is empty
        self.val = val
    
    def __call__(self):
        return self.val

# PyTree handling of JVP implementation.  

from collections.abc import Hashable
import itertools as it
from collections.abc import Callable

# Represents the node of the data container.
class NodeType(NamedTuple):
    name:str
    to_iterable: Callable
    from_iterable: Callable

# Basically a typecheck function to register only lists, tuples and dicts as valid tree DS.   
def register_pytree_node(ty, to_iter, from_iter):
    node_types[ty] = NodeType(str(ty), to_iter, from_iter)

node_types = {}

# Only dict, tuple and list can represent PyTree. This also acts as a typecheck.
register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs))

register_pytree_node(list, lambda l: (None, l), lambda _, xs: list(xs))

register_pytree_node(dict, lambda d: map(tuple, unzip2(sorted(d.items()))), lambda keys, vals: dict(zip(keys, vals)))


class PyTreeDef(NamedTuple):
    node_type: NodeType
    node_metadata:Hashable
    child_treedefs: tuple['PyTreeDef', ...]

class Leaf: pass
leaf = Leaf()

# Flatten the tree.  
def tree_flatten(x):
    children_iter, tree_def = _tree_flatten(x)
    return list(children_iter), tree_def

def _tree_flatten(x):
    node_type = node_types.get(type(x))
    if node_type:
        node_metadata, children = node_type.to_iterable(x)
        children_flat, child_trees = unzip2(map(_tree_flatten, children))
        flattened = it.chain.from_iterable(children_flat)
        return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees))

    else:
        return [x], leaf

# Unflatten the tree.  
def tree_unflatten(tree_def, xs):
    return _tree_unflatten(tree_def, iter(xs))

  
def _tree_unflatten(tree_def, xs):
    if tree_def is leaf:
        return next(xs)
    else:
        children = (_tree_unflatten(t, xs) for t in tree_def.child_treedefs)

        return tree_def.node_type.from_iterable(tree_def.node_metadata, children)

We have successfully implemented arbitrary depth input/output containers. These will be helpful with future transformations as well. Here’s an example code to understand how it works:

# Define some arbitrary Pythonic function
def f(x):
  y = 3.0 * sin(x) * cos(x) 
  z = x*x + y*y
  # We can handle arbitrary depth inputs.
  return {'Rick': z, 'Astley': [x, y]}

# Evaluate the functions at specific values using JVP.
x, xdot = 1.0, 1.5
y, ydot = jvp(f, (x,), (xdot,))

# Print the results
print(y)
print(ydot)

With the improved JVP implementation now in place, we can move to a another important bottleneck that JAX tries to resolve, especially when it comes to machine learning workflows - and that is vectorized batching. Traditionally, when batching data for ML workflows,(a reason to batch data is to parallelly process a chunk of data, and not process element by element) loop can be slow and can become a bottleneck for performance critical code. JAX vmap addresses these issues, and provides the following functionalities:

  • Applying operations to the entire array at once, not individual elements.
  • Processing multiple inputs simultaneously.
  • Seamless integration with other JAX functions, including jit and grad.

Let us implement vectorized batching with vmap, but before that let us implement some helper functions:

  • mapped_aval: It produces mapped abstract values from the unmapped ones, by removing an axis from it.
  • move_batch_axis: It is used to move the batch dimensions around(by basically moving the axis).
# Produce mapped values from the unmapped ones.
def mapped_aval(batch_dim, aval):
    shape = list(aval.shape)
    del shape[batch_dim]
    return ShapedArray(tuple(shape), aval.dtype)

# Move the batch axis by mving the axis.
def move_batch_axis(axis_size, src, dst, x):
    if src is not_mapped:
        target_shape = list(np.shape(x))
        target_shape.insert(dst, axis_size)
        return broadcast(x, target_shape, [dst])

    elif src == dst:
        return x
        
    else:
        return move_axis(x, src, dst)
  
# Move the axis from src to dst.
def move_axis(x, src, dst):
    perm = [i for i in range(np.dim(x)) if i != src]
    perm.insert(dst, src)
    return transpose(x, perm)

With the helper function implementation in place, let us shift our focus to implementing a BatchTracer for vectorized batching. The tracer carries a batched value and an optional integer indicating which axis is the batch axis. Similar to other trace classes, BatchTrace implements the pure and lift methods containing the boxed up values in a BatchTracer instance. We use the MainTrace’s global data field to store the batch axis size.

from typing import Union

# Wrapper class.
class NotMapped: pass
not_mapped = NotMapped()

# Wrapper class. Apart from BatchAxis, all are mapped. 
BatchAxis = Union[NotMapped, int] 

# Tracer to accomadate batching of data.
class BatchTracer(Tracer):
    def __init__(self, trace, val, batch_dim):
        self._trace = trace
        self.val = val
        self.batch_dim = batch_dim

    @property
    def aval(self):
        if self.batch_dim is not_mapped:
            return get_aval(self.val)
        return mapped_aval(self.batch_dim, get_aval(self.val))

    def full_lower(self):
        if self.batch_dim is not_mapped:
            return full_lower(self.val)
        return self

# Tracer
class BatchTrace(Trace):
    pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)

    def process_primitive(self, primitive, tracers, params):
        vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)
        vmap_rule = vmap_rules[primitive]
        val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params)

        return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)]

  

    @property
    def axis_size(self):
        return self.main.global_data

vmap_rules = {}

The next step is to implement the batched interpreter rules for each primitive. The implementation is divided into three classes of implementations, one for binary operators(addition and multiplication primitives), one for unary operator(such as sin, cos and negation primitives), and a separate one for reduce sum. With all these in place, we add a transformation API to start the trace(see the vmap_flat and vmap implementation)

from functools import partial

# primitive rules - addition, multiplication 
def binop_batching_rule(op, axis_size, vals_in, dims_in):
    (x, y), (x_bdim, y_bdim) = vals_in, dims_in
    if x_bdim != y_bdim:
        if x_bdim is not_mapped:
            x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
            x_bdim = y_bdim
        else:
            y = move_batch_axis(axis_size, y_bdim, x_bdim, y)

    return [op(x,y)], [x_bdim]

vmap_rules[add_p] = partial(binop_batching_rule, add)
vmap_rules[mul_p] = partial(binop_batching_rule, mul)

# primitive rules - sin, cos, negation
def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
    (x,), (x_bdim,) = vals_in, dims_in
    return [op(x)], [x_bdim]

vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)
vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)
vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)

# primitive rules - reduce sum
def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):
    (x,), (x_bdim,) = vals_in, dims_in
    new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)
    out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)
    return [reduce_sum(x, new_axis)], [out_bdim]

vmap_rules[reduce_sum_p] = reduce_sum_batching_rule

# Transformation API.
def vmap_flat(f, in_axes, *args):
    axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes) if ax is not not_mapped}

    with new_main(BatchTrace, axis_size) as main:
        trace = BatchTrace(main)
        tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x for x, ax in zip(args, in_axes)]

        outs = f(*tracers_in)

        tracers_out = [full_raise(trace,out) for out in outs]
        vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)

    outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out) for val_out, bdim in zip(vals_out, bdims_out)]

    return outs_transposed

def vmap(f, in_axes):
    def batched_f(*args):
        args_flat, in_tree = tree_flatten(args)
        in_axes_flat, in_tree2 = tree_flatten(in_axes)
        if in_tree != in_tree2 : raise TypeError

        f_flat, out_tree = flatten_fun(f, in_tree)
        outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)
        return tree_unflatten(out_tree(), outs_flat)    

    return batched_f

Let’s see our implementation in action!

# Pythonic function
def add_to_a_scalar(scalar):
  assert np.ndim(scalar) == 0
  return 69 + scalar

# Vectorized operation using VMAP
vector_in = np.arange(420.0)
vector_out = vmap(add_to_a_scalar, (0,))(vector_in)

# Output
print(vector_in)
print(vector_out)

With the implementations of VMap and JVP(basically, Autodiff) in place, the next transformations in place are JIT and VJP(for reverse mode autodiff). The implemented transformations only needed each Tracer to carry an extra bit of context, but for JIT and VJP, we need much richer context(the next few sections are going to explain how), and for that, we need to represent Pythonic programs as JaxPrs.

Part 2 - JaxPrs

JaxPrs are JAX’s internal representation of programs. For JIT implementation, we need JaxPrs because JIT need to stage out any computation out of Python(mostly to XLA), and therefore to represent the data using JaxPrs helps in tracing the python function back up. In the case of VJP, JaxPrs provide a way to represent the computation for the backward pass of the reverse mode autodiff. To represent JaxPrs as Python data structure, we re-use the ShapedArray class defined before(for types) and can represent the term syntax with a few Python structs.

Pasted image 20241017142150

# Class to hold abstract value as ShapedArray.
class Var:
    aval: ShapedArray
    def __init__(self, aval): self.aval = aval

# Class for holding value(both normal and abstract)
class Lit:
    val: Any
    aval: ShapedArray

    def __init__(self, val):
        self.aval = aval = raise_to_shaped(get_aval(val))
        self.val = np.array(val, aval.dtype)

# Atom is the building block for JaxPrs.
Atom = Union[Var, Lit]

# A JaxprEqn is basically a class holding the primitive, and the inputs and outputs associated with it. 
class JaxprEqn(NamedTuple):
    primitive: Primitive
    inputs: list[Atom]
    params: dict[str, Any]
    out_binders: list[Var]

# A JaxPr can hold multiple JaxprEqn 
class Jaxpr(NamedTuple):
    in_binders: list[Var]
    eqns: list[JaxprEqn]
    outs: list[Atom]

  
    def __hash__(self): return id(self)
    __eq__ = op.is_
  

def raise_to_shaped(aval):
    return ShapedArray(aval.shape, aval.dtype)

Type checking is very strict in JAX, which is crucial for speeding up computational workflows. Strict type checking allows JAX to perform type specialization, and optimize code for specific data types. For JaxPrs, type checking involves checking whether there are no unbound variables, and that variables are bound only once, and the equation of the type of primitive matches the type of output binders. JaxPrs are platform-agnostic, so type checking ensures consistency across platforms.

class JaxprsType(NamedTuple):

    in_types: list[ShapedArray]
    out_types: list[ShapedArray]

    def __repr__(self):
        in_types = ", ".join(aval.str_short() for aval in self.in_types)
        out_types = ", ".join(aval.str_short() for aval in self.out_types)

        return f'({in_types}) -> ({out_types})'

# Typechcek for reasons mentioned above.
def typecheck_jaxpr(jaxpr):
    env: set[Var] = set()

    for v in jaxpr.in_binders:
        if v in env: raise TypeError
        env.add(v)

    for eqn in jaxpr.eqns:

        in_types = [typecheck_atom(env, x) for x in eqn.inputs]
        out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)

        for out_binder, out_type in zip(eqn.out_binders, out_types):
            if not out_type == out_binder.aval: raise TypeError

        for out_binder in eqn.out_binders:
            if out_binder in env: raise TypeError
            env.add(out_binder)

        in_types = [v.aval for v in jaxpr.in_binders]
        out_types = [typecheck_atom(env, x) for x in jaxpr.outs]
        return JaxprsType(in_types, out_types)

def typecheck_atom(env, x):
    if isinstance(x, Var):
        if x not in env: raise TypeError("Unbound Variable")
        return x.aval

    elif isinstance(x, Lit):
        return raise_to_shaped(get_aval(x.val))
    
    else:
        assert False

# This is a simple JaxPr interpreter, with type checking.
def eval_jaxpr(jaxpr, args):
    env = {}

    def read(x):
        return env[x] if type(x) is Var else x.val

    def write(v, val):
        assert v not in env
        env[v] = val
        
    map(write, jaxpr.in_binders, args)

    for eqn in jaxpr.eqns:
        in_vals = map(read, eqn.inputs)
        outs = bind(eqn.primitive, *in_vals, **eqn.params) # Using bind makes this interpreter traceable too.
        map(write, eqn.out_binders, outs)
        
    return map(read, jaxpr.outs)

def jaxpr_as_fun(jaxpr):
    return lambda *args: eval_jaxpr(jaxpr, args)

Similarly to what we did with other interpreters, we’ll now enable tracing for JaxPrs. There are two ways in which we can do this, and we’ll start with what jit uses. But first, let us define some helper functions and then build up the JaxPrTrace and JaxPrTracer.

The JaxPrTrace class implements a new_arg function to return a Tracer instance after adding it to the builder. The get_or_make_tracer method add a tracer to the builder, or if it doesn’t exists(checked using the id of the Tracer instance). The pure and lift variables of the Tracer return the reference to this function. The process_primitive function is similar to the ones described before, with the only difference being the use of JaxPrs.

# Helper functions: Jaxprs with Tracing.

# Split a list and outputs the partitions.
def split_list(lst, n):
    assert 0 <= n <= len(lst)
    return lst[:n], lst[n:]

# Partition a list and return the components. 
def partition_list(bs, l):
    assert len(bs) == len(l)
    lists = lst1, lst2 = [], []
    for b, x in zip(bs, l):
        lists[b].append(x)
    return lst1, lst2

# Tracer, as mentioned contains the boxed-up values.
class JaxprTracer(Tracer):
    __slots__ = ['aval']
    aval: ShapedArray

    def __init__(self, trace, aval):
        self._trace = trace
        self.aval = aval
  
# Main JaxPrTrace class.
class JaxprTrace(Trace):

    def new_arg(self, aval):
        aval = raise_to_shaped(aval)
        tracer = self.builder.new_tracer(self, aval)
        self.builder.tracer_to_var[id(tracer)] = Var(aval)
        return tracer

    def get_or_make_const_tracer(self, val):
        tracer = self.builder.const_tracers.get(id(val))
        if tracer is None:
            tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val)))
            self.builder.add_const(tracer, val)
        return tracer
    pure = lift = get_or_make_const_tracer

    def process_primitive(self, primitive, tracers, params):
        avals_in = [t.aval for t in tracers]
        avals_out = abstract_eval_rules[primitive](*avals_in, **params)
        out_tracers = [self.builder.new_tracer(self, a) for a  in avals_out]
        inputs = [self.builder.getvar(t) for t in tracers]
        outvars = [self.builder.add_var(t) for t in out_tracers]

        self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars))
        return out_tracers

    @property
    def builder(self):
        return self.main.global_data

abstract_eval_rules = {}

JaxPrBuilder is the container we use to keeps track of the variables, constants and equations, - as the interpreter global data and will be referenced later as we build up the JaxPrs. The implementation is followed.

# Container class to hold up data.
class JaxprBuilder:

    eqns: list[JaxprEqn]
    tracer_to_var: dict[int, Var]
    const_tracers: dict[int, JaxprTracer]
    constvals: dict[Var, Any]
    tracers: list[JaxprTracer]

    def __init__(self):
        self.eqns = []
        self.tracer_to_var = {}
        self.const_tracers = {}
        self.constvals = {}
        self.tracers = []
        
	# Add a new tracer with a given aval
    def new_tracer(self, trace, aval):
        tracer = JaxprTracer(trace, aval)
        self.tracers.append(tracer)
        return tracer
	## Other getter and setters method for the class, self explanatory.

    def add_eqn(self, eqn):
        self.eqns.append(eqn)

    def add_var(self, tracer):
        assert id(tracer) not in self.tracer_to_var
        var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)
        return var

    def getvar(self, tracer):
        var = self.tracer_to_var.get(id(tracer))
        assert var is not None
        return var

    def add_const(self, tracer, val):
        var = self.add_var(tracer)
        self.const_tracers[id(val)] = tracer
        self.constvals[var] = val
        return var

    def build(self, in_tracers, out_tracers):
        constvars, constvals = unzip2(self.constvals.items())
        t2v = lambda t: self.tracer_to_var[id(t)]
        in_binders = constvars + [t2v(t) for t in in_tracers]
        out_vars = [t2v(t) for t in out_tracers]
        jaxpr = Jaxpr(in_binders, self.eqns, out_vars)
        typecheck_jaxpr(jaxpr) # important step!
        jaxpr, constvals = _inline_literals(jaxpr, constvals)
        
        return jaxpr, constvals

def _inline_literals(jaxpr, consts):
    const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))
    scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]

    new_const_binders, lit_binders = partition_list(scalars, const_binders)
    new_consts, lit_vals = partition_list(scalars, consts)
    literals = dict(zip(lit_binders, map(Lit, lit_vals)))
    new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs], eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]
    new_outs = [literals.get(x, x) for x in jaxpr.outs]
    new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)
    typecheck_jaxpr(new_jaxpr)
    
    return new_jaxpr, new_consts

With the Tracer, Trace implementations in place, let us implement the eval_rules as we did for other cases as well. Most of these are very general, with the intention that these abstraction will be reused for other JaxPr-producing trace methods.

# Binop for Add, Multiply.
def binop_abstract_eval(x, y):
    if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
        raise TypeError
    if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError
    return [ShapedArray(x.shape, x.dtype)]

abstract_eval_rules[add_p] = binop_abstract_eval
abstract_eval_rules[mul_p] = binop_abstract_eval

# Compare for less than, greater than.
def compare_abstract_eval(x, y):
    if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
        raise TypeError
    if x.shape != y.shape: raise TypeError
    return [ShapedArray(x.shape, np.dtype('bool'))]

abstract_eval_rules[greater_p] = compare_abstract_eval
abstract_eval_rules[less_p] = compare_abstract_eval

# Vectorized Op for Sin, Cosine and Negation.
def vectorized_unop_abstract_eval(x):
    return [ShapedArray(x.shape, x.dtype)]

abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval

# Different eval for reduce_sum.
def reduce_sum_abstract_eval(x, *, axis):
    axis_ = set(axis)
    new_shape = [d for i,d in enumerate(x.shape) if i not in axis_]
    return [ShapedArray(tuple(new_shape), x.dtype)]

abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval

# One for broadcast as well.
def broadcast_abstract_eval(x, *, shape, axes):
    return [ShapedArray(tuple(shape), x.dtype)]

abstract_eval_rules[broadcast_p] = broadcast_abstract_eval

With all the things in place, we can kick off our Transformation API. There is however a really fundamental flaw in make_jaxpr_v1, which maybe deserves a blog post on it’s own3 . In short, the input which were not boxed up in JaxprTracer instances ended up wasting memory, time dispatching and maybe even fragmenting memory.

This “omnistagging” issue ensures that JaxprTrace started by make_jaxpr is always applied. Conceptually, the dynamic trace is identical to stashing the current interpreter stack and starting a new one with the JaxprTrace at the bottom. The new transformation API(make_jaxpr) uses the dynamic_trace global(see Part 1) for this reason.

from functools import lru_cache

def make_jaxpr_v1(f, *avals_in):
    avals_in, in_tree = tree_flatten(avals_in)
    f, out_tree = flatten_fun(f, in_tree)
    builder = JaxprBuilder()

    with new_main(JaxprTrace, builder) as main:
        trace = JaxprTrace(main)
        tracers_in = [trace.new_arg(aval) for aval in avals_in]
        outs = f(*tracers_in)
        tracers_out = [full_raise(trace, out) for out in outs]
        jaxpr, consts = builder.build(tracers_in, tracers_out)
    return jaxpr, consts, out_tree()
  

# There is a limitations tho. This version can't stage out all the primitve opeations performed by the Python Callable.

@contextmanager
def new_dynamic(main):
    global dynamic_trace
    prev_dynamic_trace, dynamic_trace = dynamic_trace, main
    try:
        yield
    finally:
        dynamic_trace = prev_dynamic_trace


def make_jaxpr(f, *avals_in):
    avals_in, in_tree = tree_flatten(avals_in)
    f, out_tree = flatten_fun(f, in_tree)
    builder = JaxprBuilder()

    with new_main(JaxprTrace, builder) as main:
        with new_dynamic(main):
            trace = JaxprTrace(main)
            tracers_in = [trace.new_arg(aval) for aval in avals_in]
            outs = f(*tracers_in)
            tracers_out = [full_raise(trace, out) for  out in outs]
            jaxpr, consts = builder.build(tracers_in, tracers_out)
        
        return jaxpr, consts, out_tree()

Part 3 - JIT

After converting the Pythonic functions into JaxPrs, the next step is taken up by JIT compiler. JIT, or Just-In-Time compiler analyzes the JaxPrs and identifies the specific operations needed, then the optimized machine code is generated for those operations. JIT only compiles the necessary code(and caches the compiled machine code for future use) - giving significant speedups for computationally expensive workflows.

Similar to JaxPrs and JVP, even JIT has a transformation like API transforming a Python function, but conceptually, JIT under the hood is a high-order primitive(basically, a high order primitive is parameterized by a function) rather than a transformation. Similar to a primitive, JIT take JaxPrs as input, returns a “transformed” function(in this case, an optimized version of that function) as output, and operates on functional level, transforming it’s execution.

In order to handle high order primitives, JIT uses a staged processing approach, where we can just use make_jaxpr in the primitive wrapper to form JaxPrs up-front and skip the python function entirely4 - which is what we need to stage these computation to XLA(Accelerated Linear Algebra) for ML workflows.

Since JIT is a high-level primitive, we need to give it transformation rules. When we evaluate any xla_primitive application, we stage out the computation to XLA by translating the JaxPrs into an XLA HLO program ; including transferring the argument values to the XLA device, executing the XLA program(and cache the results as per the shape and dtype signature), and transferring back the results.

Pasted image 20241017142406

# This is the staged JIT wrapper, with computation done by XLA.
def jit(f):

    def f_jitted(*args):
        avals_in = [raise_to_shaped(get_aval(x)) for x in args]
        jaxpr, consts, out_tree = make_jaxpr(f, *avals_in)
        outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts = len(consts))
        return tree_unflatten(out_tree, outs)
    
    return f_jitted

xla_call_p = Primitive('xla_call')

# Utility for XLA call.
class IDhashable:
    val: Any
    def __init__(self, val):
        self.val = val

    def __hash__(self):
        return id(self.val)

    def __eq__(self, other):
        return type(other) is IDhashable and id(self.val) == id(other.val)

from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc

xe = xc._xla
xops = xc._xla.ops

def xla_call_impl(*args, jaxpr, num_consts):
    consts, args = args[:num_consts], args[num_consts:]
    hashable_consts = tuple(map(IDhashable, consts))
    execute = xla_callable(IDhashable(jaxpr), hashable_consts)
    return execute(*args)

impl_rules[xla_call_p] = xla_call_impl

@lru_cache
def xla_callable(hashable_jaxpr, hashable_consts):
    jaxpr = hashable_jaxpr.val
    typecheck_jaxpr(jaxpr)

    consts = [x.val for x in hashable_consts]
    in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]

    c = xc.XlaBuilder('xla_call')
    xla_consts = _xla_consts(c, consts)
    xla_params = _xla_params(c, in_avals)
    outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)
    out = xops.Tuple(c, outs)
    compiled = xb.get_backend(None).compile(
        xc._xla.mlir.xla_computation_to_mlir_module(c.build(out)))
    return partial(execure_compiled, compiled, [v.aval for v in jaxpr.outs])

def _xla_consts(c, consts):
    unique_consts = {id(cnst): cnst for cnst in consts}
    xla_consts = {

        id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()

    }
    return [xla_consts[id(cnst)] for cnst in consts]

def _xla_params(c, avals_in):
    return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]

def _xla_shape(aval):
    return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)

Let us now define the transformations for xla_call_p, other than its evaluation rule.

# JVP rule for XLA call.
def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
  del num_consts
  new_jaxpr, new_consts = jvp_jaxpr(jaxpr)
  outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,
              num_consts=len(new_consts))
  n = len(outs) // 2
  primals_out, tangents_out = outs[:n], outs[n:]
  return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule

# JVP for the JaxPrs.

@lru_cache
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
  def jvp_traceable(*primals_and_tangents):
    n = len(primals_and_tangents) // 2
    primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]
    return jvp(jaxpr_as_fun(jaxpr), primals, tangents)

  in_avals = [v.aval for v in jaxpr.in_binders]
  new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals)
  return new_jaxpr, new_consts

# VMAP rule for XLA call.
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
  del num_consts  # Unused
  new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
  outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
              num_consts=len(new_consts))
  return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule

@lru_cache
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
               ) -> tuple[Jaxpr, list[Any]]:
  vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
  in_avals = [unmapped_aval(axis_size, d, v.aval)
              for v, d in zip(jaxpr.in_binders, bdims_in)]
  new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals)
  return new_jaxpr, new_consts

def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
                  ) -> ShapedArray:
  if batch_dim is not_mapped:
    return aval
  else:
    shape = list(aval.shape)
    shape.insert(batch_dim, axis_size)
    return ShapedArray(tuple(shape), aval.dtype)


# Abstract Eval XLA call rule.
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):
  del num_consts  # Unused
  jaxpr_type = typecheck_jaxpr(jaxpr)
  if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
    raise TypeError
  return jaxpr_type.out_types
abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule


def xla_call_translation(c, in_avals, out_avals, in_vals, *, jaxpr, num_consts):
  del num_consts, out_avals
  # Calling jaxpr_subcomp directly would inline.
  with ir.InsertionPoint(c.module.body):
    @func.func(*(aval_to_ir_type(aval) for aval in in_avals))
    def inner_xla_call(*params):
      return jaxpr_subcomp(c, jaxpr, params)
    name = c.symbol_table.insert(inner_xla_call.func_op)
  return func.CallOp(inner_xla_call.func_op, in_vals).results
hlo_translations[xla_call_p] = xla_call_translation

With all the rules in place, we turn our attention to an important issue - memory persistence for arrays. After the XLA operation is done, we transferred the results back to CPU memory as np.array, but most of the time, we need to transfer these results back for the next operation. For that, we’ll introduce an Array class to wrap up XLA buffers.

def handle_result(aval: ShapedArray, buf):  # noqa: F811
  return Array(aval, buf)

class Array:
  buf: Any
  aval: ShapedArray

  def __init__(self, aval, buf):
    self.aval = aval
    self.buf = buf

  dtype = property(lambda self: self.aval.dtype)
  shape = property(lambda self: self.aval.shape)
  ndim  = property(lambda self: self.aval.ndim)

  def __array__(self): return np.asarray(self.buf)
  def __repr__(self):  return repr(np.asarray(self.buf))
  def __str__(self):   return str(np.asarray(self.buf))

  _neg = staticmethod(neg)
  _add = staticmethod(add)
  _radd = staticmethod(add)
  _mul = staticmethod(mul)
  _rmul = staticmethod(mul)
  _gt = staticmethod(greater)
  _lt = staticmethod(less)
input_handlers[Array] = lambda x: x.buf

jax_types.add(Array)

With that, we implemented another core feature of JAX. Let’s move on to the next part, where we implement some special autodiff functions - linearize and vjp, which have some caveats.

Part 4 - linearize and vjp

Here’s an diagram summarizing key points of linearize and VJP in JAX. Let us implement them!!

Pasted image 20241017142629

Implementing linearize JaxPr

Linearize computes the linear approximation of a function, and it operates in the tangent space. Therefore, we want to stage out the linear part of JVP computation to build a JaxPr from a JVP. To do this, we need to perform partial evaluation - to evaluate all the primal values as a tarce, but stage the tangent computations into a Jaxpr. Unlike the previous make_jaxpr functions, this approach stages out only those primitive binds with a dependence on tangent inputs.

def split_half(lst: list[Any]) -> tuple[list[Any], list[Any]]:
  assert not len(lst) % 2
  return split_list(lst, len(lst) // 2)

def merge_lists(which: list[bool], l1: list[Any], l2: list[Any]) -> list[Any]:
  l1, l2 = iter(l1), iter(l2)
  out = [next(l2) if b else next(l1) for b in which]
  assert next(l1, None) is next(l2, None) is None
  return out

def linearize_flat(f, *primals_in):
  pvals_in = ([PartialVal.known(x) for x in primals_in] +
              [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
  def f_jvp(*primals_tangents_in):
    primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
    return [*primals_out, *tangents_out]
  jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)
  primal_pvals, _ = split_half(pvals_out)
  assert all(pval.is_known for pval in primal_pvals)
  primals_out = [pval.const for pval in primal_pvals]
  f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])
  return primals_out, f_lin

# This linearize function has JVp and partial evaluation combined.
def linearize(f, *primals_in):
  primals_in_flat, in_tree = tree_flatten(primals_in)
  f, out_tree = flatten_fun(f, in_tree)
  primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)
  primals_out = tree_unflatten(out_tree(), primals_out_flat)

  def f_lin(*tangents_in):
    tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
    if in_tree != in_tree2: raise TypeError
    tangents_out_flat = f_lin_flat(*tangents_in_flat)
    return tree_unflatten(out_tree(), tangents_out_flat)

  return primals_out, f_lin

def vspace(aval: ShapedArray) -> ShapedArray:
  return raise_to_shaped(aval)  # TODO handle integers?

As mentioned, in the linearize, there is JVP + general partial information transformation. The workflow is simple, turn a Python callable into outputs of two types - one where all the outputs can be computed from the known outputs, and a partial JaxPr which can only be performed after its required inputs are known5. Think of partial evaluation as “unzipping” one computation into two(one for primal, and one for tangent jaxpr). We kind of only want to form a JaxPr for those operations whose operations must be delayed due to dependence on unknown inputs, which reduces unnecessary evaluations.

For the reasons mentioned above, let us start our implementation by creating a PartialVal class. Partial evaluation will take a list of PartialVal representing inputs, and return a list of PartialVal outputs along with a jaxpr representing the delayed computation:

class PartialVal(NamedTuple):
  aval: ShapedArray
  const: Any | None
 
  @classmethod
  def known(cls, val: Any):
    return PartialVal(get_aval(val), val)

  @classmethod
  def unknown(cls, aval: ShapedArray):
    return PartialVal(aval, None)

  # To check whether the inputs are known or unknown, so we distribute accordingly. 
  is_known   = property(lambda self: self.const is not None)
  is_unknown = property(lambda self: self.const is     None)

# The transformation API.
def partial_eval_flat(f: Callable, pvals_in: list[PartialVal]
                      ) -> tuple[Jaxpr, list[PartialVal], list[Any]]:
  with new_main(PartialEvalTrace) as main:
    trace = PartialEvalTrace(main)
    tracers_in = [trace.new_arg(pval) for pval in pvals_in]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    pvals_out = [t.pval for t in tracers_out]
    unk_tracers_in  = [t for t in tracers_in  if t.pval.is_unknown]
    unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown]
    jaxpr, consts = tracers_to_jaxpr(unk_tracers_in, unk_tracers_out)
  return jaxpr, pvals_out, consts

Now, we’ll implement the PartialEvalTrace and PartialEvalTracer. The difference with the previous versions of the Trace, Tracer classes is that the interpreter will build JaxPrs on the fly and will keep track of data dependencies. In order to do so, they implement a Bipartite DAG(Directed Acyclic Graph) between the nodes of the PartialEvalTracer (representing staged out values) and JaxprRecipe nodes (representing formulas for how to compute some values from others). The reason to choose Bipartite graph structure is to simplify dependency management(modular updates and extensions)

These recipe’s can be of several types - JaxprEqnRecipe (corresponding to JaxPrEqn’s primitive application), and constants, lambda binders.


from weakref import ref, ReferenceType

# Wrapper for a lambda recipe.
class LambdaBindingRecipe(NamedTuple):
  pass

# Wrapper for a const recipe.
class ConstRecipe(NamedTuple):
  val: Any

# Wrapper for Jaxpr recipe
class JaxprEqnRecipe(NamedTuple):
  prim: Primitive
  tracers_in: list['PartialEvalTracer']
  params: dict[str, Any]
  avals_out: list[ShapedArray]
  tracer_refs_out: list['ReferenceType[PartialEvalTracer]']

JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]

# Partial Eval Tracer - contains boxedup values.
class PartialEvalTracer(Tracer):
  pval: PartialVal
  recipe: JaxprRecipe | None

  def __init__(self, trace, pval, recipe):
    self._trace = trace
    self.pval = pval
    self.recipe = recipe

  aval = property(lambda self: self.pval.aval)

  def full_lower(self):
    if self.pval.is_known:
      return full_lower(self.pval.const)
    return self

With these implementations in place, let us now implement PartialEvalTrace. Each argument in it corresponds to a LambdaBindingRecipe leaf node, and each constant is a ConstRecipe leaf node holding a reference to the constant.

The implementation for process_primitive is also straightforward. If all inputs are known then we can bind the primitive to the known values(basically evaluate it in Python). On the other hand, if any inputs is unknown, then we stage out into a JaxprEqnRecipe representing the primitive application. All but the call to xla_call_primitive works on this logic(in the XLA primitive, we require recursive treatment)


class PartialEvalTrace(Trace):
  def new_arg(self, pval: PartialVal) -> Any:
    return PartialEvalTracer(self, pval, LambdaBindingRecipe())

  def lift(self, val: Any) -> PartialEvalTracer:
    return PartialEvalTracer(self, PartialVal.known(val), None)
  pure = lift

  def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:
    if tracer.pval.is_unknown:
      return tracer
    else:
      pval = PartialVal.unknown(raise_to_shaped(tracer.aval))
      return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))

  def process_primitive(self, primitive, tracers, params):
    if all(t.pval.is_known for t in tracers):
      return bind(primitive, *map(full_lower, tracers), **params)
    rule = partial_eval_rules.get(primitive)
    if rule: return rule(self, tracers, **params)
    tracers_in = [self.instantiate_const(t) for t in tracers]
    avals_in = [t.aval for t in tracers_in]
    avals_out = abstract_eval_rules[primitive](*avals_in, **params)
    tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)
                   for aval in avals_out]
    eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,
                         map(ref, tracers_out))
    for t in tracers_out: t.recipe = eqn
    return tracers_out

partial_eval_rules = {}

Now, we can build graph representations of JaxPrs with PartialEvalTrace - we just need a mechanism to convert graph representation to standard JaxPr(corresponds to a topological sort of the graph).


def tracers_to_jaxpr(tracers_in: list[PartialEvalTracer],
                     tracers_out: list[PartialEvalTracer]):
  tracer_to_var: dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))
                                   for t in tracers_in}
  constvar_to_val: dict[int, Any] = {}
  constid_to_var: dict[int, Var] = {}
  processed_eqns: set[int] = set()
  eqns: list[JaxprEqn] = []
  for t in toposort(tracers_out, tracer_parents):
    if isinstance(t.recipe, LambdaBindingRecipe):
      assert id(t) in set(map(id, tracers_in))
    elif isinstance(t.recipe, ConstRecipe):
      val = t.recipe.val
      var = constid_to_var.get(id(val))
      if var is None:
        aval = raise_to_shaped(get_aval(val))
        var = constid_to_var[id(val)] = Var(aval)
        constvar_to_val[var] = val
      tracer_to_var[id(t)] = var
    elif isinstance(t.recipe, JaxprEqnRecipe):
      if id(t.recipe) not in processed_eqns:
        eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
        processed_eqns.add(id(t.recipe))
    else:
      raise TypeError(t.recipe)

  constvars, constvals = unzip2(constvar_to_val.items())
  in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
  out_vars = [tracer_to_var[id(t)] for t in tracers_out]
  jaxpr = Jaxpr(in_binders, eqns, out_vars)
  typecheck_jaxpr(jaxpr)
  return jaxpr, constvals

def recipe_to_eqn(tracer_to_var: dict[int, Var], recipe: JaxprEqnRecipe
                  ) -> JaxprEqn:
  inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]
  out_binders = [Var(aval) for aval in recipe.avals_out]
  for t_ref, var in zip(recipe.tracer_refs_out, out_binders):
    if t_ref() is not None: tracer_to_var[id(t_ref())] = var
  return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)

def tracer_parents(t: PartialEvalTracer) -> list[PartialEvalTracer]:
  return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []

## Toposort and stuff

def toposort(out_nodes: list[Any], parents: Callable[[Any], list[Any]]):
  if not out_nodes: return []
  out_nodes = remove_duplicates(out_nodes)

  child_counts = {}
  stack = list(out_nodes)
  while stack:
    node = stack.pop()
    if id(node) in child_counts:
      child_counts[id(node)] += 1
    else:
      child_counts[id(node)] = 1
      stack.extend(parents(node))
  for node in out_nodes:
    child_counts[id(node)] -= 1

  sorted_nodes = []
  childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
  while childless_nodes:
    node = childless_nodes.pop()
    sorted_nodes.append(node)
    for parent in parents(node):
      if child_counts[id(parent)] == 1:
        childless_nodes.append(parent)
      else:
        child_counts[id(parent)] -= 1

  sorted_nodes = sorted_nodes[::-1]
  check_toposort(sorted_nodes, parents)
  return sorted_nodes

def remove_duplicates(lst):
  seen = set()
  return [x for x in lst if id(x) not in seen and not seen.add(id(x))]

def check_toposort(nodes: list[Any], parents: Callable[[Any], list[Any]]):
  seen = set()
  for node in nodes:
    assert all(id(parent) in seen for parent in parents(node))
    seen.add(id(node))

Let us test it in action. We also need to implement the partial evaluation rule xla_call_p (to handle JIT and related functions). There are two rules to write, one for trace-time partial evaluation(xla_call_partial_eval), and one for partial evaluation of Jaxprs(xla_call_peval_eqn)

# Example usage.
y, sin_lin = linearize(sin, 3.)
print(y, sin(3.))
print(sin_lin(1.), cos(3.))

## Rules for XLA Primitive(trace time and Partial Eval)

def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
  del num_consts  # Unused
  in_unknowns = [not t.pval.is_known for t in tracers]
  jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
  known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
  known_vals = [t.pval.const for t in known_tracers]
  outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0)
  outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res)
  res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
  outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
           for v in jaxpr2.outs]
  eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers,
                       dict(jaxpr=jaxpr2, num_consts=0),
                       [v.aval for v in jaxpr2.outs], map(ref, outs2))
  for t in outs2: t.recipe = eqn
  return merge_lists(out_unknowns, outs1, outs2)
partial_eval_rules[xla_call_p] = xla_call_partial_eval

def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
                       instantiate: list[bool] | None = None,
                       ) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
  env: dict[Var, bool] = {}
  residuals: set[Var] = set()

  def read(x: Atom) -> bool:
    return type(x) is Var and env[x]

  def write(unk: bool, v: Var) -> None:
    env[v] = unk

  def new_res(x: Atom) -> Atom:
    if type(x) is Var: residuals.add(x)
    return x

  eqns1, eqns2 = [], []
  map(write, in_unknowns, jaxpr.in_binders)
  for eqn in jaxpr.eqns:
    unks_in = map(read, eqn.inputs)
    rule = partial_eval_jaxpr_rules.get(eqn.primitive)
    if rule:
      eqn1, eqn2, unks_out, res = rule(unks_in, eqn)
      eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res)
      map(write, unks_out, eqn.out_binders)
    elif any(unks_in):
      inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)]
      eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders))
      map(partial(write, True), eqn.out_binders)
    else:
      eqns1.append(eqn)
      map(partial(write, False), eqn.out_binders)
  out_unknowns = map(read, jaxpr.outs)
  if instantiate is not None:
    for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):
      if inst and not uk: new_res(v)
    out_unknowns = map(op.or_, out_unknowns, instantiate)

  residuals, num_res = list(residuals), len(residuals)
  assert all(type(v) is Var for v in residuals), residuals

  ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
  outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)

  jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals)
  jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2)
  typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2)

  return jaxpr1, jaxpr2, out_unknowns, num_res

def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
  jaxprty = typecheck_jaxpr(jaxpr)    # (a1,  a2) -> (b1, b2 )
  jaxpr1ty = typecheck_jaxpr(jaxpr1)  #  a1       -> (b1, res)
  jaxpr2ty = typecheck_jaxpr(jaxpr2)  # (res, a2) -> b2

  a1, a2 = partition_list(unks_in, jaxprty.in_types)
  b1, b2 = partition_list(unks_out, jaxprty.out_types)
  b1_, res = split_list(jaxpr1ty.out_types, len(b1))
  res_, a2_ = split_list(jaxpr2ty.in_types, len(res))
  b2_ = jaxpr2ty.out_types

  if jaxpr1ty.in_types != a1: raise TypeError
  if jaxpr2ty.out_types != b2: raise TypeError
  if b1 != b1_: raise TypeError
  if res != res_: raise TypeError
  if a2 != a2_: raise TypeError
  if b2 != b2_: raise TypeError

partial_eval_jaxpr_rules = {}

def xla_call_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
                       ) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Var]]:
  jaxpr = eqn.params['jaxpr']
  jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
  ins1, ins2 = partition_list(unks_in, eqn.inputs)
  out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)
  residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]
  eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),
                  out_binders1 + residuals)
  eqn2 = JaxprEqn(xla_call_p, residuals + ins2,
                  dict(jaxpr=jaxpr2, num_consts=0), out_binders2)
  return eqn1, eqn2, unks_out, residuals
partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn

Now, we can compose linearize and jit .


#Example usage.

@jit
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)

Implementing vjp and grad

The vjp transformation is very similar to linearize, with the only difference being the fact that we transpose the linear part of the computation before returning it, so our implementation is pretty straightforward. Also, since we have the linear computation as a JaxPr, we can implement the transpose transformation as a JaxPr interpreter. The use of UndefPrimal instance is to indicate which arguments we want to transpose(and with respect to what). We also register this as a pytree node as that gives us a handy way to prune these placeholders out of argument lists.

# VJP implementation
def vjp(f, x):
  y, f_lin = linearize(f, x)
  f_vjp = lambda y_bar: transpose(f_lin)(y_bar)
  return y, f_vjp

# Transpose transformation

def vjp_flat(f, *primals_in):
  pvals_in = ([PartialVal.known(x) for x in primals_in] +
              [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
  primal_pvals_in, tangent_pvals_in = split_half(pvals_in)
  def f_jvp(*primals_tangents_in):
    primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
    return [*primals_out, *tangents_out]
  jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)  # linearize
  primal_pvals, _ = split_half(pvals_out)
  assert all(pval.is_known for pval in primal_pvals)
  primals_out = [pval.const for pval in primal_pvals]
  transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in]
  f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts)
  return primals_out, f_vjp

def vjp(f, *primals_in):
  primals_in_flat, in_tree = tree_flatten(primals_in)
  f, out_tree = flatten_fun(f, in_tree)
  primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat)
  primals_out = tree_unflatten(out_tree(), primals_out_flat)

  def f_vjp(*cotangents_out):
    cotangents_out_flat, _ = tree_flatten(cotangents_out)
    cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)
    return tree_unflatten(in_tree, cotangents_in_flat)

  return primals_out, f_vjp

class UndefPrimal(NamedTuple):
  aval: ShapedArray

register_pytree_node(UndefPrimal,
                     lambda u: (u.aval, ()),
                     lambda aval, _: UndefPrimal(aval))

Next, we can write eval_jaxpr_transposed, along with the transpose rules for all the primitives(which can be linear in at least one argument).


# NB: the analogous function in JAX is called 'backward_pass'
def eval_jaxpr_transposed(jaxpr: Jaxpr, args: list[Any], cotangents: list[Any]
                          ) -> list[Any]:
  primal_env: dict[Var, Any] = {}
  ct_env: dict[Var, Any] = {}

  def read_primal(x: Atom) -> Any:
    return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val

  def write_primal(v: Var, val: Any) -> None:
    if type(val) is not UndefPrimal:
      primal_env[v] = val

  def read_cotangent(v: Var) -> Any:
    return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype))

  def write_cotangent(x: Atom, val: Any):
    if type(x) is Var and val is not None:
      ct_env[x] = add(ct_env[x], val) if x in ct_env else val

  map(write_primal, jaxpr.in_binders, args)
  map(write_cotangent, jaxpr.outs, cotangents)
  for eqn in jaxpr.eqns[::-1]:
    primals_in = map(read_primal, eqn.inputs)
    cts_in = map(read_cotangent, eqn.out_binders)
    rule = transpose_rules[eqn.primitive]
    cts_out = rule(cts_in, *primals_in, **eqn.params)
    map(write_cotangent, eqn.inputs, cts_out)

  return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args)
          if type(x) is UndefPrimal]

transpose_rules = {}

# Rules

def mul_transpose_rule(cts, x, y):
  z_bar, = cts
  assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal)
  return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)]
transpose_rules[mul_p] = mul_transpose_rule

def neg_transpose_rule(cts, x):
  ybar, = cts
  assert type(x) is UndefPrimal
  return [neg(ybar)]
transpose_rules[neg_p] = neg_transpose_rule

def add_transpose_rule(cts, x, y):
  z_bar, = cts
  return [z_bar, z_bar]
transpose_rules[add_p] = add_transpose_rule

def reduce_sum_transpose_rule(cts, x, *, axis):
  y_bar, = cts
  return [broadcast(y_bar, x.aval.shape, axis)]
transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule

def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
  del num_consts  # Unused
  undef_primals = [type(x) is UndefPrimal for x in invals]
  transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))
  residuals, _ = partition_list(undef_primals, invals)
  outs = bind(xla_call_p, *new_consts, *residuals, *cts,
              jaxpr=transposed_jaxpr, num_consts=len(new_consts))
  outs = iter(outs)
  return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule

@lru_cache
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
                    ) -> tuple[Jaxpr, list[Any]]:
  avals_in, avals_out = typecheck_jaxpr(jaxpr)
  traceable = partial(eval_jaxpr_transposed, jaxpr)
  args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]
  trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))
  typecheck_jaxpr(trans_jaxpr)
  return trans_jaxpr, consts

Now that we can linearize and transpose, we can finally write grad.

# Grad implementation.
def grad(f):
  def gradfun(x, *xs):
    y, f_vjp = vjp(f, x, *xs)
    if np.shape(y) != (): raise TypeError
    x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y)))
    return x_bar
  return gradfun

# Example usage.
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

print(grad(f)(3.))

Finally, we are done with our implementation!!! Give yourself a pat on the back, you now have your version of JAX, in Python, spelled out completely. I’ll be maintaining a repository for this particular blog, and will update things as I learn more about this amazing library. Would like to thank Lucas Beyer for replying to my tweet to motivate me to understand more about this framework.

Resources

  1. Autodiadax - Implementing the JAX Core from scratch.
  2. PyTorch is Dead, Long Live JAX - Neel Gupta.
  3. Is JAX better than PyTorch - Reddit Discussion.
  4. JAX - The Sharp Bits - JAX Docs

Footnotes:

  1. A better and more detailed explanation(without abstracting any math) goes something like this(according to JAX Docs):

    To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with “tall” Jacobians, but inefficient for “wide” Jacobians.

    If you’re doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\mathbb{R}^N$ to a scalar loss value in $\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\partial{f(x)} \in \mathbb{R}^{1 \times n}$, which we often identify with the Gradient vector $\nabla f(x) \in \mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluate the original function, sure seems inefficient! In particular, for training neural networks, where f is a training loss function and n can be in the millions or billions, this approach just won’t scale. To do better for functions like this, we just need to use reverse-mode autodiff.

  2. For a refresher on autodiff, refer this Hackernews post with some really awesome explanations by really awesome people(and possible pitfalls). 

  3. This is often referred to as the “omnistaging” issue in the JAX-ML repo. Even if I try, I can’t explain in the detail this PR is described. Highly recommend to read it. 

  4. As per the official documentation, this is what happens in staged processing.

    In this case, make_jaxpr puts its JaxprTrace at the top of the interpreter stack, and no transformations lower in the stack, which might enter via closed-over Tracers, are applied to the Python callable as we trace it. (Transformations applied within the Python callable are applied as usual, being added to the stack above the JaxprTrace.) Instead, the transformations lower in the stack are later applied to the call primitive, and the call primitive’s rules must then transform the jaxpr itself. Because we trace to a jaxpr up-front, this approach can’t support data-dependent Python control flow, but it is more straightforward to implement. We refer to this kind of higher-order primitive as an “initial-style higher-order primitive”, and say that its jaxpr-processing transformation rules are “initial-style transformation rules.”

  5. This transformation is tricky to summarize in a type signature. If we assume the input function’s type signature is (a1, a2) -> (b1, b2), where a1 and a2 represent the known and unknown inputs, respectively, and where b1 only has a data dependency on a1 while b2 has some data dependency on a2, then we might write: partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)