Lernapparat

Python graph operations in the JIT

Jan. 19, 2021

As you may know from my recent blog posts, I think that the PyTorch JIT is a great but under-used technology.

One part of achieving TorchScript's full potential is improving the developer experience, in particular, I would like to make it more accesible from Python. But to know the gap, we need to assess the (near-future, hopefully) status quo. I invite you to join me in this exploration.

So here is the challenge: Can we do nontrivial graph operations in Python?

Introducing vmap

PyTorch assembly line

To pick a particular operation, let us look at vmap (vectorizing map) -- one of hallmark features of JAX.

The idea is to take some function operating on tensors and producing a function that does the same thing, but where some tensors have an additional batch dimension. Often, this just amounts to broadcasting, but that can get tricky, too, and it is tedious to verify correctness, so why not do the right thing automatically.

Of course, we can produce a (not terribly exciting) version of this in Python with relative ease. We let the user specify which dimensions should be the batch dimension (one per tensor, None for an input that fixed among all batch items), and we ask for dimensions of the outputs. We make our life easy by just considering tensors.

So our Python vmapped function takes the inputs, splits it by batch items, runs the individual batch items in a for loop and then stacks the result list to hand back. Here we go:

def python_vmap(fn, in_dims, out_dims):
    assert len(out_dims) == 1 and out_dims[0] is not None
    out_dim = out_dims[0]
    batch_in_dims = [(i, d) for i, d in enumerate(in_dims) if d is not None]
    assert len(batch_in_dims) > 0
    size_i, size_d = batch_in_dims[0]
    def vmapped_fn(*inputs):
        res = []
        for idx in range(inputs[size_i].size(size_d)):
            one_inp = []
            for i, d in enumerate(in_dims):
                if d is None:
                    one_inp.append(inputs[i])
                else:
                    one_inp.append(inputs[i].select(d, idx))
            res.append(fn(*one_inp))
        return torch.stack(res, out_dim)
    return vmapped_fn

As an excercise, you might extend it to more than one output tensor.

Let us define a small two-layer perceptron function and try it. Just for fun, we do not use transposed weights - they would allow us to just use broadcasting for the batching and this probably is the reason why PyTorch uses them. To make our function more interesting, we run the weights through sigmoid before applying them.

Because this blog post is about the JIT, we decorate our function with torch.jit.script:

@torch.jit.script
def two_layers(x, w1, b1, w2, b2):
    w1 = w1.sigmoid()
    x = x @ w1 + b1
    x = x.tanh()
    w2 = w2.sigmoid()
    x = x @ w2 + b2
    return x

We also need some dummy inputs. $x$ will have a batch dimension, the weights and biases not so much.

batch_x = torch.randn(5, 1)  # batch 5, input feature 1
w1 = torch.randn(1, 10)
b1 = torch.randn(10)
w2 = torch.randn(10, 1)
b2 = torch.randn(1)

We can run a single example through this function:

two_layers(batch_x[1], w1, b2, w2, b2)
tensor([-4.8923])

But now that we want to run a batch through it, what do we do?

We run it though our vmap, of course.

python_vmapped_two_layers = python_vmap(two_layers, [0, None, None, None, None], [0])

python_vmapped_two_layers(batch_x, w1, b2, w2, b2)
tensor([[-0.6070],
        [-4.8923],
        [-0.1973],
        [-2.2427],
        [-0.2574]])

Source to source vmap in the JIT

But we wanted to learn something about the JIT. So let us look at our function's JIT graph.

def make_graph(gr):
    import graphviz
    dot = graphviz.Digraph(format='svg', graph_attr={'labelloc': 't'})

    nodes = {}
    for i in gr.inputs():
        nname = i.debugName()
        label = nname.split('.')[0]
        nodes[nname] = (nname, dot)
        dot.node(nname, label, color='blue')

    unseen_ops = {'prim::ListConstruct', 'aten::index', 
                  'aten::size', 'aten::slice', 'aten::unsqueeze', 'aten::squeeze',
                  'aten::to', 'aten::view', 'aten::permute', 'aten::transpose', 'aten::contiguous',
                  'aten::permute', 'aten::Int', 'prim::TupleUnpack', 'prim::ListUnpack', 'aten::unbind',
                  'aten::select', 'aten::detach', 'aten::stack', 'aten::reshape', 'aten::split_with_sizes',
                  'aten::cat', 'aten::expand', 'aten::expand_as', 'aten::_shape_as_tensor',
                  'aten::_size_if_not_equal', 'prim::BroadcastSizes',
                  'prim::Constant',
                  }

    def process_block(nodeit, dot):
        firstnode = None
        lastnode = None
        for n in nodeit:
            k = n.kind()
            outs = list(n.outputs())
            inps = list(n.inputs())
            type_outs = [o.type().kind() for o in outs]
            type_inps = [o.type().kind() for o in inps]
            if k == 'prim::If':
                label = k.split('::')[1]
                nname = outs[0].debugName()
                for i in inps:
                    src, srcdot = nodes.get(i.debugName(), (None, None))
                    if src is not None:
                        srcdot.edge(src, nname + '_in')
                dot.node(nname + '_in', 'If', shape='diamond')
                dot.node(nname, '', width='0.1', height='0.1')
                dot.edge(nname + '_in', nname, style='invis')
                nodes[nname] = (nname, dot)
                bl = list(n.blocks())
                for i, b in enumerate(bl):
                    with dot.subgraph(name=f"cluster_{nname}_{i}", graph_attr={'label':''}) as sub_dot:
                        firstnode, lastnode = process_block(b.nodes(), sub_dot)
                    dot.edge(nname + '_in', firstnode, label="yn"[i])
                    dot.edge(lastnode, nname)
                if firstnode is None:
                    firstnode = nname + '_in'
                lastnode = nname
            elif k == 'prim::DifferentiableGraph':
                label = 'DifferentiableGraph'
                nname = outs[0].debugName()
                nodes[nname] = (nname, dot)
                sg = n.g('Subgraph')
                nis = list(n.inputs())
                sgis = list(sg.inputs())
                assert len(nis) == len(sgis)
                for ni, sgi in zip(nis, sgis):
                    if ni.debugName() in nodes:
                        nodes[sgi.debugName()] = nodes[ni.debugName()]
                with dot.subgraph(name=f"cluster_{nname}", graph_attr={
                    'label': 'DifferentiableGraph', 'labelloc':'b', 'labeljust':'r'}) as sub_dot:
                    firstnode, lastnode = process_block(sg.nodes(), sub_dot)
                nos = list(n.outputs())
                sgos = list(sg.outputs())
                assert len(nos) <= len(sgos)
                for no, sgo in zip(nos, sgos):
                    if sgo.debugName() in nodes:
                        nodes[no.debugName()] = (nodes[sgo.debugName()][0], dot)
            elif k == 'vmapper::vmap':
                label = 'vmap'
                nname = outs[0].debugName()
                b, = n.blocks()
                for i, ib in zip(inps, b.inputs()):
                    src, srcdot = nodes.get(i.debugName(), (None, None))
                    if src is not None:
                        srcdot.edge(src, nname + '_batch')
                        nodes[ib.debugName()] = (nname + '_batch', dot)
                nodes[nname] = (nname, dot)
                with dot.subgraph(name=f"cluster_{nname}_{i}", graph_attr={'label':'vmap'}) as sub_dot:
                    firstnode, lastnode = process_block(b.nodes(), sub_dot)
                #dot.edge(nname + '_in', firstnode, label="yn"[i])
                #dot.edge(lastnode, nname)
                for o, ob in zip(n.outputs(), b.outputs()):
                    nodes[o.debugName()] = (nodes[ob.debugName()][0], dot)
                if firstnode is None:
                    firstnode = nname + '_in'
                lastnode = nname
            elif k not in unseen_ops and outs:
                if k == 'prim::CallFunction':
                    label = 'call ' + next(n.inputs()).node().s("name")
                else:
                    label = k.replace('aten::', '').replace('prim::', '')
                nname = outs[0].debugName()
                dot.node(nname, label, shape='box', style='rounded')
                for o in outs:
                    nodes[o.debugName()] = (nname, dot)
                for i in inps:
                    src, srcdot = nodes.get(i.debugName(), (None, None))
                    if src is not None:
                        srcdot.edge(src, nname)
                if firstnode is None:
                    firstnode = nname
                lastnode = nname
        return firstnode, lastnode

    process_block(gr.nodes(), dot)
    dot.node('.outputs', 'outputs', color='blue')
    for i, o in enumerate(gr.outputs()):
        src, srcdot = nodes.get(o.debugName(), (None, None))
        if src is not None:
            dot.edge(src, '.outputs')

    return dot
