From ee9c7e204f4c5a8781ac39792521d0bb0918eec4 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Fri, 19 Jan 2024 21:37:54 +0800 Subject: [PATCH] delete document cache embedding (#2101) Co-authored-by: jyong --- api/core/embedding/cached_embedding.py | 73 +++++++++----------------- 1 file changed, 26 insertions(+), 47 deletions(-) diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 94ae1556aa..ba25eaa787 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -1,10 +1,12 @@ import base64 import json import logging -from typing import List, Optional +from typing import List, Optional, cast import numpy as np from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelPropertyKey +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from extensions.ext_database import db from langchain.embeddings.base import Embeddings @@ -22,56 +24,33 @@ class CacheEmbedding(Embeddings): self._user = user def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Embed search docs.""" - # use doc embedding cache or store if not exists - text_embeddings = [None for _ in range(len(texts))] - embedding_queue_indices = [] - for i, text in enumerate(texts): - hash = helper.generate_text_hash(text) - embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' - embedding = redis_client.get(embedding_cache_key) - if embedding: - redis_client.expire(embedding_cache_key, 3600) - text_embeddings[i] = list(np.frombuffer(base64.b64decode(embedding), dtype="float")) + """Embed search docs in batches of 10.""" + text_embeddings = [] + try: + model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) + model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials) + max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ + if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 + for i in range(0, len(texts), max_chunks): + batch_texts = texts[i:i + max_chunks] - else: - embedding_queue_indices.append(i) - - if embedding_queue_indices: - try: embedding_result = self._model_instance.invoke_text_embedding( - texts=[texts[i] for i in embedding_queue_indices], + texts=batch_texts, user=self._user ) - embedding_results = embedding_result.embeddings - except Exception as ex: - logger.error('Failed to embed documents: ', ex) - raise ex + for vector in embedding_result.embeddings: + try: + normalized_embedding = (vector / np.linalg.norm(vector)).tolist() + text_embeddings.append(normalized_embedding) + except IntegrityError: + db.session.rollback() + except Exception as e: + logging.exception('Failed to add embedding to redis') - for i, indice in enumerate(embedding_queue_indices): - hash = helper.generate_text_hash(texts[indice]) - - try: - embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' - vector = embedding_results[i] - normalized_embedding = (vector / np.linalg.norm(vector)).tolist() - text_embeddings[indice] = normalized_embedding - # encode embedding to base64 - embedding_vector = np.array(normalized_embedding) - vector_bytes = embedding_vector.tobytes() - # Transform to Base64 - encoded_vector = base64.b64encode(vector_bytes) - # Transform to string - encoded_str = encoded_vector.decode("utf-8") - redis_client.setex(embedding_cache_key, 3600, encoded_str) - - except IntegrityError: - db.session.rollback() - continue - except: - logging.exception('Failed to add embedding to redis') - continue + except Exception as ex: + logger.error('Failed to embed documents: ', ex) + raise ex return text_embeddings @@ -82,7 +61,7 @@ class CacheEmbedding(Embeddings): embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' embedding = redis_client.get(embedding_cache_key) if embedding: - redis_client.expire(embedding_cache_key, 3600) + redis_client.expire(embedding_cache_key, 600) return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) @@ -105,7 +84,7 @@ class CacheEmbedding(Embeddings): encoded_vector = base64.b64encode(vector_bytes) # Transform to string encoded_str = encoded_vector.decode("utf-8") - redis_client.setex(embedding_cache_key, 3600, encoded_str) + redis_client.setex(embedding_cache_key, 600, encoded_str) except IntegrityError: db.session.rollback()