File size: 5,939 Bytes
74c8926 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | """
patch_pshuman_vram.py
=====================
Apply VRAM-reduction optimisations to /root/PSHuman/inference.py.
Patches applied:
1. load_pshuman_pipeline β adds VAE slicing + CPU offload on top of
the existing fp16 + xformers that are already in the file.
2. run_inference β adds torch.cuda.empty_cache() after the pipeline
call so fragmented VRAM is reclaimed between multi-view denoising.
fp32 @ 768 res β 40 GB. fp16 β 20 GB. fp16 + xformers β 16-18 GB.
fp16 + xformers + VAE slicing + CPU offload β 14-16 GB peak β fits 24 GB.
Run:
/root/miniconda/envs/pshuman/bin/python /root/MeshForge/scripts/patch_pshuman_vram.py
"""
import pathlib, sys
TARGET = pathlib.Path("/root/PSHuman/inference.py")
if not TARGET.exists():
sys.exit(f"ERROR: {TARGET} not found β run after PSHuman is cloned")
src = TARGET.read_text()
original = src # keep a backup reference
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Patch 1: load_pshuman_pipeline
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
OLD_LOAD = """\
def load_pshuman_pipeline(cfg):
pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(cfg.pretrained_model_name_or_path, torch_dtype=weight_dtype)
pipeline.unet.enable_xformers_memory_efficient_attention()
if torch.cuda.is_available():
pipeline.to('cuda')
return pipeline"""
NEW_LOAD = """\
def load_pshuman_pipeline(cfg):
pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
cfg.pretrained_model_name_or_path,
torch_dtype=weight_dtype, # float16 β halves VRAM vs fp32
)
# xformers: reduces peak VRAM during multi-head denoising attention
try:
pipeline.unet.enable_xformers_memory_efficient_attention()
print("[PSHuman] xformers memory-efficient attention enabled")
except Exception as _xe:
print(f"[PSHuman] xformers unavailable ({_xe}) β falling back to attention slicing")
pipeline.unet.enable_attention_slicing(1)
# VAE slicing: prevents OOM when decoding a 7-view 768-res batch at once
if hasattr(pipeline, "enable_vae_slicing"):
pipeline.enable_vae_slicing()
print("[PSHuman] VAE slicing enabled")
# CPU offload: idle pipeline components (text encoder, VAE, safety checker)
# move to RAM when not actively used, freeing ~3-4 GB of static VRAM.
# pipeline() is called via standard diffusers __call__, so hooks work.
if torch.cuda.is_available():
try:
pipeline.enable_model_cpu_offload()
print("[PSHuman] model CPU offload enabled")
except Exception as _oe:
print(f"[PSHuman] CPU offload unavailable ({_oe}) β loading to CUDA directly")
pipeline.to("cuda")
return pipeline"""
if OLD_LOAD in src:
src = src.replace(OLD_LOAD, NEW_LOAD)
print("[patch 1] load_pshuman_pipeline β VRAM optimisations applied")
elif "enable_vae_slicing" in src:
print("[patch 1] load_pshuman_pipeline β already patched, skipping")
else:
# Looser match for minor whitespace/version differences
import re
m = re.search(
r'def load_pshuman_pipeline\(cfg\):.*?return pipeline',
src, re.DOTALL
)
if m:
src = src[:m.start()] + NEW_LOAD + src[m.end():]
print("[patch 1] load_pshuman_pipeline β applied via regex fallback")
else:
print("[patch 1] WARNING: could not locate load_pshuman_pipeline β skipping")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Patch 2: empty CUDA cache after pipeline call in run_inference
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Insert torch.cuda.empty_cache() right after the pipeline __call__ block.
# The existing code already has `torch.cuda.empty_cache()` at the bottom of
# the batch loop β so only add if it's missing near the unet_out line.
OLD_CACHE_ANCHOR = """\
with torch.autocast("cuda"):
# B*Nv images
guidance_scale = cfg.validation_guidance_scales
unet_out = pipeline("""
NEW_CACHE_ANCHOR = """\
torch.cuda.empty_cache() # free fragmented VRAM before denoising
with torch.autocast("cuda"):
# B*Nv images
guidance_scale = cfg.validation_guidance_scales
unet_out = pipeline("""
if OLD_CACHE_ANCHOR in src and "empty_cache() # free fragmented" not in src:
src = src.replace(OLD_CACHE_ANCHOR, NEW_CACHE_ANCHOR)
print("[patch 2] run_inference β added pre-denoising cache flush")
else:
print("[patch 2] run_inference β cache flush already present or anchor not found, skipping")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Write back only if changed
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if src != original:
backup = TARGET.with_suffix(".py.orig")
if not backup.exists():
backup.write_text(original)
print(f"[patch] Backup saved β {backup}")
TARGET.write_text(src)
print(f"[patch] Written β {TARGET}")
else:
print("[patch] No changes made.")
|