fix: better gard nan value from numpy for issue #11827 (#11864)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>
This commit is contained in:
yihong 2024-12-20 09:28:32 +08:00 committed by GitHub
parent 95a7e50137
commit 463fbe2680
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 18 additions and 4 deletions

View File

@ -92,7 +92,10 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
average = embeddings_batch[0] average = embeddings_batch[0]
else: else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist() embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
# calc usage # calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

View File

@ -88,7 +88,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
average = embeddings_batch[0] average = embeddings_batch[0]
else: else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist() embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
# calc usage # calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

View File

@ -97,7 +97,10 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
average = embeddings_batch[0] average = embeddings_batch[0]
else: else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist() embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
# calc usage # calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

View File

@ -100,7 +100,10 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel):
average = embeddings_batch[0] average = embeddings_batch[0]
else: else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist() embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

View File

@ -116,6 +116,8 @@ class CacheEmbedding(Embeddings):
embedding_results = embedding_result.embeddings[0] embedding_results = embedding_result.embeddings[0]
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
if np.isnan(embedding_results).any():
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex: except Exception as ex:
if dify_config.DEBUG: if dify_config.DEBUG:
logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'") logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")