Counting nanoGPT FLOPs in PyTorch

October 26, 2023

Recently PyTorch added the export function which AOT compiles your models. The resulting output is a static analysis of how data moves through the ATen IR.

For this module:

class M(torch.nn.Module):
    def __init__(self, d_in, d_out):
        super(M, self).__init__()
        self.layer = nn.Linear(d_in, d_out)
    def forward(self, x):
        return self.layer(x)

arg = torch.randn(2, 10)
model = M(10, 20)

exported = export(model, args=(arg,))
print(exported)

The output is:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[20, 10], arg1_1: f32[20], arg2_1: f32[2, 10]):
            #
            permute: f32[10, 20] = torch.ops.aten.permute.default(arg0_1, [1, 0]);  arg0_1 = None
            addmm: f32[2, 20] = torch.ops.aten.addmm.default(arg1_1, arg2_1, permute);  arg1_1 = arg2_1 = permute = None
            return (addmm,)

Graph Signature: ExportGraphSignature(parameters=['L__self___layer.weight', 'L__self___layer.bias'], buffers=[], user_inputs=['arg2_1'], user_outputs=['addmm'], inputs_to_parameters={'arg0_1': 'L__self___layer.weight', 'arg1_1': 'L__self___layer.bias'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

We know exactly how data is moving through the model, meaning we can precisely allocate memory, etc.

Notice how the shapes are tracked as well. This is crucial for counting FLOPs. FLOPs is useful in determining hardware utilization.

Counting FLOPs

The API is pretty good but it will take a fair bit a meddling with nodes and referencing the ATen instruction set to get the hang of things.

Ok, so what do we want to do? Loop over the nodes and for certain nodes count the flops in that node. Which nodes do we care about? Anything that involves matrix/tensor multiplication. There are other operations such as scalar multiplication and addition, vector addition, softmax, etc. But the bulk of FLOPs will come from matrix multiplications so we'll focus on those.

Looking at the ATen IR:

  • aten.mm : mm(Tensor self, Tensor mat2) -> Tensor
  • aten.addmm : addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
  • aten.bmm : bmm(Tensor self, Tensor mat2) -> Tensor

are the instructions we are about.

NOTE: the FLOPs for a matrix multiplication t1 (mxk) and t2 (kxn) are roughly 2 _ m _ k * n.

Now we just need to get these values from the nodes in the graph.

for n in exported.graph.nodes:
    print(n)
arg0_1
arg1_1
arg2_1
permute
addmm
output

The nodes are of type fx.Node. We need to get the inputs for the addmm instruction. We can access these through the all_input_nodes attribute.

for n in exported.graph.nodes:
    print(n, n.all_input_nodes)
arg0_1 []
arg1_1 []
arg2_1 []
permute [arg0_1]
addmm [arg1_1, arg2_1, permute]
output [addmm]

[arg1_1, arg2_1, permute] roughly matches the signature above. Let's go back to the exported program output.

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[20, 10], arg1_1: f32[20], arg2_1: f32[2, 10]):
            #
            permute: f32[10, 20] = torch.ops.aten.permute.default(arg0_1, [1, 0]);  arg0_1 = None
            addmm: f32[2, 20] = torch.ops.aten.addmm.default(arg1_1, arg2_1, permute);  arg1_1 = arg2_1 = permute = None
            return (addmm,)

arg1_1 maps to f32[20]. So that's the bias. We can ignore that argument. arg2_1 is f32[2, 10] which maps to t1. permute is f32[10, 20] which maps to t2. We now know which input nodes we will be using. The next step is to retrieve the dimensions of the tensors.

This part is where the API gets a bit odd. AFAIK the way to determine this is the meta attribute of a node. Specifically meta['val'], which returns a FakeTensor.

for n in exported.graph.nodes:
    print(n, n.all_input_nodes)
    for a in n.all_input_nodes:
        print('fake tensor', a, a.meta['val'])
arg0_1 []
arg1_1 []
arg2_1 []
permute [arg0_1]
fake tensor arg0_1 FakeTensor(..., size=(20, 10))
addmm [arg1_1, arg2_1, permute]
fake tensor arg1_1 FakeTensor(..., size=(20,))
fake tensor arg2_1 FakeTensor(..., size=(2, 10))
fake tensor permute FakeTensor(..., size=(10, 20))
output [addmm]
fake tensor addmm FakeTensor(..., size=(2, 20))

meta['val'].size() is what we want.

Let's implement the flop count for the addmm instruction.

def count_addmm_flops(node):
    # ignore bias
    t1 = node.all_input_nodes[1].meta['val']
    t2 = node.all_input_nodes[2].meta['val']
    m, k = t1.size()
    n = t2.size(1)
    return 2 * m * k * n

nodes = list(exported.graph.nodes)
addmm_node = nodes[-2]
print(addmm_node)

print('addmm flops', count_addmm_flops(addmm_node))

Getting the node when looping is a bit tricky. Using node.name isn't specific enough. The correct way to do this is node.target.name(). But target isn't always an object, sometimes it's a string. So we need to add a check for isinstance.

addmm_node.target.name()
# output is 'aten::addmm'

total_flops = 0
for n in exported.graph.nodes:
    if isinstance(n.target, torch._ops.OpOverload):
        if n.target.name() == 'aten::addmm':
            total_flops += count_addmm_flops(n)
print('total flops', total_flops)
# output is 800

Now let's implement flop counts for the other instructions.

def count_bmm_flops(node):
    t1 = node.all_input_nodes[0].meta['val']
    t2 = node.all_input_nodes[1].meta['val']
    batch_size, m, k = t1.size()
    n = t2.size(2)
    return 2 * batch_size * m * k * n

def count_mm_flops(node):
    t1 = node.all_input_nodes[0].meta['val']
    t2 = node.all_input_nodes[1].meta['val']
    m, k = t1.size()
    n = t2.size(1)
    return 2 * m * k * n

total_flops = 0
for n in exported.graph.nodes:
    if isinstance(n.target, torch._ops.OpOverload):
        if n.target.name() == 'aten::addmm':
            total_flops += count_addmm_flops(n)
        elif n.target.name() == 'aten::mm':

Our current graph doesn't call these ops so let's adjust that.

class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.layer = nn.Linear(10, 20)
        self.param = nn.Parameter(torch.zeros(20, 20))
        self.param2 = nn.Parameter(torch.zeros(2, 10, 2))
    def forward(self, x):
        x = self.layer(x)
        x = x @ self.param
        x = x.view(2, 2, 10)
        x = x @ self.param2
        return x.view(2, -1)

arg = torch.randn(2, 10)
model = M()

exported = export(model, args=(arg,))
print(exported)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[20, 20], arg1_1: f32[2, 10, 2], arg2_1: f32[20, 10], arg3_1: f32[20], arg4_1: f32[2, 10]):
            #
            permute: f32[10, 20] = torch.ops.aten.permute.default(arg2_1, [1, 0]);  arg2_1 = None
            addmm: f32[2, 20] = torch.ops.aten.addmm.default(arg3_1, arg4_1, permute);  arg3_1 = arg4_1 = permute = None
            mm: f32[2, 20] = torch.ops.aten.mm.default(addmm, arg0_1);  addmm = arg0_1 = None
            view: f32[2, 2, 10] = torch.ops.aten.view.default(mm, [2, 2, 10]);  mm = None
            expand: f32[2, 2, 10] = torch.ops.aten.expand.default(view, [2, 2, 10]);  view = None
            view_1: f32[2, 2, 10] = torch.ops.aten.view.default(expand, [2, 2, 10]);  expand = None
            expand_1: f32[2, 10, 2] = torch.ops.aten.expand.default(arg1_1, [2, 10, 2]);  arg1_1 = None
            view_2: f32[2, 10, 2] = torch.ops.aten.view.default(expand_1, [2, 10, 2]);  expand_1 = None
            bmm: f32[2, 2, 2] = torch.ops.aten.bmm.default(view_1, view_2);  view_1 = view_2 = None
            view_3: f32[2, 2, 2] = torch.ops.aten.view.default(bmm, [2, 2, 2]);  bmm = None
            view_4: f32[2, 4] = torch.ops.aten.view.default(view_3, [2, -1]);  view_3 = None
            return (view_4,)

