Inference: Integrate fp8 kernels

This commit is contained in:
Yaki
2025-07-17 10:57:13 +00:00
committed by Yoav HaCohen
parent bdc8f017f0
commit ccfc030932

View File

@@ -182,6 +182,28 @@ def seed_everething(seed: int):
torch.mps.manual_seed(seed)
def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel:
if precision == "float8_e4m3fn":
try:
from q8_kernels.integration.patch_transformer import (
patch_diffusers_transformer as patch_transformer_for_q8_kernels,
)
transformer = Transformer3DModel.from_pretrained(
ckpt_path, dtype=torch.float8_e4m3fn
)
patch_transformer_for_q8_kernels(transformer)
return transformer
except ImportError:
raise ValueError(
"Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from https://github.com/Lightricks/LTXVideo-Q8-Kernels"
)
elif precision == "bfloat16":
return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16)
else:
return Transformer3DModel.from_pretrained(ckpt_path)
def create_ltx_video_pipeline(
ckpt_path: str,
precision: str,
@@ -204,7 +226,7 @@ def create_ltx_video_pipeline(
allowed_inference_steps = configs.get("allowed_inference_steps", None)
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
transformer = Transformer3DModel.from_pretrained(ckpt_path)
transformer = create_transformer(ckpt_path, precision)
# Use constructor if sampler is specified, otherwise use from_pretrained
if sampler == "from_checkpoint" or not sampler:
@@ -247,8 +269,6 @@ def create_ltx_video_pipeline(
prompt_enhancer_llm_tokenizer = None
vae = vae.to(torch.bfloat16)
if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
transformer = transformer.to(torch.bfloat16)
text_encoder = text_encoder.to(torch.bfloat16)
# Use submodels for the pipeline