make_graph(two_layers.graph)

svg

Nothing particularly exciting here.

But how can we go about defining a vmap for this? Well, a good first step might be to get the information into the graph.

Declaring vmap in the JIT

We thus introduce a vmapper::vmap node. (Using prim as a prefix makes some internal checks more lenient.) Similar to, say, loops, vmapper::vmap nodes have a body (a block in JIT terminology) that comprises the instructions to be vectorized.

Here is our first twist: As inputs and outputs of the block, we use those input and output tensors that have a batch dimension (and take the other inputs from outside the block, which is allowed).

As we need to tell the vmapper::vmap node where our batch dimensions are, we introduce two integer list attributes for them.Mind you, you should not expect the graph manipulation to work off stock PyTorch, I have an experimental branch for it.

What should go inside the vmap? Well, we wanted to vmap the entire graph, so everything. (Actually, this is not true because functions returning mutliple values will build a tuple of them. But I ignore this difficulty for the purpose of the exposition.)

So our plan is to take (a copy of) the graph and

  • add a vmapper::vmap node with in_dims and out_dims attribute and one block.
  • move all nodes into the block
  • make the graph inputs with batch dimensions to be inputs to vmapper::vmap and rewire all other the direct use of them to use the block's inputs instead
  • do the analogous thing for the outputs.

When creating new values (block inputs, outputs of the vmapper::vmap node), we need to set the type (well, actually Tensor is the default, but still...).

def vmap_graph(gr, in_dims, out_dims):
    gr = gr.copy()

    # treat final prim::TupleConstruct or the "hidden" prim::Return node as output
    node_list = list(gr.nodes())
    if node_list[-1].kind() == 'prim::TupleConstruct':
        output_node = node_list.pop(-1)
    else:
        output_node = gr.return_node()

    assert len(in_dims) == len(list(gr.inputs()))
    assert len(out_dims) == len(list(output_node.inputs()))

    # add a `vmapper::vmap` node with `in_dims` and `out_dims` attribute and one block.
    n_vmap = gr.create('vmapper::vmap', 0)
    n_vmap.is_("in_dims", [d for d in in_dims if d is not None])
    n_vmap.is_("out_dims", out_dims)
    n_vmap.insertAfter(gr.param_node()) # insert it at the top
    bl = n_vmap.addBlock() # add the block

    # move all nodes (except ourselves...) into the block
    for n in node_list:
        if n != n_vmap:
            n.moveBefore(bl.returnNode())

    # make the graph inputs with batch dimensions to be inputs to `vmapper::vmap` and
    # rewire all other the direct use of them to use the block's inputs instead- 
    for i, d in zip(gr.inputs(), in_dims):
        if d is not None:
            # create a block input and rewire
            bl_i = bl.addInputToBlock().setType(i.type())
            i.replaceAllUsesWith(bl_i)
            # add the vmapper::vmap input
            n_vmap.addInput(i)
    # fix outputs (as we are changing them, we first build the list)
    for idx, o in enumerate(list(output_node.inputs())):
        # register the output as block output
        bl.registerOutput(o)
        # create vmapper::vmap output
        n_vmap_o = n_vmap.addOutput().setType(o.type())
        # replace the graph output (=the input to the hidden return node) with the vmapper::vmap output
        output_node.replaceInput(idx, n_vmap_o)
    return gr

Neat! So it's a bit of work, but it's not too bad.

Let's look at our vmapped graph:

vgr = vmap_graph(two_layers.graph, [0, None, None, None, None], [0])
make_graph(vgr)

svg

Cool! But we are not done yet, running our vmapped function won't work:

# need new compilation unit as create_function will complain if function already exists (need to fix this)
vmapped_two_layers = torch.jit.CompilationUnit().create_function("vmapped_two_layers", vgr)

vmapped_two_layers(batch_x, w1, b1, w2, b2)
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-11-c9db2eb144e3> in <module>
      2 vmapped_two_layers = torch.jit.CompilationUnit().create_function("vmapped_two_layers", vgr)
      3 
----> 4 vmapped_two_layers(batch_x, w1, b1, w2, b2)


RuntimeError: 0 INTERNAL ASSERT FAILED at "../torch/csrc/jit/ir/alias_analysis.cpp":461, please report a bug to PyTorch. We don't have an op for vmapper::vmap but it isn't a special case.  Argument types: Tensor,

So we need to get rid of the vmap again.

We do this by implementing the construct we had in Python in TorchScript.

Implementing a simplistic (fallback) vmap

Because manipulating the graph can be tedious, we implement a function inserting a constant definition. This is part of the exploration I want to do here, to find out how much busywork it is to work with TorchScript graphs and what utility functions might help us. C++ has many of these createFoo functions, but they're not exposed to Python and it's not immediate clear to me that we would follow the same API in Python.

def insert_const(gr, value, typ=None, before=None, after=None):
    assert int(before is not None) + int(after is not None) == 1

    if typ is None:
        typ = torch._C._jit_try_infer_type(value)

    n = gr.create('prim::Constant')
    if value is None:
        pass
    elif isinstance(value, (int, bool)):
        n.i_("value", value)
        if isinstance(typ, torch._C.OptionalType):
            typ = typ.getElementType()
    else:
        raise NotImplemented()
    n.output().setType(typ)
    if before is not None:
        n.insertBefore(before)
    else:
        n.insertAfter(after)
    return n.output()

So now onto building our fallback:

For all vmapper::vmap nodes we need to - define a prim::Loop node (TorchScript's generalization of for and while) - figure out the number of iterations (size of the batch dimenson of the first input tensor), - add a True constant (for the while part), - add a loop node taking those as inputs. - create empty lists of Tensors for all (batched) outputs. - In the loop body: - take the batched inputs and select the current batch item (the indexing being done with torch.select), - hook them to the computation, - run the operation (move all nodes from the vmapper::vmap), - append the results to the respective lists, - after the loop body, torch.stack the result lists to get the outputs and hook them up, - remove the now deserted vmapper::vmap.

Let's go.

def vmap_to_fallback(gr):
    gr = gr.copy()
    for n_vmap in gr.findAllNodes('vmapper::vmap'):
        # True bool for the loop "while" condition
        loop_true = insert_const(gr, True, before=n_vmap)
        # consts for batching dimensions
        inp_dims_vals = []
        for d in n_vmap._is("in_dims"):
            inp_dims_vals.append(insert_const(gr, d, before=n_vmap))
        out_dims_vals = []
        for d in n_vmap._is("out_dims"):
            out_dims_vals.append(insert_const(gr, d, before=n_vmap))

        # figure out the number of iterations (size of the batch dimenson of the first input tensor)
        n = gr.create('aten::size')
        n.addInput(list(n_vmap.inputs())[0])
        n.addInput(inp_dims_vals[0])
        n.insertBefore(n_vmap)
        loop_range = n.output().setType(torch._C.IntType.get())

        # create lists for the outputs
        output_lists = []
        for o in n_vmap.outputs():
            n = gr.create('prim::ListConstruct')
            output_lists.append(n.output().setType(torch._C.ListType.ofTensors()))
            n.insertBefore(n_vmap)

        # insert the loop node
        n  = gr.create('prim::Loop', 0)
        n.addInput(loop_range)
        n.addInput(loop_true)
        n.insertBefore(n_vmap)

        # build the loop body
        bl = n.addBlock()
        # loop counter = index into the batch = the (only) input of the loop block
        loop_counter = bl.addInputToBlock().setType(torch._C.IntType.get())
        # the while condition again
        bl.registerOutput(loop_true)


        bl_vmap, = n_vmap.blocks()

        # for all batched inputs to `vmapper::vmap` select the current batch item and hook it to where the
        # `vmapper::vmap` block input was used
        for i, ib, d in zip(n_vmap.inputs(), bl_vmap.inputs(), inp_dims_vals):
            n = gr.create('aten::select')
            n.addInput(i)
            n.addInput(d)
            n.addInput(loop_counter)
            n.insertBefore(bl.returnNode())
            ib.replaceAllUsesWith(n.output())

        # move all the computation nodes over
        for n in list(bl_vmap.nodes()):
            n.moveBefore(bl.returnNode())

        # append the results  (= `vmapper::vmap` block outputs) to their respective lists
        for o, ol in zip(bl_vmap.outputs(), output_lists):
            n = gr.create('aten::append')
            n.output().setType(torch._C.ListType.ofTensors())
            n.addInput(ol)
            n.addInput(o)
            n.insertBefore(bl.returnNode())

        # stack the outputs after the lop
        for o, ol, d in zip(n_vmap.outputs(), output_lists, out_dims_vals):
            n = gr.create('aten::stack')
            n.addInput(ol)
            n.addInput(d)
            o.replaceAllUsesWith(n.output())
            n.insertBefore(n_vmap)
        # get rid of the `vmapper::vmap`
        n_vmap.destroy()

    return gr

So here we see that it would be really neat if we could get the node creation for PyTorch functions to be closer to how we write PyTorch. We will have to look into that.

But first let us try our function.

processed_gr = vmap_to_fallback(vgr)
processed_gr # not handled well by the ad hoc visualizer
graph(%x.1 : Tensor,
      %w1.1 : Tensor,
      %b1.1 : Tensor,
      %w2.1 : Tensor,
      %b2.1 : Tensor):
  %15 : bool = prim::Constant[value=1]()
  %16 : int = prim::Constant[value=0]()
  %17 : int = prim::Constant[value=0]()
  %18 : int = aten::size(%x.1, %16)
  %19 : Tensor[] = prim::ListConstruct()
   = prim::Loop(%18, %15)
    block0(%20 : int):
      %21 : Tensor = aten::select(%x.1, %16, %20)
      %7 : int = prim::Constant[value=1]()
      %w1.3 : Tensor = aten::sigmoid(%w1.1) # <ipython-input-3-7a3aa41c5a43>:3:9
      %9 : Tensor = aten::matmul(%21, %w1.3) # <ipython-input-3-7a3aa41c5a43>:4:8
      %x.3 : Tensor = aten::add(%9, %b1.1, %7) # <ipython-input-3-7a3aa41c5a43>:4:8
      %x.5 : Tensor = aten::tanh(%x.3) # <ipython-input-3-7a3aa41c5a43>:5:8
      %w2.3 : Tensor = aten::sigmoid(%w2.1) # <ipython-input-3-7a3aa41c5a43>:6:9
      %13 : Tensor = aten::matmul(%x.5, %w2.3) # <ipython-input-3-7a3aa41c5a43>:7:8
      %x.7 : Tensor = aten::add(%13, %b2.1, %7) # <ipython-input-3-7a3aa41c5a43>:7:8
      %22 : Tensor[] = aten::append(%19, %x.7)
      -> (%15)
  %23 : Tensor = aten::stack(%19, %17)
  return (%23)

That looks approximately as expected (and neatly, the source lines are still there for the loop body). Let's try to run it.

vmapped_two_layers = torch.jit.CompilationUnit().create_function("vmapped_two_layers", processed_gr)
print(vmapped_two_layers.code)
(vmapped_two_layers(batch_x, w1, b1, w2, b2),
 torch.allclose(python_vmapped_two_layers(batch_x, w1, b1, w2, b2), python_vmapped_two_layers(batch_x, w1, b1, w2, b2)))
def vmapped_two_layers(x: Tensor,
    w1: Tensor,
    b1: Tensor,
    w2: Tensor,
    b2: Tensor) -> Tensor:
  _0 = torch.size(x, 0)
  _1 = annotate(List[Tensor], [])
  for _2 in range(_0):
    _3 = torch.select(x, 0, _2)
    w10 = torch.sigmoid(w1)
    x0 = torch.add(torch.matmul(_3, w10), b1, alpha=1)
    x1 = torch.tanh(x0)
    w20 = torch.sigmoid(w2)
    x2 = torch.add(torch.matmul(x1, w20), b2, alpha=1)
    _4 = torch.append(_1, x2)
  return torch.stack(_1, 0)






(tensor([[ 0.6605],
         [-2.9427],
         [ 0.8782],
         [-0.3098],
         [ 0.8467]]),
 True)

So awesome, by combining those two transformations, we have a working vmap!

But the for loop is very inefficient - the point of vmap is to get rid of it. Also, so far, the two steps wouldn't have been needed, right?

Optimizations: Only compute non-batched values once

The first thing we might stop doing is compute the same non-batched values over and over again, namely the sigmoid functions.

So we see if any of the nodes doesn't (indirectly) depend on batched inputs and move them to before the vmapper::vmap node. This works because the inputs are from outside the block anyways.

def move_out_non_batched(gr):
    gr = gr.copy()
    for n_vmap in gr.findAllNodes('vmapper::vmap'):
        bl, = n_vmap.blocks()

        batch_dim_values = {}
        for i,d in zip(bl.inputs(), n_vmap._is('in_dims')):
            batch_dim_values[i] = d
        for n in list(bl.nodes()):
            if not any(i in batch_dim_values for i in n.inputs()):
                n.moveBefore(n_vmap)
            else:
                for o in n.outputs():
                    # it might not be true that all outputs depend are batched
                    batch_dim_values[o] = 'unk'
    return gr

vgr2 = move_out_non_batched(vgr)
make_graph(vgr2)

svg

That works! So let us pick a nontrivial function to optimize. Matmul!

Optimizing matmul

So how would batching in Matmul work. Let's look at the input and output convention. From the documentation:

help(torch.matmul)
Help on built-in function matmul:

matmul(...)
    matmul(input, other, *, out=None) -> Tensor

    Matrix product of two tensors.

    The behavior depends on the dimensionality of the tensors as follows:

    - If both tensors are 1-dimensional, the dot product (scalar) is returned.
    - If both arguments are 2-dimensional, the matrix-matrix product is returned.
    - If the first argument is 1-dimensional and the second argument is 2-dimensional,
      a 1 is prepended to its dimension for the purpose of the matrix multiply.
      After the matrix multiply, the prepended dimension is removed.
    - If the first argument is 2-dimensional and the second argument is 1-dimensional,
      the matrix-vector product is returned.
    - If both arguments are at least 1-dimensional and at least one argument is
      N-dimensional (where N > 2), then a batched matrix multiply is returned.  If the first
      argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the
      batched matrix multiply and removed after.  If the second argument is 1-dimensional, a
      1 is appended to its dimension for the purpose of the batched matrix multiple and removed after.
      The non-matrix (i.e. batch) dimensions are :ref:`broadcasted <broadcasting-semantics>` (and thus
      must be broadcastable).  For example, if :attr:`input` is a
      :math:`(j \times 1 \times n \times n)` tensor and :attr:`other` is a :math:`(k \times n \times n)`
      tensor, :attr:`out` will be a :math:`(j \times k \times n \times n)` tensor.

      Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs
      are broadcastable, and not the matrix dimensions. For example, if :attr:`input` is a
      :math:`(j \times 1 \times n \times m)` tensor and :attr:`other` is a :math:`(k \times m \times p)`
      tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the
      matrix dimensions) are different. :attr:`out` will be a :math:`(j \times k \times n \times p)` tensor.

    This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.

    .. note::

        The 1-dimensional dot product version of this function does not support an :attr:`out` parameter.

    Arguments:
        input (Tensor): the first tensor to be multiplied
        other (Tensor): the second tensor to be multiplied

    Keyword args:
        out (Tensor, optional): the output tensor.

    Example::

        >>> # vector x vector
        >>> tensor1 = torch.randn(3)
        >>> tensor2 = torch.randn(3)
        >>> torch.matmul(tensor1, tensor2).size()
        torch.Size([])
        >>> # matrix x vector
        >>> tensor1 = torch.randn(3, 4)
        >>> tensor2 = torch.randn(4)
        >>> torch.matmul(tensor1, tensor2).size()
        torch.Size([3])
        >>> # batched matrix x broadcasted vector
        >>> tensor1 = torch.randn(10, 3, 4)
        >>> tensor2 = torch.randn(4)
        >>> torch.matmul(tensor1, tensor2).size()
        torch.Size([10, 3])
        >>> # batched matrix x batched matrix
        >>> tensor1 = torch.randn(10, 3, 4)
        >>> tensor2 = torch.randn(10, 4, 5)
        >>> torch.matmul(tensor1, tensor2).size()
        torch.Size([10, 3, 5])
        >>> # batched matrix x broadcasted matrix
        >>> tensor1 = torch.randn(10, 3, 4)
        >>> tensor2 = torch.randn(4, 5)
        >>> torch.matmul(tensor1, tensor2).size()
        torch.Size([10, 3, 5])

Oh, wow, that was a lot. But so I have the hypothesis that we might simplify that description if we make it more algorithmically. So we try to define an equivalent matmul that always reduces to $batchdims \times n \times k , batchdims \times k \times m \mapsto batchdims \times n \times m$ with the same number of batch dimensions in both operands.

