feat: remove redundant normalization by token count before l2_norm of embeddings
This commit is contained in:
@@ -81,10 +81,8 @@ impl Embed {
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
|
||||
let embeddings = self.model.forward(&token_ids, &token_type_ids, None)?;
|
||||
let (_n_sentence, n_tokens, hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
let embeddings = normalize_l2(&embeddings)?
|
||||
.reshape(hidden_size)?
|
||||
let embeddings = normalize_l2(&embeddings.sum(1)?)?
|
||||
.reshape(self.hidden_size)?
|
||||
.to_vec1::<f32>()?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user