diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 6d26a1fe6..647b2a909 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -48,6 +48,7 @@ EmbeddingModel = { "BaiduYiyan": BaiduYiyanEmbed, "Voyage AI": VoyageEmbed, "HuggingFace": HuggingFaceEmbed, + "VolcEngine":VolcEngineEmbed, } CvModel = { diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 567fde437..5ad6e2a92 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -718,3 +718,10 @@ class HuggingFaceEmbed(Base): else: raise Exception(f"Error: {response.status_code} - {response.text}") +class VolcEngineEmbed(OpenAIEmbed): + def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"): + if not base_url: + base_url = "https://ark.cn-beijing.volces.com/api/v3" + ark_api_key = json.loads(key).get('ark_api_key', '') + model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '') + super().__init__(ark_api_key,model_name,base_url)