def my_matmul(a, b):
    unsqueeze_a = a.dim() == 1
    unsqueeze_b = b.dim() == 1
    if unsqueeze_a:
        a = a.unsqueeze(-2)
    if unsqueeze_b:
        b = b.unsqueeze(-1)
    while a.dim() < b.dim():
        a = a.unsqueeze(0)
    while a.dim() > b.dim():
        b = b.unsqueeze(0)
    assert a.dim() == b.dim()  # check that we're really using the same number of dims
    res = a @ b
    if unsqueeze_a:
        res = res.squeeze(-2)
    if unsqueeze_b:
        res = res.squeeze(-1)
    return res


# tests from the docs
# vector x vector
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
assert torch.matmul(tensor1, tensor2).size() == my_matmul(tensor1, tensor2).size()

# matrix x vector
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
assert torch.matmul(tensor1, tensor2).size() == my_matmul(tensor1, tensor2).size()

# batched matrix x broadcasted vector
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
assert torch.matmul(tensor1, tensor2).size() == my_matmul(tensor1, tensor2).size()

# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
assert torch.matmul(tensor1, tensor2).size() == my_matmul(tensor1, tensor2).size()

# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
assert torch.matmul(tensor1, tensor2).size() == my_matmul(tensor1, tensor2).size()

OK, this wasn't too bad. But there is a twist: If we might have vmap-batched operands, we need to treat that dimension differently.

So here is a version that supports that the leading dimension to optionally be the batch dimension.

@torch.jit.script
def standardized_matmul(a, b, a_is_batch: bool, b_is_batch: bool):
    # from how I understand the (too) many cases in matmul, this is equivalent when both is_batch are False
    # both need to have batch as the leading dim
    unsqueeze_a = a.dim() == 1 + int(a_is_batch)
    unsqueeze_b = b.dim() == 1 + int(b_is_batch)
    if unsqueeze_a:
        a = a.unsqueeze(-2)
    if unsqueeze_b:
        b = b.unsqueeze(-1)
    while a.dim() + int(b_is_batch) < b.dim() + int(a_is_batch):
        a = a.unsqueeze(int(a_is_batch))
    while a.dim() + int(b_is_batch) > b.dim() + int(a_is_batch):
        b = b.unsqueeze(int(b_is_batch))
    res = a @ b
    if unsqueeze_a:
        res = res.squeeze(-2)
    if unsqueeze_b:
        res = res.squeeze(-1)
    return res

Well, but the batch dimension might not be the leading dimension, and in that case we would need to move it there. It's easy to do by adding a dimension where we want it to go, transposing and removing the extra dimension. Then we can make a matmul_with_batch taking optional batch dimensions.

@torch.jit.script
def move_dim_to(t, dim_from: int, dim_to: int):
    if dim_to < dim_from:
        t = t.unsqueeze(dim_to)
        t = t.transpose(dim_from + 1, dim_to)
        t = t.squeeze(dim_from + 1)
    elif dim_to > dim_from:
        t = t.unsqueeze(dim_to + 1)
        t = t.transpose(dim_from, dim_to + 1)
        t = t.squeeze(dim_from)
    return t

@torch.jit.script
def matmul_with_batch(a, b, a_batch_dim: Optional[int], b_batch_dim: Optional[int]):
    if a_batch_dim is not None and a_batch_dim != 0:
        a = move_dim_to(a, a_batch_dim, 0)
    if b_batch_dim is not None and b_batch_dim != 0:
        b = move_dim_to(b, b_batch_dim, 0)
    return standardized_matmul(a, b, a_batch_dim is not None, b_batch_dim is not None)

So now we have a function and we can call it from a script function. But how would we do that during graph surgery?

Well in lieu of a more principled approach (we would need that) we might look at what goes on under the hood when we call a script function from another script function:

@torch.jit.script
def dummy(a):
  b = matmul_with_batch(a, a, 1, 1)
  return b

dummy.graph
graph(%a.1 : Tensor):
  %4 : Function = prim::Constant[name="matmul_with_batch"]()
  %3 : int = prim::Constant[value=1]() # <ipython-input-21-9d3b27cad730>:3:30
  %b.1 : Tensor = prim::CallFunction(%4, %a.1, %a.1, %3, %3) # <ipython-input-21-9d3b27cad730>:3:6
  return (%b.1)

Aha, so there is a Function constant. But which function? Avid readers of my blog might remember from my fuser post that we created a Python binding for this FunctionType and that the function type is specific to the function contains the graph.

So we write a little hack function that inserts the function constant at the top of the graph and gives us the output value.

We also write a helper that inserts a call to a function (and returns the node).

def insert_matmul_with_batch_function_constant(gr):
    @torch.jit.script
    def dummy(a):
        b = matmul_with_batch(a, a, 1, 1)
        return b
    n_template = [n for n in dummy.graph.findAllNodes('prim::Constant') if n.output().type().kind() == 'FunctionType'][0]
    n = gr.create('prim::Constant')
    n.s_('name', n_template.s('name'))
    n.output().setType(n_template.output().type())
    n.insertAfter(gr.param_node())
    dummy # avoid early destruction
    return n.output()

def insert_call(gr, fn_value, params, before=None, after=None):
    assert int(before is not None) + int(after is not None) == 1
    n = gr.create('prim::CallFunction', 0)
    n.addInput(fn_value)
    for p in params:
        n.addInput(p)
    for o in fn_value.type().optimized_graph().outputs():
        n.addOutput().setType(o.type())
    if before is not None:
        n.insertBefore(before)
    else:
        n.insertAfter(after)
    return n

But before we can apply it, we need to make sure we have the inputs to the matmul as batch matrices. This is the case at the top of the vmapper::vmap but not (usually) in the middle.

So we need to make sure matmul is always at the top, i.e. we need to saw the graph apart.

This is quite a complex operation. Moving part of the computation nodes is easy, but the rewiring of the values requires some care as there are many cases to consider. We'll only deal with the simplest ones.

known_node_kinds = {'aten::matmul'}
def split_vmap_before_known(gr):
    gr = gr.copy()

    vmap_to_process = gr.findAllNodes('vmapper::vmap')

    while vmap_to_process:
        n_vmap = vmap_to_process.pop(0)
        bl, = n_vmap.blocks()

        split_nodes = [n for idx, n in enumerate(bl.nodes()) if idx > 0 and n.kind() in known_node_kinds]
        if len(split_nodes) > 0:
            split_at = split_nodes[0]
            idxes = {n: idx for idx, n in enumerate(bl.nodes())}
            idxes[bl.returnNode()] = len(idxes)
            created_at = {v: -1 for v in bl.inputs()}
            first_used_at = {}
            last_used_at = {v: -1 for v in bl.inputs()}

            for n in bl.nodes():
                idx = idxes[n]
                for i in n.inputs():
                    if i in last_used_at:
                        last_used_at[i] = idx
                        if i not in first_used_at:
                            first_used_at[i] = idx
                for o in n.outputs():
                    created_at[o] = idx
                    last_used_at[o] = -1

            idx = idxes[bl.returnNode()]
            for o in bl.outputs():
                if o in last_used_at:
                    last_used_at[o] = idx
                    if i not in first_used_at:
                        first_used_at[o] = idx


            split_idx = idxes[split_at]

            n_vmap2 = gr.create('vmapper::vmap', 0)
            n_vmap2.insertAfter(n_vmap)
            bl2 = n_vmap2.addBlock()
            for n in list(bl.nodes())[split_idx:]:
                n.moveBefore(bl2.returnNode())

            outputs_to_move = [(idx, o, d) for idx, (o, d) in enumerate(zip(bl.outputs(), n_vmap._is("out_dims"))) if created_at[o] >= split_idx]
            n_vmap_os = list(n_vmap.outputs())
            new_out_dims = []
            for idx_o, o, d in outputs_to_move:
                bl2.registerOutput(o)
                o2 = n_vmap2.addOutput()
                o2.setType(o.type())
                n_vmap_os[idx_o].replaceAllUsesWith(o2)
                new_out_dims.append(d)
            n_vmap2.is_("out_dims", new_out_dims)

            out_dims =  n_vmap._is("out_dims")
            for idx, _, _ in reversed(outputs_to_move):
                bl.returnNode().removeInput(idx)
                n_vmap.eraseOutput(idx)
                del out_dims[idx]

            new_in_dims = []
            for k, cr_idx in created_at.items():
                lu_idx = last_used_at[k]
                if cr_idx < split_idx and lu_idx >= split_idx and k.type().kind() == 'TensorType':
                    # we might have to deal with scalars, too, from intermediate scalars those should be put in a list or so
                    # also, inputs that are used in the bottom half should be taken from the outside instead
                    typ = k.type()
                    bl.registerOutput(k)
                    out_dims.append(0)
                    new_in_dims.append(0)
                    o_new = n_vmap.addOutput().setType(typ)
                    n_vmap2.addInput(o_new)
                    i_new = bl2.addInputToBlock().setType(typ)
                    for u in k.uses():
                        if idxes[u.user] >= split_idx and u.user != bl.returnNode():
                            u.user.replaceInputWith(k, i_new)

            n_vmap.is_("out_dims", out_dims)
            n_vmap2.is_("in_dims", new_in_dims)
            vmap_to_process.insert(0, n_vmap2)
    return gr

