|
|
|
|
|
""" |
|
|
Hybrid ONNX Inference for Rgveda Embedding Model |
|
|
|
|
|
Uses: |
|
|
- Base embeddinggemma-300m ONNX model (from onnx-community) |
|
|
- Fine-tuned dense layer weights (from Ganaraj/rgveda-embedding-gemma) |
|
|
|
|
|
This provides ONNX inference with Rigveda-specific fine-tuning. |
|
|
""" |
|
|
|
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
from transformers import AutoTokenizer |
|
|
from pathlib import Path |
|
|
|
|
|
class RgvedaEmbeddingONNXHybrid: |
|
|
""" |
|
|
Hybrid ONNX inference using base model + fine-tuned weights. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_dir="."): |
|
|
"""Initialize the model.""" |
|
|
print("Loading Rgveda Embedding Model (Hybrid ONNX)...") |
|
|
self.model_dir = Path(model_dir) |
|
|
|
|
|
|
|
|
model_path = self.model_dir / "onnx" / "model.onnx" |
|
|
print(f"Loading ONNX model: {model_path}") |
|
|
self.session = ort.InferenceSession(str(model_path)) |
|
|
|
|
|
|
|
|
print("Loading tokenizer...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
str(self.model_dir) |
|
|
) |
|
|
|
|
|
|
|
|
print("Loading fine-tuned weights...") |
|
|
weights_dir = self.model_dir / "weights" |
|
|
self.dense1_weight = np.load(weights_dir / "dense1_weight.npy") |
|
|
self.dense2_weight = np.load(weights_dir / "dense2_weight.npy") |
|
|
|
|
|
print(f"\n✓ Model loaded successfully!") |
|
|
print(f" Base model: ONNX (embeddinggemma-300m)") |
|
|
print(f" Fine-tuning: Rigveda-specific dense layers") |
|
|
print(f" Dense1: {self.dense1_weight.shape}") |
|
|
print(f" Dense2: {self.dense2_weight.shape}") |
|
|
|
|
|
def mean_pooling(self, token_embeddings, attention_mask): |
|
|
"""Mean pooling with attention mask.""" |
|
|
input_mask_expanded = np.expand_dims(attention_mask, -1) |
|
|
input_mask_expanded = np.broadcast_to( |
|
|
input_mask_expanded, token_embeddings.shape |
|
|
).astype(np.float32) |
|
|
|
|
|
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1) |
|
|
sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None) |
|
|
return sum_embeddings / sum_mask |
|
|
|
|
|
def encode(self, texts, batch_size=32, show_progress=False): |
|
|
""" |
|
|
Encode texts to embeddings using hybrid approach. |
|
|
|
|
|
Args: |
|
|
texts: List of strings or single string |
|
|
batch_size: Batch size for processing |
|
|
show_progress: Show progress bar |
|
|
|
|
|
Returns: |
|
|
embeddings: numpy array of shape (num_texts, 768) |
|
|
""" |
|
|
if isinstance(texts, str): |
|
|
texts = [texts] |
|
|
|
|
|
all_embeddings = [] |
|
|
|
|
|
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch_texts = texts[i:i+batch_size] |
|
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
|
batch_texts, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=2048, |
|
|
return_tensors="np" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
last_hidden_state, _ = self.session.run( |
|
|
None, |
|
|
{ |
|
|
'input_ids': inputs['input_ids'].astype(np.int64), |
|
|
'attention_mask': inputs['attention_mask'].astype(np.int64) |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
pooled = self.mean_pooling(last_hidden_state, inputs['attention_mask']) |
|
|
|
|
|
|
|
|
|
|
|
dense1_out = pooled @ self.dense1_weight.T |
|
|
|
|
|
|
|
|
dense2_out = dense1_out @ self.dense2_weight.T |
|
|
|
|
|
|
|
|
norms = np.linalg.norm(dense2_out, axis=1, keepdims=True) |
|
|
normalized = dense2_out / np.clip(norms, a_min=1e-9, a_max=None) |
|
|
|
|
|
all_embeddings.append(normalized) |
|
|
|
|
|
return np.vstack(all_embeddings) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
model = RgvedaEmbeddingONNXHybrid(".") |
|
|
|
|
|
|
|
|
prefixes = { |
|
|
"query": "task: search result | query: ", |
|
|
"document": "title: none | text: ", |
|
|
} |
|
|
|
|
|
query = prefixes["query"] + "वृष्टि-विद्युत्-सदृशं दैविकं आगमनम्" |
|
|
documents = [ |
|
|
prefixes["document"] + "असामि हि प्रयज्यवः कण्वं दद प्रचेतसः", |
|
|
prefixes["document"] + "उत द्वार उशतीर् वि श्रयन्ताम् उत देवाṁ उशत आ वहेह", |
|
|
prefixes["document"] + "प्राग्नये बृहते यज्ञियाय ऋतस्य वृष्णे असुराय मन्म", |
|
|
] |
|
|
|
|
|
|
|
|
print("\nEncoding query...") |
|
|
query_embedding = model.encode(query) |
|
|
print(f"Query embedding shape: {query_embedding.shape}") |
|
|
|
|
|
print("\nEncoding documents...") |
|
|
doc_embeddings = model.encode(documents) |
|
|
print(f"Document embeddings shape: {doc_embeddings.shape}") |
|
|
|
|
|
|
|
|
similarities = query_embedding @ doc_embeddings.T |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("Results") |
|
|
print("="*80) |
|
|
print(f"\nQuery: {query}\n") |
|
|
print("Document similarities:") |
|
|
for i, (doc, sim) in enumerate(zip(documents, similarities[0])): |
|
|
print(f" {i+1}. {sim:.4f} - {doc[:70]}...") |
|
|
|