feat: remove redundant normalization by token count before l2_norm of embeddings

This commit is contained in:
Himadri Bhattacharjee
2025-07-22 19:38:39 +05:30
parent 87e096f0bc
commit daccd63006

View File

@@ -81,10 +81,8 @@ impl Embed {
let token_type_ids = token_ids.zeros_like()?;
let embeddings = self.model.forward(&token_ids, &token_type_ids, None)?;
let (_n_sentence, n_tokens, hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = normalize_l2(&embeddings)?
.reshape(hidden_size)?
let embeddings = normalize_l2(&embeddings.sum(1)?)?
.reshape(self.hidden_size)?
.to_vec1::<f32>()?;
Ok(embeddings)
}