|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True" |
|
|
os.environ["PADDLEOCR_HOME"] = "./paddleocr_models" |
|
|
|
|
|
from fastapi import FastAPI, UploadFile, File |
|
|
from fastapi.responses import JSONResponse |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
import tempfile |
|
|
|
|
|
|
|
|
paddleocr_available = False |
|
|
PaddleOCR = None |
|
|
|
|
|
|
|
|
pytesseract_available = False |
|
|
pytesseract = None |
|
|
cv2_available = False |
|
|
cv2 = None |
|
|
|
|
|
|
|
|
def get_opencv(): |
|
|
"""动态导入OpenCV""" |
|
|
global cv2, cv2_available |
|
|
if cv2 is None: |
|
|
try: |
|
|
import cv2 as _cv2 |
|
|
cv2 = _cv2 |
|
|
cv2_available = True |
|
|
print("✅ OpenCV动态导入成功") |
|
|
except ImportError as e: |
|
|
print(f"❌ OpenCV动态导入失败: {e}") |
|
|
cv2_available = False |
|
|
return cv2 |
|
|
|
|
|
|
|
|
def get_pytesseract(): |
|
|
"""动态导入pytesseract作为备选方案""" |
|
|
global pytesseract, pytesseract_available |
|
|
if pytesseract is None: |
|
|
try: |
|
|
import pytesseract as _pytesseract |
|
|
pytesseract = _pytesseract |
|
|
|
|
|
|
|
|
get_opencv() |
|
|
except ImportError as e: |
|
|
print(f"❌ pytesseract动态导入失败: {e}") |
|
|
pytesseract_available = False |
|
|
return pytesseract |
|
|
|
|
|
|
|
|
check_tesseract_availability() |
|
|
return pytesseract |
|
|
|
|
|
def check_tesseract_availability(): |
|
|
"""检查tesseract可执行文件是否可用""" |
|
|
global pytesseract, pytesseract_available |
|
|
if pytesseract is None: |
|
|
pytesseract_available = False |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
if not hasattr(pytesseract.pytesseract, 'tesseract_cmd'): |
|
|
print("⚠️ tesseract_cmd未配置") |
|
|
pytesseract_available = False |
|
|
return |
|
|
|
|
|
|
|
|
from subprocess import run, PIPE, CalledProcessError |
|
|
tesseract_cmd = pytesseract.pytesseract.tesseract_cmd |
|
|
|
|
|
|
|
|
result = run([tesseract_cmd, '--version'], |
|
|
capture_output=True, text=True, timeout=5) |
|
|
|
|
|
if result.returncode == 0: |
|
|
|
|
|
version = result.stdout.strip().split('\n')[0] if result.stdout else "unknown" |
|
|
pytesseract_available = True |
|
|
print(f"✅ tesseract可执行文件可用,版本: {version}") |
|
|
else: |
|
|
print(f"⚠️ tesseract命令执行失败,返回码: {result.returncode}") |
|
|
print(f"错误输出: {result.stderr.strip()}") |
|
|
pytesseract_available = False |
|
|
|
|
|
except FileNotFoundError: |
|
|
|
|
|
print(f"⚠️ tesseract可执行文件未找到: {pytesseract.pytesseract.tesseract_cmd}") |
|
|
pytesseract_available = False |
|
|
except CalledProcessError as e: |
|
|
|
|
|
print(f"⚠️ tesseract命令执行失败: {e}") |
|
|
pytesseract_available = False |
|
|
except Exception as e: |
|
|
|
|
|
print(f"⚠️ 测试tesseract可执行文件时出错: {e}") |
|
|
pytesseract_available = False |
|
|
|
|
|
|
|
|
def get_paddleocr(): |
|
|
"""动态导入PaddleOCR""" |
|
|
global PaddleOCR, paddleocr_available |
|
|
if PaddleOCR is None: |
|
|
try: |
|
|
from paddleocr import PaddleOCR as _PaddleOCR |
|
|
PaddleOCR = _PaddleOCR |
|
|
paddleocr_available = True |
|
|
print("✅ PaddleOCR动态导入成功") |
|
|
except ImportError as e: |
|
|
print(f"❌ PaddleOCR动态导入失败: {e}") |
|
|
paddleocr_available = False |
|
|
|
|
|
get_pytesseract() |
|
|
return PaddleOCR |
|
|
|
|
|
try: |
|
|
import fitz |
|
|
fitz_available = True |
|
|
except ImportError: |
|
|
fitz_available = False |
|
|
|
|
|
try: |
|
|
import cv2 |
|
|
import numpy as np |
|
|
cv2_available = True |
|
|
except ImportError: |
|
|
cv2_available = False |
|
|
|
|
|
try: |
|
|
from onnxruntime import InferenceSession |
|
|
onnx_available = True |
|
|
except ImportError: |
|
|
onnx_available = False |
|
|
|
|
|
app = FastAPI(title="智能音频转录与摘要服务") |
|
|
|
|
|
|
|
|
app.mount("/front", StaticFiles(directory="static", html=True), name="static") |
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True" |
|
|
|
|
|
|
|
|
whisper_model = None |
|
|
summarizer = None |
|
|
ocr_model = None |
|
|
|
|
|
|
|
|
models_loaded = { |
|
|
"whisper": False, |
|
|
"summarizer": False, |
|
|
"ocr": False |
|
|
} |
|
|
|
|
|
|
|
|
ocr_load_error = None |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from faster_whisper import WhisperModel |
|
|
whisper_available = True |
|
|
except ImportError: |
|
|
whisper_available = False |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from transformers import pipeline |
|
|
transformers_available = True |
|
|
except ImportError: |
|
|
transformers_available = False |
|
|
|
|
|
|
|
|
|
|
|
def load_whisper_model(): |
|
|
"""延迟加载Whisper模型""" |
|
|
global whisper_model, models_loaded |
|
|
if whisper_model is None and whisper_available: |
|
|
try: |
|
|
whisper_model = WhisperModel("medium", device="cpu", compute_type="int8") |
|
|
models_loaded["whisper"] = True |
|
|
print("Successfully loaded Whisper model") |
|
|
except Exception as e: |
|
|
print(f"Error loading Whisper model: {e}") |
|
|
models_loaded["whisper"] = False |
|
|
return whisper_model |
|
|
|
|
|
def load_summarizer_model(): |
|
|
"""延迟加载摘要模型""" |
|
|
global summarizer, models_loaded |
|
|
if summarizer is None and transformers_available: |
|
|
|
|
|
try: |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
model_name = "csebuetnlp/mT5_multilingual_XLSum" |
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, legacy=False) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1) |
|
|
|
|
|
models_loaded["summarizer"] = True |
|
|
print(f"✅ 成功加载模型: {model_name}") |
|
|
|
|
|
|
|
|
return summarizer |
|
|
|
|
|
except ImportError: |
|
|
print("❌ 错误: 缺少依赖库。请运行: pip install protobuf sentencepiece") |
|
|
except Exception as e: |
|
|
|
|
|
print(f"❌ 加载模型出错 (类型: {type(e).__name__}): {e}") |
|
|
print("建议尝试清理 HuggingFace 缓存目录: rm -rf ~/.cache/huggingface/hub") |
|
|
|
|
|
|
|
|
print("Failed to load any Chinese summarization model. Summarization functionality will be unavailable.") |
|
|
models_loaded["summarizer"] = False |
|
|
return None |
|
|
return summarizer |
|
|
|
|
|
|
|
|
def load_ocr_model(): |
|
|
"""延迟加载OCR模型,优先使用PaddleOCR,失败则使用pytesseract""" |
|
|
global ocr_model, models_loaded, ocr_load_error |
|
|
|
|
|
|
|
|
ocr_load_error = None |
|
|
|
|
|
if ocr_model is None: |
|
|
|
|
|
_PaddleOCR = get_paddleocr() |
|
|
if _PaddleOCR: |
|
|
print("Starting to load PaddleOCR model...") |
|
|
|
|
|
configs = [ |
|
|
{ |
|
|
'lang': 'ch', |
|
|
'device': 'cpu', |
|
|
} |
|
|
] |
|
|
|
|
|
all_errors = [] |
|
|
|
|
|
for i, config in enumerate(configs): |
|
|
try: |
|
|
print(f"Trying PaddleOCR config {i+1}: {config}") |
|
|
ocr_model = _PaddleOCR(**config) |
|
|
models_loaded["ocr"] = True |
|
|
print(f"Successfully loaded PaddleOCR model with config {i+1}") |
|
|
return ocr_model |
|
|
except Exception as e: |
|
|
error_msg = f"Config {i+1} failed: {str(e)}" |
|
|
print(error_msg) |
|
|
all_errors.append(error_msg) |
|
|
continue |
|
|
|
|
|
|
|
|
print(f"All PaddleOCR configurations failed. Trying pytesseract as fallback...") |
|
|
|
|
|
|
|
|
_pytesseract = get_pytesseract() |
|
|
|
|
|
|
|
|
check_tesseract_availability() |
|
|
|
|
|
if _pytesseract and pytesseract_available: |
|
|
print("Using pytesseract as OCR solution...") |
|
|
|
|
|
ocr_model = { |
|
|
'type': 'pytesseract', |
|
|
'engine': _pytesseract |
|
|
} |
|
|
models_loaded["ocr"] = True |
|
|
print("Successfully configured pytesseract OCR") |
|
|
return ocr_model |
|
|
elif _pytesseract: |
|
|
|
|
|
error_details = "pytesseract库已安装,但tesseract可执行文件不可用,OCR功能无法使用。请安装tesseract可执行文件并确保其在系统PATH中。" |
|
|
print(f"❌ {error_details}") |
|
|
models_loaded["ocr"] = False |
|
|
ocr_load_error = error_details |
|
|
return None |
|
|
|
|
|
|
|
|
error_details = "所有OCR方案均不可用(PaddleOCR和pytesseract均未安装或导入失败)" |
|
|
print(f"❌ {error_details}") |
|
|
models_loaded["ocr"] = False |
|
|
ocr_load_error = error_details |
|
|
return None |
|
|
return ocr_model |
|
|
|
|
|
|
|
|
def pdf_to_images(pdf_path, dpi=300): |
|
|
"""将PDF转换为图片列表""" |
|
|
images = [] |
|
|
if fitz_available: |
|
|
doc = fitz.open(pdf_path) |
|
|
for page_num in range(doc.page_count): |
|
|
page = doc[page_num] |
|
|
|
|
|
pix = page.get_pixmap(dpi=dpi) |
|
|
|
|
|
img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.h, pix.w, pix.n) |
|
|
|
|
|
if pix.n == 3: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
|
elif pix.n == 4: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) |
|
|
images.append(img) |
|
|
doc.close() |
|
|
return images |
|
|
|
|
|
|
|
|
def preprocess_image(image): |
|
|
"""图像预处理,提高OCR识别率""" |
|
|
if cv2_available: |
|
|
|
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, |
|
|
cv2.THRESH_BINARY, 11, 2) |
|
|
|
|
|
blurred = cv2.GaussianBlur(thresh, (1, 1), 0) |
|
|
return blurred |
|
|
return image |
|
|
|
|
|
@app.get("/health") |
|
|
def health_check(): |
|
|
"""健康检查接口""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"services": { |
|
|
"whisper_available": whisper_available, |
|
|
"transformers_available": transformers_available, |
|
|
"ocr_available": { |
|
|
"paddleocr": paddleocr_available, |
|
|
"pytesseract": pytesseract_available, |
|
|
"pymupdf": fitz_available, |
|
|
"opencv": cv2_available, |
|
|
"onnxruntime": onnx_available |
|
|
}, |
|
|
"models_loaded": models_loaded |
|
|
}, |
|
|
"message": "服务正常运行" |
|
|
} |
|
|
|
|
|
@app.post("/ocr") |
|
|
async def ocr_document(file: UploadFile = File(...)): |
|
|
"""OCR文档解析接口,支持PDF和图片""" |
|
|
try: |
|
|
|
|
|
get_paddleocr() |
|
|
get_pytesseract() |
|
|
|
|
|
|
|
|
if not paddleocr_available and not pytesseract_available: |
|
|
return JSONResponse(content={ |
|
|
"error": "所有OCR方案均不可用", |
|
|
"details": { |
|
|
"paddleocr": "PaddleOCR模块未安装或不兼容Python 3.13", |
|
|
"pytesseract": "pytesseract库已安装,但tesseract可执行文件不可用" |
|
|
}, |
|
|
"suggestions": [ |
|
|
"对于Python 3.13用户:安装tesseract可执行文件后重试", |
|
|
"对于Python 3.10-3.12用户:安装PaddleOCR: pip install paddleocr", |
|
|
"tesseract可执行文件下载地址:https://github.com/tesseract-ocr/tesseract/wiki/Downloads" |
|
|
] |
|
|
}, status_code=503) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".tmp", delete=False) as temp_file: |
|
|
temp_file.write(await file.read()) |
|
|
temp_file_path = temp_file.name |
|
|
|
|
|
images = [] |
|
|
file_ext = file.filename.lower().split('.')[-1] if '.' in file.filename else '' |
|
|
|
|
|
try: |
|
|
|
|
|
if file_ext == 'pdf': |
|
|
if not fitz_available: |
|
|
return JSONResponse(content={"error": "PyMuPDF模块未安装,PDF处理功能不可用"}, status_code=503) |
|
|
|
|
|
|
|
|
images = pdf_to_images(temp_file_path) |
|
|
if not images: |
|
|
return JSONResponse(content={"error": "PDF转换图片失败"}, status_code=500) |
|
|
|
|
|
|
|
|
elif file_ext in ['jpg', 'jpeg', 'png', 'bmp', 'tiff', 'tif']: |
|
|
if not cv2_available: |
|
|
return JSONResponse(content={"error": "OpenCV模块未安装,图片处理功能不可用"}, status_code=503) |
|
|
|
|
|
|
|
|
img = cv2.imread(temp_file_path) |
|
|
if img is not None: |
|
|
images.append(img) |
|
|
else: |
|
|
return JSONResponse(content={"error": "图片读取失败"}, status_code=500) |
|
|
|
|
|
else: |
|
|
return JSONResponse(content={"error": "不支持的文件格式,仅支持PDF和图片"}, status_code=400) |
|
|
|
|
|
|
|
|
current_ocr_model = load_ocr_model() |
|
|
if current_ocr_model is None: |
|
|
|
|
|
error_msg = ocr_load_error or 'OCR模型加载失败' |
|
|
return JSONResponse(content={ |
|
|
"error": "OCR模型加载失败", |
|
|
"details": error_msg, |
|
|
"suggestions": [ |
|
|
"检查Python版本是否兼容(推荐3.10-3.12用于PaddleOCR)", |
|
|
"如果使用Python 3.13,确保tesseract可执行文件已正确安装", |
|
|
"查看服务器日志获取更多详细信息" |
|
|
] |
|
|
}, status_code=503) |
|
|
|
|
|
|
|
|
all_results = [] |
|
|
page_num = 1 |
|
|
|
|
|
for img in images: |
|
|
|
|
|
preprocessed_img = preprocess_image(img) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
if isinstance(current_ocr_model, dict) and current_ocr_model['type'] == 'pytesseract': |
|
|
|
|
|
pytesseract_engine = current_ocr_model['engine'] |
|
|
|
|
|
|
|
|
if isinstance(preprocessed_img, np.ndarray): |
|
|
|
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
if file_ext in ['png', 'jpg', 'jpeg'] and 'screenshot' in file.filename.lower(): |
|
|
|
|
|
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img |
|
|
img_pil = Image.fromarray(gray_img) |
|
|
else: |
|
|
img_pil = Image.fromarray(preprocessed_img) |
|
|
|
|
|
|
|
|
full_text = pytesseract_engine.image_to_string( |
|
|
img_pil, |
|
|
lang='chi_sim+eng', |
|
|
config='--psm 6' |
|
|
) |
|
|
|
|
|
result_data = pytesseract_engine.image_to_data( |
|
|
img_pil, |
|
|
output_type=pytesseract_engine.Output.DICT, |
|
|
lang='chi_sim+eng', |
|
|
config='--psm 6' |
|
|
) |
|
|
else: |
|
|
|
|
|
full_text = pytesseract_engine.image_to_string( |
|
|
preprocessed_img, |
|
|
lang='chi_sim+eng', |
|
|
config='--psm 6' |
|
|
) |
|
|
|
|
|
result_data = pytesseract_engine.image_to_data( |
|
|
preprocessed_img, |
|
|
output_type=pytesseract_engine.Output.DICT, |
|
|
lang='chi_sim+eng', |
|
|
config='--psm 6' |
|
|
) |
|
|
|
|
|
|
|
|
page_text = [] |
|
|
|
|
|
valid_confidences = [float(conf) for conf in result_data['conf'] if float(conf) > 0] |
|
|
avg_confidence = sum(valid_confidences) / len(valid_confidences) if valid_confidences else 0.5 |
|
|
|
|
|
|
|
|
lines = full_text.strip().split('\n') |
|
|
for line in lines: |
|
|
line = line.strip() |
|
|
if line: |
|
|
page_text.append({"text": line, "confidence": avg_confidence / 100.0}) |
|
|
elif hasattr(current_ocr_model, 'ocr'): |
|
|
try: |
|
|
result = current_ocr_model.ocr(preprocessed_img, cls=True) |
|
|
except Exception as ocr_err: |
|
|
|
|
|
try: |
|
|
result = current_ocr_model.ocr(preprocessed_img, cls=False) |
|
|
print("OCR with cls=False succeeded after cls=True failed") |
|
|
except Exception as ocr_err2: |
|
|
return JSONResponse(content={"error": f"OCR识别失败: {str(ocr_err2)}"}, status_code=500) |
|
|
|
|
|
|
|
|
page_text = [] |
|
|
for line in result[0]: |
|
|
text = line[1][0] |
|
|
confidence = line[1][1] |
|
|
page_text.append({"text": text, "confidence": confidence}) |
|
|
elif hasattr(current_ocr_model, 'readtext'): |
|
|
|
|
|
result = current_ocr_model.readtext(preprocessed_img) |
|
|
|
|
|
|
|
|
page_text = [] |
|
|
for detection in result: |
|
|
text = detection[1] |
|
|
confidence = detection[2] |
|
|
page_text.append({"text": text, "confidence": confidence}) |
|
|
else: |
|
|
return JSONResponse(content={"error": "未知的OCR模型类型"}, status_code=500) |
|
|
except Exception as ocr_err: |
|
|
return JSONResponse(content={"error": f"OCR识别失败: {str(ocr_err)}"}, status_code=500) |
|
|
|
|
|
all_results.append({ |
|
|
"page": page_num, |
|
|
"content": page_text, |
|
|
"full_text": "\n".join([item["text"] for item in page_text]) |
|
|
}) |
|
|
page_num += 1 |
|
|
|
|
|
|
|
|
full_document_text = "\n\n".join([page["full_text"] for page in all_results]) |
|
|
|
|
|
return JSONResponse(content={ |
|
|
"success": True, |
|
|
"filename": file.filename, |
|
|
"page_count": len(all_results), |
|
|
"pages": all_results, |
|
|
"full_text": full_document_text, |
|
|
"ocr_engine": "paddleocr" if hasattr(current_ocr_model, 'ocr') else "pytesseract" |
|
|
}) |
|
|
|
|
|
finally: |
|
|
|
|
|
if os.path.exists(temp_file_path): |
|
|
os.unlink(temp_file_path) |
|
|
|
|
|
except Exception as e: |
|
|
error_details = { |
|
|
"error": "OCR处理失败", |
|
|
"details": str(e), |
|
|
"services": { |
|
|
"paddleocr_available": paddleocr_available, |
|
|
"pytesseract_available": pytesseract_available, |
|
|
"fitz_available": fitz_available, |
|
|
"cv2_available": cv2_available, |
|
|
"models_loaded": models_loaded.get("ocr", False) |
|
|
}, |
|
|
"suggestion": "查看服务器日志获取更多详细信息,或尝试使用兼容的Python版本" |
|
|
} |
|
|
return JSONResponse(content=error_details, status_code=500) |
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return {"message": "智能文档处理服务已启动,支持音频转录和OCR文档解析"} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|
|
|
@app.post("/transcribe") |
|
|
async def transcribe_audio(file: UploadFile = File(...)): |
|
|
"""上传音频文件,返回转录文本""" |
|
|
try: |
|
|
|
|
|
if not whisper_available: |
|
|
return JSONResponse(content={"error": "faster_whisper模块未安装,转录功能不可用"}, status_code=503) |
|
|
|
|
|
|
|
|
try: |
|
|
from pydub import AudioSegment |
|
|
except ImportError as e: |
|
|
error_msg = str(e) |
|
|
if "audioop" in error_msg or "pyaudioop" in error_msg: |
|
|
return JSONResponse(content={"error": "Python 3.13+环境下pydub依赖的audioop模块已被移除,音频处理功能不可用"}, status_code=503) |
|
|
else: |
|
|
return JSONResponse(content={"error": "pydub模块未安装,音频处理功能不可用"}, status_code=503) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".tmp", delete=False) as temp_file: |
|
|
temp_file.write(await file.read()) |
|
|
temp_file_path = temp_file.name |
|
|
|
|
|
try: |
|
|
|
|
|
audio = AudioSegment.from_file(temp_file_path) |
|
|
wav_path = temp_file_path + ".wav" |
|
|
audio.export(wav_path, format="wav") |
|
|
|
|
|
|
|
|
current_whisper_model = load_whisper_model() |
|
|
if current_whisper_model is None: |
|
|
return JSONResponse(content={"error": "Whisper模型加载失败,转录功能不可用"}, status_code=503) |
|
|
|
|
|
segments, info = current_whisper_model.transcribe(wav_path, beam_size=3, language="zh", vad_filter=True) |
|
|
transcription = "".join([segment.text for segment in segments]) |
|
|
|
|
|
|
|
|
os.unlink(temp_file_path) |
|
|
os.unlink(wav_path) |
|
|
|
|
|
return JSONResponse(content={"transcription": transcription}) |
|
|
finally: |
|
|
|
|
|
if os.path.exists(temp_file_path): |
|
|
os.unlink(temp_file_path) |
|
|
if 'wav_path' in locals() and os.path.exists(wav_path): |
|
|
os.unlink(wav_path) |
|
|
except Exception as e: |
|
|
return JSONResponse(content={"error": str(e)}, status_code=500) |
|
|
|
|
|
@app.post("/summarize") |
|
|
async def summarize_text(text: dict): |
|
|
"""对文本进行摘要""" |
|
|
try: |
|
|
transcription = text.get("text", "") |
|
|
if not transcription: |
|
|
return JSONResponse(content={"error": "没有提供文本"}, status_code=400) |
|
|
|
|
|
|
|
|
current_summarizer = load_summarizer_model() |
|
|
if current_summarizer is None: |
|
|
return JSONResponse(content={"error": "摘要模型加载失败,摘要功能不可用"}, status_code=503) |
|
|
|
|
|
|
|
|
try: |
|
|
summary = current_summarizer( |
|
|
transcription, |
|
|
max_length=150, |
|
|
min_length=30, |
|
|
do_sample=False, |
|
|
num_beams=2, |
|
|
length_penalty=0.8, |
|
|
no_repeat_ngram_size=3 |
|
|
) |
|
|
summary_text = summary[0]["summary_text"] |
|
|
|
|
|
summary_text = summary_text.strip().replace('\n', ' ').replace('\t', ' ') |
|
|
|
|
|
if len(summary_text) < 10: |
|
|
summary_text = "摘要生成过短,请检查原始文本" |
|
|
return JSONResponse(content={"summary": summary_text}) |
|
|
except Exception as e: |
|
|
return JSONResponse(content={"error": f"摘要生成失败: {str(e)}"}, status_code=500) |
|
|
except Exception as e: |
|
|
return JSONResponse(content={"error": str(e)}, status_code=500) |
|
|
|
|
|
@app.post("/process") |
|
|
async def process_audio(file: UploadFile = File(...)): |
|
|
"""上传音频文件,返回转录文本和摘要""" |
|
|
try: |
|
|
|
|
|
if not whisper_available: |
|
|
return JSONResponse(content={"error": "faster_whisper模块未安装,转录功能不可用"}, status_code=503) |
|
|
|
|
|
|
|
|
try: |
|
|
from pydub import AudioSegment |
|
|
except ImportError as e: |
|
|
error_msg = str(e) |
|
|
if "audioop" in error_msg or "pyaudioop" in error_msg: |
|
|
return JSONResponse(content={"error": "Python 3.13+环境下pydub依赖的audioop模块已被移除,音频处理功能不可用"}, status_code=503) |
|
|
else: |
|
|
return JSONResponse(content={"error": "pydub模块未安装,音频处理功能不可用"}, status_code=503) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".tmp", delete=False) as temp_file: |
|
|
temp_file.write(await file.read()) |
|
|
temp_file_path = temp_file.name |
|
|
|
|
|
|
|
|
audio = AudioSegment.from_file(temp_file_path) |
|
|
wav_path = temp_file_path + ".wav" |
|
|
audio.export(wav_path, format="wav") |
|
|
|
|
|
|
|
|
current_whisper_model = load_whisper_model() |
|
|
if current_whisper_model is None: |
|
|
return JSONResponse(content={"error": "Whisper模型加载失败,转录功能不可用"}, status_code=503) |
|
|
|
|
|
|
|
|
segments, info = current_whisper_model.transcribe(wav_path, beam_size=3, language="zh", vad_filter=True) |
|
|
transcription = "".join([segment.text for segment in segments]) |
|
|
|
|
|
|
|
|
current_summarizer = load_summarizer_model() |
|
|
|
|
|
|
|
|
if current_summarizer is None: |
|
|
summary = None |
|
|
warning = "摘要模型加载失败,仅返回转录结果" |
|
|
else: |
|
|
try: |
|
|
|
|
|
summary = current_summarizer( |
|
|
transcription, |
|
|
max_length=150, |
|
|
min_length=30, |
|
|
do_sample=False, |
|
|
num_beams=2, |
|
|
length_penalty=0.8, |
|
|
no_repeat_ngram_size=3 |
|
|
) |
|
|
summary = summary[0]["summary_text"] |
|
|
|
|
|
summary = summary.strip().replace('\n', ' ').replace('\t', ' ') |
|
|
|
|
|
if len(summary) < 10: |
|
|
summary = "摘要生成过短,请检查原始文本" |
|
|
warning = None |
|
|
except Exception as e: |
|
|
print(f"Error during summarization: {e}") |
|
|
summary = None |
|
|
warning = "摘要生成失败,仅返回转录结果" |
|
|
|
|
|
|
|
|
if os.path.exists(temp_file_path): |
|
|
os.unlink(temp_file_path) |
|
|
if os.path.exists(wav_path): |
|
|
os.unlink(wav_path) |
|
|
|
|
|
if warning: |
|
|
return JSONResponse(content={ |
|
|
"transcription": transcription, |
|
|
"summary": None, |
|
|
"warning": warning |
|
|
}, status_code=200) |
|
|
else: |
|
|
return JSONResponse(content={ |
|
|
"transcription": transcription, |
|
|
"summary": summary |
|
|
}) |
|
|
except Exception as e: |
|
|
|
|
|
if 'temp_file_path' in locals() and os.path.exists(temp_file_path): |
|
|
os.unlink(temp_file_path) |
|
|
if 'wav_path' in locals() and os.path.exists(wav_path): |
|
|
os.unlink(wav_path) |
|
|
return JSONResponse(content={"error": str(e)}, status_code=500) |
|
|
|
|
|
|