Counting nanoGPT FLOPs in PyTorch
October 26, 2023
Recently PyTorch added the export
function which AOT compiles your models. The resulting output is a static analysis of how data moves through the ATen IR.
For this module:
class M(torch.nn.Module): def __init__(self, d_in, d_out): super(M, self).__init__() self.layer = nn.Linear(d_in, d_out) def forward(self, x): return self.layer(x) arg = torch.randn(2, 10) model = M(10, 20) exported = export(model, args=(arg,)) print(exported)
The output is:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[20, 10], arg1_1: f32[20], arg2_1: f32[2, 10]): # permute: f32[10, 20] = torch.ops.aten.permute.default(arg0_1, [1, 0]); arg0_1 = None addmm: f32[2, 20] = torch.ops.aten.addmm.default(arg1_1, arg2_1, permute); arg1_1 = arg2_1 = permute = None return (addmm,) Graph Signature: ExportGraphSignature(parameters=['L__self___layer.weight', 'L__self___layer.bias'], buffers=[], user_inputs=['arg2_1'], user_outputs=['addmm'], inputs_to_parameters={'arg0_1': 'L__self___layer.weight', 'arg1_1': 'L__self___layer.bias'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Symbol to range: {}
We know exactly how data is moving through the model, meaning we can precisely allocate memory, etc.
Notice how the shapes are tracked as well. This is crucial for counting FLOPs. FLOPs is useful in determining hardware utilization.
Counting FLOPs
The API is pretty good but it will take a fair bit a meddling with nodes and referencing the ATen instruction set to get the hang of things.
Ok, so what do we want to do? Loop over the nodes and for certain nodes count the flops in that node. Which nodes do we care about? Anything that involves matrix/tensor multiplication. There are other operations such as scalar multiplication and addition, vector addition, softmax, etc. But the bulk of FLOPs will come from matrix multiplications so we'll focus on those.
Looking at the ATen IR:
- aten.mm : mm(Tensor self, Tensor mat2) -> Tensor
- aten.addmm : addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
- aten.bmm : bmm(Tensor self, Tensor mat2) -> Tensor
are the instructions we are about.
NOTE: the FLOPs for a matrix multiplication t1 (mxk) and t2 (kxn) are roughly 2 _ m _ k * n.
Now we just need to get these values from the nodes in the graph.
for n in exported.graph.nodes: print(n)
arg0_1 arg1_1 arg2_1 permute addmm output
The nodes are of type fx.Node
. We need to get the inputs for the addmm
instruction. We can access these through
the all_input_nodes
attribute.
for n in exported.graph.nodes: print(n, n.all_input_nodes)
arg0_1 [] arg1_1 [] arg2_1 [] permute [arg0_1] addmm [arg1_1, arg2_1, permute] output [addmm]
[arg1_1, arg2_1, permute]
roughly matches the signature above. Let's go back to the exported program output.
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[20, 10], arg1_1: f32[20], arg2_1: f32[2, 10]): # permute: f32[10, 20] = torch.ops.aten.permute.default(arg0_1, [1, 0]); arg0_1 = None addmm: f32[2, 20] = torch.ops.aten.addmm.default(arg1_1, arg2_1, permute); arg1_1 = arg2_1 = permute = None return (addmm,)
arg1_1 maps to f32[20]. So that's the bias. We can ignore that argument. arg2_1 is f32[2, 10] which maps to t1. permute is f32[10, 20] which maps to t2. We now know which input nodes we will be using. The next step is to retrieve the dimensions of the tensors.
This part is where the API gets a bit odd. AFAIK the way to determine this is the meta
attribute of a node. Specifically meta['val']
, which returns a FakeTensor.
for n in exported.graph.nodes: print(n, n.all_input_nodes) for a in n.all_input_nodes: print('fake tensor', a, a.meta['val'])
arg0_1 [] arg1_1 [] arg2_1 [] permute [arg0_1] fake tensor arg0_1 FakeTensor(..., size=(20, 10)) addmm [arg1_1, arg2_1, permute] fake tensor arg1_1 FakeTensor(..., size=(20,)) fake tensor arg2_1 FakeTensor(..., size=(2, 10)) fake tensor permute FakeTensor(..., size=(10, 20)) output [addmm] fake tensor addmm FakeTensor(..., size=(2, 20))
meta['val'].size()
is what we want.
Let's implement the flop count for the addmm instruction.
def count_addmm_flops(node): # ignore bias t1 = node.all_input_nodes[1].meta['val'] t2 = node.all_input_nodes[2].meta['val'] m, k = t1.size() n = t2.size(1) return 2 * m * k * n nodes = list(exported.graph.nodes) addmm_node = nodes[-2] print(addmm_node) print('addmm flops', count_addmm_flops(addmm_node))
Getting the node when looping is a bit tricky. Using node.name
isn't specific enough. The correct way to do this is node.target.name()
. But target
isn't always an object, sometimes it's a string. So we need to add a check for isinstance
.
addmm_node.target.name() # output is 'aten::addmm' total_flops = 0 for n in exported.graph.nodes: if isinstance(n.target, torch._ops.OpOverload): if n.target.name() == 'aten::addmm': total_flops += count_addmm_flops(n) print('total flops', total_flops) # output is 800
Now let's implement flop counts for the other instructions.
def count_bmm_flops(node): t1 = node.all_input_nodes[0].meta['val'] t2 = node.all_input_nodes[1].meta['val'] batch_size, m, k = t1.size() n = t2.size(2) return 2 * batch_size * m * k * n def count_mm_flops(node): t1 = node.all_input_nodes[0].meta['val'] t2 = node.all_input_nodes[1].meta['val'] m, k = t1.size() n = t2.size(1) return 2 * m * k * n total_flops = 0 for n in exported.graph.nodes: if isinstance(n.target, torch._ops.OpOverload): if n.target.name() == 'aten::addmm': total_flops += count_addmm_flops(n) elif n.target.name() == 'aten::mm':
Our current graph doesn't call these ops so let's adjust that.
class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.layer = nn.Linear(10, 20) self.param = nn.Parameter(torch.zeros(20, 20)) self.param2 = nn.Parameter(torch.zeros(2, 10, 2)) def forward(self, x): x = self.layer(x) x = x @ self.param x = x.view(2, 2, 10) x = x @ self.param2 return x.view(2, -1) arg = torch.randn(2, 10) model = M() exported = export(model, args=(arg,)) print(exported)
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[20, 20], arg1_1: f32[2, 10, 2], arg2_1: f32[20, 10], arg3_1: f32[20], arg4_1: f32[2, 10]): # permute: f32[10, 20] = torch.ops.aten.permute.default(arg2_1, [1, 0]); arg2_1 = None addmm: f32[2, 20] = torch.ops.aten.addmm.default(arg3_1, arg4_1, permute); arg3_1 = arg4_1 = permute = None mm: f32[2, 20] = torch.ops.aten.mm.default(addmm, arg0_1); addmm = arg0_1 = None view: f32[2, 2, 10] = torch.ops.aten.view.default(mm, [2, 2, 10]); mm = None expand: f32[2, 2, 10] = torch.ops.aten.expand.default(view, [2, 2, 10]); view = None view_1: f32[2, 2, 10] = torch.ops.aten.view.default(expand, [2, 2, 10]); expand = None expand_1: f32[2, 10, 2] = torch.ops.aten.expand.default(arg1_1, [2, 10, 2]); arg1_1 = None view_2: f32[2, 10, 2] = torch.ops.aten.view.default(expand_1, [2, 10, 2]); expand_1 = None bmm: f32[2, 2, 2] = torch.ops.aten.bmm.default(view_1, view_2); view_1 = view_2 = None view_3: f32[2, 2, 2] = torch.ops.aten.view.default(bmm, [2, 2, 2]); bmm = None view_4: f32[2, 4] = torch.ops.aten.view.default(view_3, [2, -1]); view_3 = None return (view_4,)
Let's turn our flop counting into an extensible function before we go further.
from collections import defaultdict def flops_report(exported_program): flops = 0 flops_by_type = defaultdict(int) # Store FLOPs for each type of operation print('Counting FLOPs for the exported graph...') op_to_count_func = { "aten::bmm":count_bmm_flops, "aten::addmm":count_addmm_flops, "aten::mm":count_mm_flops, } print(f"Number of nodes in the exported graph: {len(exported_program.graph.nodes)}") print(f"Tracking operations {list(op_to_count_func.keys())}") all_ops = set() for n in exported_program.graph.nodes:
Now let's call it on our model.
>>> flops_report(exported) Counting FLOPs for the exported graph... Number of nodes in the exported graph: 17 Tracking operations ['aten::bmm', 'aten::addmm', 'aten::mm'] All operations seen: ['aten::bmm', 'aten::addmm', 'aten::view', 'aten::permute', 'aten::expand', 'aten::mm'] FLOPs by Operation Type: aten::addmm: 800 FLOPs (31.2500%) aten::mm: 1600 FLOPs (62.5000%) aten::bmm: 160 FLOPs (6.2500%) Total FLOPs: 2560
Ok! This looks good. Now let's use it to estimate FLOPs for nanoGPT. We'll add most of the model.py
file to our notebook.
Luckily we have a reference FLOPs Karpathy already calculate by hand we can use for comparison.
name flops ratio (%) attention/kqv 3623878656 1.2426 attention/scores 1610612736 0.5522 attention/reduce 1610612736 0.5522 attention/proj 1207959552 0.4142 attention 8053063680 2.7612 mlp/ffw1 4831838208 1.6567 mlp/ffw2 4831838208 1.6567 mlp 9663676416 3.3135 block 17716740096 6.0747 transformer 212600881152 72.8963 dense 79047426048 27.1037 forward_total 291648307200 100.0000 backward_total 583296614400 200.0000 total 874944921600 300.0000
block, transformer, and dense are the ones we care about.
- block is FLOPs for a single transformer block.
- transformer is FLOPS of block * number of blocks
- dense is the linear transform back to the vocab embedding. Note that this calculation assumes every time dimension is used (training). For inference this value would be 77266944 instead.
212600881152 + 77266944 = 212678148096 (this is the value we are aiming for).
config = GPTConfig() print(config) # GPTConfig(block_size=1024, vocab_size=50304, n_layer=12, n_head=12, n_embd=768, dropout=0.0, bias=True) gpt = GPT(config) print(gpt) # GPT( # (transformer): ModuleDict( # (wte): Embedding(50304, 768) # (wpe): Embedding(1024, 768) # (drop): Dropout(p=0.0, inplace=False) # (h): ModuleList( # (0-11): 12 x Block( # (ln_1): LayerNorm() # (attn): CausalSelfAttention( # (c_attn): Linear(in_features=768, out_features=2304, bias=True) # (c_proj): Linear(in_features=768, out_features=768, bias=True) # (attn_dropout): Dropout(p=0.0, inplace=False)
The output of exported
is super long so I'm not going to show that here.
flops_report(exported)
Counting FLOPs for the exported graph... Number of nodes in the exported graph: 887 Tracking operations ['aten::bmm', 'aten::addmm', 'aten::mm'] All operations seen: ['aten::add.Tensor', 'aten::addmm', 'aten::where.self', 'aten::view', 'aten::permute', 'aten::lift_fresh_copy', 'aten::embedding', 'aten::clone', 'aten::bmm', 'aten::split.Tensor', 'aten::mul.Tensor', 'aten::slice.Tensor', 'aten::mm', 'aten::eq.Scalar', 'aten::gelu', 'aten::_softmax', 'aten::arange.start_step', 'aten::expand', 'aten::scalar_tensor', 'aten::index.Tensor', 'aten::native_layer_norm'] FLOPs by Operation Type: aten::addmm: 173946175488 FLOPs (81.7885%) aten::bmm: 38654705664 FLOPs (18.1752%) aten::mm: 77266944 FLOPs (0.0363%) Total FLOPs: 212678148096
This is exactly the value we we're aiming for!