vgr3 = split_vmap_before_known(vgr2)
make_graph(vgr3)

svg

So far so good. Let's see if it still runs:

processed_gr = vmap_to_fallback(vgr3)
vmapped_two_layers = torch.jit.CompilationUnit().create_function("vmapped_two_layers", processed_gr)

(vmapped_two_layers(batch_x, w1, b1, w2, b2),
 torch.allclose(vmapped_two_layers(batch_x, w1, b1, w2, b2), python_vmapped_two_layers(batch_x, w1, b1, w2, b2)))
(tensor([[ 0.6605],
         [-2.9427],
         [ 0.8782],
         [-0.3098],
         [ 0.8467]]),
 True)

Now we can optimize aten::matmul nodes by pulling them out of the vmap block and using our batch-aware matmul.

We introduce an utility function that removes unused (after moving out the matmul) vmap inputs. We probably should also remove inputs that are immediately used as outputs, but well...

def remove_unused_vmap_inputs(n_vmap):
    assert n_vmap.kind() == 'vmapper::vmap'
    bl, = n_vmap.blocks()
    dims = n_vmap._is("in_dims")
    inputs_to_remove = [idx for idx, i in enumerate(bl.inputs()) if len(i.uses()) == 0]
    n_param = bl.paramNode()
    for i in reversed(inputs_to_remove):
        # need to go from end to fix indices, also O(N**2) alarm from removing things from vector
        n_param.eraseOutput(i)
        n_vmap.removeInput(i)
        del dims[i]
    n_vmap.is_("in_dims", dims)

With these, we can finally put in our generalized matmul.

def optimize_matmuls(gr):
    gr = gr.copy()
    val_matmul_with_batch = insert_matmul_with_batch_function_constant(gr) # DCE-eliminated, but we could try harder...

    for n_vmap in gr.findAllNodes('vmapper::vmap'):
        bl, = n_vmap.blocks()

        batch_dim_values = {}
        for i,d in zip(bl.inputs(), n_vmap._is('in_dims')):
            batch_dim_values[i] = d

        block_nodes = list(bl.nodes())
        if block_nodes and block_nodes[0].kind() == 'aten::matmul':
            n = block_nodes[0]
            batch_dims = [batch_dim_values.get(i) for i in n.inputs()]
            if n.kind() == 'aten::matmul':
                batch_dim_consts = [insert_const(gr, d, typ=torch._C.OptionalType(torch._C.IntType.get()),
                                                 before=n_vmap) for d in batch_dims]
                n_vmap_inputs = list(n_vmap.inputs()) # these might change, so we reconstruct
                val_inputs = []
                for i in n.inputs():
                    if i.node().kind() == 'prim::Param':
                        val_inputs.append(n_vmap_inputs[i.offset()])
                    else:
                        val_inputs.append(i)
                n_new_mm = insert_call(gr, val_matmul_with_batch, val_inputs+batch_dim_consts, before=n_vmap)
                n_vmap.addInput(n_new_mm.output())
                n_vmap.is_("in_dims", n_vmap._is("in_dims") + [0])
                new_bl_input = bl.addInputToBlock().setType(n_new_mm.output().type())
                n.output().replaceAllUsesWith(new_bl_input)  # check for out dim update...
                batch_dim_values[new_bl_input] = 0
                n.destroy()
                remove_unused_vmap_inputs(n_vmap)
                #fix_inputs_as_outputs(n_vmap)
                # TODO: check outputs if we have to set dim to 0!!!!
    return gr

vgr4 = optimize_matmuls(vgr3)
make_graph(vgr4)

svg

processed_gr = vmap_to_fallback(vgr4)
vmapped_two_layers = torch.jit.CompilationUnit().create_function("vmapped_two_layers", processed_gr)

(vmapped_two_layers(batch_x, w1, b1, w2, b2),
 torch.allclose(vmapped_two_layers(batch_x, w1, b1, w2, b2), python_vmapped_two_layers(batch_x, w1, b1, w2, b2)))
    (tensor([[ 0.6605],
             [-2.9427],
             [ 0.8782],
             [-0.3098],
             [ 0.8467]]),
     True)

We can also look at the code of our vmapped function:

print(vmapped_two_layers.code)
    def vmapped_two_layers(x: Tensor,
        w1: Tensor,
        b1: Tensor,
        w2: Tensor,
        b2: Tensor) -> Tensor:
      w10 = torch.sigmoid(w1)
      w20 = torch.sigmoid(w2)
      _0 = __torch__.matmul_with_batch(x, w10, 0, None, )
      _1 = torch.size(_0, 0)
      _2 = annotate(List[Tensor], [])
      for _3 in range(_1):
        x0 = torch.add(torch.select(_0, 0, _3), b1, alpha=1)
        x1 = torch.tanh(x0)
        _4 = torch.append(_2, x1)
      _5 = __torch__.matmul_with_batch(torch.stack(_2, 0), w20, 0, None, )
      _6 = torch.size(_5, 0)
      _7 = annotate(List[Tensor], [])
      for _8 in range(_6):
        x2 = torch.add(torch.select(_5, 0, _8), b2, alpha=1)
        _9 = torch.append(_7, x2)
      return torch.stack(_7, 0)

And if we really want - but it gets messy due to the loop unrolling that the jit does - we can also look at the inlined graph.

