| | import sys
|
| | sys.path.append('.')
|
| | from train import *
|
| | from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count, parameter_count
|
| | from rscd.models.backbones.lamba_util.csms6s import flops_selective_scan_fn, flops_selective_scan_ref, selective_scan_flop_jit
|
| |
|
| | def parse_args():
|
| | parser = argparse.ArgumentParser(description='count params and flops')
|
| | parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py")
|
| | parser.add_argument("--size", type=int, default=256)
|
| | args = parser.parse_args()
|
| | return args
|
| |
|
| | def flops_mamba(model, shape=(3, 224, 224)):
|
| |
|
| | supported_ops = {
|
| | "aten::silu": None,
|
| | "aten::neg": None,
|
| | "aten::exp": None,
|
| | "aten::flip": None,
|
| |
|
| |
|
| | "prim::PythonOp.SelectiveScanCuda": selective_scan_flop_jit,
|
| | "prim::PythonOp.SelectiveScanMamba": selective_scan_flop_jit,
|
| | "prim::PythonOp.SelectiveScanOflex": selective_scan_flop_jit,
|
| | "prim::PythonOp.SelectiveScanCore": selective_scan_flop_jit,
|
| | "prim::PythonOp.SelectiveScanNRow": selective_scan_flop_jit,
|
| | }
|
| |
|
| | model.cuda().eval()
|
| |
|
| | input1 = torch.randn((1, *shape), device=next(model.parameters()).device)
|
| | input2 = torch.randn((1, *shape), device=next(model.parameters()).device)
|
| | params = parameter_count(model)[""]
|
| | Gflops, unsupported = flop_count(model=model, inputs=(input1,input2), supported_ops=supported_ops)
|
| |
|
| | del model, input1, input2
|
| |
|
| | return f"params {params / 1e6} GFLOPs {sum(Gflops.values())}"
|
| |
|
| | if __name__ == "__main__":
|
| | args = parse_args()
|
| | cfg = Config.fromfile(args.config)
|
| | net = myTrain(cfg).net.cuda()
|
| |
|
| | size = args.size
|
| | input = torch.rand((1, 3, size, size)).cuda()
|
| |
|
| | net.eval()
|
| | flops = FlopCountAnalysis(net, (input, input))
|
| | print(flop_count_table(flops, max_depth = 2))
|
| |
|
| | print(flops_mamba(net, (3, size, size)))
|
| |
|