Pipeline: Enhance inference pipelines with new features

* Adaptive normalization after latent upsampling
* CFG Star Rescale
* Varying STG/CFG parameters per step
* Support skipping the initial and/or the final diffusion steps
* CRF compression for image condition (useful for getting more motion in image-to-video)
This commit is contained in:
Yoav HaCohen
2025-05-14 12:25:08 +03:00
parent cb6f842770
commit 93af6864b9
10 changed files with 220 additions and 43 deletions

View File

@@ -1,4 +1,3 @@
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.7-dev.safetensors"
downscale_factor: 0.6666666
@@ -14,20 +13,22 @@ prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-P
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
first_pass:
guidance_scale: [3]
stg_scale: [1]
rescaling_scale: [0.7]
guidance_timesteps: [1.0]
skip_block_list: [19] # [[1], [1,2], [1,2,3], [27], [28], [28]]
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
stg_scale: [0, 0, 4, 4, 4, 2, 1]
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
num_inference_steps: 30
skip_final_inference_steps: 3
cfg_star_rescale: true
second_pass:
guidance_scale: [3]
guidance_scale: [1]
stg_scale: [1]
rescaling_scale: [0.7]
rescaling_scale: [1]
guidance_timesteps: [1.0]
skip_block_list: [19] # [[1], [1,2], [1,2,3], [27], [28], [28]]
num_inference_steps: 10
strength: 0.85
skip_block_list: [27]
num_inference_steps: 30
skip_initial_inference_steps: 17
cfg_star_rescale: true

View File

@@ -0,0 +1,17 @@
pipeline_type: base
checkpoint_path: "ltx-video-2b-v0.9.1.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

View File

@@ -0,0 +1,17 @@
pipeline_type: base
checkpoint_path: "ltx-video-2b-v0.9.5.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

View File

@@ -1,9 +1,8 @@
pipeline_type: base
checkpoint_path: "ltxv-2b-0.9.6-distilled-04-25.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
num_inference_steps: 8
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05

17
configs/ltxv-2b-0.9.yaml Normal file
View File

@@ -0,0 +1,17 @@
pipeline_type: base
checkpoint_path: "ltx-video-2b-v0.9.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

View File

