| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | from typing import Any, Callable |
| |
|
| | from transformers import is_torch_available, is_torch_mlu_available |
| | from transformers.testing_utils import ( |
| | TestCasePlus, |
| | execute_subprocess_async, |
| | get_torch_dist_unique_port, |
| | require_torch_multi_accelerator, |
| | ) |
| |
|
| |
|
| | if is_torch_available(): |
| | import functools |
| |
|
| | import torch |
| | import torch.distributed |
| | from torch.distributed._composable.fsdp import fully_shard, register_fsdp_forward_method |
| | from torch.distributed.device_mesh import init_device_mesh |
| | from torch.distributed.fsdp import FullyShardedDataParallel |
| | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
| |
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from transformers.models.gpt2.modeling_gpt2 import GPT2Block |
| |
|
| | data = 4 * [ |
| | "Hello world!", |
| | "The quick brown fox jumps over the lazy dog.", |
| | ] |
| |
|
| | def manage_process_group(func: Callable[..., Any]) -> Callable[..., Any]: |
| | """Manage the creation and destruction of the distributed process group for the wrapped function.""" |
| |
|
| | def wrapped(*args: Any, **kwargs: Any) -> Any: |
| | if is_torch_mlu_available(): |
| | device_count = torch.mlu.device_count() |
| | else: |
| | device_count = torch.cuda.device_count() |
| | torch.distributed.init_process_group(world_size=device_count) |
| | try: |
| | return func(*args, **kwargs) |
| | finally: |
| | torch.distributed.destroy_process_group() |
| |
|
| | return wrapped |
| |
|
| | @manage_process_group |
| | def fsdp_generate(): |
| | if is_torch_mlu_available(): |
| | torch.mlu.set_device(device := torch.device(rank := torch.distributed.get_rank())) |
| | else: |
| | torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank())) |
| |
|
| | model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device) |
| |
|
| | fsdp_model = FullyShardedDataParallel( |
| | model, |
| | auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={GPT2Block}), |
| | limit_all_gathers=True, |
| | use_orig_params=True, |
| | ) |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
| | batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device) |
| |
|
| | with FullyShardedDataParallel.summon_full_params(fsdp_model): |
| | _ = fsdp_model.module.generate( |
| | input_ids=batch["input_ids"], |
| | attention_mask=batch["attention_mask"], |
| | max_length=30, |
| | ) |
| |
|
| | @manage_process_group |
| | def fsdp2_generate(): |
| | if is_torch_mlu_available(): |
| | torch.mlu.set_device(device := torch.device(rank := torch.distributed.get_rank())) |
| | else: |
| | torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank())) |
| |
|
| | model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device) |
| |
|
| | mesh = init_device_mesh(device.type, (torch.distributed.get_world_size(),)) |
| | for submodule in model.modules(): |
| | if isinstance(submodule, GPT2Block): |
| | fully_shard(submodule, mesh=mesh) |
| | fully_shard(model, mesh=mesh) |
| |
|
| | register_fsdp_forward_method(model, "generate") |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
| | batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device) |
| |
|
| | _ = model.generate( |
| | input_ids=batch["input_ids"], |
| | attention_mask=batch["attention_mask"], |
| | max_length=30, |
| | ) |
| |
|
| |
|
| | class TestFSDPGeneration(TestCasePlus): |
| | @require_torch_multi_accelerator |
| | def test_fsdp_generate(self): |
| | if is_torch_mlu_available(): |
| | device_count = torch.mlu.device_count() |
| | else: |
| | device_count = torch.cuda.device_count() |
| | distributed_args = f"""--nproc_per_node={device_count} |
| | --master_port={get_torch_dist_unique_port()} |
| | {self.test_file_dir}/test_fsdp.py |
| | """.split() |
| | args = "--fsdp".split() |
| | cmd = ["torchrun"] + distributed_args + args |
| | execute_subprocess_async(cmd, env=self.get_env()) |
| | |
| |
|
| | @require_torch_multi_accelerator |
| | def test_fsdp2_generate(self): |
| | if is_torch_mlu_available(): |
| | device_count = torch.mlu.device_count() |
| | else: |
| | device_count = torch.cuda.device_count() |
| | distributed_args = f"""--nproc_per_node={device_count} |
| | --master_port={get_torch_dist_unique_port()} |
| | {self.test_file_dir}/test_fsdp.py |
| | """.split() |
| | args = "--fsdp2".split() |
| | cmd = ["torchrun"] + distributed_args + args |
| | execute_subprocess_async(cmd, env=self.get_env()) |
| | |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | |
| | |
| |
|
| | class CLIArgs(argparse.Namespace): |
| | fsdp: bool |
| | fsdp2: bool |
| |
|
| | parser = argparse.ArgumentParser() |
| | group = parser.add_mutually_exclusive_group() |
| | group.add_argument("--fsdp", action="store_true") |
| | group.add_argument("--fsdp2", action="store_true") |
| | args = parser.parse_args(namespace=CLIArgs()) |
| |
|
| | if args.fsdp: |
| | fsdp_generate() |
| | elif args.fsdp2: |
| | fsdp2_generate() |
| | else: |
| | raise ValueError("Missing test selection") |
| |
|