Counting nanoGPT FLOPs in PyTorch
October 26, 2023
Contents
Recently PyTorch added the export function which AOT compiles your models. The resulting output is a static analysis of how data moves through the ATen IR.
For this module:
CODE BLOCKclass M(torch.nn.Module): def __init__(self, d_in, d_out): super(M, self).__init__() self.layer = nn.Linear(d_in, d_out) def forward(self, x): return self.layer(x) arg = torch.randn(2, 10) model = M(10, 20) exported = export(model, args=(arg,)) print(exported)
The output is:
CODE BLOCKExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[20, 10], arg1_1: f32[20], arg2_1: f32[2, 10]): # permute: f32[10, 20] = torch.ops.aten.permute.default(arg0_1, [1, 0]); arg0_1 = None addmm: f32[2, 20] = torch.ops.aten.addmm.default(arg1_1, arg2_1, permute); arg1_1 = arg2_1 = permute = None return (addmm,) Graph Signature: ExportGraphSignature(parameters=['L__self___layer.weight', 'L__self___layer.bias'], buffers=[], user_inputs=['arg2_1'], user_outputs=['addmm'], inputs_to_parameters={'arg0_1': 'L__self___layer.weight', 'arg1_1': 'L__self___layer.bias'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Symbol to range: {}
We know exactly how data is moving through the model, meaning we can precisely allocate memory, etc.
Notice how the shapes are tracked as well. This is crucial for counting FLOPs. FLOPs is useful in determining hardware utilization.
Counting FLOPs
The API is pretty good but it will take a fair bit a meddling with nodes and referencing the ATen instruction set to get the hang of things.
Ok, so what do we want to do? Loop over the nodes and for certain nodes count the flops in that node. Which nodes do we care about? Anything that involves matrix/tensor multiplication. There are other operations such as scalar multiplication and addition, vector addition, softmax, etc. But the bulk of FLOPs will come from matrix multiplications so we'll focus on those.
Looking at the ATen IR:
- aten.mm : mm(Tensor self, Tensor mat2) -> Tensor
- aten.addmm : addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
- aten.bmm : bmm(Tensor self, Tensor mat2) -> Tensor
are the instructions we are about.
NOTE: the FLOPs for a matrix multiplication t1 (mxk) and t2 (kxn) are roughly 2 _ m _ k * n.
Now we just need to get these values from the nodes in the graph.
CODE BLOCKfor n in exported.graph.nodes: print(n)
CODE BLOCKarg0_1 arg1_1 arg2_1 permute addmm output
The nodes are of type fx.Node. We need to get the inputs for the addmm instruction. We can access these through
the all_input_nodes attribute.
CODE BLOCKfor n in exported.graph.nodes: print(n, n.all_input_nodes)
CODE BLOCKarg0_1 [] arg1_1 [] arg2_1 [] permute [arg0_1] addmm [arg1_1, arg2_1, permute] output [addmm]
[arg1_1, arg2_1, permute] roughly matches the signature above. Let's go back to the exported program output.
CODE BLOCKExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[20, 10], arg1_1: f32[20], arg2_1: f32[2, 10]): # permute: f32[10, 20] = torch.ops.aten.permute.default(arg0_1, [1, 0]); arg0_1 = None addmm: f32[2, 20] = torch.ops.aten.addmm.default(arg1_1, arg2_1, permute); arg1_1 = arg2_1 = permute = None return (addmm,)
arg1_1 maps to f32[20]. So that's the bias. We can ignore that argument. arg2_1 is f32[2, 10] which maps to t1. permute is f32[10, 20] which maps to t2. We now know which input nodes we will be using. The next step is to retrieve the dimensions of the tensors.
This part is where the API gets a bit odd. AFAIK the way to determine this is the meta attribute of a node. Specifically meta['val'], which returns a FakeTensor.
CODE BLOCKfor n in exported.graph.nodes: print(n, n.all_input_nodes) for a in n.all_input_nodes: print('fake tensor', a, a.meta['val'])
CODE BLOCKarg0_1 [] arg1_1 [] arg2_1 [] permute [arg0_1] fake tensor arg0_1 FakeTensor(..., size=(20, 10)) addmm [arg1_1, arg2_1, permute] fake tensor arg1_1 FakeTensor(..., size=(20,)) fake tensor arg2_1 FakeTensor(..., size=(2, 10)) fake tensor permute FakeTensor(..., size=(10, 20)) output [addmm] fake tensor addmm FakeTensor(..., size=(2, 20))
meta['val'].size() is what we want.
Let's implement the flop count for the addmm instruction.
CODE BLOCKdef count_addmm_flops(node): # ignore bias t1 = node.all_input_nodes[1].meta['val'] t2 = node.all_input_nodes[2].meta['val'] m, k = t1.size() n = t2.size(1) return 2 * m * k * n nodes = list(exported.graph.nodes) addmm_node = nodes[-2] print(addmm_node) print('addmm flops', count_addmm_flops(addmm_node))
Getting the node when looping is a bit tricky. Using node.name isn't specific enough. The correct way to do this is node.target.name(). But target isn't always an object, sometimes it's a string. So we need to add a check for isinstance.
CODE BLOCKaddmm_node.target.name() # output is 'aten::addmm' total_flops = 0 for n in exported.graph.nodes: if isinstance(n.target, torch._ops.OpOverload): if n.target.name() == 'aten::addmm': total_flops += count_addmm_flops(n) print('total flops', total_flops) # output is 800
Now let's implement flop counts for the other instructions.
CODE BLOCKdef count_bmm_flops(node): t1 = node.all_input_nodes[0].meta['val'] t2 = node.all_input_nodes[1].meta['val'] batch_size, m, k = t1.size() n = t2.size(2) return 2 * batch_size * m * k * n def count_mm_flops(node): t1 = node.all_input_nodes[0].meta['val'] t2 = node.all_input_nodes[1].meta['val'] m, k = t1.size() n = t2.size(1) return 2 * m * k * n total_flops = 0 for n in exported.graph.nodes: if isinstance(n.target, torch._ops.OpOverload): if n.target.name() == 'aten::addmm': total_flops += count_addmm_flops(n) elif n.target.name() == 'aten::mm': total_flops += count_mm_flops(n) elif n.target.name() == 'aten::bmm': total_flops += count_bmm_flops(n) print('total flops', total_flops)
Our current graph doesn't call these ops so let's adjust that.
CODE BLOCKclass M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.layer = nn.Linear(10, 20) self.param = nn.Parameter(torch.zeros(20, 20)) self.param2 = nn.Parameter(torch.zeros(2, 10, 2)) def forward(self, x): x = self.layer(x) x = x @ self.param x = x.view(2, 2, 10) x = x @ self.param2 return x.view(2, -1) arg = torch.randn(2, 10) model = M() exported = export(model, args=(arg,)) print(exported)
CODE BLOCKExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[20, 20], arg1_1: f32[2, 10, 2], arg2_1: f32[20, 10], arg3_1: f32[20], arg4_1: f32[2, 10]): # permute: f32[10, 20] = torch.ops.aten.permute.default(arg2_1, [1, 0]); arg2_1 = None addmm: f32[2, 20] = torch.ops.aten.addmm.default(arg3_1, arg4_1, permute); arg3_1 = arg4_1 = permute = None mm: f32[2, 20] = torch.ops.aten.mm.default(addmm, arg0_1); addmm = arg0_1 = None view: f32[2, 2, 10] = torch.ops.aten.view.default(mm, [2, 2, 10]); mm = None expand: f32[2, 2, 10] = torch.ops.aten.expand.default(view, [2, 2, 10]); view = None view_1: f32[2, 2, 10] = torch.ops.aten.view.default(expand, [2, 2, 10]); expand = None expand_1: f32[2, 10, 2] = torch.ops.aten.expand.default(arg1_1, [2, 10, 2]); arg1_1 = None view_2: f32[2, 10, 2] = torch.ops.aten.view.default(expand_1, [2, 10, 2]); expand_1 = None bmm: f32[2, 2, 2] = torch.ops.aten.bmm.default(view_1, view_2); view_1 = view_2 = None view_3: f32[2, 2, 2] = torch.ops.aten.view.default(bmm, [2, 2, 2]); bmm = None view_4: f32[2, 4] = torch.ops.aten.view.default(view_3, [2, -1]); view_3 = None return (view_4,)
Let's turn our flop counting into an extensible function before we go further.
CODE BLOCKfrom collections import defaultdict def flops_report(exported_program): flops = 0 flops_by_type = defaultdict(int) # Store FLOPs for each type of operation print('Counting FLOPs for the exported graph...') op_to_count_func = { "aten::bmm":count_bmm_flops, "aten::addmm":count_addmm_flops, "aten::mm":count_mm_flops, } print(f"Number of nodes in the exported graph: {len(exported_program.graph.nodes)}") print(f"Tracking operations {list(op_to_count_func.keys())}") all_ops = set() for n in exported_program.graph.nodes: if isinstance(n.target, torch._ops.OpOverload): op_type = n.target.name() all_ops.add(op_type) flops_for_node = 0 if op_type in op_to_count_func: flops_for_node = op_to_count_func[op_type](n) flops_by_type[op_type] += flops_for_node flops += flops_for_node print(f"All operations seen: {list(all_ops)}") print("\nFLOPs by Operation Type:") for op_type, flops_for_type in flops_by_type.items(): percentage = (flops_for_type / flops) * 100 if flops != 0 else 0 print(f"{op_type}: {flops_for_type} FLOPs ({percentage:.4f}%)") print(f"\nTotal FLOPs: {flops}") return flops
Now let's call it on our model.
CODE BLOCK>>> flops_report(exported) Counting FLOPs for the exported graph... Number of nodes in the exported graph: 17 Tracking operations ['aten::bmm', 'aten::addmm', 'aten::mm'] All operations seen: ['aten::bmm', 'aten::addmm', 'aten::view', 'aten::permute', 'aten::expand', 'aten::mm'] FLOPs by Operation Type: aten::addmm: 800 FLOPs (31.2500%) aten::mm: 1600 FLOPs (62.5000%) aten::bmm: 160 FLOPs (6.2500%) Total FLOPs: 2560
Ok! This looks good. Now let's use it to estimate FLOPs for nanoGPT. We'll add most of the model.py file to our notebook.
Luckily we have a reference FLOPs Karpathy already calculate by hand we can use for comparison.
CODE BLOCKname flops ratio (%) attention/kqv 3623878656 1.2426 attention/scores 1610612736 0.5522 attention/reduce 1610612736 0.5522 attention/proj 1207959552 0.4142 attention 8053063680 2.7612 mlp/ffw1 4831838208 1.6567 mlp/ffw2 4831838208 1.6567 mlp 9663676416 3.3135 block 17716740096 6.0747 transformer 212600881152 72.8963 dense 79047426048 27.1037 forward_total 291648307200 100.0000 backward_total 583296614400 200.0000 total 874944921600 300.0000
block, transformer, and dense are the ones we care about.
- block is FLOPs for a single transformer block.
- transformer is FLOPS of block * number of blocks
- dense is the linear transform back to the vocab embedding. Note that this calculation assumes every time dimension is used (training). For inference this value would be 77266944 instead.
212600881152 + 77266944 = 212678148096 (this is the value we are aiming for).
CODE BLOCKconfig = GPTConfig() print(config) # GPTConfig(block_size=1024, vocab_size=50304, n_layer=12, n_head=12, n_embd=768, dropout=0.0, bias=True) gpt = GPT(config) print(gpt) # GPT( # (transformer): ModuleDict( # (wte): Embedding(50304, 768) # (wpe): Embedding(1024, 768) # (drop): Dropout(p=0.0, inplace=False) # (h): ModuleList( # (0-11): 12 x Block( # (ln_1): LayerNorm() # (attn): CausalSelfAttention( # (c_attn): Linear(in_features=768, out_features=2304, bias=True) # (c_proj): Linear(in_features=768, out_features=768, bias=True) # (attn_dropout): Dropout(p=0.0, inplace=False) # (resid_dropout): Dropout(p=0.0, inplace=False) # ) # (ln_2): LayerNorm() # (mlp): MLP( # (c_fc): Linear(in_features=768, out_features=3072, bias=True) # (gelu): GELU(approximate='none') # (c_proj): Linear(in_features=3072, out_features=768, bias=True) # (dropout): Dropout(p=0.0, inplace=False) # ) # ) # ) # (ln_f): LayerNorm() # ) # (lm_head): Linear(in_features=768, out_features=50304, bias=False) # ) gpt_input = torch.zeros((1, config.block_size), dtype=torch.long) print(gpt_input) print(gpt_input.size())
The output of exported is super long so I'm not going to show that here.
CODE BLOCKflops_report(exported)
CODE BLOCKCounting FLOPs for the exported graph... Number of nodes in the exported graph: 887 Tracking operations ['aten::bmm', 'aten::addmm', 'aten::mm'] All operations seen: ['aten::add.Tensor', 'aten::addmm', 'aten::where.self', 'aten::view', 'aten::permute', 'aten::lift_fresh_copy', 'aten::embedding', 'aten::clone', 'aten::bmm', 'aten::split.Tensor', 'aten::mul.Tensor', 'aten::slice.Tensor', 'aten::mm', 'aten::eq.Scalar', 'aten::gelu', 'aten::_softmax', 'aten::arange.start_step', 'aten::expand', 'aten::scalar_tensor', 'aten::index.Tensor', 'aten::native_layer_norm'] FLOPs by Operation Type: aten::addmm: 173946175488 FLOPs (81.7885%) aten::bmm: 38654705664 FLOPs (18.1752%) aten::mm: 77266944 FLOPs (0.0363%) Total FLOPs: 212678148096
This is exactly the value we we're aiming for!