From 8b357769164f387518a5b88c36044b6449ad8a56 Mon Sep 17 00:00:00 2001 From: liuhua <10215101452@stu.ecnu.edu.cn> Date: Wed, 27 Nov 2024 09:30:49 +0800 Subject: [PATCH] Fix a bug in VolcEngine (#3658) ### What problem does this PR solve? Fix a bug in VolcEngine #3553 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn> --- rag/llm/__init__.py | 1 + rag/llm/embedding_model.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 6d26a1fe6..647b2a909 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -48,6 +48,7 @@ EmbeddingModel = { "BaiduYiyan": BaiduYiyanEmbed, "Voyage AI": VoyageEmbed, "HuggingFace": HuggingFaceEmbed, + "VolcEngine":VolcEngineEmbed, } CvModel = { diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 567fde437..5ad6e2a92 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -718,3 +718,10 @@ class HuggingFaceEmbed(Base): else: raise Exception(f"Error: {response.status_code} - {response.text}") +class VolcEngineEmbed(OpenAIEmbed): + def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"): + if not base_url: + base_url = "https://ark.cn-beijing.volces.com/api/v3" + ark_api_key = json.loads(key).get('ark_api_key', '') + model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '') + super().__init__(ark_api_key,model_name,base_url)