vmapped_two_layers.inlined_graph
graph(%x.1 : Tensor,
      %w1.1 : Tensor,
      %b1.1 : Tensor,
      %w2.1 : Tensor,
      %b2.1 : Tensor):
  %5 : Function = prim::Constant[name="matmul_with_batch"]()
  %6 : int = prim::Constant[value=1]()
  %w1.3 : Tensor = aten::sigmoid(%w1.1) # <ipython-input-3-7a3aa41c5a43>:3:9
  %w2.3 : Tensor = aten::sigmoid(%w2.1) # <ipython-input-3-7a3aa41c5a43>:6:9
  %9 : int = prim::Constant[value=0]()
  %10 : int? = prim::Constant()
  %36 : int = prim::Constant[value=-1]() # <ipython-input-19-8e17841f14bc>:10:24
  %37 : int = prim::Constant[value=-2]() # <ipython-input-19-8e17841f14bc>:8:24
  %38 : int = prim::Constant[value=9223372036854775807]() # <ipython-input-19-8e17841f14bc>:11:4
  %39 : int = prim::Constant[value=1]() # <ipython-input-20-c5e8db4281bb>:5:35
  %40 : int = prim::Constant[value=0]() # <ipython-input-20-c5e8db4281bb>:15:50
  %41 : None = prim::Constant() # <ipython-input-20-c5e8db4281bb>:15:26
  %42 : bool = prim::Constant[value=0]() # <ipython-input-20-c5e8db4281bb>:15:7
  %43 : bool = aten::__isnot__(%9, %41) # <ipython-input-20-c5e8db4281bb>:15:7
  %44 : bool, %a_batch_dim.16 : int? = prim::If(%43) # <ipython-input-20-c5e8db4281bb>:15:7
    block0():
      %a_batch_dim.5 : int = prim::unchecked_cast(%9)
      %47 : bool = aten::ne(%a_batch_dim.5, %40) # <ipython-input-20-c5e8db4281bb>:15:35
      -> (%47, %a_batch_dim.5)
    block1():
      -> (%42, %9)
  %a_batch_dim.1 : int?, %a.5 : Tensor = prim::If(%44) # <ipython-input-20-c5e8db4281bb>:15:4
    block0():
      %a_batch_dim.9 : int = prim::unchecked_cast(%a_batch_dim.16)
      %51 : bool = aten::lt(%40, %a_batch_dim.9) # <ipython-input-20-c5e8db4281bb>:3:7
      %a.6 : Tensor = prim::If(%51) # <ipython-input-20-c5e8db4281bb>:3:4
        block0():
          %t.5 : Tensor = aten::unsqueeze(%x.1, %40) # <ipython-input-20-c5e8db4281bb>:4:12
          %54 : int = aten::add(%a_batch_dim.9, %39) # <ipython-input-20-c5e8db4281bb>:5:24
          %t.8 : Tensor = aten::transpose(%t.5, %54, %40) # <ipython-input-20-c5e8db4281bb>:5:12
          %56 : int = aten::add(%a_batch_dim.9, %39) # <ipython-input-20-c5e8db4281bb>:6:22
          %t.11 : Tensor = aten::squeeze(%t.8, %56) # <ipython-input-20-c5e8db4281bb>:6:12
          -> (%t.11)
        block1():
          %58 : bool = aten::gt(%40, %a_batch_dim.9) # <ipython-input-20-c5e8db4281bb>:7:9
          %t.30 : Tensor = prim::If(%58) # <ipython-input-20-c5e8db4281bb>:7:4
            block0():
              %t.14 : Tensor = aten::unsqueeze(%x.1, %39) # <ipython-input-20-c5e8db4281bb>:8:12
              %t.17 : Tensor = aten::transpose(%t.14, %a_batch_dim.9, %39) # <ipython-input-20-c5e8db4281bb>:9:12
              %t.20 : Tensor = aten::squeeze(%t.17, %a_batch_dim.9) # <ipython-input-20-c5e8db4281bb>:10:12
              -> (%t.20)
            block1():
              -> (%x.1)
          -> (%t.30)
      -> (%a_batch_dim.9, %a.6)
    block1():
      -> (%a_batch_dim.16, %x.1)
  %63 : bool = aten::__isnot__(%10, %41) # <ipython-input-20-c5e8db4281bb>:17:7
  %64 : bool, %b_batch_dim.16 : int? = prim::If(%63) # <ipython-input-20-c5e8db4281bb>:17:7
    block0():
      %b_batch_dim.5 : int = prim::unchecked_cast(%10)
      %67 : bool = aten::ne(%b_batch_dim.5, %40) # <ipython-input-20-c5e8db4281bb>:17:35
      -> (%67, %b_batch_dim.5)
    block1():
      -> (%42, %10)
  %b_batch_dim.1 : int?, %b.5 : Tensor = prim::If(%64) # <ipython-input-20-c5e8db4281bb>:17:4
    block0():
      %b_batch_dim.9 : int = prim::unchecked_cast(%b_batch_dim.16)
      %71 : bool = aten::lt(%40, %b_batch_dim.9) # <ipython-input-20-c5e8db4281bb>:3:7
      %b.6 : Tensor = prim::If(%71) # <ipython-input-20-c5e8db4281bb>:3:4
        block0():
          %t.21 : Tensor = aten::unsqueeze(%w1.3, %40) # <ipython-input-20-c5e8db4281bb>:4:12
          %74 : int = aten::add(%b_batch_dim.9, %39) # <ipython-input-20-c5e8db4281bb>:5:24
          %t.22 : Tensor = aten::transpose(%t.21, %74, %40) # <ipython-input-20-c5e8db4281bb>:5:12
          %76 : int = aten::add(%b_batch_dim.9, %39) # <ipython-input-20-c5e8db4281bb>:6:22
          %t.23 : Tensor = aten::squeeze(%t.22, %76) # <ipython-input-20-c5e8db4281bb>:6:12
          -> (%t.23)
        block1():
          %78 : bool = aten::gt(%40, %b_batch_dim.9) # <ipython-input-20-c5e8db4281bb>:7:9
          %t.31 : Tensor = prim::If(%78) # <ipython-input-20-c5e8db4281bb>:7:4
            block0():
              %t.24 : Tensor = aten::unsqueeze(%w1.3, %39) # <ipython-input-20-c5e8db4281bb>:8:12
              %t.25 : Tensor = aten::transpose(%t.24, %b_batch_dim.9, %39) # <ipython-input-20-c5e8db4281bb>:9:12
              %t.26 : Tensor = aten::squeeze(%t.25, %b_batch_dim.9) # <ipython-input-20-c5e8db4281bb>:10:12
              -> (%t.26)
            block1():
              -> (%w1.3)
          -> (%t.31)
      -> (%b_batch_dim.9, %b.6)
    block1():
      -> (%b_batch_dim.16, %w1.3)
  %83 : bool = aten::__isnot__(%a_batch_dim.1, %41) # <ipython-input-20-c5e8db4281bb>:19:37
  %84 : bool = aten::__isnot__(%b_batch_dim.1, %41) # <ipython-input-20-c5e8db4281bb>:19:62
  %85 : int = aten::dim(%a.5) # <ipython-input-19-8e17841f14bc>:5:18
  %86 : int = aten::Int(%83) # <ipython-input-19-8e17841f14bc>:5:33
  %87 : int = aten::add(%39, %86) # <ipython-input-19-8e17841f14bc>:5:29
  %unsqueeze_a.2 : bool = aten::eq(%85, %87) # <ipython-input-19-8e17841f14bc>:5:18
  %89 : int = aten::dim(%b.5) # <ipython-input-19-8e17841f14bc>:6:18
  %90 : int = aten::Int(%84) # <ipython-input-19-8e17841f14bc>:6:33
  %91 : int = aten::add(%39, %90) # <ipython-input-19-8e17841f14bc>:6:29
  %unsqueeze_b.2 : bool = aten::eq(%89, %91) # <ipython-input-19-8e17841f14bc>:6:18
  %a.17 : Tensor = prim::If(%unsqueeze_a.2) # <ipython-input-19-8e17841f14bc>:7:4
    block0():
      %a.7 : Tensor = aten::unsqueeze(%a.5, %37) # <ipython-input-19-8e17841f14bc>:8:12
      -> (%a.7)
    block1():
      -> (%a.5)
  %b.17 : Tensor = prim::If(%unsqueeze_b.2) # <ipython-input-19-8e17841f14bc>:9:4
    block0():
      %b.7 : Tensor = aten::unsqueeze(%b.5, %36) # <ipython-input-19-8e17841f14bc>:10:12
      -> (%b.7)
    block1():
      -> (%b.5)
  %97 : int = aten::dim(%a.17) # <ipython-input-19-8e17841f14bc>:11:10
  %98 : int = aten::Int(%84) # <ipython-input-19-8e17841f14bc>:11:20
  %99 : int = aten::add(%97, %98) # <ipython-input-19-8e17841f14bc>:11:10
  %100 : int = aten::dim(%b.17) # <ipython-input-19-8e17841f14bc>:11:38
  %101 : int = aten::Int(%83) # <ipython-input-19-8e17841f14bc>:11:48
  %102 : int = aten::add(%100, %101) # <ipython-input-19-8e17841f14bc>:11:38
  %103 : bool = aten::lt(%99, %102) # <ipython-input-19-8e17841f14bc>:11:10
  %a.1 : Tensor = prim::Loop(%38, %103, %a.17) # <ipython-input-19-8e17841f14bc>:11:4
    block0(%105 : int, %a.18 : Tensor):
      %107 : int = aten::Int(%83) # <ipython-input-19-8e17841f14bc>:12:24
      %a.12 : Tensor = aten::unsqueeze(%a.18, %107) # <ipython-input-19-8e17841f14bc>:12:12
      %109 : int = aten::dim(%a.12) # <ipython-input-19-8e17841f14bc>:11:10
      %110 : int = aten::Int(%84) # <ipython-input-19-8e17841f14bc>:11:20
      %111 : int = aten::add(%109, %110) # <ipython-input-19-8e17841f14bc>:11:10
      %112 : int = aten::dim(%b.17) # <ipython-input-19-8e17841f14bc>:11:38
      %113 : int = aten::Int(%83) # <ipython-input-19-8e17841f14bc>:11:48
      %114 : int = aten::add(%112, %113) # <ipython-input-19-8e17841f14bc>:11:38
      %115 : bool = aten::lt(%111, %114) # <ipython-input-19-8e17841f14bc>:11:10
      -> (%115, %a.12)
  %116 : int = aten::dim(%a.1) # <ipython-input-19-8e17841f14bc>:13:10
  %117 : int = aten::Int(%84) # <ipython-input-19-8e17841f14bc>:13:20
  %118 : int = aten::add(%116, %117) # <ipython-input-19-8e17841f14bc>:13:10
  %119 : int = aten::dim(%b.17) # <ipython-input-19-8e17841f14bc>:13:38
  %120 : int = aten::Int(%83) # <ipython-input-19-8e17841f14bc>:13:48
  %121 : int = aten::add(%119, %120) # <ipython-input-19-8e17841f14bc>:13:38
  %122 : bool = aten::gt(%118, %121) # <ipython-input-19-8e17841f14bc>:13:10
  %b.1 : Tensor = prim::Loop(%38, %122, %b.17) # <ipython-input-19-8e17841f14bc>:13:4
    block0(%124 : int, %b.18 : Tensor):
      %126 : int = aten::Int(%84) # <ipython-input-19-8e17841f14bc>:14:24
      %b.13 : Tensor = aten::unsqueeze(%b.18, %126) # <ipython-input-19-8e17841f14bc>:14:12
      %128 : int = aten::dim(%a.1) # <ipython-input-19-8e17841f14bc>:13:10
      %129 : int = aten::Int(%84) # <ipython-input-19-8e17841f14bc>:13:20
      %130 : int = aten::add(%128, %129) # <ipython-input-19-8e17841f14bc>:13:10
      %131 : int = aten::dim(%b.13) # <ipython-input-19-8e17841f14bc>:13:38
      %132 : int = aten::Int(%83) # <ipython-input-19-8e17841f14bc>:13:48
      %133 : int = aten::add(%131, %132) # <ipython-input-19-8e17841f14bc>:13:38
      %134 : bool = aten::gt(%130, %133) # <ipython-input-19-8e17841f14bc>:13:10
      -> (%134, %b.13)
  %res.2 : Tensor = aten::matmul(%a.1, %b.1) # <ipython-input-19-8e17841f14bc>:15:10
  %res.16 : Tensor = prim::If(%unsqueeze_a.2) # <ipython-input-19-8e17841f14bc>:16:4
    block0():
      %res.4 : Tensor = aten::squeeze(%res.2, %37) # <ipython-input-19-8e17841f14bc>:17:14
      -> (%res.4)
    block1():
      -> (%res.2)
  %res.5 : Tensor = prim::If(%unsqueeze_b.2) # <ipython-input-19-8e17841f14bc>:18:4
    block0():
      %res.10 : Tensor = aten::squeeze(%res.16, %36) # <ipython-input-19-8e17841f14bc>:19:14
      -> (%res.10)
    block1():
      -> (%res.16)
  %12 : bool = prim::Constant[value=1]()
  %13 : int = prim::Constant[value=0]()
  %14 : int = prim::Constant[value=0]()
  %15 : int = aten::size(%res.5, %13)
  %16 : Tensor[] = prim::ListConstruct()
   = prim::Loop(%15, %12)
    block0(%17 : int):
      %18 : Tensor = aten::select(%res.5, %13, %17)
      %x.3 : Tensor = aten::add(%18, %b1.1, %6) # <ipython-input-3-7a3aa41c5a43>:4:8
      %x.5 : Tensor = aten::tanh(%x.3) # <ipython-input-3-7a3aa41c5a43>:5:8
      %21 : Tensor[] = aten::append(%16, %x.5)
      -> (%12)
  %22 : Tensor = aten::stack(%16, %14)
  %23 : int = prim::Constant[value=0]()
  %24 : int? = prim::Constant()
  %140 : int = prim::Constant[value=-1]() # <ipython-input-19-8e17841f14bc>:10:24
  %141 : int = prim::Constant[value=-2]() # <ipython-input-19-8e17841f14bc>:8:24
  %142 : int = prim::Constant[value=9223372036854775807]() # <ipython-input-19-8e17841f14bc>:11:4
  %143 : int = prim::Constant[value=1]() # <ipython-input-20-c5e8db4281bb>:5:35
  %144 : int = prim::Constant[value=0]() # <ipython-input-20-c5e8db4281bb>:15:50
  %145 : None = prim::Constant() # <ipython-input-20-c5e8db4281bb>:15:26
  %146 : bool = prim::Constant[value=0]() # <ipython-input-20-c5e8db4281bb>:15:7
  %147 : bool = aten::__isnot__(%23, %145) # <ipython-input-20-c5e8db4281bb>:15:7
  %148 : bool, %a_batch_dim.15 : int? = prim::If(%147) # <ipython-input-20-c5e8db4281bb>:15:7
    block0():
      %a_batch_dim.4 : int = prim::unchecked_cast(%23)
      %151 : bool = aten::ne(%a_batch_dim.4, %144) # <ipython-input-20-c5e8db4281bb>:15:35
      -> (%151, %a_batch_dim.4)
    block1():
      -> (%146, %23)
  %a_batch_dim : int?, %a.2 : Tensor = prim::If(%148) # <ipython-input-20-c5e8db4281bb>:15:4
    block0():
      %a_batch_dim.8 : int = prim::unchecked_cast(%a_batch_dim.15)
      %155 : bool = aten::lt(%144, %a_batch_dim.8) # <ipython-input-20-c5e8db4281bb>:3:7
      %a.3 : Tensor = prim::If(%155) # <ipython-input-20-c5e8db4281bb>:3:4
        block0():
          %t.4 : Tensor = aten::unsqueeze(%22, %144) # <ipython-input-20-c5e8db4281bb>:4:12
          %158 : int = aten::add(%a_batch_dim.8, %143) # <ipython-input-20-c5e8db4281bb>:5:24
          %t.7 : Tensor = aten::transpose(%t.4, %158, %144) # <ipython-input-20-c5e8db4281bb>:5:12
          %160 : int = aten::add(%a_batch_dim.8, %143) # <ipython-input-20-c5e8db4281bb>:6:22
          %t.10 : Tensor = aten::squeeze(%t.7, %160) # <ipython-input-20-c5e8db4281bb>:6:12
          -> (%t.10)
        block1():
          %162 : bool = aten::gt(%144, %a_batch_dim.8) # <ipython-input-20-c5e8db4281bb>:7:9
          %t.29 : Tensor = prim::If(%162) # <ipython-input-20-c5e8db4281bb>:7:4
            block0():
              %t.13 : Tensor = aten::unsqueeze(%22, %143) # <ipython-input-20-c5e8db4281bb>:8:12
              %t.16 : Tensor = aten::transpose(%t.13, %a_batch_dim.8, %143) # <ipython-input-20-c5e8db4281bb>:9:12
              %t.19 : Tensor = aten::squeeze(%t.16, %a_batch_dim.8) # <ipython-input-20-c5e8db4281bb>:10:12
              -> (%t.19)
            block1():
              -> (%22)
          -> (%t.29)
      -> (%a_batch_dim.8, %a.3)
    block1():
      -> (%a_batch_dim.15, %22)
  %167 : bool = aten::__isnot__(%24, %145) # <ipython-input-20-c5e8db4281bb>:17:7
  %168 : bool, %b_batch_dim.15 : int? = prim::If(%167) # <ipython-input-20-c5e8db4281bb>:17:7
    block0():
      %b_batch_dim.4 : int = prim::unchecked_cast(%24)
      %171 : bool = aten::ne(%b_batch_dim.4, %144) # <ipython-input-20-c5e8db4281bb>:17:35
      -> (%171, %b_batch_dim.4)
    block1():
      -> (%146, %24)
  %b_batch_dim : int?, %b.2 : Tensor = prim::If(%168) # <ipython-input-20-c5e8db4281bb>:17:4
    block0():
      %b_batch_dim.8 : int = prim::unchecked_cast(%b_batch_dim.15)
      %175 : bool = aten::lt(%144, %b_batch_dim.8) # <ipython-input-20-c5e8db4281bb>:3:7
      %b.3 : Tensor = prim::If(%175) # <ipython-input-20-c5e8db4281bb>:3:4
        block0():
          %t.3 : Tensor = aten::unsqueeze(%w2.3, %144) # <ipython-input-20-c5e8db4281bb>:4:12
          %178 : int = aten::add(%b_batch_dim.8, %143) # <ipython-input-20-c5e8db4281bb>:5:24
          %t.6 : Tensor = aten::transpose(%t.3, %178, %144) # <ipython-input-20-c5e8db4281bb>:5:12
          %180 : int = aten::add(%b_batch_dim.8, %143) # <ipython-input-20-c5e8db4281bb>:6:22
          %t.9 : Tensor = aten::squeeze(%t.6, %180) # <ipython-input-20-c5e8db4281bb>:6:12
          -> (%t.9)
        block1():
          %182 : bool = aten::gt(%144, %b_batch_dim.8) # <ipython-input-20-c5e8db4281bb>:7:9
          %t.28 : Tensor = prim::If(%182) # <ipython-input-20-c5e8db4281bb>:7:4
            block0():
              %t.12 : Tensor = aten::unsqueeze(%w2.3, %143) # <ipython-input-20-c5e8db4281bb>:8:12
              %t.15 : Tensor = aten::transpose(%t.12, %b_batch_dim.8, %143) # <ipython-input-20-c5e8db4281bb>:9:12
              %t.18 : Tensor = aten::squeeze(%t.15, %b_batch_dim.8) # <ipython-input-20-c5e8db4281bb>:10:12
              -> (%t.18)
            block1():
              -> (%w2.3)
          -> (%t.28)
      -> (%b_batch_dim.8, %b.3)
    block1():
      -> (%b_batch_dim.15, %w2.3)
  %187 : bool = aten::__isnot__(%a_batch_dim, %145) # <ipython-input-20-c5e8db4281bb>:19:37
  %188 : bool = aten::__isnot__(%b_batch_dim, %145) # <ipython-input-20-c5e8db4281bb>:19:62
  %189 : int = aten::dim(%a.2) # <ipython-input-19-8e17841f14bc>:5:18
  %190 : int = aten::Int(%187) # <ipython-input-19-8e17841f14bc>:5:33
  %191 : int = aten::add(%143, %190) # <ipython-input-19-8e17841f14bc>:5:29
  %unsqueeze_a.1 : bool = aten::eq(%189, %191) # <ipython-input-19-8e17841f14bc>:5:18
  %193 : int = aten::dim(%b.2) # <ipython-input-19-8e17841f14bc>:6:18
  %194 : int = aten::Int(%188) # <ipython-input-19-8e17841f14bc>:6:33
  %195 : int = aten::add(%143, %194) # <ipython-input-19-8e17841f14bc>:6:29
  %unsqueeze_b.1 : bool = aten::eq(%193, %195) # <ipython-input-19-8e17841f14bc>:6:18
  %a.15 : Tensor = prim::If(%unsqueeze_a.1) # <ipython-input-19-8e17841f14bc>:7:4
    block0():
      %a.4 : Tensor = aten::unsqueeze(%a.2, %141) # <ipython-input-19-8e17841f14bc>:8:12
      -> (%a.4)
    block1():
      -> (%a.2)
  %b.15 : Tensor = prim::If(%unsqueeze_b.1) # <ipython-input-19-8e17841f14bc>:9:4
    block0():
      %b.4 : Tensor = aten::unsqueeze(%b.2, %140) # <ipython-input-19-8e17841f14bc>:10:12
      -> (%b.4)
    block1():
      -> (%b.2)
  %201 : int = aten::dim(%a.15) # <ipython-input-19-8e17841f14bc>:11:10
  %202 : int = aten::Int(%188) # <ipython-input-19-8e17841f14bc>:11:20
  %203 : int = aten::add(%201, %202) # <ipython-input-19-8e17841f14bc>:11:10
  %204 : int = aten::dim(%b.15) # <ipython-input-19-8e17841f14bc>:11:38
  %205 : int = aten::Int(%187) # <ipython-input-19-8e17841f14bc>:11:48
  %206 : int = aten::add(%204, %205) # <ipython-input-19-8e17841f14bc>:11:38
  %207 : bool = aten::lt(%203, %206) # <ipython-input-19-8e17841f14bc>:11:10
  %a : Tensor = prim::Loop(%142, %207, %a.15) # <ipython-input-19-8e17841f14bc>:11:4
    block0(%209 : int, %a.16 : Tensor):
      %211 : int = aten::Int(%187) # <ipython-input-19-8e17841f14bc>:12:24
      %a.11 : Tensor = aten::unsqueeze(%a.16, %211) # <ipython-input-19-8e17841f14bc>:12:12
      %213 : int = aten::dim(%a.11) # <ipython-input-19-8e17841f14bc>:11:10
      %214 : int = aten::Int(%188) # <ipython-input-19-8e17841f14bc>:11:20
      %215 : int = aten::add(%213, %214) # <ipython-input-19-8e17841f14bc>:11:10
      %216 : int = aten::dim(%b.15) # <ipython-input-19-8e17841f14bc>:11:38
      %217 : int = aten::Int(%187) # <ipython-input-19-8e17841f14bc>:11:48
      %218 : int = aten::add(%216, %217) # <ipython-input-19-8e17841f14bc>:11:38
      %219 : bool = aten::lt(%215, %218) # <ipython-input-19-8e17841f14bc>:11:10
      -> (%219, %a.11)
  %220 : int = aten::dim(%a) # <ipython-input-19-8e17841f14bc>:13:10
  %221 : int = aten::Int(%188) # <ipython-input-19-8e17841f14bc>:13:20
  %222 : int = aten::add(%220, %221) # <ipython-input-19-8e17841f14bc>:13:10
  %223 : int = aten::dim(%b.15) # <ipython-input-19-8e17841f14bc>:13:38
  %224 : int = aten::Int(%187) # <ipython-input-19-8e17841f14bc>:13:48
  %225 : int = aten::add(%223, %224) # <ipython-input-19-8e17841f14bc>:13:38
  %226 : bool = aten::gt(%222, %225) # <ipython-input-19-8e17841f14bc>:13:10
  %b : Tensor = prim::Loop(%142, %226, %b.15) # <ipython-input-19-8e17841f14bc>:13:4
    block0(%228 : int, %b.16 : Tensor):
      %230 : int = aten::Int(%188) # <ipython-input-19-8e17841f14bc>:14:24
      %b.12 : Tensor = aten::unsqueeze(%b.16, %230) # <ipython-input-19-8e17841f14bc>:14:12
      %232 : int = aten::dim(%a) # <ipython-input-19-8e17841f14bc>:13:10
      %233 : int = aten::Int(%188) # <ipython-input-19-8e17841f14bc>:13:20
      %234 : int = aten::add(%232, %233) # <ipython-input-19-8e17841f14bc>:13:10
      %235 : int = aten::dim(%b.12) # <ipython-input-19-8e17841f14bc>:13:38
      %236 : int = aten::Int(%187) # <ipython-input-19-8e17841f14bc>:13:48
      %237 : int = aten::add(%235, %236) # <ipython-input-19-8e17841f14bc>:13:38
      %238 : bool = aten::gt(%234, %237) # <ipython-input-19-8e17841f14bc>:13:10
      -> (%238, %b.12)
  %res.1 : Tensor = aten::matmul(%a, %b) # <ipython-input-19-8e17841f14bc>:15:10
  %res.15 : Tensor = prim::If(%unsqueeze_a.1) # <ipython-input-19-8e17841f14bc>:16:4
    block0():
      %res.3 : Tensor = aten::squeeze(%res.1, %141) # <ipython-input-19-8e17841f14bc>:17:14
      -> (%res.3)
    block1():
      -> (%res.1)
  %res : Tensor = prim::If(%unsqueeze_b.1) # <ipython-input-19-8e17841f14bc>:18:4
    block0():
      %res.9 : Tensor = aten::squeeze(%res.15, %140) # <ipython-input-19-8e17841f14bc>:19:14
      -> (%res.9)
    block1():
      -> (%res.15)
  %26 : bool = prim::Constant[value=1]()
  %27 : int = prim::Constant[value=0]()
  %28 : int = prim::Constant[value=0]()
  %29 : int = aten::size(%res, %27)
  %30 : Tensor[] = prim::ListConstruct()
   = prim::Loop(%29, %26)
    block0(%31 : int):
      %32 : Tensor = aten::select(%res, %27, %31)
      %x.7 : Tensor = aten::add(%32, %b2.1, %6) # <ipython-input-3-7a3aa41c5a43>:7:8
      %34 : Tensor[] = aten::append(%30, %x.7)
      -> (%26)
  %35 : Tensor = aten::stack(%30, %28)
  return (%35)

