Nekochu commited on
Commit
09dfa8a
·
1 Parent(s): 9d04602

fix: use FP32 dtype on CPU (no bfloat16 autocast with INT8 model)

Browse files
Files changed (1) hide show
  1. infer/inference.py +4 -3
infer/inference.py CHANGED
@@ -442,10 +442,11 @@ class ImageGenerator:
442
  self.ae = self.ae.cpu()
443
  cudagc()
444
 
 
445
  if args is not None and hasattr(args, 'single_denoise') and not args.single_denoise:
446
- x = torch.randn(num_samples,16,height // 8,width // 8,device=self.device,dtype=torch.bfloat16,generator=torch.Generator(device=self.device).manual_seed(seed),)
447
  else:
448
- x= torch.zeros(num_samples,16,height // 8,width // 8,device=self.device,dtype=torch.bfloat16,)
449
 
450
 
451
  timesteps = sampling.get_schedule(num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True)
@@ -470,7 +471,7 @@ class ImageGenerator:
470
  ref_image_raw=ref_images_raw,
471
  empty_llm=empty_llm)
472
 
473
- with torch.autocast(device_type=self.device.type,dtype=torch.bfloat16):
474
  # Lpred,Rpred = self.double_denoise(**inputs,cfg_guidance=cfg_guidance,timesteps=timesteps,height=height,width=width)#图像中包括ref image
475
  Lpred,Rpred = self.denoise(**inputs,cfg_guidance=cfg_guidance,timesteps=timesteps,show_progress=show_progress,timesteps_truncate=1.0,)#图像中包括ref image
476
  Lpred=self.unpack(Lpred.float(),height,width)
 
442
  self.ae = self.ae.cpu()
443
  cudagc()
444
 
445
+ _dtype = torch.float32 if self.device.type == "cpu" else torch.bfloat16
446
  if args is not None and hasattr(args, 'single_denoise') and not args.single_denoise:
447
+ x = torch.randn(num_samples,16,height // 8,width // 8,device=self.device,dtype=_dtype,generator=torch.Generator(device=self.device).manual_seed(seed),)
448
  else:
449
+ x= torch.zeros(num_samples,16,height // 8,width // 8,device=self.device,dtype=_dtype,)
450
 
451
 
452
  timesteps = sampling.get_schedule(num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True)
 
471
  ref_image_raw=ref_images_raw,
472
  empty_llm=empty_llm)
473
 
474
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32) if self.device.type == "cpu" else torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
475
  # Lpred,Rpred = self.double_denoise(**inputs,cfg_guidance=cfg_guidance,timesteps=timesteps,height=height,width=width)#图像中包括ref image
476
  Lpred,Rpred = self.denoise(**inputs,cfg_guidance=cfg_guidance,timesteps=timesteps,show_progress=show_progress,timesteps_truncate=1.0,)#图像中包括ref image
477
  Lpred=self.unpack(Lpred.float(),height,width)