From 93af6864b942e931908fcf1a3f04feebf97dbb36 Mon Sep 17 00:00:00 2001 From: Yoav HaCohen Date: Wed, 14 May 2025 12:25:08 +0300 Subject: [PATCH] Pipeline: Enhance inference pipelines with new features * Adaptive normalization after latent upsampling * CFG Star Rescale * Varying STG/CFG parameters per step * Support skipping the initial and/or the final diffusion steps * CRF compression for image condition (useful for getting more motion in image-to-video) --- configs/ltxv-13b-0.9.7-dev.yaml | 25 ++--- configs/ltxv-2b-0.9.1.yaml | 17 ++++ configs/ltxv-2b-0.9.5.yaml | 17 ++++ configs/ltxv-2b-0.9.6-distilled.yaml | 7 +- configs/ltxv-2b-0.9.yaml | 17 ++++ inference.py | 18 ++-- ltx_video/pipelines/crf_compressor.py | 50 ++++++++++ ltx_video/pipelines/pipeline_ltx_video.py | 106 +++++++++++++++++++--- pyproject.toml | 4 +- tests/test_inference.py | 2 +- 10 files changed, 220 insertions(+), 43 deletions(-) create mode 100644 configs/ltxv-2b-0.9.1.yaml create mode 100644 configs/ltxv-2b-0.9.5.yaml create mode 100644 configs/ltxv-2b-0.9.yaml create mode 100644 ltx_video/pipelines/crf_compressor.py diff --git a/configs/ltxv-13b-0.9.7-dev.yaml b/configs/ltxv-13b-0.9.7-dev.yaml index 0a4724e..ae54825 100644 --- a/configs/ltxv-13b-0.9.7-dev.yaml +++ b/configs/ltxv-13b-0.9.7-dev.yaml @@ -1,4 +1,3 @@ - pipeline_type: multi-scale checkpoint_path: "ltxv-13b-0.9.7-dev.safetensors" downscale_factor: 0.6666666 @@ -14,20 +13,22 @@ prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-P prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" stochastic_sampling: false - first_pass: - guidance_scale: [3] - stg_scale: [1] - rescaling_scale: [0.7] - guidance_timesteps: [1.0] - skip_block_list: [19] # [[1], [1,2], [1,2,3], [27], [28], [28]] + guidance_scale: [1, 1, 6, 8, 6, 1, 1] + stg_scale: [0, 0, 4, 4, 4, 2, 1] + rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] + guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180] + skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] num_inference_steps: 30 + skip_final_inference_steps: 3 + cfg_star_rescale: true second_pass: - guidance_scale: [3] + guidance_scale: [1] stg_scale: [1] - rescaling_scale: [0.7] + rescaling_scale: [1] guidance_timesteps: [1.0] - skip_block_list: [19] # [[1], [1,2], [1,2,3], [27], [28], [28]] - num_inference_steps: 10 - strength: 0.85 + skip_block_list: [27] + num_inference_steps: 30 + skip_initial_inference_steps: 17 + cfg_star_rescale: true \ No newline at end of file diff --git a/configs/ltxv-2b-0.9.1.yaml b/configs/ltxv-2b-0.9.1.yaml new file mode 100644 index 0000000..6e888de --- /dev/null +++ b/configs/ltxv-2b-0.9.1.yaml @@ -0,0 +1,17 @@ +pipeline_type: base +checkpoint_path: "ltx-video-2b-v0.9.1.safetensors" +guidance_scale: 3 +stg_scale: 1 +rescaling_scale: 0.7 +skip_block_list: [19] +num_inference_steps: 40 +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "bfloat16" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false \ No newline at end of file diff --git a/configs/ltxv-2b-0.9.5.yaml b/configs/ltxv-2b-0.9.5.yaml new file mode 100644 index 0000000..5998c60 --- /dev/null +++ b/configs/ltxv-2b-0.9.5.yaml @@ -0,0 +1,17 @@ +pipeline_type: base +checkpoint_path: "ltx-video-2b-v0.9.5.safetensors" +guidance_scale: 3 +stg_scale: 1 +rescaling_scale: 0.7 +skip_block_list: [19] +num_inference_steps: 40 +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "bfloat16" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false \ No newline at end of file diff --git a/configs/ltxv-2b-0.9.6-distilled.yaml b/configs/ltxv-2b-0.9.6-distilled.yaml index 39fae26..328d929 100644 --- a/configs/ltxv-2b-0.9.6-distilled.yaml +++ b/configs/ltxv-2b-0.9.6-distilled.yaml @@ -1,9 +1,8 @@ pipeline_type: base checkpoint_path: "ltxv-2b-0.9.6-distilled-04-25.safetensors" -guidance_scale: 3 -stg_scale: 1 -rescaling_scale: 0.7 -skip_block_list: [19] +guidance_scale: 1 +stg_scale: 0 +rescaling_scale: 1 num_inference_steps: 8 stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" decode_timestep: 0.05 diff --git a/configs/ltxv-2b-0.9.yaml b/configs/ltxv-2b-0.9.yaml new file mode 100644 index 0000000..f501ca6 --- /dev/null +++ b/configs/ltxv-2b-0.9.yaml @@ -0,0 +1,17 @@ +pipeline_type: base +checkpoint_path: "ltx-video-2b-v0.9.safetensors" +guidance_scale: 3 +stg_scale: 1 +rescaling_scale: 0.7 +skip_block_list: [19] +num_inference_steps: 40 +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "bfloat16" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false \ No newline at end of file diff --git a/inference.py b/inference.py index 322809e..0642afa 100644 --- a/inference.py +++ b/inference.py @@ -11,6 +11,7 @@ import imageio import json import numpy as np import torch +import cv2 from safetensors import safe_open from PIL import Image from transformers import ( @@ -35,6 +36,7 @@ from ltx_video.pipelines.pipeline_ltx_video import ( from ltx_video.schedulers.rf import RectifiedFlowScheduler from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler +import ltx_video.pipelines.crf_compressor as crf_compressor MAX_HEIGHT = 720 MAX_WIDTH = 1280 @@ -96,7 +98,12 @@ def load_image_to_tensor_with_resize_and_crop( image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) if not just_crop: image = image.resize((target_width, target_height)) - frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float() + + image = np.array(image) + image = cv2.GaussianBlur(image, (3, 3), 0) + frame_tensor = torch.from_numpy(image).float() + frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0 + frame_tensor = frame_tensor.permute(2, 0, 1) frame_tensor = (frame_tensor / 127.5) - 1.0 # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width) return frame_tensor.unsqueeze(0).unsqueeze(2) @@ -266,13 +273,6 @@ def main(): help="Path to the input video (or imaage) to be modified using the video-to-video pipeline", ) - parser.add_argument( - "--strength", - type=float, - default=1.0, - help="Editing strength (noising level) for video-to-video pipeline.", - ) - # Conditioning arguments parser.add_argument( "--conditioning_media_paths", @@ -407,7 +407,6 @@ def infer( negative_prompt: str, offload_to_cpu: bool, input_media_path: Optional[str] = None, - strength: Optional[float] = 1.0, conditioning_media_paths: Optional[List[str]] = None, conditioning_strengths: Optional[List[float]] = None, conditioning_start_frames: Optional[List[int]] = None, @@ -614,7 +613,6 @@ def infer( frame_rate=frame_rate, **sample, media_items=media_item, - strength=strength, conditioning_items=conditioning_items, is_video=True, vae_per_channel_normalize=True, diff --git a/ltx_video/pipelines/crf_compressor.py b/ltx_video/pipelines/crf_compressor.py new file mode 100644 index 0000000..9b9380a --- /dev/null +++ b/ltx_video/pipelines/crf_compressor.py @@ -0,0 +1,50 @@ +import av +import torch +import io +import numpy as np + + +def _encode_single_frame(output_file, image_array: np.ndarray, crf): + container = av.open(output_file, "w", format="mp4") + try: + stream = container.add_stream( + "libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"} + ) + stream.height = image_array.shape[0] + stream.width = image_array.shape[1] + av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat( + format="yuv420p" + ) + container.mux(stream.encode(av_frame)) + container.mux(stream.encode()) + finally: + container.close() + + +def _decode_single_frame(video_file): + container = av.open(video_file) + try: + stream = next(s for s in container.streams if s.type == "video") + frame = next(container.decode(stream)) + finally: + container.close() + return frame.to_ndarray(format="rgb24") + + +def compress(image: torch.Tensor, crf=29): + if crf == 0: + return image + + image_array = ( + (image[: (image.shape[0] // 2) * 2, : (image.shape[1] // 2) * 2] * 255.0) + .byte() + .cpu() + .numpy() + ) + with io.BytesIO() as output_file: + _encode_single_frame(output_file, image_array, crf) + video_bytes = output_file.getvalue() + with io.BytesIO(video_bytes) as video_file: + image_array = _decode_single_frame(video_file) + tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0 + return tensor diff --git a/ltx_video/pipelines/pipeline_ltx_video.py b/ltx_video/pipelines/pipeline_ltx_video.py index 85161fd..16ee39e 100644 --- a/ltx_video/pipelines/pipeline_ltx_video.py +++ b/ltx_video/pipelines/pipeline_ltx_video.py @@ -45,6 +45,11 @@ from ltx_video.models.autoencoders.vae_encode import ( ) +try: + import torch_xla.distributed.spmd as xs +except ImportError: + xs = None + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -127,7 +132,8 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, - max_timestep: Optional[float] = 1.0, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, **kwargs, ): """ @@ -170,14 +176,21 @@ def retrieve_timesteps( scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps - if max_timestep < 1.0: - if max_timestep < timesteps.min(): + if ( + skip_initial_inference_steps < 0 + or skip_final_inference_steps < 0 + or skip_initial_inference_steps + skip_final_inference_steps + >= num_inference_steps + ): raise ValueError( - f"max_timestep {max_timestep} is smaller than the minimum timestep {timesteps.min()}" + "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" ) - timesteps = timesteps[timesteps <= max_timestep] - num_inference_steps = len(timesteps) + + timesteps = timesteps[ + skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps + ] scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + num_inference_steps = len(timesteps) return timesteps, num_inference_steps @@ -752,8 +765,11 @@ class LTXVideoPipeline(DiffusionPipeline): prompt: Union[str, List[str]] = None, negative_prompt: str = "", num_inference_steps: int = 20, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, timesteps: List[int] = None, guidance_scale: Union[float, List[float]] = 4.5, + cfg_star_rescale: bool = False, skip_layer_strategy: Optional[SkipLayerStrategy] = None, skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, stg_scale: Union[float, List[float]] = 1.0, @@ -779,7 +795,6 @@ class LTXVideoPipeline(DiffusionPipeline): text_encoder_max_tokens: int = 256, stochastic_sampling: bool = False, media_items: Optional[torch.Tensor] = None, - strength: Optional[float] = 1.0, **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: """ @@ -796,6 +811,12 @@ class LTXVideoPipeline(DiffusionPipeline): num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. If `timesteps` is provided, this parameter is ignored. + skip_initial_inference_steps (`int`, *optional*, defaults to 0): + The number of initial timesteps to skip. After calculating the timesteps, this number of timesteps will + be removed from the beginning of the timesteps list. Meaning the highest-timesteps values will not run. + skip_final_inference_steps (`int`, *optional*, defaults to 0): + The number of final timesteps to skip. After calculating the timesteps, this number of timesteps will + be removed from the end of the timesteps list. Meaning the lowest-timesteps values will not run. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. @@ -805,6 +826,9 @@ class LTXVideoPipeline(DiffusionPipeline): Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + cfg_star_rescale (`bool`, *optional*, defaults to `False`): + If set to `True`, applies the CFG star rescale. Scales the negative prediction according to dot + product between positive and negative. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to self.unet.config.sample_size): @@ -852,10 +876,6 @@ class LTXVideoPipeline(DiffusionPipeline): If set to `True`, the sampling is stochastic. If set to `False`, the sampling is deterministic. media_items ('torch.Tensor', *optional*): The input media item used for image-to-image / video-to-video. - When provided, they will be noised according to 'strength' and then fully denoised. - strength ('floaty', *optional* defaults to 1.0): - The editing level in image-to-image / video-to-video. The provided input will be noised - to this level. Examples: Returns: @@ -912,8 +932,12 @@ class LTXVideoPipeline(DiffusionPipeline): if isinstance(self.scheduler, TimestepShifter): retrieve_timesteps_kwargs["samples_shape"] = latent_shape - assert strength == 1.0 or latents is not None or media_items is not None, ( - "strength < 1 is used for image-to-image/video-to-video - " + assert ( + skip_initial_inference_steps == 0 + or latents is not None + or media_items is not None + ), ( + f"skip_initial_inference_steps ({skip_initial_inference_steps}) is used for image-to-image/video-to-video - " "media_item or latents should be provided." ) @@ -922,9 +946,11 @@ class LTXVideoPipeline(DiffusionPipeline): num_inference_steps, device, timesteps, - max_timestep=strength, + skip_initial_inference_steps=skip_initial_inference_steps, + skip_final_inference_steps=skip_final_inference_steps, **retrieve_timesteps_kwargs, ) + if self.allowed_inference_steps is not None: for timestep in [round(x, 4) for x in timesteps.tolist()]: assert ( @@ -981,7 +1007,7 @@ class LTXVideoPipeline(DiffusionPipeline): # Normalize skip_block_list to always be None or a list of lists matching timesteps if skip_block_list is not None: # Convert single list to list of lists if needed - if not isinstance(skip_block_list[0], list): + if len(skip_block_list) == 0 or not isinstance(skip_block_list[0], list): skip_block_list = [skip_block_list] * len(timesteps) else: new_skip_block_list = [] @@ -1189,6 +1215,22 @@ class LTXVideoPipeline(DiffusionPipeline): )[-2:] if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_conds)[:2] + + if cfg_star_rescale: + # Rescales the unconditional noise prediction using the projection of the conditional prediction onto it: + # α = (⟨ε_text, ε_uncond⟩ / ||ε_uncond||²), then ε_uncond ← α * ε_uncond + # where ε_text is the conditional noise prediction and ε_uncond is the unconditional one. + positive_flat = noise_pred_text.view(batch_size, -1) + negative_flat = noise_pred_uncond.view(batch_size, -1) + dot_product = torch.sum( + positive_flat * negative_flat, dim=1, keepdim=True + ) + squared_norm = ( + torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 + ) + alpha = dot_product / squared_norm + noise_pred_uncond = alpha * noise_pred_uncond + noise_pred = noise_pred_uncond + guidance_scale[i] * ( noise_pred_text - noise_pred_uncond ) @@ -1695,6 +1737,37 @@ class LTXVideoPipeline(DiffusionPipeline): return num_frames +def adain_filter_latent( + latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0 +): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on + statistics from a reference latent tensor. + + Args: + latent (torch.Tensor): Input latents to normalize + reference_latent (torch.Tensor): The reference latents providing style statistics. + factor (float): Blending factor between original and transformed latent. + Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean( + reference_latents[i, c], dim=None + ) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result + + class LTXMultiScalePipeline: def _upsample_latents( self, latest_upsampler: LatentUpsampler, latents: torch.Tensor @@ -1743,6 +1816,9 @@ class LTXMultiScalePipeline: latents = result.images upsampled_latents = self._upsample_latents(self.latent_upsampler, latents) + upsampled_latents = adain_filter_latent( + latents=upsampled_latents, reference_latents=latents + ) kwargs = original_kwargs diff --git a/pyproject.toml b/pyproject.toml index 44db660..bcf8544 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,9 @@ dependencies = [ inference-script = [ "accelerate", "matplotlib", - "imageio[ffmpeg]" + "imageio[ffmpeg]", + "av", + "opencv-python" ] test = [ "pytest", diff --git a/tests/test_inference.py b/tests/test_inference.py index d7e1c08..0a7cf91 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -101,13 +101,13 @@ def test_vid2vid(tmp_path, test_paths): "frame_rate": 25, "prompt": "A young woman with wavy, shoulder-length light brown hair stands outdoors on a foggy day. She wears a cozy pink turtleneck sweater, with a serene expression and piercing blue eyes. A wooden fence and a misty, grassy field fade into the background, evoking a calm and introspective mood.", "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", - "strength": 0.95, "offload_to_cpu": False, "input_media_path": test_paths["input_video_path"], } config = { "num_inference_steps": 3, + "skip_initial_inference_steps": 1, "guidance_scale": 2.5, "stg_scale": 1, "stg_rescale": 0.7,