Let's turn our flop counting into an extensible function before we go further.

from collections import defaultdict

def flops_report(exported_program):
    flops = 0
    flops_by_type = defaultdict(int)  # Store FLOPs for each type of operation

    print('Counting FLOPs for the exported graph...')

    op_to_count_func = {
        "aten::bmm":count_bmm_flops,
        "aten::addmm":count_addmm_flops,
        "aten::mm":count_mm_flops,
    }

    print(f"Number of nodes in the exported graph: {len(exported_program.graph.nodes)}")
    print(f"Tracking operations {list(op_to_count_func.keys())}")

    all_ops = set()

    for n in exported_program.graph.nodes:

Now let's call it on our model.

>>> flops_report(exported)

Counting FLOPs for the exported graph...
Number of nodes in the exported graph: 17
Tracking operations ['aten::bmm', 'aten::addmm', 'aten::mm']
All operations seen: ['aten::bmm', 'aten::addmm', 'aten::view', 'aten::permute', 'aten::expand', 'aten::mm']

FLOPs by Operation Type:
aten::addmm: 800 FLOPs (31.2500%)
aten::mm: 1600 FLOPs (62.5000%)
aten::bmm: 160 FLOPs (6.2500%)

Total FLOPs: 2560

Ok! This looks good. Now let's use it to estimate FLOPs for nanoGPT. We'll add most of the model.py file to our notebook.

