From 463fbe268047520fba99b60c123a88a4c5141884 Mon Sep 17 00:00:00 2001 From: yihong Date: Fri, 20 Dec 2024 09:28:32 +0800 Subject: [PATCH] fix: better gard nan value from numpy for issue #11827 (#11864) Signed-off-by: yihong0618 --- .../azure_openai/text_embedding/text_embedding.py | 5 ++++- .../model_providers/cohere/text_embedding/text_embedding.py | 5 ++++- .../model_providers/openai/text_embedding/text_embedding.py | 5 ++++- .../model_providers/upstage/text_embedding/text_embedding.py | 5 ++++- api/core/rag/embedding/cached_embedding.py | 2 ++ 5 files changed, 18 insertions(+), 4 deletions(-) diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index c45ce87ea7..69d2cfaded 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -92,7 +92,10 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding # calc usage usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 5fd4d637be..9e4df27060 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -88,7 +88,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding # calc usage usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index bec01fe679..9c8c8d5882 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -97,7 +97,10 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding # calc usage usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index 7dd495b55e..5b340e53bb 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -100,7 +100,10 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 652f7e145f..8ddda7e983 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -116,6 +116,8 @@ class CacheEmbedding(Embeddings): embedding_results = embedding_result.embeddings[0] embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() + if np.isnan(embedding_results).any(): + raise ValueError("Normalized embedding is nan please try again") except Exception as ex: if dify_config.DEBUG: logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")