Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from address_extractor import AddressExtractor | |
| import tempfile | |
| import os | |
| import librosa | |
| import soundfile as sf | |
| # Instantiate your AddressExtractor class | |
| address_extractor = AddressExtractor() | |
| def extract_from_text(input_text): | |
| if not input_text.strip(): | |
| return "Error: No text provided." | |
| messages = [ | |
| {"role": "system", "content": address_extractor.system_prompt_text}, | |
| {"role": "user", "content": input_text}, | |
| ] | |
| prompt = address_extractor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| chat_input = address_extractor.tokenizer(prompt, return_tensors="pt").to(address_extractor.bitnet_model.device) | |
| chat_outputs = address_extractor.bitnet_model.generate(**chat_input, max_new_tokens=256) | |
| generated_text = address_extractor.tokenizer.decode( | |
| chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True | |
| ) | |
| return generated_text.strip() or "No address detected." | |
| def extract_from_audio(audio_file): | |
| if audio_file is None: | |
| return "Error: No audio provided." | |
| # with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| # tmp_file.write(audio_file.read()) | |
| # tmp_file_path = tmp_file.name | |
| try: | |
| audio, sr = librosa.load(audio_file, sr=16000) | |
| sf.write(audio_file, audio, 16000) | |
| # segments = address_extractor.whisper_model.transcribe(tmp_file_path) | |
| segments = address_extractor.whisper_model.transcribe(audio_file) | |
| input_text = " ".join([seg.text.strip() for seg in segments]) | |
| input_text = address_extractor.preprocess_text(input_text) | |
| messages = [ | |
| {"role": "system", "content": address_extractor.system_prompt_speech}, | |
| {"role": "user", "content": input_text}, | |
| ] | |
| prompt = address_extractor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| chat_input = address_extractor.tokenizer(prompt, return_tensors="pt").to(address_extractor.bitnet_model.device) | |
| chat_outputs = address_extractor.bitnet_model.generate(**chat_input, max_new_tokens=256) | |
| generated_text = address_extractor.tokenizer.decode( | |
| chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True | |
| ) | |
| result = generated_text.strip() or "No address detected." | |
| finally: | |
| # os.remove(tmp_file_path) | |
| pass | |
| return result | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 📦 US Address Extractor") | |
| with gr.Tab("Text Input"): | |
| text_input = gr.Textbox(lines=3, label="Enter Text") | |
| text_output = gr.Textbox(label="Extracted Address") | |
| text_button = gr.Button("Extract Address") | |
| text_button.click(fn=extract_from_text, inputs=text_input, outputs=text_output) | |
| with gr.Tab("Audio Input (.wav)"): | |
| audio_input = gr.Audio(type="filepath", label="Upload a .wav Audio File") | |
| audio_output = gr.Textbox(label="Extracted Address") | |
| audio_button = gr.Button("Extract Address") | |
| audio_button.click(fn=extract_from_audio, inputs=audio_input, outputs=audio_output) | |
| demo.launch(server_name="0.0.0.0", share=True, server_port=7860) | |