Luckily we have a reference FLOPs Karpathy already calculate by hand we can use for comparison.

name                 flops          ratio (%)
attention/kqv            3623878656     1.2426
attention/scores         1610612736     0.5522
attention/reduce         1610612736     0.5522
attention/proj           1207959552     0.4142
attention                8053063680     2.7612
mlp/ffw1                 4831838208     1.6567
mlp/ffw2                 4831838208     1.6567
mlp                      9663676416     3.3135
block                   17716740096     6.0747
transformer            212600881152    72.8963
dense                   79047426048    27.1037
forward_total          291648307200   100.0000
backward_total         583296614400   200.0000
total                  874944921600   300.0000

block, transformer, and dense are the ones we care about.

  • block is FLOPs for a single transformer block.
  • transformer is FLOPS of block * number of blocks
  • dense is the linear transform back to the vocab embedding. Note that this calculation assumes every time dimension is used (training). For inference this value would be 77266944 instead.

212600881152 + 77266944 = 212678148096 (this is the value we are aiming for).

config = GPTConfig()
print(config)

# GPTConfig(block_size=1024, vocab_size=50304, n_layer=12, n_head=12, n_embd=768, dropout=0.0, bias=True)

gpt = GPT(config)
print(gpt)

# GPT(
#   (transformer): ModuleDict(
#     (wte): Embedding(50304, 768)
#     (wpe): Embedding(1024, 768)
#     (drop): Dropout(p=0.0, inplace=False)
#     (h): ModuleList(
#       (0-11): 12 x Block(
#         (ln_1): LayerNorm()
#         (attn): CausalSelfAttention(
#           (c_attn): Linear(in_features=768, out_features=2304, bias=True)
#           (c_proj): Linear(in_features=768, out_features=768, bias=True)
#           (attn_dropout): Dropout(p=0.0, inplace=False)

The output of exported is super long so I'm not going to show that here.

flops_report(exported)
Counting FLOPs for the exported graph...
Number of nodes in the exported graph: 887
Tracking operations ['aten::bmm', 'aten::addmm', 'aten::mm']
All operations seen: ['aten::add.Tensor', 'aten::addmm', 'aten::where.self', 'aten::view', 'aten::permute', 'aten::lift_fresh_copy', 'aten::embedding', 'aten::clone', 'aten::bmm', 'aten::split.Tensor', 'aten::mul.Tensor', 'aten::slice.Tensor', 'aten::mm', 'aten::eq.Scalar', 'aten::gelu', 'aten::_softmax', 'aten::arange.start_step', 'aten::expand', 'aten::scalar_tensor', 'aten::index.Tensor', 'aten::native_layer_norm']

FLOPs by Operation Type:
aten::addmm: 173946175488 FLOPs (81.7885%)
aten::bmm: 38654705664 FLOPs (18.1752%)
aten::mm: 77266944 FLOPs (0.0363%)

Total FLOPs: 212678148096

This is exactly the value we we're aiming for!