diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 3505615a2c..dbdfe026d9 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -18,31 +18,30 @@ class CacheEmbedding(Embeddings): def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed search docs.""" # use doc embedding cache or store if not exists - text_embeddings = [] - embedding_queue_texts = [] - for text in texts: + text_embeddings = [None for _ in range(len(texts))] + embedding_queue_indices = [] + for i, text in enumerate(texts): hash = helper.generate_text_hash(text) embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first() if embedding: - text_embeddings.append(embedding.get_embedding()) + text_embeddings[i] = embedding.get_embedding() else: - embedding_queue_texts.append(text) + embedding_queue_indices.append(i) - if embedding_queue_texts: + if embedding_queue_indices: try: - embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts) + embedding_results = self._embeddings.client.embed_documents([texts[i] for i in embedding_queue_indices]) except Exception as ex: raise self._embeddings.handle_exceptions(ex) - i = 0 - normalized_embedding_results = [] - for text in embedding_queue_texts: - hash = helper.generate_text_hash(text) + + for i, indice in enumerate(embedding_queue_indices): + hash = helper.generate_text_hash(texts[indice]) try: embedding = Embedding(model_name=self._embeddings.name, hash=hash) vector = embedding_results[i] normalized_embedding = (vector / np.linalg.norm(vector)).tolist() - normalized_embedding_results.append(normalized_embedding) + text_embeddings[indice] = normalized_embedding embedding.set_embedding(normalized_embedding) db.session.add(embedding) db.session.commit() @@ -52,10 +51,7 @@ class CacheEmbedding(Embeddings): except: logging.exception('Failed to add embedding to db') continue - finally: - i += 1 - text_embeddings.extend(normalized_embedding_results) return text_embeddings def embed_query(self, text: str) -> List[float]: