Coercer commited on
Commit
f0cc4e1
·
verified ·
1 Parent(s): f59682f

Upload extract_sdxl_embeddings.py

Browse files
Files changed (1) hide show
  1. extract_sdxl_embeddings.py +100 -0
extract_sdxl_embeddings.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # extract_sdxl_embeddings.py
2
+ import argparse
3
+ from pathlib import Path
4
+ from typing import List
5
+
6
+ import torch
7
+ from safetensors.torch import save_file
8
+ from diffusers import StableDiffusionXLPipeline
9
+
10
+
11
+ def read_prompts(txt_path: str) -> List[str]:
12
+ with open(txt_path, "r", encoding="utf-8") as f:
13
+ return [line.rstrip("\n") for line in f]
14
+
15
+
16
+ def load_sdxl(checkpoint_path: str, precision: str):
17
+ precision = precision.lower()
18
+ if precision == "bf16":
19
+ dtype = torch.bfloat16
20
+ else:
21
+ dtype = torch.float16 # T4 suele ir mejor así para SDXL
22
+
23
+ path = Path(checkpoint_path)
24
+ if path.is_dir():
25
+ pipe = StableDiffusionXLPipeline.from_pretrained(
26
+ checkpoint_path,
27
+ torch_dtype=dtype,
28
+ use_safetensors=True,
29
+ )
30
+ else:
31
+ # Útil para .safetensors / .ckpt de un solo archivo.
32
+ pipe = StableDiffusionXLPipeline.from_single_file(
33
+ checkpoint_path,
34
+ torch_dtype=dtype,
35
+ )
36
+
37
+ pipe.to("cuda" if torch.cuda.is_available() else "cpu")
38
+ pipe.set_progress_bar_config(disable=True)
39
+ pipe.eval()
40
+ return pipe
41
+
42
+
43
+ @torch.no_grad()
44
+ def encode_batch(pipe: StableDiffusionXLPipeline, batch_prompts: List[str]):
45
+ device = pipe._execution_device if hasattr(pipe, "_execution_device") else next(pipe.text_encoder.parameters()).device
46
+
47
+ # Diffusers soporta prompt_embeds y pooled_prompt_embeds en SDXL. :contentReference[oaicite:3]{index=3}
48
+ prompt_embeds, pooled_prompt_embeds = pipe.encode_prompt(
49
+ prompt=batch_prompts,
50
+ prompt_2=batch_prompts,
51
+ device=device,
52
+ num_images_per_prompt=1,
53
+ do_classifier_free_guidance=False,
54
+ )[:2]
55
+
56
+ return prompt_embeds.detach().cpu(), pooled_prompt_embeds.detach().cpu()
57
+
58
+
59
+ def main():
60
+ parser = argparse.ArgumentParser()
61
+ parser.add_argument("--sdxl_checkpoint", type=str, required=True,
62
+ help="Ruta al .safetensors / .ckpt o directorio Diffusers de SDXL.")
63
+ parser.add_argument("--prompts_txt", type=str, required=True)
64
+ parser.add_argument("--out_dir", type=str, default="output_embeddings")
65
+ parser.add_argument("--batch_size", type=int, default=4)
66
+ parser.add_argument("--precision", type=str, default="fp16", choices=["fp16", "bf16"])
67
+ parser.add_argument("--pad_width", type=int, default=5)
68
+ args = parser.parse_args()
69
+
70
+ out_dir = Path(args.out_dir)
71
+ out_dir.mkdir(parents=True, exist_ok=True)
72
+
73
+ prompts = read_prompts(args.prompts_txt)
74
+ pipe = load_sdxl(args.sdxl_checkpoint, args.precision)
75
+
76
+ n = len(prompts)
77
+ print(f"Procesando {n} prompts...")
78
+
79
+ for i in range(0, n, args.batch_size):
80
+ batch_prompts = prompts[i:i + args.batch_size]
81
+ text_embeds, pooled_text_embeds = encode_batch(pipe, batch_prompts)
82
+
83
+ for b in range(text_embeds.shape[0]):
84
+ file_idx = i + b
85
+ file_name = f"{file_idx:0{args.pad_width}d}.safetensors"
86
+ save_path = out_dir / file_name
87
+
88
+ sample = {
89
+ "text_embeds": text_embeds[b:b+1].contiguous().to(torch.float16),
90
+ "pooled_text_embeds": pooled_text_embeds[b:b+1].contiguous().to(torch.float16),
91
+ }
92
+ save_file(sample, str(save_path))
93
+
94
+ print(f"Guardado hasta {min(i + args.batch_size - 1, n - 1):0{args.pad_width}d}")
95
+
96
+ print(f"Listo. Salida en: {out_dir}")
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()