From 35b7d17d97cfeacf48a150dbf2267b5ec40dc8ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E8=85=BE?= <101850389+hangters@users.noreply.github.com> Date: Wed, 11 Sep 2024 12:17:44 +0800 Subject: [PATCH] fix SILICONFLOW embedding error (#2363) ### What problem does this PR solve? #2335 fix SILICONFLOW embedding error ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Zhedong Cen --- rag/llm/embedding_model.py | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 0c9ce73fc..7cfd3e319 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -577,11 +577,40 @@ class UpstageEmbed(OpenAIEmbed): super().__init__(key, model_name, base_url) -class SILICONFLOWEmbed(OpenAIEmbed): - def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"): +class SILICONFLOWEmbed(Base): + def __init__( + self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings" + ): if not base_url: - base_url = "https://api.siliconflow.cn/v1" - super().__init__(key, model_name, base_url) + base_url = "https://api.siliconflow.cn/v1/embeddings" + self.headers = { + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {key}", + } + self.base_url = base_url + self.model_name = model_name + + def encode(self, texts: list, batch_size=32): + payload = { + "model": self.model_name, + "input": texts, + "encoding_format": "float", + } + res = requests.post(self.base_url, json=payload, headers=self.headers).json() + return ( + np.array([d["embedding"] for d in res["data"]]), + res["usage"]["total_tokens"], + ) + + def encode_queries(self, text): + payload = { + "model": self.model_name, + "input": text, + "encoding_format": "float", + } + res = requests.post(self.base_url, json=payload, headers=self.headers).json() + return np.array(res["data"][0]["embedding"]), res["usage"]["total_tokens"] class ReplicateEmbed(Base):