docker-hf / main.py
rbc33's picture
Upload folder using huggingface_hub
2da32e3 verified
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List, Optional, Literal
from transformers import AutoTokenizer, AutoModelForCausalLM
import uvicorn
app = FastAPI(title="OpenAI-compatible API")
# --- Configuraci贸n del Modelo ---
# Puedes cambiar esto por variables de entorno en HF Spaces
MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Cargando modelo: {MODEL_ID} en {device}...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
low_cpu_mem_usage=True
)
if device == "cpu":
model.to(device)
print("隆Modelo cargado exitosamente!")
except Exception as e:
print(f"Error cargando el modelo: {e}")
raise e
# --- Modelos Pydantic (Formato OpenAI) ---
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = "default-model"
messages: List[ChatMessage]
max_tokens: Optional[int] = 512
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
stream: Optional[bool] = False
class Choice(BaseModel):
index: int
message: ChatMessage
finish_reason: str
class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
choices: List[Choice]
usage: Usage
# --- Endpoint API ---
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(request: ChatCompletionRequest):
try:
# 1. Aplicar el Chat Template (convierte la lista de mensajes en el string que entiende el modelo)
# Esto hace que funcione con Llama, Mistral, Qwen, etc. autom谩ticamente.
input_text = tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(input_text, return_tensors="pt").to(device)
# 2. Generar respuesta
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
do_sample=True
)
# 3. Decodificar solo la parte nueva (la respuesta)
generated_ids = outputs[0][inputs.input_ids.shape[1]:]
response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
# 4. Calcular tokens (aproximado)
prompt_tokens = len(inputs.input_ids[0])
completion_tokens = len(generated_ids)
# 5. Formatear como OpenAI
return ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model=MODEL_ID,
choices=[
Choice(
index=0,
message=ChatMessage(role="assistant", content=response_text),
finish_reason="stop"
)
],
usage=Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
)
except Exception as e:
print(f"Error en generaci贸n: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
def home():
return {"message": "OpenAI-compatible API is running", "model": MODEL_ID}
# Permite ejecutar con `python main.py` para pruebas locales
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)