So here we end our experiment (if you really wanted - and used my experimental branch - you could also implement pulling pointwise ops like tanh out of the graph or experiment with pairwise ops like addition).

Conclusion

I think we have gone quite far with an overall reasonable amount of work.

Our main goal was to get a feel of where we stand (after a few experimental patches, notably my PR 49969 regarding safety of blocks, nodes and values and some API expansion) in terms of graph manipulation from Python. It works, but there are quite a few rough corners.

  • Convenience functions - we saw some, but more would probably be better.
  • A more principled way around generating short bits of code to insert into the graph.
  • A better way to "find patterns", e.g. TVM has a nifty pattern matching facility for transformations of the computational graph and we might learn from that.

I should say that the vmap used as an example here is far from complete: - obviously, the operator-specific parts are missing, - while I'd be hopeful that the vmapper::vmap blocks could be nested, the matmul optimization would probably not be able to reasonably stack dimensions, - one big thing one would want is to include gradient computation and get a batch of gradients. But this currently leaves the JIT for anything not handled by the JIT's own differentiation mechanism, autodiff. My view is (and I had this as a "lofty idea" in the blog post on ScriptTorch) that we should see to have source-to-source derivatives for all torch operations (aten:: nodes), and possibly even for simple (if) control flow. Then it would be easy to do differentiation first and then vmap and get gradients. There are some obstacles to this (like traced functions in ATen that don't show up in derivatives.yaml), but we also have some ideas how to deal with them.

So I hope you enjoyed this little demo of what is possible (almost) today. With any bit of luck you got some appetite for better developer tooling inside the JIT.

As with my other ScriptTorch work, this is to facilitate discussion. Please do reach out on slack or via mail tv@lernapparat.de with your comments and ideas.