Daankular commited on
Commit
90db216
·
1 Parent(s): 6f79261

Update apply_texture to MV-Adapter new API

Browse files

- init_adapter() removed; replace with init_custom_adapter(num_views=6)
+ load_custom_adapter(path, weight_name=...)
- Add ShiftSNRScheduler wrapping (required by new API)
- Move dtype to pipe.to() after adapter init; add cond_encoder.to()
- Replace image=/cameras= call args with reference_image= + control_image=
(plucker embeddings from get_plucker_embeds_from_cameras_ortho)
- Remove stale torch.autocast wrapper
- Import get_plucker_embeds_from_cameras_ortho

Files changed (1) hide show
  1. app.py +31 -22
app.py CHANGED
@@ -950,41 +950,50 @@ def apply_texture(glb_path, input_image, remove_background, variant, tex_seed,
950
 
951
  from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline
952
  from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
953
- from mvadapter.utils import get_orthogonal_camera
954
  import torchvision.transforms.functional as TF
955
 
956
  progress(0.4, desc=f"Running MV-Adapter ({variant})...")
957
 
958
- pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(
959
- sd_id,
960
- torch_dtype=torch.float16,
961
- ).to(DEVICE)
962
-
963
- pipe.init_adapter(
964
- image_encoder_path="openai/clip-vit-large-patch14",
965
- ipa_weight_path=os.path.join(mvadapter_weights, "mvadapter_i2mv_sdxl.safetensors"),
966
- adapter_tokens=256,
967
  )
 
 
968
 
969
  ref_pil = Image.open(img_path).convert("RGB")
970
  cameras = get_orthogonal_camera(
971
  elevation_deg=[0, 0, 0, 0, 0, 0],
972
  distance=[1.8] * 6,
973
  left=-0.55, right=0.55, bottom=-0.55, top=0.55,
974
- azimuth_deg=[x - 90 for x in [0, 45, 90, 135, 180, 270]],
975
  device=DEVICE,
976
  )
977
-
978
- with torch.autocast(DEVICE):
979
- out = pipe(
980
- image=ref_pil,
981
- height=768, width=768,
982
- num_images_per_prompt=6,
983
- guidance_scale=3.0,
984
- num_inference_steps=30,
985
- generator=torch.Generator(device=DEVICE).manual_seed(int(tex_seed)),
986
- cameras=cameras,
987
- )
 
 
 
 
 
 
 
988
 
989
  mv_grid = out.images # list of 6 PIL images
990
  grid_w = mv_grid[0].width * len(mv_grid)
 
950
 
951
  from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline
952
  from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
953
+ from mvadapter.utils import get_orthogonal_camera, get_plucker_embeds_from_cameras_ortho
954
  import torchvision.transforms.functional as TF
955
 
956
  progress(0.4, desc=f"Running MV-Adapter ({variant})...")
957
 
958
+ pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(sd_id)
959
+ pipe.scheduler = ShiftSNRScheduler.from_scheduler(
960
+ pipe.scheduler,
961
+ shift_mode="interpolated",
962
+ shift_scale=8.0,
963
+ )
964
+ pipe.init_custom_adapter(num_views=6)
965
+ pipe.load_custom_adapter(
966
+ mvadapter_weights, weight_name="mvadapter_i2mv_sdxl.safetensors"
967
  )
968
+ pipe.to(device=DEVICE, dtype=torch.float16)
969
+ pipe.cond_encoder.to(device=DEVICE, dtype=torch.float16)
970
 
971
  ref_pil = Image.open(img_path).convert("RGB")
972
  cameras = get_orthogonal_camera(
973
  elevation_deg=[0, 0, 0, 0, 0, 0],
974
  distance=[1.8] * 6,
975
  left=-0.55, right=0.55, bottom=-0.55, top=0.55,
976
+ azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]],
977
  device=DEVICE,
978
  )
979
+ plucker_embeds = get_plucker_embeds_from_cameras_ortho(
980
+ cameras.c2w, [1.1] * 6, width=768
981
+ )
982
+ control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1)
983
+
984
+ out = pipe(
985
+ "high quality",
986
+ height=768, width=768,
987
+ num_images_per_prompt=6,
988
+ guidance_scale=3.0,
989
+ num_inference_steps=30,
990
+ generator=torch.Generator(device=DEVICE).manual_seed(int(tex_seed)),
991
+ control_image=control_images,
992
+ control_conditioning_scale=1.0,
993
+ reference_image=ref_pil,
994
+ reference_conditioning_scale=1.0,
995
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
996
+ )
997
 
998
  mv_grid = out.images # list of 6 PIL images
999
  grid_w = mv_grid[0].width * len(mv_grid)