File size: 5,939 Bytes
74c8926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
patch_pshuman_vram.py
=====================
Apply VRAM-reduction optimisations to /root/PSHuman/inference.py.

Patches applied:
  1. load_pshuman_pipeline β€” adds VAE slicing + CPU offload on top of
     the existing fp16 + xformers that are already in the file.
  2. run_inference β€” adds torch.cuda.empty_cache() after the pipeline
     call so fragmented VRAM is reclaimed between multi-view denoising.

fp32 @ 768 res β‰ˆ 40 GB.  fp16 β‰ˆ 20 GB.  fp16 + xformers β‰ˆ 16-18 GB.
fp16 + xformers + VAE slicing + CPU offload β‰ˆ 14-16 GB peak β†’ fits 24 GB.

Run:
    /root/miniconda/envs/pshuman/bin/python /root/MeshForge/scripts/patch_pshuman_vram.py
"""
import pathlib, sys

TARGET = pathlib.Path("/root/PSHuman/inference.py")
if not TARGET.exists():
    sys.exit(f"ERROR: {TARGET} not found β€” run after PSHuman is cloned")

src = TARGET.read_text()
original = src  # keep a backup reference

# ─────────────────────────────────────────────────────────────────
# Patch 1: load_pshuman_pipeline
# ─────────────────────────────────────────────────────────────────
OLD_LOAD = """\
def load_pshuman_pipeline(cfg):
    pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(cfg.pretrained_model_name_or_path, torch_dtype=weight_dtype)
    pipeline.unet.enable_xformers_memory_efficient_attention()
    if torch.cuda.is_available():
        pipeline.to('cuda')
    return pipeline"""

NEW_LOAD = """\
def load_pshuman_pipeline(cfg):
    pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
        cfg.pretrained_model_name_or_path,
        torch_dtype=weight_dtype,          # float16 β€” halves VRAM vs fp32
    )

    # xformers: reduces peak VRAM during multi-head denoising attention
    try:
        pipeline.unet.enable_xformers_memory_efficient_attention()
        print("[PSHuman] xformers memory-efficient attention enabled")
    except Exception as _xe:
        print(f"[PSHuman] xformers unavailable ({_xe}) β€” falling back to attention slicing")
        pipeline.unet.enable_attention_slicing(1)

    # VAE slicing: prevents OOM when decoding a 7-view 768-res batch at once
    if hasattr(pipeline, "enable_vae_slicing"):
        pipeline.enable_vae_slicing()
        print("[PSHuman] VAE slicing enabled")

    # CPU offload: idle pipeline components (text encoder, VAE, safety checker)
    # move to RAM when not actively used, freeing ~3-4 GB of static VRAM.
    # pipeline() is called via standard diffusers __call__, so hooks work.
    if torch.cuda.is_available():
        try:
            pipeline.enable_model_cpu_offload()
            print("[PSHuman] model CPU offload enabled")
        except Exception as _oe:
            print(f"[PSHuman] CPU offload unavailable ({_oe}) β€” loading to CUDA directly")
            pipeline.to("cuda")

    return pipeline"""

if OLD_LOAD in src:
    src = src.replace(OLD_LOAD, NEW_LOAD)
    print("[patch 1] load_pshuman_pipeline β€” VRAM optimisations applied")
elif "enable_vae_slicing" in src:
    print("[patch 1] load_pshuman_pipeline β€” already patched, skipping")
else:
    # Looser match for minor whitespace/version differences
    import re
    m = re.search(
        r'def load_pshuman_pipeline\(cfg\):.*?return pipeline',
        src, re.DOTALL
    )
    if m:
        src = src[:m.start()] + NEW_LOAD + src[m.end():]
        print("[patch 1] load_pshuman_pipeline β€” applied via regex fallback")
    else:
        print("[patch 1] WARNING: could not locate load_pshuman_pipeline β€” skipping")

# ─────────────────────────────────────────────────────────────────
# Patch 2: empty CUDA cache after pipeline call in run_inference
# ─────────────────────────────────────────────────────────────────
# Insert torch.cuda.empty_cache() right after the pipeline __call__ block.
# The existing code already has `torch.cuda.empty_cache()` at the bottom of
# the batch loop β€” so only add if it's missing near the unet_out line.

OLD_CACHE_ANCHOR = """\
            with torch.autocast("cuda"):
                # B*Nv images
                guidance_scale = cfg.validation_guidance_scales
                unet_out = pipeline("""

NEW_CACHE_ANCHOR = """\
            torch.cuda.empty_cache()  # free fragmented VRAM before denoising
            with torch.autocast("cuda"):
                # B*Nv images
                guidance_scale = cfg.validation_guidance_scales
                unet_out = pipeline("""

if OLD_CACHE_ANCHOR in src and "empty_cache()  # free fragmented" not in src:
    src = src.replace(OLD_CACHE_ANCHOR, NEW_CACHE_ANCHOR)
    print("[patch 2] run_inference β€” added pre-denoising cache flush")
else:
    print("[patch 2] run_inference β€” cache flush already present or anchor not found, skipping")

# ─────────────────────────────────────────────────────────────────
# Write back only if changed
# ─────────────────────────────────────────────────────────────────
if src != original:
    backup = TARGET.with_suffix(".py.orig")
    if not backup.exists():
        backup.write_text(original)
        print(f"[patch] Backup saved β†’ {backup}")
    TARGET.write_text(src)
    print(f"[patch] Written β†’ {TARGET}")
else:
    print("[patch] No changes made.")