#!/usr/bin/env python3 """ Hybrid ONNX Inference for Rgveda Embedding Model Uses: - Base embeddinggemma-300m ONNX model (from onnx-community) - Fine-tuned dense layer weights (from Ganaraj/rgveda-embedding-gemma) This provides ONNX inference with Rigveda-specific fine-tuning. """ import onnxruntime as ort import numpy as np from transformers import AutoTokenizer from pathlib import Path class RgvedaEmbeddingONNXHybrid: """ Hybrid ONNX inference using base model + fine-tuned weights. """ def __init__(self, model_dir="."): """Initialize the model.""" print("Loading Rgveda Embedding Model (Hybrid ONNX)...") self.model_dir = Path(model_dir) # Load base ONNX model model_path = self.model_dir / "onnx" / "model.onnx" print(f"Loading ONNX model: {model_path}") self.session = ort.InferenceSession(str(model_path)) # Load tokenizer from local directory print("Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( str(self.model_dir) ) # Load fine-tuned dense weights print("Loading fine-tuned weights...") weights_dir = self.model_dir / "weights" self.dense1_weight = np.load(weights_dir / "dense1_weight.npy") self.dense2_weight = np.load(weights_dir / "dense2_weight.npy") print(f"\n✓ Model loaded successfully!") print(f" Base model: ONNX (embeddinggemma-300m)") print(f" Fine-tuning: Rigveda-specific dense layers") print(f" Dense1: {self.dense1_weight.shape}") print(f" Dense2: {self.dense2_weight.shape}") def mean_pooling(self, token_embeddings, attention_mask): """Mean pooling with attention mask.""" input_mask_expanded = np.expand_dims(attention_mask, -1) input_mask_expanded = np.broadcast_to( input_mask_expanded, token_embeddings.shape ).astype(np.float32) sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1) sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None) return sum_embeddings / sum_mask def encode(self, texts, batch_size=32, show_progress=False): """ Encode texts to embeddings using hybrid approach. Args: texts: List of strings or single string batch_size: Batch size for processing show_progress: Show progress bar Returns: embeddings: numpy array of shape (num_texts, 768) """ if isinstance(texts, str): texts = [texts] all_embeddings = [] # Process in batches for i in range(0, len(texts), batch_size): batch_texts = texts[i:i+batch_size] # Tokenize inputs = self.tokenizer( batch_texts, padding=True, truncation=True, max_length=2048, return_tensors="np" ) # Run ONNX model # Get last_hidden_state (raw transformer output) not sentence_embedding # sentence_embedding already has base dense layers which we don't want last_hidden_state, _ = self.session.run( None, { 'input_ids': inputs['input_ids'].astype(np.int64), 'attention_mask': inputs['attention_mask'].astype(np.int64) } ) # Do mean pooling ourselves (like the Ganaraj model does) pooled = self.mean_pooling(last_hidden_state, inputs['attention_mask']) # Now apply fine-tuned dense layers on the pooled output # Dense layer 1 (768 -> 3072) dense1_out = pooled @ self.dense1_weight.T # Dense layer 2 (3072 -> 768) dense2_out = dense1_out @ self.dense2_weight.T # L2 normalization norms = np.linalg.norm(dense2_out, axis=1, keepdims=True) normalized = dense2_out / np.clip(norms, a_min=1e-9, a_max=None) all_embeddings.append(normalized) return np.vstack(all_embeddings) # Example usage if __name__ == "__main__": # Initialize model model = RgvedaEmbeddingONNXHybrid(".") # Test queries and documents with Devanagari script prefixes = { "query": "task: search result | query: ", "document": "title: none | text: ", } query = prefixes["query"] + "वृष्टि-विद्युत्-सदृशं दैविकं आगमनम्" documents = [ prefixes["document"] + "असामि हि प्रयज्यवः कण्वं दद प्रचेतसः", prefixes["document"] + "उत द्वार उशतीर् वि श्रयन्ताम् उत देवाṁ उशत आ वहेह", prefixes["document"] + "प्राग्नये बृहते यज्ञियाय ऋतस्य वृष्णे असुराय मन्म", ] # Encode print("\nEncoding query...") query_embedding = model.encode(query) print(f"Query embedding shape: {query_embedding.shape}") print("\nEncoding documents...") doc_embeddings = model.encode(documents) print(f"Document embeddings shape: {doc_embeddings.shape}") # Compute similarities similarities = query_embedding @ doc_embeddings.T print("\n" + "="*80) print("Results") print("="*80) print(f"\nQuery: {query}\n") print("Document similarities:") for i, (doc, sim) in enumerate(zip(documents, similarities[0])): print(f" {i+1}. {sim:.4f} - {doc[:70]}...")