| | import torch |
| |
|
| | def normalise2(tensor): |
| | '''[0,1] -> [-1,1]''' |
| | return (tensor*2 - 1.).clamp(-1,1) |
| |
|
| | def tfg_data(dataloader, face_hide_percentage, use_ref, use_audio): |
| | def inf_gen(generator): |
| | while True: |
| | yield from generator |
| | data = inf_gen(dataloader) |
| | for batch in data: |
| | img_batch, model_kwargs = tfg_process_batch(batch, face_hide_percentage, use_ref, use_audio) |
| | yield img_batch, model_kwargs |
| | |
| |
|
| | def tfg_process_batch(batch, face_hide_percentage, use_ref=False, use_audio=False, sampling_use_gt_for_ref=False, noise = None): |
| | model_kwargs = {} |
| | B, F,C, H, W = batch["image"].shape |
| | img_batch = normalise2(batch["image"].reshape(B*F, C, H, W).contiguous()) |
| | model_kwargs = tfg_add_cond_inputs(img_batch, model_kwargs, face_hide_percentage, noise) |
| | if use_ref: |
| | model_kwargs = tfg_add_reference(batch, model_kwargs, sampling_use_gt_for_ref) |
| | if use_audio: |
| | model_kwargs = tfg_add_audio(batch,model_kwargs) |
| | return img_batch, model_kwargs |
| |
|
| | def tfg_add_reference(batch, model_kwargs, sampling_use_gt_for_ref=False): |
| | |
| | |
| | if sampling_use_gt_for_ref: |
| | B, F,C, H, W = batch["image"].shape |
| | img_batch = normalise2(batch["image"].reshape(B*F, C, H, W).contiguous()) |
| | model_kwargs["ref_img"] = img_batch |
| | else: |
| | _, _, C, H , W = batch["ref_img"].shape |
| | ref_img = normalise2(batch["ref_img"].reshape(-1, C, H, W).contiguous()) |
| | model_kwargs["ref_img"] = ref_img |
| | return model_kwargs |
| |
|
| | def tfg_add_audio(batch, model_kwargs): |
| | |
| | B, F, _, h, w = batch["indiv_mels"].shape |
| | indiv_mels = batch["indiv_mels"] |
| | indiv_mels = indiv_mels.squeeze(dim=2).reshape(B*F, h , w) |
| | model_kwargs["indiv_mels"] = indiv_mels |
| | |
| | if "mel" in batch: |
| | mel = batch["mel"] |
| | model_kwargs["mel"]=mel |
| | return model_kwargs |
| |
|
| | def tfg_add_cond_inputs(img_batch, model_kwargs, face_hide_percentage, noise=None): |
| | B, C, H, W = img_batch.shape |
| | mask = torch.zeros(B,1,H,W) |
| | mask_start_idx = int (H*(1-face_hide_percentage)) |
| | mask[:,:,mask_start_idx:,:]=1. |
| | if noise is None: |
| | noise = torch.randn_like(img_batch) |
| | assert noise.shape == img_batch.shape, "Noise shape != Image shape" |
| | cond_img = img_batch *(1. - mask)+mask*noise |
| |
|
| | model_kwargs["cond_img"] = cond_img |
| | model_kwargs["mask"] = mask |
| | return model_kwargs |
| |
|
| |
|
| | def get_n_params(model): |
| | pp=0 |
| | for p in list(model.parameters()): |
| | nn=1 |
| | for s in list(p.size()): |
| | nn=nn*s |
| | pp+=nn |
| | return pp |