@@ -11,6 +11,7 @@ import imageio
import json
import numpy as np
import torch
import cv2
from safetensors import safe_open
from PIL import Image
from transformers import (
@@ -35,6 +36,7 @@ from ltx_video.pipelines.pipeline_ltx_video import (
from ltx_video.schedulers.rf import RectifiedFlowScheduler
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
import ltx_video.pipelines.crf_compressor as crf_compressor
MAX_HEIGHT = 720
MAX_WIDTH = 1280
@@ -96,7 +98,12 @@ def load_image_to_tensor_with_resize_and_crop(
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
if not just_crop:
image = image.resize((target_width, target_height))
frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
image = np.array(image)
image = cv2.GaussianBlur(image, (3, 3), 0)
frame_tensor = torch.from_numpy(image).float()
frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
frame_tensor = frame_tensor.permute(2, 0, 1)
frame_tensor = (frame_tensor / 127.5) - 1.0
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
return frame_tensor.unsqueeze(0).unsqueeze(2)
@@ -266,13 +273,6 @@ def main():
help="Path to the input video (or imaage) to be modified using the video-to-video pipeline",
)
parser.add_argument(
"--strength",
type=float,
default=1.0,
help="Editing strength (noising level) for video-to-video pipeline.",
)
# Conditioning arguments
parser.add_argument(
"--conditioning_media_paths",
@@ -407,7 +407,6 @@ def infer(
negative_prompt: str,
offload_to_cpu: bool,
input_media_path: Optional[str] = None,
strength: Optional[float] = 1.0,
conditioning_media_paths: Optional[List[str]] = None,
conditioning_strengths: Optional[List[float]] = None,
conditioning_start_frames: Optional[List[int]] = None,
@@ -614,7 +613,6 @@ def infer(
frame_rate=frame_rate,
**sample,
media_items=media_item,
strength=strength,
conditioning_items=conditioning_items,
is_video=True,
vae_per_channel_normalize=True,

View File

@@ -0,0 +1,50 @@
import av
import torch
import io
import numpy as np
def _encode_single_frame(output_file, image_array: np.ndarray, crf):
container = av.open(output_file, "w", format="mp4")
try:
stream = container.add_stream(
"libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
)
stream.height = image_array.shape[0]
stream.width = image_array.shape[1]
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
format="yuv420p"
)
container.mux(stream.encode(av_frame))
container.mux(stream.encode())
finally:
container.close()
def _decode_single_frame(video_file):
container = av.open(video_file)
try:
stream = next(s for s in container.streams if s.type == "video")
frame = next(container.decode(stream))
finally:
container.close()
return frame.to_ndarray(format="rgb24")
def compress(image: torch.Tensor, crf=29):
if crf == 0:
return image
image_array = (
(image[: (image.shape[0] // 2) * 2, : (image.shape[1] // 2) * 2] * 255.0)
.byte()
.cpu()
.numpy()
)
with io.BytesIO() as output_file:
_encode_single_frame(output_file, image_array, crf)
video_bytes = output_file.getvalue()
with io.BytesIO(video_bytes) as video_file:
image_array = _decode_single_frame(video_file)
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
return tensor

View File

@@ -45,6 +45,11 @@ from ltx_video.models.autoencoders.vae_encode import (
)
try:
import torch_xla.distributed.spmd as xs
except ImportError:
xs = None
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -127,7 +132,8 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
max_timestep: Optional[float] = 1.0,
skip_initial_inference_steps: int = 0,
skip_final_inference_steps: int = 0,
**kwargs,
):
"""
@@ -170,14 +176,21 @@ def retrieve_timesteps(
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
if max_timestep < 1.0:
if max_timestep < timesteps.min():
if (
skip_initial_inference_steps < 0
or skip_final_inference_steps < 0
or skip_initial_inference_steps + skip_final_inference_steps
>= num_inference_steps
):
raise ValueError(
f"max_timestep {max_timestep} is smaller than the minimum timestep {timesteps.min()}"
"invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps"
)
timesteps = timesteps[timesteps <= max_timestep]
num_inference_steps = len(timesteps)
timesteps = timesteps[
skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps
]
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
num_inference_steps = len(timesteps)
return timesteps, num_inference_steps
@@ -752,8 +765,11 @@ class LTXVideoPipeline(DiffusionPipeline):
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 20,
skip_initial_inference_steps: int = 0,
skip_final_inference_steps: int = 0,
timesteps: List[int] = None,
guidance_scale: Union[float, List[float]] = 4.5,
cfg_star_rescale: bool = False,
skip_layer_strategy: Optional[SkipLayerStrategy] = None,
skip_block_list: Optional[Union[List[List[int]], List[int]]] = None,
stg_scale: Union[float, List[float]] = 1.0,
@@ -779,7 +795,6 @@ class LTXVideoPipeline(DiffusionPipeline):
text_encoder_max_tokens: int = 256,
stochastic_sampling: bool = False,
media_items: Optional[torch.Tensor] = None,
strength: Optional[float] = 1.0,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
"""
@@ -796,6 +811,12 @@ class LTXVideoPipeline(DiffusionPipeline):
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. If `timesteps` is provided, this parameter is ignored.
skip_initial_inference_steps (`int`, *optional*, defaults to 0):
The number of initial timesteps to skip. After calculating the timesteps, this number of timesteps will
be removed from the beginning of the timesteps list. Meaning the highest-timesteps values will not run.
skip_final_inference_steps (`int`, *optional*, defaults to 0):
The number of final timesteps to skip. After calculating the timesteps, this number of timesteps will
be removed from the end of the timesteps list. Meaning the lowest-timesteps values will not run.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
@@ -805,6 +826,9 @@ class LTXVideoPipeline(DiffusionPipeline):
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
cfg_star_rescale (`bool`, *optional*, defaults to `False`):
If set to `True`, applies the CFG star rescale. Scales the negative prediction according to dot
product between positive and negative.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
@@ -852,10 +876,6 @@ class LTXVideoPipeline(DiffusionPipeline):
If set to `True`, the sampling is stochastic. If set to `False`, the sampling is deterministic.
media_items ('torch.Tensor', *optional*):
The input media item used for image-to-image / video-to-video.
When provided, they will be noised according to 'strength' and then fully denoised.
strength ('floaty', *optional* defaults to 1.0):
The editing level in image-to-image / video-to-video. The provided input will be noised
to this level.
Examples:
Returns:
@@ -912,8 +932,12 @@ class LTXVideoPipeline(DiffusionPipeline):
if isinstance(self.scheduler, TimestepShifter):
retrieve_timesteps_kwargs["samples_shape"] = latent_shape
assert strength == 1.0 or latents is not None or media_items is not None, (
"strength < 1 is used for image-to-image/video-to-video - "
assert (
skip_initial_inference_steps == 0
or latents is not None
or media_items is not None
), (
f"skip_initial_inference_steps ({skip_initial_inference_steps}) is used for image-to-image/video-to-video - "
"media_item or latents should be provided."
)
@@ -922,9 +946,11 @@ class LTXVideoPipeline(DiffusionPipeline):
num_inference_steps,
device,
timesteps,
max_timestep=strength,
skip_initial_inference_steps=skip_initial_inference_steps,
skip_final_inference_steps=skip_final_inference_steps,
**retrieve_timesteps_kwargs,
)
if self.allowed_inference_steps is not None:
for timestep in [round(x, 4) for x in timesteps.tolist()]:
assert (
@@ -981,7 +1007,7 @@ class LTXVideoPipeline(DiffusionPipeline):
# Normalize skip_block_list to always be None or a list of lists matching timesteps
if skip_block_list is not None:
# Convert single list to list of lists if needed
if not isinstance(skip_block_list[0], list):
if len(skip_block_list) == 0 or not isinstance(skip_block_list[0], list):
skip_block_list = [skip_block_list] * len(timesteps)
else:
new_skip_block_list = []
@@ -1189,6 +1215,22 @@ class LTXVideoPipeline(DiffusionPipeline):
)[-2:]
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_conds)[:2]
if cfg_star_rescale:
# Rescales the unconditional noise prediction using the projection of the conditional prediction onto it:
# α = (⟨ε_text, ε_uncond⟩ / ||ε_uncond||²), then ε_uncond ← α * ε_uncond
# where ε_text is the conditional noise prediction and ε_uncond is the unconditional one.
positive_flat = noise_pred_text.view(batch_size, -1)
negative_flat = noise_pred_uncond.view(batch_size, -1)
dot_product = torch.sum(
positive_flat * negative_flat, dim=1, keepdim=True
)
squared_norm = (
torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
)
alpha = dot_product / squared_norm
noise_pred_uncond = alpha * noise_pred_uncond
noise_pred = noise_pred_uncond + guidance_scale[i] * (
noise_pred_text - noise_pred_uncond
)
@@ -1695,6 +1737,37 @@ class LTXVideoPipeline(DiffusionPipeline):
return num_frames
def adain_filter_latent(
latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0
):
"""
Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on
statistics from a reference latent tensor.
Args:
latent (torch.Tensor): Input latents to normalize
reference_latent (torch.Tensor): The reference latents providing style statistics.
factor (float): Blending factor between original and transformed latent.
Range: -10.0 to 10.0, Default: 1.0
Returns:
torch.Tensor: The transformed latent tensor
"""
result = latents.clone()
for i in range(latents.size(0)):
for c in range(latents.size(1)):
r_sd, r_mean = torch.std_mean(
reference_latents[i, c], dim=None
) # index by original dim order
i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
result = torch.lerp(latents, result, factor)
return result
class LTXMultiScalePipeline:
def _upsample_latents(
self, latest_upsampler: LatentUpsampler, latents: torch.Tensor
@@ -1743,6 +1816,9 @@ class LTXMultiScalePipeline:
latents = result.images
upsampled_latents = self._upsample_latents(self.latent_upsampler, latents)
upsampled_latents = adain_filter_latent(
latents=upsampled_latents, reference_latents=latents
)
kwargs = original_kwargs

View File

@@ -30,7 +30,9 @@ dependencies = [
inference-script = [
"accelerate",
"matplotlib",
"imageio[ffmpeg]"
"imageio[ffmpeg]",
"av",
"opencv-python"
]
test = [
"pytest",

View File

@@ -101,13 +101,13 @@ def test_vid2vid(tmp_path, test_paths):
"frame_rate": 25,
"prompt": "A young woman with wavy, shoulder-length light brown hair stands outdoors on a foggy day. She wears a cozy pink turtleneck sweater, with a serene expression and piercing blue eyes. A wooden fence and a misty, grassy field fade into the background, evoking a calm and introspective mood.",
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
"strength": 0.95,
"offload_to_cpu": False,
"input_media_path": test_paths["input_video_path"],
}
config = {
"num_inference_steps": 3,
"skip_initial_inference_steps": 1,
"guidance_scale": 2.5,
"stg_scale": 1,
"stg_rescale": 0.7,