sayakpaul's picture
sayakpaul HF Staff
up
5ca41bd
raw
history blame
1.04 kB
from datetime import datetime
import gradio as gr
import spaces
import torch
from diffusers import FluxPipeline
from aoti import aoti_load_
# --- Model Loading ---
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/Flux.1-Dev", torch_dtype=torch.bfloat16
).to(device)
pipeline.transformer.fuse_qkv_projections()
aoti_load_(pipeline.transformer, "sayakpaul/flux-dev-aot", "flux-dev-aot.pt2")
@spaces.GPU
def generate_image(prompt: str, progress=gr.Progress(track_tqdm=True)):
generator = torch.Generator(device='cuda').manual_seed(42)
t0 = datetime.now()
output = pipeline(
prompt=prompt,
num_inference_steps=28,
generator=generator,
)
return [(output.images[0], f'{(datetime.now() - t0).total_seconds():.2f}s')]
gr.Interface(
fn=generate_image,
inputs=gr.Text(label="Prompt"),
outputs=gr.Gallery(),
examples=["A cat playing with a ball of yarn"],
cache_examples=False,
).launch()