diff --git a/src/embed.rs b/src/embed.rs index 8d4ec4b..176deaa 100644 --- a/src/embed.rs +++ b/src/embed.rs @@ -81,10 +81,8 @@ impl Embed { let token_type_ids = token_ids.zeros_like()?; let embeddings = self.model.forward(&token_ids, &token_type_ids, None)?; - let (_n_sentence, n_tokens, hidden_size) = embeddings.dims3()?; - let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; - let embeddings = normalize_l2(&embeddings)? - .reshape(hidden_size)? + let embeddings = normalize_l2(&embeddings.sum(1)?)? + .reshape(self.hidden_size)? .to_vec1::()?; Ok(embeddings) }