mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 16:19:05 +08:00
VolcEngine SDK V3 adaptation (#2082)
1) Configuration interface update 2) Back-end adaptation API update Note: The official no longer supports the Skylark1/2 series, and all have been switched to the Doubao series  ### What problem does this PR solve? _Briefly describe what this PR aims to solve. Include background context that will help reviewers understand the purpose of the PR._ ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [ ] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): Co-authored-by: 海贼宅 <stu_xyx@163.com>
This commit is contained in:
parent
e953f01951
commit
7539d142a9
@ -113,13 +113,10 @@ def add_llm():
|
|||||||
|
|
||||||
if factory == "VolcEngine":
|
if factory == "VolcEngine":
|
||||||
# For VolcEngine, due to its special authentication method
|
# For VolcEngine, due to its special authentication method
|
||||||
# Assemble volc_ak, volc_sk, endpoint_id into api_key
|
# Assemble ark_api_key endpoint_id into api_key
|
||||||
temp = list(ast.literal_eval(req["llm_name"]).items())[0]
|
llm_name = req["llm_name"]
|
||||||
llm_name = temp[0]
|
api_key = '{' + f'"ark_api_key": "{req.get("ark_api_key", "")}", ' \
|
||||||
endpoint_id = temp[1]
|
f'"ep_id": "{req.get("endpoint_id", "")}", ' + '}'
|
||||||
api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \
|
|
||||||
f'"volc_sk": "{req.get("volc_sk", "")}", ' \
|
|
||||||
f'"ep_id": "{endpoint_id}", ' + '}'
|
|
||||||
elif factory == "Tencent Hunyuan":
|
elif factory == "Tencent Hunyuan":
|
||||||
api_key = '{' + f'"hunyuan_sid": "{req.get("hunyuan_sid", "")}", ' \
|
api_key = '{' + f'"hunyuan_sid": "{req.get("hunyuan_sid", "")}", ' \
|
||||||
f'"hunyuan_sk": "{req.get("hunyuan_sk", "")}"' + '}'
|
f'"hunyuan_sk": "{req.get("hunyuan_sk", "")}"' + '}'
|
||||||
|
@ -349,13 +349,19 @@
|
|||||||
"status": "1",
|
"status": "1",
|
||||||
"llm": [
|
"llm": [
|
||||||
{
|
{
|
||||||
"llm_name": "Skylark2-pro-32k",
|
"llm_name": "Doubao-pro-128k",
|
||||||
|
"tags": "LLM,CHAT,128k",
|
||||||
|
"max_tokens": 131072,
|
||||||
|
"model_type": "chat"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "Doubao-pro-32k",
|
||||||
"tags": "LLM,CHAT,32k",
|
"tags": "LLM,CHAT,32k",
|
||||||
"max_tokens": 32768,
|
"max_tokens": 32768,
|
||||||
"model_type": "chat"
|
"model_type": "chat"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"llm_name": "Skylark2-pro-4k",
|
"llm_name": "Doubao-pro-4k",
|
||||||
"tags": "LLM,CHAT,4k",
|
"tags": "LLM,CHAT,4k",
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"model_type": "chat"
|
"model_type": "chat"
|
||||||
|
@ -450,72 +450,16 @@ class LocalLLM(Base):
|
|||||||
|
|
||||||
|
|
||||||
class VolcEngineChat(Base):
|
class VolcEngineChat(Base):
|
||||||
def __init__(self, key, model_name, base_url):
|
def __init__(self, key, model_name, base_url='https://ark.cn-beijing.volces.com/api/v3'):
|
||||||
"""
|
"""
|
||||||
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
|
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
|
||||||
Assemble ak, sk, ep_id into api_key, store it as a dictionary type, and parse it for use
|
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
|
||||||
model_name is for display only
|
model_name is for display only
|
||||||
"""
|
"""
|
||||||
self.client = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
|
base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3'
|
||||||
self.volc_ak = eval(key).get('volc_ak', '')
|
ark_api_key = eval(key).get('ark_api_key', '')
|
||||||
self.volc_sk = eval(key).get('volc_sk', '')
|
model_name = eval(key).get('ep_id', '')
|
||||||
self.client.set_ak(self.volc_ak)
|
super().__init__(ark_api_key, model_name, base_url)
|
||||||
self.client.set_sk(self.volc_sk)
|
|
||||||
self.model_name = eval(key).get('ep_id', '')
|
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf):
|
|
||||||
if system:
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
try:
|
|
||||||
req = {
|
|
||||||
"parameters": {
|
|
||||||
"min_new_tokens": gen_conf.get("min_new_tokens", 1),
|
|
||||||
"top_k": gen_conf.get("top_k", 0),
|
|
||||||
"max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
|
|
||||||
"temperature": gen_conf.get("temperature", 0.1),
|
|
||||||
"max_new_tokens": gen_conf.get("max_tokens", 1000),
|
|
||||||
"top_p": gen_conf.get("top_p", 0.3),
|
|
||||||
},
|
|
||||||
"messages": history
|
|
||||||
}
|
|
||||||
response = self.client.chat(self.model_name, req)
|
|
||||||
ans = response.choices[0].message.content.strip()
|
|
||||||
if response.choices[0].finish_reason == "length":
|
|
||||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
|
||||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
|
||||||
return ans, response.usage.total_tokens
|
|
||||||
except Exception as e:
|
|
||||||
return "**ERROR**: " + str(e), 0
|
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf):
|
|
||||||
if system:
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
ans = ""
|
|
||||||
tk_count = 0
|
|
||||||
try:
|
|
||||||
req = {
|
|
||||||
"parameters": {
|
|
||||||
"min_new_tokens": gen_conf.get("min_new_tokens", 1),
|
|
||||||
"top_k": gen_conf.get("top_k", 0),
|
|
||||||
"max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
|
|
||||||
"temperature": gen_conf.get("temperature", 0.1),
|
|
||||||
"max_new_tokens": gen_conf.get("max_tokens", 1000),
|
|
||||||
"top_p": gen_conf.get("top_p", 0.3),
|
|
||||||
},
|
|
||||||
"messages": history
|
|
||||||
}
|
|
||||||
stream = self.client.stream_chat(self.model_name, req)
|
|
||||||
for resp in stream:
|
|
||||||
if not resp.choices[0].message.content:
|
|
||||||
continue
|
|
||||||
ans += resp.choices[0].message.content
|
|
||||||
if resp.choices[0].finish_reason == "stop":
|
|
||||||
tk_count = resp.usage.total_tokens
|
|
||||||
yield ans
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
|
||||||
yield tk_count
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxChat(Base):
|
class MiniMaxChat(Base):
|
||||||
|
@ -502,12 +502,11 @@ The above is the content you need to summarize.`,
|
|||||||
baseUrlNameMessage: 'Please input your base url!',
|
baseUrlNameMessage: 'Please input your base url!',
|
||||||
vision: 'Does it support Vision?',
|
vision: 'Does it support Vision?',
|
||||||
ollamaLink: 'How to integrate {{name}}',
|
ollamaLink: 'How to integrate {{name}}',
|
||||||
volcModelNameMessage:
|
volcModelNameMessage: 'Please input your model name!',
|
||||||
'Please input your model name! Format: {"ModelName":"EndpointID"}',
|
addEndpointID: 'EndpointID of the model',
|
||||||
addVolcEngineAK: 'VOLC ACCESS_KEY',
|
endpointIDMessage: 'Please input your EndpointID of the model',
|
||||||
volcAKMessage: 'Please input your VOLC_ACCESS_KEY',
|
addArkApiKey: 'VOLC ARK_API_KEY',
|
||||||
addVolcEngineSK: 'VOLC SECRET_KEY',
|
ArkApiKeyMessage: 'Please input your ARK_API_KEY',
|
||||||
volcSKMessage: 'Please input your SECRET_KEY',
|
|
||||||
bedrockModelNameMessage: 'Please input your model name!',
|
bedrockModelNameMessage: 'Please input your model name!',
|
||||||
addBedrockEngineAK: 'ACCESS KEY',
|
addBedrockEngineAK: 'ACCESS KEY',
|
||||||
bedrockAKMessage: 'Please input your ACCESS KEY',
|
bedrockAKMessage: 'Please input your ACCESS KEY',
|
||||||
|
@ -465,11 +465,11 @@ export default {
|
|||||||
modelTypeMessage: '請輸入模型類型!',
|
modelTypeMessage: '請輸入模型類型!',
|
||||||
baseUrlNameMessage: '請輸入基礎 Url!',
|
baseUrlNameMessage: '請輸入基礎 Url!',
|
||||||
ollamaLink: '如何集成 {{name}}',
|
ollamaLink: '如何集成 {{name}}',
|
||||||
volcModelNameMessage: '請輸入模型名稱!格式:{"模型名稱":"EndpointID"}',
|
volcModelNameMessage: '請輸入模型名稱!',
|
||||||
addVolcEngineAK: '火山 ACCESS_KEY',
|
addEndpointID: '模型 EndpointID',
|
||||||
volcAKMessage: '請輸入VOLC_ACCESS_KEY',
|
endpointIDMessage: '請輸入模型對應的EndpointID',
|
||||||
addVolcEngineSK: '火山 SECRET_KEY',
|
addArkApiKey: '火山 ARK_API_KEY',
|
||||||
volcSKMessage: '請輸入VOLC_SECRET_KEY',
|
ArkApiKeyMessage: '請輸入火山創建的ARK_API_KEY',
|
||||||
bedrockModelNameMessage: '請輸入名稱!',
|
bedrockModelNameMessage: '請輸入名稱!',
|
||||||
addBedrockEngineAK: 'ACCESS KEY',
|
addBedrockEngineAK: 'ACCESS KEY',
|
||||||
bedrockAKMessage: '請輸入 ACCESS KEY',
|
bedrockAKMessage: '請輸入 ACCESS KEY',
|
||||||
|
@ -482,11 +482,11 @@ export default {
|
|||||||
modelTypeMessage: '请输入模型类型!',
|
modelTypeMessage: '请输入模型类型!',
|
||||||
baseUrlNameMessage: '请输入基础 Url!',
|
baseUrlNameMessage: '请输入基础 Url!',
|
||||||
ollamaLink: '如何集成 {{name}}',
|
ollamaLink: '如何集成 {{name}}',
|
||||||
volcModelNameMessage: '请输入模型名称!格式:{"模型名称":"EndpointID"}',
|
volcModelNameMessage: '请输入模型名称!',
|
||||||
addVolcEngineAK: '火山 ACCESS_KEY',
|
addEndpointID: '模型 EndpointID',
|
||||||
volcAKMessage: '请输入VOLC_ACCESS_KEY',
|
endpointIDMessage: '请输入模型对应的EndpointID',
|
||||||
addVolcEngineSK: '火山 SECRET_KEY',
|
addArkApiKey: '火山 ARK_API_KEY',
|
||||||
volcSKMessage: '请输入VOLC_SECRET_KEY',
|
ArkApiKeyMessage: '请输入火山创建的ARK_API_KEY',
|
||||||
bedrockModelNameMessage: '请输入名称!',
|
bedrockModelNameMessage: '请输入名称!',
|
||||||
addBedrockEngineAK: 'ACCESS KEY',
|
addBedrockEngineAK: 'ACCESS KEY',
|
||||||
bedrockAKMessage: '请输入 ACCESS KEY',
|
bedrockAKMessage: '请输入 ACCESS KEY',
|
||||||
|
@ -8,6 +8,8 @@ type FieldType = IAddLlmRequestBody & {
|
|||||||
vision: boolean;
|
vision: boolean;
|
||||||
volc_ak: string;
|
volc_ak: string;
|
||||||
volc_sk: string;
|
volc_sk: string;
|
||||||
|
endpoint_id: string;
|
||||||
|
ark_api_key: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
const { Option } = Select;
|
const { Option } = Select;
|
||||||
@ -51,7 +53,7 @@ const VolcEngineModal = ({
|
|||||||
return (
|
return (
|
||||||
<Flex justify={'space-between'}>
|
<Flex justify={'space-between'}>
|
||||||
<a
|
<a
|
||||||
href="https://www.volcengine.com/docs/82379/1095322"
|
href="https://www.volcengine.com/docs/82379/1302008"
|
||||||
target="_blank"
|
target="_blank"
|
||||||
rel="noreferrer"
|
rel="noreferrer"
|
||||||
>
|
>
|
||||||
@ -88,18 +90,18 @@ const VolcEngineModal = ({
|
|||||||
<Input placeholder={t('volcModelNameMessage')} />
|
<Input placeholder={t('volcModelNameMessage')} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item<FieldType>
|
<Form.Item<FieldType>
|
||||||
label={t('addVolcEngineAK')}
|
label={t('addEndpointID')}
|
||||||
name="volc_ak"
|
name="endpoint_id"
|
||||||
rules={[{ required: true, message: t('volcAKMessage') }]}
|
rules={[{ required: true, message: t('endpointIDMessage') }]}
|
||||||
>
|
>
|
||||||
<Input placeholder={t('volcAKMessage')} />
|
<Input placeholder={t('endpointIDMessage')} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item<FieldType>
|
<Form.Item<FieldType>
|
||||||
label={t('addVolcEngineSK')}
|
label={t('addArkApiKey')}
|
||||||
name="volc_sk"
|
name="ark_api_key"
|
||||||
rules={[{ required: true, message: t('volcAKMessage') }]}
|
rules={[{ required: true, message: t('ArkApiKeyMessage') }]}
|
||||||
>
|
>
|
||||||
<Input placeholder={t('volcAKMessage')} />
|
<Input placeholder={t('ArkApiKeyMessage')} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item noStyle dependencies={['model_type']}>
|
<Form.Item noStyle dependencies={['model_type']}>
|
||||||
{({ getFieldValue }) =>
|
{({ getFieldValue }) =>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user