Spaces:
Paused
Paused
fix: use FP32 dtype on CPU (no bfloat16 autocast with INT8 model)
Browse files- infer/inference.py +4 -3
infer/inference.py
CHANGED
|
@@ -442,10 +442,11 @@ class ImageGenerator:
|
|
| 442 |
self.ae = self.ae.cpu()
|
| 443 |
cudagc()
|
| 444 |
|
|
|
|
| 445 |
if args is not None and hasattr(args, 'single_denoise') and not args.single_denoise:
|
| 446 |
-
x = torch.randn(num_samples,16,height // 8,width // 8,device=self.device,dtype=
|
| 447 |
else:
|
| 448 |
-
x= torch.zeros(num_samples,16,height // 8,width // 8,device=self.device,dtype=
|
| 449 |
|
| 450 |
|
| 451 |
timesteps = sampling.get_schedule(num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True)
|
|
@@ -470,7 +471,7 @@ class ImageGenerator:
|
|
| 470 |
ref_image_raw=ref_images_raw,
|
| 471 |
empty_llm=empty_llm)
|
| 472 |
|
| 473 |
-
with torch.autocast(device_type=self.device.type,dtype=torch.bfloat16):
|
| 474 |
# Lpred,Rpred = self.double_denoise(**inputs,cfg_guidance=cfg_guidance,timesteps=timesteps,height=height,width=width)#图像中包括ref image
|
| 475 |
Lpred,Rpred = self.denoise(**inputs,cfg_guidance=cfg_guidance,timesteps=timesteps,show_progress=show_progress,timesteps_truncate=1.0,)#图像中包括ref image
|
| 476 |
Lpred=self.unpack(Lpred.float(),height,width)
|
|
|
|
| 442 |
self.ae = self.ae.cpu()
|
| 443 |
cudagc()
|
| 444 |
|
| 445 |
+
_dtype = torch.float32 if self.device.type == "cpu" else torch.bfloat16
|
| 446 |
if args is not None and hasattr(args, 'single_denoise') and not args.single_denoise:
|
| 447 |
+
x = torch.randn(num_samples,16,height // 8,width // 8,device=self.device,dtype=_dtype,generator=torch.Generator(device=self.device).manual_seed(seed),)
|
| 448 |
else:
|
| 449 |
+
x= torch.zeros(num_samples,16,height // 8,width // 8,device=self.device,dtype=_dtype,)
|
| 450 |
|
| 451 |
|
| 452 |
timesteps = sampling.get_schedule(num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True)
|
|
|
|
| 471 |
ref_image_raw=ref_images_raw,
|
| 472 |
empty_llm=empty_llm)
|
| 473 |
|
| 474 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.float32) if self.device.type == "cpu" else torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
|
| 475 |
# Lpred,Rpred = self.double_denoise(**inputs,cfg_guidance=cfg_guidance,timesteps=timesteps,height=height,width=width)#图像中包括ref image
|
| 476 |
Lpred,Rpred = self.denoise(**inputs,cfg_guidance=cfg_guidance,timesteps=timesteps,show_progress=show_progress,timesteps_truncate=1.0,)#图像中包括ref image
|
| 477 |
Lpred=self.unpack(Lpred.float(),height,width)
|