Pavanb's picture
Update app.py
bf0ff35 verified
import gradio as gr
from address_extractor import AddressExtractor
import tempfile
import os
import librosa
import soundfile as sf
# Instantiate your AddressExtractor class
address_extractor = AddressExtractor()
def extract_from_text(input_text):
if not input_text.strip():
return "Error: No text provided."
messages = [
{"role": "system", "content": address_extractor.system_prompt_text},
{"role": "user", "content": input_text},
]
prompt = address_extractor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
chat_input = address_extractor.tokenizer(prompt, return_tensors="pt").to(address_extractor.bitnet_model.device)
chat_outputs = address_extractor.bitnet_model.generate(**chat_input, max_new_tokens=256)
generated_text = address_extractor.tokenizer.decode(
chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True
)
return generated_text.strip() or "No address detected."
def extract_from_audio(audio_file):
if audio_file is None:
return "Error: No audio provided."
# with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
# tmp_file.write(audio_file.read())
# tmp_file_path = tmp_file.name
try:
audio, sr = librosa.load(audio_file, sr=16000)
sf.write(audio_file, audio, 16000)
# segments = address_extractor.whisper_model.transcribe(tmp_file_path)
segments = address_extractor.whisper_model.transcribe(audio_file)
input_text = " ".join([seg.text.strip() for seg in segments])
input_text = address_extractor.preprocess_text(input_text)
messages = [
{"role": "system", "content": address_extractor.system_prompt_speech},
{"role": "user", "content": input_text},
]
prompt = address_extractor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
chat_input = address_extractor.tokenizer(prompt, return_tensors="pt").to(address_extractor.bitnet_model.device)
chat_outputs = address_extractor.bitnet_model.generate(**chat_input, max_new_tokens=256)
generated_text = address_extractor.tokenizer.decode(
chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True
)
result = generated_text.strip() or "No address detected."
finally:
# os.remove(tmp_file_path)
pass
return result
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 📦 US Address Extractor")
with gr.Tab("Text Input"):
text_input = gr.Textbox(lines=3, label="Enter Text")
text_output = gr.Textbox(label="Extracted Address")
text_button = gr.Button("Extract Address")
text_button.click(fn=extract_from_text, inputs=text_input, outputs=text_output)
with gr.Tab("Audio Input (.wav)"):
audio_input = gr.Audio(type="filepath", label="Upload a .wav Audio File")
audio_output = gr.Textbox(label="Extracted Address")
audio_button = gr.Button("Extract Address")
audio_button.click(fn=extract_from_audio, inputs=audio_input, outputs=audio_output)
demo.launch(server_name="0.0.0.0", share=True, server_port=7860)