The notebook can be found here.
Recently PyTorch added the export
function which AOT compiles your models. The resulting output is a static analysis of how data moves through the ATen IR.
For this module:
class M(torch.nn.Module):
def __init__(self, d_in, d_out):
super(M, self).__init__()
self.layer = nn.Linear(d_in, d_out)
def forward(self, x):
return self.layer(x)
arg = torch.randn(2, 10)
model = M(10, 20)
exported = export(model, args=(arg,))
print(exported)
The output is:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[20, 10], arg1_1: f32[20], arg2_1: f32[2, 10]):
#
permute: f32[10, 20] = torch.ops.aten.permute.default(arg0_1, [1, 0]); arg0_1 = None
addmm: f32[2, 20] = torch.ops.aten.addmm.default(arg1_1, arg2_1, permute); arg1_1 = arg2_1 = permute = None
return (addmm,)
Graph Signature: ExportGraphSignature(parameters=['L__self___layer.weight', 'L__self___layer.bias'], buffers=[], user_inputs=['arg2_1'], user_outputs=['addmm'], inputs_to_parameters={'arg0_1': 'L__self___layer.weight', 'arg1_1': 'L__self___layer.bias'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
We know exactly how data is moving through the model, meaning we can precisely allocate memory, etc.
Notice how the shapes are tracked as well. This is crucial for counting FLOPs. FLOPs is useful in determining hardware utilization.
Counting FLOPs
The API is pretty good but it will take a fair bit a meddling with nodes and referencing the ATen instruction set to get the hang of things.
Ok, so what do we want to do? Loop over the nodes and for certain nodes count the flops in that node. Which nodes do we care about? Anything that involves matrix/tensor multiplication. There are other operations such as scalar multiplication and addition, vector addition, softmax, etc. But the bulk of FLOPs will come from matrix multiplications so we'll focus on those.
Looking at the ATen IR:
- aten.mm : mm(Tensor self, Tensor mat2) -> Tensor
- aten.addmm : addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
- aten.bmm : bmm(Tensor self, Tensor mat2) -> Tensor
are the instructions we are about.
NOTE: the FLOPs for a matrix multiplication t1 (mxk) and t2 (kxn) are roughly 2 * m * k * n.
Now we just need to get these values from the nodes in the graph.
for n in exported.graph.nodes:
print(n)
arg0_1
arg1_1
arg2_1
permute
addmm
output
The nodes are of type fx.Node
. We need to get the inputs for the addmm
instruction. We can access these through
the all_input_nodes
attribute.
for n in exported.graph.nodes:
print(n, n.all_input_nodes)
arg0_1 []
arg1_1 []
arg2_1 []
permute [arg0_1]
addmm [arg1_1, arg2_1, permute]
output [addmm]
[arg1_1, arg2_1, permute]
roughly matches the signature above. Let's go back to the exported program output.
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[20, 10], arg1_1: f32[20], arg2_1: f32[2, 10]):
#
permute: f32[10, 20] = torch.ops.aten.permute.default(arg0_1, [1, 0]); arg0_1 = None
addmm: f32[2, 20] = torch.ops.aten.addmm.default(arg1_1, arg2_1, permute); arg1_1 = arg2_1 = permute = None
return (addmm,)
arg1_1 maps to f32[20]. So that's the bias. We can ignore that argument. arg2_1 is f32[2, 10] which maps to t1. permute is f32[10, 20] which maps to t2. We now know which input nodes we will be using. The next step is to retrieve the dimensions of the tensors.
This part is where the API gets a bit odd. AFAIK the way to determine this is the meta
attribute of a node. Specifically meta['val']
, which returns a FakeTensor.
for n in exported.graph.nodes:
print(n, n.all_input_nodes)
for a in n.all_input_nodes:
print('fake tensor', a, a.meta['val'])
arg0_1 []
arg1_1 []
arg2_1 []
permute [arg0_1]
fake tensor arg0_1 FakeTensor(..., size=(20, 10))
addmm [arg1_1, arg2_1, permute]
fake tensor arg1_1 FakeTensor(..., size=(20,))
fake tensor arg2_1 FakeTensor(..., size=(2, 10))
fake tensor permute FakeTensor(..., size=(10, 20))
output [addmm]
fake tensor addmm FakeTensor(..., size=(2, 20))
meta['val'].size()
is what we want.
Let's implement the flop count for the addmm instruction.
def count_addmm_flops(node):
# ignore bias
t1 = node.all_input_nodes[1].meta['val']
t2 = node.all_input_nodes[2].meta['val']
m, k = t1.size()
n = t2.size(1)
return 2 * m * k * n
nodes = list(exported.graph.nodes)
addmm_node = nodes[-2]
print(addmm_node)
print('addmm flops', count_addmm_flops(addmm_node))
Getting the node when looping is a bit tricky. Using node.name
isn't specific enough. The correct way to do this is node.target.name()
. But target
isn't always an object, sometimes it's a string. So we need to add a check for isinstance
.
addmm_node.target.name()
# output is 'aten::addmm'
total_flops = 0
for n in exported.graph.nodes:
if isinstance(n.target, torch._ops.OpOverload):
if n.target.name() == 'aten::addmm':
total_flops += count_addmm_flops(n)
print('total flops', total_flops)
# output is 800
Now let's implement flop counts for the other instructions.
def count_bmm_flops(node):
t1 = node.all_input_nodes[0].meta['val']
t2 = node.all_input_nodes[1].meta['val']
batch_size, m, k = t1.size()
n = t2.size(2)
return 2 * batch_size * m * k * n
def count_mm_flops(node):
t1 = node.all_input_nodes[0].meta['val']
t2 = node.all_input_nodes[1].meta['val']
m, k = t1.size()
n = t2.size(1)
return 2 * m * k * n
total_flops = 0
for n in exported.graph.nodes:
if isinstance(n.target, torch._ops.OpOverload):
if n.target.name() == 'aten::addmm':
total_flops += count_addmm_flops(n)
elif n.target.name() == 'aten::mm':
total_flops += count_mm_flops(n)
elif n.target.name() == 'aten::bmm':
total_flops += count_bmm_flops(n)
print('total flops', total_flops)
Our current graph doesn't call these ops so let's adjust that.
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.layer = nn.Linear(10, 20)
self.param = nn.Parameter(torch.zeros(20, 20))
self.param2 = nn.Parameter(torch.zeros(2, 10, 2))
def forward(self, x):
x = self.layer(x)
x = x @ self.param
x = x.view(2, 2, 10)
x = x @ self.param2
return x.view(2, -1)
arg = torch.randn(2, 10)
model = M()
exported = export(model, args=(arg,))
print(exported)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[20, 20], arg1_1: f32[2, 10, 2], arg2_1: f32[20, 10], arg3_1: f32[20], arg4_1: f32[2, 10]):
#
permute: f32[10, 20] = torch.ops.aten.permute.default(arg2_1, [1, 0]); arg2_1 = None
addmm: f32[2, 20] = torch.ops.aten.addmm.default(arg3_1, arg4_1, permute); arg3_1 = arg4_1 = permute = None
mm: f32[2, 20] = torch.ops.aten.mm.default(addmm, arg0_1); addmm = arg0_1 = None
view: f32[2, 2, 10] = torch.ops.aten.view.default(mm, [2, 2, 10]); mm = None
expand: f32[2, 2, 10] = torch.ops.aten.expand.default(view, [2, 2, 10]); view = None
view_1: f32[2, 2, 10] = torch.ops.aten.view.default(expand, [2, 2, 10]); expand = None
expand_1: f32[2, 10, 2] = torch.ops.aten.expand.default(arg1_1, [2, 10, 2]); arg1_1 = None
view_2: f32[2, 10, 2] = torch.ops.aten.view.default(expand_1, [2, 10, 2]); expand_1 = None
bmm: f32[2, 2, 2] = torch.ops.aten.bmm.default(view_1, view_2); view_1 = view_2 = None
view_3: f32[2, 2, 2] = torch.ops.aten.view.default(bmm, [2, 2, 2]); bmm = None
view_4: f32[2, 4] = torch.ops.aten.view.default(view_3, [2, -1]); view_3 = None
return (view_4,)
Let's turn our flop counting into an extensible function before we go further.
from collections import defaultdict
def flops_report(exported_program):
flops = 0
flops_by_type = defaultdict(int) # Store FLOPs for each type of operation
print('Counting FLOPs for the exported graph...')
op_to_count_func = {
"aten::bmm":count_bmm_flops,
"aten::addmm":count_addmm_flops,
"aten::mm":count_mm_flops,
}
print(f"Number of nodes in the exported graph: {len(exported_program.graph.nodes)}")
print(f"Tracking operations {list(op_to_count_func.keys())}")
all_ops = set()
for n in exported_program.graph.nodes:
if isinstance(n.target, torch._ops.OpOverload):
op_type = n.target.name()
all_ops.add(op_type)
flops_for_node = 0
if op_type in op_to_count_func:
flops_for_node = op_to_count_func[op_type](n)
flops_by_type[op_type] += flops_for_node
flops += flops_for_node
print(f"All operations seen: {list(all_ops)}")
print("\nFLOPs by Operation Type:")
for op_type, flops_for_type in flops_by_type.items():
percentage = (flops_for_type / flops) * 100 if flops != 0 else 0
print(f"{op_type}: {flops_for_type} FLOPs ({percentage:.4f}%)")
print(f"\nTotal FLOPs: {flops}")
return flops
Now let's call it on our model.
>>> flops_report(exported)
Counting FLOPs for the exported graph...
Number of nodes in the exported graph: 17
Tracking operations ['aten::bmm', 'aten::addmm', 'aten::mm']
All operations seen: ['aten::bmm', 'aten::addmm', 'aten::view', 'aten::permute', 'aten::expand', 'aten::mm']
FLOPs by Operation Type:
aten::addmm: 800 FLOPs (31.2500%)
aten::mm: 1600 FLOPs (62.5000%)
aten::bmm: 160 FLOPs (6.2500%)
Total FLOPs: 2560
Ok! This looks good. Now let's use it to estimate FLOPs for nanoGPT. We'll add most of the model.py
file to our notebook.
Luckily we have a reference FLOPs Karpathy already calculate by hand we can use for comparison.
name flops ratio (%)
attention/kqv 3623878656 1.2426
attention/scores 1610612736 0.5522
attention/reduce 1610612736 0.5522
attention/proj 1207959552 0.4142
attention 8053063680 2.7612
mlp/ffw1 4831838208 1.6567
mlp/ffw2 4831838208 1.6567
mlp 9663676416 3.3135
block 17716740096 6.0747
transformer 212600881152 72.8963
dense 79047426048 27.1037
forward_total 291648307200 100.0000
backward_total 583296614400 200.0000
total 874944921600 300.0000
block, transformer, and dense are the ones we care about.
- block is FLOPs for a single transformer block.
- transformer is FLOPS of block * number of blocks
- dense is the linear transform back to the vocab embedding. Note that this calculation assumes every time dimension is used (training). For inference this value would be 77266944 instead.
212600881152 + 77266944 = 212678148096 (this is the value we are aiming for).
config = GPTConfig()
print(config)
# GPTConfig(block_size=1024, vocab_size=50304, n_layer=12, n_head=12, n_embd=768, dropout=0.0, bias=True)
gpt = GPT(config)
print(gpt)
# GPT(
# (transformer): ModuleDict(
# (wte): Embedding(50304, 768)
# (wpe): Embedding(1024, 768)
# (drop): Dropout(p=0.0, inplace=False)
# (h): ModuleList(
# (0-11): 12 x Block(
# (ln_1): LayerNorm()
# (attn): CausalSelfAttention(
# (c_attn): Linear(in_features=768, out_features=2304, bias=True)
# (c_proj): Linear(in_features=768, out_features=768, bias=True)
# (attn_dropout): Dropout(p=0.0, inplace=False)
# (resid_dropout): Dropout(p=0.0, inplace=False)
# )
# (ln_2): LayerNorm()
# (mlp): MLP(
# (c_fc): Linear(in_features=768, out_features=3072, bias=True)
# (gelu): GELU(approximate='none')
# (c_proj): Linear(in_features=3072, out_features=768, bias=True)
# (dropout): Dropout(p=0.0, inplace=False)
# )
# )
# )
# (ln_f): LayerNorm()
# )
# (lm_head): Linear(in_features=768, out_features=50304, bias=False)
# )
gpt_input = torch.zeros((1, config.block_size), dtype=torch.long)
print(gpt_input)
print(gpt_input.size())
# tensor([[0, 0, 0, ..., 0, 0, 0]])
# torch.Size([1, 1024])
exported = export(gpt, args=(gpt_input,))
The output of exported
is super long so I'm not going to show that here.
flops_report(exported)
Counting FLOPs for the exported graph...
Number of nodes in the exported graph: 887
Tracking operations ['aten::bmm', 'aten::addmm', 'aten::mm']
All operations seen: ['aten::add.Tensor', 'aten::addmm', 'aten::where.self', 'aten::view', 'aten::permute', 'aten::lift_fresh_copy', 'aten::embedding', 'aten::clone', 'aten::bmm', 'aten::split.Tensor', 'aten::mul.Tensor', 'aten::slice.Tensor', 'aten::mm', 'aten::eq.Scalar', 'aten::gelu', 'aten::_softmax', 'aten::arange.start_step', 'aten::expand', 'aten::scalar_tensor', 'aten::index.Tensor', 'aten::native_layer_norm']
FLOPs by Operation Type:
aten::addmm: 173946175488 FLOPs (81.7885%)
aten::bmm: 38654705664 FLOPs (18.1752%)
aten::mm: 77266944 FLOPs (0.0363%)
Total FLOPs: 212678148096
This is exactly the value we we're aiming for!