| import sys
|
| sys.path.append('.')
|
| from train import *
|
| from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count, parameter_count
|
| from rscd.models.backbones.lamba_util.csms6s import flops_selective_scan_fn, flops_selective_scan_ref, selective_scan_flop_jit
|
|
|
| def parse_args():
|
| parser = argparse.ArgumentParser(description='count params and flops')
|
| parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py")
|
| parser.add_argument("--size", type=int, default=256)
|
| args = parser.parse_args()
|
| return args
|
|
|
| def flops_mamba(model, shape=(3, 224, 224)):
|
|
|
| supported_ops = {
|
| "aten::silu": None,
|
| "aten::neg": None,
|
| "aten::exp": None,
|
| "aten::flip": None,
|
|
|
|
|
| "prim::PythonOp.SelectiveScanCuda": selective_scan_flop_jit,
|
| "prim::PythonOp.SelectiveScanMamba": selective_scan_flop_jit,
|
| "prim::PythonOp.SelectiveScanOflex": selective_scan_flop_jit,
|
| "prim::PythonOp.SelectiveScanCore": selective_scan_flop_jit,
|
| "prim::PythonOp.SelectiveScanNRow": selective_scan_flop_jit,
|
| }
|
|
|
| model.cuda().eval()
|
|
|
| input1 = torch.randn((1, *shape), device=next(model.parameters()).device)
|
| input2 = torch.randn((1, *shape), device=next(model.parameters()).device)
|
| params = parameter_count(model)[""]
|
| Gflops, unsupported = flop_count(model=model, inputs=(input1,input2), supported_ops=supported_ops)
|
|
|
| del model, input1, input2
|
|
|
| return f"params {params / 1e6} GFLOPs {sum(Gflops.values())}"
|
|
|
| if __name__ == "__main__":
|
| args = parse_args()
|
| cfg = Config.fromfile(args.config)
|
| net = myTrain(cfg).net.cuda()
|
|
|
| size = args.size
|
| input = torch.rand((1, 3, size, size)).cuda()
|
|
|
| net.eval()
|
| flops = FlopCountAnalysis(net, (input, input))
|
| print(flop_count_table(flops, max_depth = 2))
|
|
|
| print(flops_mamba(net, (3, size, size)))
|
|
|