| | import gradio as gr |
| | import torch |
| |
|
| | import numpy as np |
| | import random |
| | import os |
| | import yaml |
| | from pathlib import Path |
| | import imageio |
| | import tempfile |
| | from PIL import Image |
| | from huggingface_hub import hf_hub_download |
| |
|
| | from inference import ( |
| | create_ltx_video_pipeline, |
| | create_latent_upsampler, |
| | load_image_to_tensor_with_resize_and_crop, |
| | seed_everething, |
| | calculate_padding, |
| | load_media_file |
| | ) |
| | from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline |
| | from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy |
| |
|
| | |
| | CONFIG_YAML = "configs/ltxv-13b-0.9.7-distilled.yaml" |
| | with open(CONFIG_YAML, "r") as f: |
| | CFG = yaml.safe_load(f) |
| |
|
| | HF_REPO = "LTTEAM/VideoAI" |
| | MODELS_DIR = "downloaded_models" |
| | Path(MODELS_DIR).mkdir(exist_ok=True) |
| |
|
| | print("Đang tải mô hình (nếu chưa có)…") |
| | ckpt = hf_hub_download(repo_id=HF_REPO, filename=CFG["checkpoint_path"], local_dir=MODELS_DIR) |
| | CFG["checkpoint_path"] = ckpt |
| | upscaler = hf_hub_download(repo_id=HF_REPO, filename=CFG["spatial_upscaler_model_path"], local_dir=MODELS_DIR) |
| | CFG["spatial_upscaler_model_path"] = upscaler |
| |
|
| | |
| | print("Khởi tạo pipeline trên CPU…") |
| | pipeline = create_ltx_video_pipeline( |
| | ckpt_path=CFG["checkpoint_path"], |
| | precision=CFG["precision"], |
| | text_encoder_model_name_or_path=CFG["text_encoder_model_name_or_path"], |
| | sampler=CFG["sampler"], |
| | device="cpu", |
| | enhance_prompt=False, |
| | prompt_enhancer_image_caption_model_name_or_path=CFG["prompt_enhancer_image_caption_model_name_or_path"], |
| | prompt_enhancer_llm_model_name_or_path=CFG["prompt_enhancer_llm_model_name_or_path"], |
| | ) |
| | print("Pipeline sẵn sàng.") |
| | print("Khởi tạo latent upsampler trên CPU…") |
| | upsampler = create_latent_upsampler(CFG["spatial_upscaler_model_path"], device="cpu") |
| | print("Upsampler sẵn sàng.") |
| |
|
| | |
| | FPS = 30.0 |
| | MAX_FRAMES = 257 |
| | MIN_DIM = 256 |
| | FIXED_SIDE = 768 |
| | MAX_RES = CFG.get("max_resolution", 1280) |
| |
|
| | def calc_new_dims(w, h): |
| | if w==0 or h==0: |
| | return FIXED_SIDE, FIXED_SIDE |
| | if w>=h: |
| | nh = FIXED_SIDE |
| | nw = round((nh*w/h)/32)*32 |
| | else: |
| | nw = FIXED_SIDE |
| | nh = round((nw*h/w)/32)*32 |
| | return ( |
| | int(max(MIN_DIM, min(nh, MAX_RES))), |
| | int(max(MIN_DIM, min(nw, MAX_RES))) |
| | ) |
| |
|
| | def get_duration(*args, duration_ui=0, **kwargs): |
| | return 75 if duration_ui > 7 else 60 |
| |
|
| |
|
| | def generate(prompt, neg_prompt, img_path, vid_path, |
| | height, width, mode, duration_ui, frames_to_use, |
| | seed, rand_seed, cfg_scale, improve_tex, device_choice, |
| | progress=gr.Progress(track_tqdm=True)): |
| |
|
| | |
| | dev = "cuda" if device_choice=="GPU" and torch.cuda.is_available() else "cpu" |
| | print(f"Sử dụng thiết bị: {dev}") |
| | pipeline.to(dev) |
| | upsampler.to(dev) |
| |
|
| | |
| | if rand_seed: |
| | seed = random.randint(0, 2**32 - 1) |
| | seed_everething(int(seed)) |
| |
|
| | |
| | tf = max(1, round(duration_ui * FPS)) |
| | n8 = round((tf-1)/8) |
| | n_frames = max(9, min(n8*8+1, MAX_FRAMES)) |
| |
|
| | |
| | h, w = int(height), int(width) |
| | h_pad = ((h-1)//32+1)*32 |
| | w_pad = ((w-1)//32+1)*32 |
| | pad = calculate_padding(h, w, h_pad, w_pad) |
| |
|
| | |
| | kwargs = { |
| | "prompt": prompt, |
| | "negative_prompt": neg_prompt, |
| | "height": h_pad, |
| | "width": w_pad, |
| | "num_frames": n_frames, |
| | "frame_rate": int(FPS), |
| | "generator": torch.Generator(device=dev).manual_seed(int(seed)), |
| | "output_type": "pt", |
| | "decode_timestep": CFG["decode_timestep"], |
| | "decode_noise_scale": CFG["decode_noise_scale"], |
| | "stochastic_sampling": CFG["stochastic_sampling"], |
| | "is_video": True, |
| | "vae_per_channel_normalize": True, |
| | "mixed_precision": CFG["precision"]=="mixed_precision", |
| | "offload_to_cpu": False, |
| | "enhance_prompt": False, |
| | } |
| | |
| | mode_stg = CFG.get("stg_mode","attention_values").lower() |
| | stg_map = { |
| | "stg_av": SkipLayerStrategy.AttentionValues, |
| | "attention_values": SkipLayerStrategy.AttentionValues, |
| | "stg_as": SkipLayerStrategy.AttentionSkip, |
| | "attention_skip": SkipLayerStrategy.AttentionSkip, |
| | "stg_r": SkipLayerStrategy.Residual, |
| | "residual": SkipLayerStrategy.Residual, |
| | "stg_t": SkipLayerStrategy.TransformerBlock, |
| | "transformer_block": SkipLayerStrategy.TransformerBlock, |
| | } |
| | kwargs["skip_layer_strategy"] = stg_map.get(mode_stg, SkipLayerStrategy.AttentionValues) |
| |
|
| | |
| | if mode=="image-to-video" and img_path: |
| | t = load_image_to_tensor_with_resize_and_crop(img_path, h, w) |
| | t = torch.nn.functional.pad(t, pad) |
| | kwargs["conditioning_items"] = [ConditioningItem(t.to(dev), 0, 1.0)] |
| | elif mode=="video-to-video" and vid_path: |
| | mi = load_media_file(vid_path, h, w, int(frames_to_use), pad).to(dev) |
| | kwargs["media_items"] = mi |
| |
|
| | |
| | if improve_tex: |
| | pipe_ms = LTXMultiScalePipeline(pipeline, upsampler) |
| | fp = CFG.get("first_pass",{}).copy() |
| | fp["guidance_scale"] = float(cfg_scale) |
| | fp.pop("num_inference_steps", None) |
| | sp = CFG.get("second_pass",{}).copy() |
| | sp["guidance_scale"] = float(cfg_scale) |
| | sp.pop("num_inference_steps", None) |
| | kwargs.update({ |
| | "downscale_factor": CFG["downscale_factor"], |
| | "first_pass": fp, |
| | "second_pass": sp |
| | }) |
| | images = pipe_ms(**kwargs).images |
| | else: |
| | fp0 = CFG.get("first_pass",{}) |
| | kwargs.update({ |
| | "timesteps": fp0.get("timesteps"), |
| | "guidance_scale": float(cfg_scale), |
| | "stg_scale": fp0.get("stg_scale"), |
| | "rescaling_scale": fp0.get("rescaling_scale"), |
| | "skip_block_list": fp0.get("skip_block_list") |
| | }) |
| | for k in ["first_pass","second_pass","downscale_factor","num_inference_steps"]: |
| | kwargs.pop(k, None) |
| | images = pipeline(**kwargs).images |
| |
|
| | |
| | l, r, t_, b = pad |
| | sh = None if b==0 else -b |
| | sw = None if r==0 else -r |
| | vid_t = images[0][:,:,:n_frames, t_:sh, l:sw] |
| | arr = vid_t.permute(1,2,3,0).cpu().numpy() |
| | arr = (np.clip(arr,0,1)*255).astype(np.uint8) |
| |
|
| | out_dir = tempfile.mkdtemp() |
| | out_path = os.path.join(out_dir, f"output_{random.randint(0,99999)}.mp4") |
| | with imageio.get_writer(out_path, fps=int(FPS), macro_block_size=1) as writer: |
| | for i in range(arr.shape[0]): |
| | progress(i/arr.shape[0], desc="Lưu video") |
| | writer.append_data(arr[i]) |
| |
|
| | return out_path, seed |
| |
|
| | |
| | css = """ |
| | #col-container { margin:0 auto; max-width:900px; } |
| | """ |
| | with gr.Blocks(css=css) as demo: |
| | gr.Markdown("## Ứng dụng LTX Video 0.9.7 Distilled") |
| | gr.Markdown( |
| | "[Mô hình trên HF](https://huggingface.co/LTTEAM/VideoAI) · " |
| | "[GitHub](https://github.com/Lightricks/LTX-Video)" |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | device = gr.Radio(["CPU", "GPU"], label="Chạy trên thiết bị", value="CPU") |
| |
|
| | with gr.Tab("Ảnh→Video"): |
| | img_in = gr.Image(label="Ảnh đầu vào", type="filepath", sources=["upload","clipboard","webcam"]) |
| | prompt1 = gr.Textbox(label="Mô tả", lines=2, value="Con sinh vật di chuyển") |
| | btn1 = gr.Button("Tạo từ ảnh") |
| |
|
| | with gr.Tab("Văn bản→Video"): |
| | prompt2 = gr.Textbox(label="Mô tả", lines=2, value="Rồng bay trên lâu đài") |
| | btn2 = gr.Button("Tạo từ văn bản") |
| |
|
| | with gr.Tab("Video→Video"): |
| | vid_in = gr.Video(label="Video đầu vào", sources=["upload","webcam"]) |
| | frames = gr.Slider(label="Số frame dùng", minimum=9, maximum=MAX_FRAMES, step=8, value=9) |
| | prompt3 = gr.Textbox(label="Mô tả", lines=2, value="Chuyển phong cách anime") |
| | btn3 = gr.Button("Tạo từ video") |
| |
|
| | duration = gr.Slider(label="Thời lượng (giây)", minimum=0.3, maximum=8.5, step=0.1, value=2) |
| | improve = gr.Checkbox(label="Cải thiện chi tiết", value=True) |
| |
|
| | with gr.Column(): |
| | out_vid = gr.Video(label="Kết quả", interactive=False) |
| |
|
| | |
| | mode_state = gr.State("image-to-video") |
| | seed_state = gr.State(42) |
| | neg_state = gr.State("worst quality, inconsistent motion, blurry, jittery, distorted") |
| | cfg_state = gr.State(CFG["first_pass"]["guidance_scale"]) |
| | h_state = gr.State(512) |
| | w_state = gr.State(704) |
| |
|
| | btn1.click(fn=generate, |
| | inputs=[prompt1, neg_state, img_in, gr.State(""), h_state, w_state, |
| | mode_state, duration, frames, seed_state, gr.State(True), |
| | cfg_state, improve, device], |
| | outputs=[out_vid, seed_state]) |
| | btn2.click(fn=generate, |
| | inputs=[prompt2, neg_state, gr.State(""), gr.State(""), h_state, w_state, |
| | mode_state, duration, frames, seed_state, gr.State(True), |
| | cfg_state, improve, device], |
| | outputs=[out_vid, seed_state]) |
| | btn3.click(fn=generate, |
| | inputs=[prompt3, neg_state, gr.State(""), vid_in, h_state, w_state, |
| | mode_state, duration, frames, seed_state, gr.State(True), |
| | cfg_state, improve, device], |
| | outputs=[out_vid, seed_state]) |
| |
|
| | if __name__ == "__main__": |
| | demo.queue().launch(share=True) |
| |
|