diff --git a/ltx_video/inference.py b/ltx_video/inference.py index 1eddc8e..211f317 100644 --- a/ltx_video/inference.py +++ b/ltx_video/inference.py @@ -182,6 +182,28 @@ def seed_everething(seed: int): torch.mps.manual_seed(seed) +def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel: + if precision == "float8_e4m3fn": + try: + from q8_kernels.integration.patch_transformer import ( + patch_diffusers_transformer as patch_transformer_for_q8_kernels, + ) + + transformer = Transformer3DModel.from_pretrained( + ckpt_path, dtype=torch.float8_e4m3fn + ) + patch_transformer_for_q8_kernels(transformer) + return transformer + except ImportError: + raise ValueError( + "Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from https://github.com/Lightricks/LTXVideo-Q8-Kernels" + ) + elif precision == "bfloat16": + return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16) + else: + return Transformer3DModel.from_pretrained(ckpt_path) + + def create_ltx_video_pipeline( ckpt_path: str, precision: str, @@ -204,7 +226,7 @@ def create_ltx_video_pipeline( allowed_inference_steps = configs.get("allowed_inference_steps", None) vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) - transformer = Transformer3DModel.from_pretrained(ckpt_path) + transformer = create_transformer(ckpt_path, precision) # Use constructor if sampler is specified, otherwise use from_pretrained if sampler == "from_checkpoint" or not sampler: @@ -247,8 +269,6 @@ def create_ltx_video_pipeline( prompt_enhancer_llm_tokenizer = None vae = vae.to(torch.bfloat16) - if precision == "bfloat16" and transformer.dtype != torch.bfloat16: - transformer = transformer.to(torch.bfloat16) text_encoder = text_encoder.to(torch.bfloat16) # Use submodels for the pipeline