Upload inference_onnx.py with huggingface_hub
Browse files- inference_onnx.py +20 -9
inference_onnx.py
CHANGED
|
@@ -29,10 +29,10 @@ class RgvedaEmbeddingONNXHybrid:
|
|
| 29 |
print(f"Loading ONNX model: {model_path}")
|
| 30 |
self.session = ort.InferenceSession(str(model_path))
|
| 31 |
|
| 32 |
-
# Load tokenizer
|
| 33 |
print("Loading tokenizer...")
|
| 34 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 35 |
-
|
| 36 |
)
|
| 37 |
|
| 38 |
# Load fine-tuned dense weights
|
|
@@ -47,6 +47,17 @@ class RgvedaEmbeddingONNXHybrid:
|
|
| 47 |
print(f" Dense1: {self.dense1_weight.shape}")
|
| 48 |
print(f" Dense2: {self.dense2_weight.shape}")
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
def encode(self, texts, batch_size=32, show_progress=False):
|
| 51 |
"""
|
| 52 |
Encode texts to embeddings using hybrid approach.
|
|
@@ -78,9 +89,9 @@ class RgvedaEmbeddingONNXHybrid:
|
|
| 78 |
)
|
| 79 |
|
| 80 |
# Run ONNX model
|
| 81 |
-
#
|
| 82 |
-
#
|
| 83 |
-
|
| 84 |
None,
|
| 85 |
{
|
| 86 |
'input_ids': inputs['input_ids'].astype(np.int64),
|
|
@@ -88,12 +99,12 @@ class RgvedaEmbeddingONNXHybrid:
|
|
| 88 |
}
|
| 89 |
)
|
| 90 |
|
| 91 |
-
#
|
| 92 |
-
|
| 93 |
-
# the Rigveda-specific fine-tuned ones instead
|
| 94 |
|
|
|
|
| 95 |
# Dense layer 1 (768 -> 3072)
|
| 96 |
-
dense1_out =
|
| 97 |
|
| 98 |
# Dense layer 2 (3072 -> 768)
|
| 99 |
dense2_out = dense1_out @ self.dense2_weight.T
|
|
|
|
| 29 |
print(f"Loading ONNX model: {model_path}")
|
| 30 |
self.session = ort.InferenceSession(str(model_path))
|
| 31 |
|
| 32 |
+
# Load tokenizer from local directory
|
| 33 |
print("Loading tokenizer...")
|
| 34 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 35 |
+
str(self.model_dir)
|
| 36 |
)
|
| 37 |
|
| 38 |
# Load fine-tuned dense weights
|
|
|
|
| 47 |
print(f" Dense1: {self.dense1_weight.shape}")
|
| 48 |
print(f" Dense2: {self.dense2_weight.shape}")
|
| 49 |
|
| 50 |
+
def mean_pooling(self, token_embeddings, attention_mask):
|
| 51 |
+
"""Mean pooling with attention mask."""
|
| 52 |
+
input_mask_expanded = np.expand_dims(attention_mask, -1)
|
| 53 |
+
input_mask_expanded = np.broadcast_to(
|
| 54 |
+
input_mask_expanded, token_embeddings.shape
|
| 55 |
+
).astype(np.float32)
|
| 56 |
+
|
| 57 |
+
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
|
| 58 |
+
sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)
|
| 59 |
+
return sum_embeddings / sum_mask
|
| 60 |
+
|
| 61 |
def encode(self, texts, batch_size=32, show_progress=False):
|
| 62 |
"""
|
| 63 |
Encode texts to embeddings using hybrid approach.
|
|
|
|
| 89 |
)
|
| 90 |
|
| 91 |
# Run ONNX model
|
| 92 |
+
# Get last_hidden_state (raw transformer output) not sentence_embedding
|
| 93 |
+
# sentence_embedding already has base dense layers which we don't want
|
| 94 |
+
last_hidden_state, _ = self.session.run(
|
| 95 |
None,
|
| 96 |
{
|
| 97 |
'input_ids': inputs['input_ids'].astype(np.int64),
|
|
|
|
| 99 |
}
|
| 100 |
)
|
| 101 |
|
| 102 |
+
# Do mean pooling ourselves (like the Ganaraj model does)
|
| 103 |
+
pooled = self.mean_pooling(last_hidden_state, inputs['attention_mask'])
|
|
|
|
| 104 |
|
| 105 |
+
# Now apply fine-tuned dense layers on the pooled output
|
| 106 |
# Dense layer 1 (768 -> 3072)
|
| 107 |
+
dense1_out = pooled @ self.dense1_weight.T
|
| 108 |
|
| 109 |
# Dense layer 2 (3072 -> 768)
|
| 110 |
dense2_out = dense1_out @ self.dense2_weight.T
|