Spaces:
Build error
Build error
| import os | |
| import subprocess | |
| from pathlib import Path | |
| import gradio as gr | |
| from config import hparams as hp | |
| from config import hparams_gradio as hp_gradio | |
| from nota_wav2lip import Wav2LipModelComparisonGradio | |
| # device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| device = hp_gradio.device | |
| print(f'Using {device} for inference.') | |
| video_label_dict = hp_gradio.sample.video | |
| audio_label_dict = hp_gradio.sample.audio | |
| LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None) | |
| LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None) | |
| LRS_INFERENCE_SAMPLE = os.getenv('LRS_INFERENCE_SAMPLE', None) | |
| if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None: | |
| subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True) | |
| if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None: | |
| subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True) | |
| path_inference_sample = "sample.tar.gz" | |
| if not Path(path_inference_sample).exists() and LRS_INFERENCE_SAMPLE is not None: | |
| subprocess.call(f"wget --no-check-certificate -O {path_inference_sample} {LRS_INFERENCE_SAMPLE}", shell=True) | |
| subprocess.call(f"tar -zxvf {path_inference_sample}", shell=True) | |
| if __name__ == "__main__": | |
| servicer = Wav2LipModelComparisonGradio( | |
| device=device, | |
| video_label_dict=video_label_dict, | |
| audio_label_list=audio_label_dict, | |
| default_video='v1', | |
| default_audio='a1' | |
| ) | |
| for video_name in sorted(video_label_dict): | |
| video_stem = Path(video_label_dict[video_name]) | |
| servicer.update_video(video_stem, video_stem.with_suffix('.json'), | |
| name=video_name) | |
| for audio_name in sorted(audio_label_dict): | |
| audio_path = Path(audio_label_dict[audio_name]) | |
| servicer.update_audio(audio_path, name=audio_name) | |
| with gr.Blocks(theme='nota-ai/theme', css=Path('docs/main.css').read_text()) as demo: | |
| gr.Markdown(Path('docs/header.md').read_text()) | |
| gr.Markdown(Path('docs/description.md').read_text()) | |
| with gr.Row(): | |
| with gr.Column(variant='panel'): | |
| gr.Markdown('## Select or Upload input video and audio', sanitize_html=False) | |
| # Define preview slots | |
| sample_video = gr.Video(interactive=False, label="Input Video") | |
| sample_audio = gr.Audio(interactive=False, label="Input Audio") | |
| # New upload inputs | |
| video_upload = gr.Video(source="upload", type="filepath", label="Upload Video") | |
| audio_upload = gr.Audio(source="upload", type="filepath", label="Upload Audio") | |
| # Define radio inputs | |
| video_selection = gr.Radio(video_label_dict, | |
| type='value', label="Select an input video:") | |
| audio_selection = gr.Radio(audio_label_dict, | |
| type='value', label="Select an input audio:") | |
| # Define button inputs | |
| with gr.Row(equal_height=True): | |
| generate_original_button = gr.Button(value="Generate with Original Model", variant="primary") | |
| generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary") | |
| with gr.Column(variant='panel'): | |
| # Define original model output components | |
| gr.Markdown('## Original Wav2Lip') | |
| original_model_output = gr.Video(label="Original Model", interactive=False) | |
| with gr.Column(): | |
| with gr.Row(equal_height=True): | |
| original_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)") | |
| original_model_fps = gr.Textbox(value="", label="FPS") | |
| original_model_params = gr.Textbox(value=servicer.params['wav2lip'], label="# Parameters") | |
| with gr.Column(variant='panel'): | |
| # Define compressed model output components | |
| gr.Markdown('## Compressed Wav2Lip (Ours)') | |
| compressed_model_output = gr.Video(label="Compressed Model", interactive=False) | |
| with gr.Column(): | |
| with gr.Row(equal_height=True): | |
| compressed_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)") | |
| compressed_model_fps = gr.Textbox(value="", label="FPS") | |
| compressed_model_params = gr.Textbox(value=servicer.params['nota_wav2lip'], label="# Parameters") | |
| # Switch video and audio samples when selecting the radio button | |
| video_selection.change(fn=servicer.switch_video_samples, inputs=video_selection, outputs=sample_video) | |
| audio_selection.change(fn=servicer.switch_audio_samples, inputs=audio_selection, outputs=sample_audio) | |
| # Update preview when uploading | |
| video_upload.change(fn=lambda x: x, inputs=video_upload, outputs=sample_video) | |
| audio_upload.change(fn=lambda x: x, inputs=audio_upload, outputs=sample_audio) | |
| # Helper: decide whether to use uploaded or selected | |
| def resolve_inputs(video_choice, audio_choice, video_file, audio_file): | |
| video_path = video_file if video_file else video_label_dict.get(video_choice) | |
| audio_path = audio_file if audio_file else audio_label_dict.get(audio_choice) | |
| return video_path, audio_path | |
| # Click the generate button for original model | |
| generate_original_button.click( | |
| fn=lambda v, a, vu, au: servicer.generate_original_model(*resolve_inputs(v, a, vu, au)), | |
| inputs=[video_selection, audio_selection, video_upload, audio_upload], | |
| outputs=[original_model_output, original_model_inference_time, original_model_fps] | |
| ) | |
| # Click the generate button for compressed model | |
| generate_compressed_button.click( | |
| fn=lambda v, a, vu, au: servicer.generate_compressed_model(*resolve_inputs(v, a, vu, au)), | |
| inputs=[video_selection, audio_selection, video_upload, audio_upload], | |
| outputs=[compressed_model_output, compressed_model_inference_time, compressed_model_fps] | |
| ) | |
| gr.Markdown(Path('docs/footer.md').read_text()) | |
| demo.queue().launch() |