Inference: Integrate fp8 kernels
This commit is contained in:
@@ -182,6 +182,28 @@ def seed_everething(seed: int):
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
|
||||
def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel:
|
||||
if precision == "float8_e4m3fn":
|
||||
try:
|
||||
from q8_kernels.integration.patch_transformer import (
|
||||
patch_diffusers_transformer as patch_transformer_for_q8_kernels,
|
||||
)
|
||||
|
||||
transformer = Transformer3DModel.from_pretrained(
|
||||
ckpt_path, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
patch_transformer_for_q8_kernels(transformer)
|
||||
return transformer
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from https://github.com/Lightricks/LTXVideo-Q8-Kernels"
|
||||
)
|
||||
elif precision == "bfloat16":
|
||||
return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16)
|
||||
else:
|
||||
return Transformer3DModel.from_pretrained(ckpt_path)
|
||||
|
||||
|
||||
def create_ltx_video_pipeline(
|
||||
ckpt_path: str,
|
||||
precision: str,
|
||||
@@ -204,7 +226,7 @@ def create_ltx_video_pipeline(
|
||||
allowed_inference_steps = configs.get("allowed_inference_steps", None)
|
||||
|
||||
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
|
||||
transformer = Transformer3DModel.from_pretrained(ckpt_path)
|
||||
transformer = create_transformer(ckpt_path, precision)
|
||||
|
||||
# Use constructor if sampler is specified, otherwise use from_pretrained
|
||||
if sampler == "from_checkpoint" or not sampler:
|
||||
@@ -247,8 +269,6 @@ def create_ltx_video_pipeline(
|
||||
prompt_enhancer_llm_tokenizer = None
|
||||
|
||||
vae = vae.to(torch.bfloat16)
|
||||
if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
|
||||
transformer = transformer.to(torch.bfloat16)
|
||||
text_encoder = text_encoder.to(torch.bfloat16)
|
||||
|
||||
# Use submodels for the pipeline
|
||||
|
||||
Reference in New Issue
Block a user