### What problem does this PR solve?

#994 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
KevinHuSh 2024-05-31 09:46:22 +08:00 committed by GitHub
parent dc7afe46fb
commit b9bb11879f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -123,30 +123,38 @@ class QWenEmbed(Base):
def encode(self, texts: list, batch_size=10): def encode(self, texts: list, batch_size=10):
import dashscope import dashscope
res = [] try:
token_count = 0 res = []
texts = [truncate(t, 2048) for t in texts] token_count = 0
for i in range(0, len(texts), batch_size): texts = [truncate(t, 2048) for t in texts]
resp = dashscope.TextEmbedding.call( for i in range(0, len(texts), batch_size):
model=self.model_name, resp = dashscope.TextEmbedding.call(
input=texts[i:i + batch_size], model=self.model_name,
text_type="document" input=texts[i:i + batch_size],
) text_type="document"
embds = [[] for _ in range(len(resp["output"]["embeddings"]))] )
for e in resp["output"]["embeddings"]: embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
embds[e["text_index"]] = e["embedding"] for e in resp["output"]["embeddings"]:
res.extend(embds) embds[e["text_index"]] = e["embedding"]
token_count += resp["usage"]["total_tokens"] res.extend(embds)
return np.array(res), token_count token_count += resp["usage"]["total_tokens"]
return np.array(res), token_count
except Exception as e:
raise Exception("Account abnormal. Please ensure it's on good standing.")
return np.array([]), 0
def encode_queries(self, text): def encode_queries(self, text):
resp = dashscope.TextEmbedding.call( try:
model=self.model_name, resp = dashscope.TextEmbedding.call(
input=text[:2048], model=self.model_name,
text_type="query" input=text[:2048],
) text_type="query"
return np.array(resp["output"]["embeddings"][0] )
["embedding"]), resp["usage"]["total_tokens"] return np.array(resp["output"]["embeddings"][0]
["embedding"]), resp["usage"]["total_tokens"]
except Exception as e:
raise Exception("Account abnormal. Please ensure it's on good standing.")
return np.array([]), 0
class ZhipuEmbed(Base): class ZhipuEmbed(Base):