rgveda-embedding-gemma-onnx / inference_onnx.py
bsbarkur's picture
Upload inference_onnx.py with huggingface_hub
24327ca verified
#!/usr/bin/env python3
"""
Hybrid ONNX Inference for Rgveda Embedding Model
Uses:
- Base embeddinggemma-300m ONNX model (from onnx-community)
- Fine-tuned dense layer weights (from Ganaraj/rgveda-embedding-gemma)
This provides ONNX inference with Rigveda-specific fine-tuning.
"""
import onnxruntime as ort
import numpy as np
from transformers import AutoTokenizer
from pathlib import Path
class RgvedaEmbeddingONNXHybrid:
"""
Hybrid ONNX inference using base model + fine-tuned weights.
"""
def __init__(self, model_dir="."):
"""Initialize the model."""
print("Loading Rgveda Embedding Model (Hybrid ONNX)...")
self.model_dir = Path(model_dir)
# Load base ONNX model
model_path = self.model_dir / "onnx" / "model.onnx"
print(f"Loading ONNX model: {model_path}")
self.session = ort.InferenceSession(str(model_path))
# Load tokenizer from local directory
print("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
str(self.model_dir)
)
# Load fine-tuned dense weights
print("Loading fine-tuned weights...")
weights_dir = self.model_dir / "weights"
self.dense1_weight = np.load(weights_dir / "dense1_weight.npy")
self.dense2_weight = np.load(weights_dir / "dense2_weight.npy")
print(f"\n✓ Model loaded successfully!")
print(f" Base model: ONNX (embeddinggemma-300m)")
print(f" Fine-tuning: Rigveda-specific dense layers")
print(f" Dense1: {self.dense1_weight.shape}")
print(f" Dense2: {self.dense2_weight.shape}")
def mean_pooling(self, token_embeddings, attention_mask):
"""Mean pooling with attention mask."""
input_mask_expanded = np.expand_dims(attention_mask, -1)
input_mask_expanded = np.broadcast_to(
input_mask_expanded, token_embeddings.shape
).astype(np.float32)
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)
return sum_embeddings / sum_mask
def encode(self, texts, batch_size=32, show_progress=False):
"""
Encode texts to embeddings using hybrid approach.
Args:
texts: List of strings or single string
batch_size: Batch size for processing
show_progress: Show progress bar
Returns:
embeddings: numpy array of shape (num_texts, 768)
"""
if isinstance(texts, str):
texts = [texts]
all_embeddings = []
# Process in batches
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
# Tokenize
inputs = self.tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=2048,
return_tensors="np"
)
# Run ONNX model
# Get last_hidden_state (raw transformer output) not sentence_embedding
# sentence_embedding already has base dense layers which we don't want
last_hidden_state, _ = self.session.run(
None,
{
'input_ids': inputs['input_ids'].astype(np.int64),
'attention_mask': inputs['attention_mask'].astype(np.int64)
}
)
# Do mean pooling ourselves (like the Ganaraj model does)
pooled = self.mean_pooling(last_hidden_state, inputs['attention_mask'])
# Now apply fine-tuned dense layers on the pooled output
# Dense layer 1 (768 -> 3072)
dense1_out = pooled @ self.dense1_weight.T
# Dense layer 2 (3072 -> 768)
dense2_out = dense1_out @ self.dense2_weight.T
# L2 normalization
norms = np.linalg.norm(dense2_out, axis=1, keepdims=True)
normalized = dense2_out / np.clip(norms, a_min=1e-9, a_max=None)
all_embeddings.append(normalized)
return np.vstack(all_embeddings)
# Example usage
if __name__ == "__main__":
# Initialize model
model = RgvedaEmbeddingONNXHybrid(".")
# Test queries and documents with Devanagari script
prefixes = {
"query": "task: search result | query: ",
"document": "title: none | text: ",
}
query = prefixes["query"] + "वृष्टि-विद्युत्-सदृशं दैविकं आगमनम्"
documents = [
prefixes["document"] + "असामि हि प्रयज्यवः कण्वं दद प्रचेतसः",
prefixes["document"] + "उत द्वार उशतीर् वि श्रयन्ताम् उत देवाṁ उशत आ वहेह",
prefixes["document"] + "प्राग्नये बृहते यज्ञियाय ऋतस्य वृष्णे असुराय मन्म",
]
# Encode
print("\nEncoding query...")
query_embedding = model.encode(query)
print(f"Query embedding shape: {query_embedding.shape}")
print("\nEncoding documents...")
doc_embeddings = model.encode(documents)
print(f"Document embeddings shape: {doc_embeddings.shape}")
# Compute similarities
similarities = query_embedding @ doc_embeddings.T
print("\n" + "="*80)
print("Results")
print("="*80)
print(f"\nQuery: {query}\n")
print("Document similarities:")
for i, (doc, sim) in enumerate(zip(documents, similarities[0])):
print(f" {i+1}. {sim:.4f} - {doc[:70]}...")