From b9bb11879fd05e3862cf761241e8541ca92e203a Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Fri, 31 May 2024 09:46:22 +0800 Subject: [PATCH] fix #994 (#1006) ### What problem does this PR solve? #994 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/llm/embedding_model.py | 52 ++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 5083a6945..328e43bac 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -123,30 +123,38 @@ class QWenEmbed(Base): def encode(self, texts: list, batch_size=10): import dashscope - res = [] - token_count = 0 - texts = [truncate(t, 2048) for t in texts] - for i in range(0, len(texts), batch_size): - resp = dashscope.TextEmbedding.call( - model=self.model_name, - input=texts[i:i + batch_size], - text_type="document" - ) - embds = [[] for _ in range(len(resp["output"]["embeddings"]))] - for e in resp["output"]["embeddings"]: - embds[e["text_index"]] = e["embedding"] - res.extend(embds) - token_count += resp["usage"]["total_tokens"] - return np.array(res), token_count + try: + res = [] + token_count = 0 + texts = [truncate(t, 2048) for t in texts] + for i in range(0, len(texts), batch_size): + resp = dashscope.TextEmbedding.call( + model=self.model_name, + input=texts[i:i + batch_size], + text_type="document" + ) + embds = [[] for _ in range(len(resp["output"]["embeddings"]))] + for e in resp["output"]["embeddings"]: + embds[e["text_index"]] = e["embedding"] + res.extend(embds) + token_count += resp["usage"]["total_tokens"] + return np.array(res), token_count + except Exception as e: + raise Exception("Account abnormal. Please ensure it's on good standing.") + return np.array([]), 0 def encode_queries(self, text): - resp = dashscope.TextEmbedding.call( - model=self.model_name, - input=text[:2048], - text_type="query" - ) - return np.array(resp["output"]["embeddings"][0] - ["embedding"]), resp["usage"]["total_tokens"] + try: + resp = dashscope.TextEmbedding.call( + model=self.model_name, + input=text[:2048], + text_type="query" + ) + return np.array(resp["output"]["embeddings"][0] + ["embedding"]), resp["usage"]["total_tokens"] + except Exception as e: + raise Exception("Account abnormal. Please ensure it's on good standing.") + return np.array([]), 0 class ZhipuEmbed(Base):