diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 2e67934d3..b705fc551 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -140,7 +140,11 @@ def add_llm(): api_key = req.get("api_key","xxxxxxxxxxxxxxx") elif factory =="XunFei Spark": llm_name = req["llm_name"] - api_key = req.get("spark_api_password","") + api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx") + elif factory == "BaiduYiyan": + llm_name = req["llm_name"] + api_key = '{' + f'"yiyan_ak": "{req.get("yiyan_ak", "")}", ' \ + f'"yiyan_sk": "{req.get("yiyan_sk", "")}"' + '}' else: llm_name = req["llm_name"] api_key = req.get("api_key","xxxxxxxxxxxxxxx") @@ -157,7 +161,7 @@ def add_llm(): msg = "" if llm["model_type"] == LLMType.EMBEDDING.value: mdl = EmbeddingModel[factory]( - key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None, + key=llm['api_key'], model_name=llm["llm_name"], base_url=llm["api_base"]) try: @@ -168,7 +172,7 @@ def add_llm(): msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e) elif llm["model_type"] == LLMType.CHAT.value: mdl = ChatModel[factory]( - key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate","XunFei Spark"] else None, + key=llm['api_key'], model_name=llm["llm_name"], base_url=llm["api_base"] ) @@ -182,7 +186,9 @@ def add_llm(): e) elif llm["model_type"] == LLMType.RERANK: mdl = RerankModel[factory]( - key=None, model_name=llm["llm_name"], base_url=llm["api_base"] + key=llm["api_key"], + model_name=llm["llm_name"], + base_url=llm["api_base"] ) try: arr, tc = mdl.similarity("Hello~ Ragflower!", ["Hi, there!"]) @@ -193,7 +199,9 @@ def add_llm(): e) elif llm["model_type"] == LLMType.IMAGE2TEXT.value: mdl = CvModel[factory]( - key=llm["api_key"] if factory in ["OpenAI-API-Compatible"] else None, model_name=llm["llm_name"], base_url=llm["api_base"] + key=llm["api_key"], + model_name=llm["llm_name"], + base_url=llm["api_base"] ) try: img_url = ( diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 62afcad83..ff2563176 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -3201,6 +3201,13 @@ "tags": "LLM", "status": "1", "llm": [] + }, + { + "name": "BaiduYiyan", + "logo": "", + "tags": "LLM", + "status": "1", + "llm": [] } ] } diff --git a/docs/quickstart.mdx b/docs/quickstart.mdx index dd691887f..d41251f4f 100644 --- a/docs/quickstart.mdx +++ b/docs/quickstart.mdx @@ -119,7 +119,7 @@ This section provides instructions on setting up the RAGFlow server on Linux. If ``` :::note - If the above steps does not work, consider using [this workaround](https://github.com/docker/for-mac/issues/7047#issuecomment-1791912053), which employs a container and does not require manual editing of the macOS settings. + If the above steps do not work, consider using [this workaround](https://github.com/docker/for-mac/issues/7047#issuecomment-1791912053), which employs a container and does not require manual editing of the macOS settings. ::: diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index b0c1b7a56..589e7d85d 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -43,7 +43,8 @@ EmbeddingModel = { "PerfXCloud": PerfXCloudEmbed, "Upstage": UpstageEmbed, "SILICONFLOW": SILICONFLOWEmbed, - "Replicate": ReplicateEmbed + "Replicate": ReplicateEmbed, + "BaiduYiyan": BaiduYiyanEmbed } @@ -101,7 +102,8 @@ ChatModel = { "01.AI": YiChat, "Replicate": ReplicateChat, "Tencent Hunyuan": HunyuanChat, - "XunFei Spark": SparkChat + "XunFei Spark": SparkChat, + "BaiduYiyan": BaiduYiyanChat } @@ -115,7 +117,8 @@ RerankModel = { "OpenAI-API-Compatible": OpenAI_APIRerank, "cohere": CoHereRerank, "TogetherAI": TogetherAIRerank, - "SILICONFLOW": SILICONFLOWRerank + "SILICONFLOW": SILICONFLOWRerank, + "BaiduYiyan": BaiduYiyanRerank } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index b6d035859..1d3fb8ccf 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1185,3 +1185,69 @@ class SparkChat(Base): } model_version = model2version[model_name] super().__init__(key, model_version, base_url) + + +class BaiduYiyanChat(Base): + def __init__(self, key, model_name, base_url=None): + import qianfan + + key = json.loads(key) + ak = key.get("yiyan_ak","") + sk = key.get("yiyan_sk","") + self.client = qianfan.ChatCompletion(ak=ak,sk=sk) + self.model_name = model_name.lower() + self.system = "" + + def chat(self, system, history, gen_conf): + if system: + self.system = system + gen_conf["penalty_score"] = ( + (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2 + ) + 1 + if "max_tokens" in gen_conf: + gen_conf["max_output_tokens"] = gen_conf["max_tokens"] + ans = "" + + try: + response = self.client.do( + model=self.model_name, + messages=history, + system=self.system, + **gen_conf + ).body + ans = response['result'] + return ans, response["usage"]["total_tokens"] + + except Exception as e: + return ans + "\n**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf): + if system: + self.system = system + gen_conf["penalty_score"] = ( + (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2 + ) + 1 + if "max_tokens" in gen_conf: + gen_conf["max_output_tokens"] = gen_conf["max_tokens"] + ans = "" + total_tokens = 0 + + try: + response = self.client.do( + model=self.model_name, + messages=history, + system=self.system, + stream=True, + **gen_conf + ) + for resp in response: + resp = resp.body + ans += resp['result'] + total_tokens = resp["usage"]["total_tokens"] + + yield ans + + except Exception as e: + return ans + "\n**ERROR**: " + str(e), 0 + + yield total_tokens diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 2045e4169..bc8241372 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -32,6 +32,7 @@ import asyncio from api.utils.file_utils import get_home_cache_dir from rag.utils import num_tokens_from_string, truncate import google.generativeai as genai +import json class Base(ABC): def __init__(self, key, model_name): @@ -591,11 +592,34 @@ class ReplicateEmbed(Base): self.client = Client(api_token=key) def encode(self, texts: list, batch_size=32): - from json import dumps - - res = self.client.run(self.model_name, input={"texts": dumps(texts)}) + res = self.client.run(self.model_name, input={"texts": json.dumps(texts)}) return np.array(res), sum([num_tokens_from_string(text) for text in texts]) def encode_queries(self, text): res = self.client.embed(self.model_name, input={"texts": [text]}) return np.array(res), num_tokens_from_string(text) + + +class BaiduYiyanEmbed(Base): + def __init__(self, key, model_name, base_url=None): + import qianfan + + key = json.loads(key) + ak = key.get("yiyan_ak", "") + sk = key.get("yiyan_sk", "") + self.client = qianfan.Embedding(ak=ak, sk=sk) + self.model_name = model_name + + def encode(self, texts: list, batch_size=32): + res = self.client.do(model=self.model_name, texts=texts).body + return ( + np.array([r["embedding"] for r in res["data"]]), + res["usage"]["total_tokens"], + ) + + def encode_queries(self, text): + res = self.client.do(model=self.model_name, texts=[text]).body + return ( + np.array([r["embedding"] for r in res["data"]]), + res["usage"]["total_tokens"], + ) diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index f452ea6c7..cc7f5525c 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -24,6 +24,7 @@ from abc import ABC import numpy as np from api.utils.file_utils import get_home_cache_dir from rag.utils import num_tokens_from_string, truncate +import json def sigmoid(x): return 1 / (1 + np.exp(-x)) @@ -288,3 +289,25 @@ class SILICONFLOWRerank(Base): rank[indexs], response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"], ) + + +class BaiduYiyanRerank(Base): + def __init__(self, key, model_name, base_url=None): + from qianfan.resources import Reranker + + key = json.loads(key) + ak = key.get("yiyan_ak", "") + sk = key.get("yiyan_sk", "") + self.client = Reranker(ak=ak, sk=sk) + self.model_name = model_name + + def similarity(self, query: str, texts: list): + res = self.client.do( + model=self.model_name, + query=query, + documents=texts, + top_n=len(texts), + ).body + rank = np.array([d["relevance_score"] for d in res["results"]]) + indexs = [d["index"] for d in res["results"]] + return rank[indexs], res["usage"]["total_tokens"] diff --git a/requirements.txt b/requirements.txt index 3dd67ceda..ee98fd1f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -62,6 +62,7 @@ pytest==8.2.2 python-dotenv==1.0.1 python_dateutil==2.8.2 python_pptx==0.6.23 +qianfan==0.4.6 readability_lxml==0.8.1 redis==5.0.3 Requests==2.32.2 diff --git a/requirements_arm.txt b/requirements_arm.txt index b03166fcd..f96c98fc5 100644 --- a/requirements_arm.txt +++ b/requirements_arm.txt @@ -100,6 +100,7 @@ python-docx==1.1.0 python-dotenv==1.0.1 python-pptx==0.6.23 PyYAML==6.0.1 +qianfan==0.4.6 redis==5.0.3 regex==2023.12.25 replicate==0.31.0 diff --git a/web/src/assets/svg/llm/yiyan.svg b/web/src/assets/svg/llm/yiyan.svg new file mode 100644 index 000000000..4c571c34a --- /dev/null +++ b/web/src/assets/svg/llm/yiyan.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index e08fbafb8..e368f6d81 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -528,6 +528,11 @@ The above is the content you need to summarize.`, SparkModelNameMessage: 'Please select Spark model', addSparkAPIPassword: 'Spark APIPassword', SparkAPIPasswordMessage: 'please input your APIPassword', + yiyanModelNameMessage: 'Please input model name', + addyiyanAK: 'yiyan API KEY', + yiyanAKMessage: 'Please input your API KEY', + addyiyanSK: 'yiyan Secret KEY', + yiyanSKMessage: 'Please input your Secret KEY', }, message: { registered: 'Registered!', diff --git a/web/src/locales/zh-traditional.ts b/web/src/locales/zh-traditional.ts index 6e27d3b1b..86afb866b 100644 --- a/web/src/locales/zh-traditional.ts +++ b/web/src/locales/zh-traditional.ts @@ -491,6 +491,11 @@ export default { SparkModelNameMessage: '請選擇星火模型!', addSparkAPIPassword: '星火 APIPassword', SparkAPIPasswordMessage: '請輸入 APIPassword', + yiyanModelNameMessage: '輸入模型名稱', + addyiyanAK: '一言 API KEY', + yiyanAKMessage: '請輸入 API KEY', + addyiyanSK: '一言 Secret KEY', + yiyanSKMessage: '請輸入 Secret KEY', }, message: { registered: '註冊成功', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 414e4c1cd..68c91cbbf 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -508,6 +508,11 @@ export default { SparkModelNameMessage: '请选择星火模型!', addSparkAPIPassword: '星火 APIPassword', SparkAPIPasswordMessage: '请输入 APIPassword', + yiyanModelNameMessage: '请输入模型名称', + addyiyanAK: '一言 API KEY', + yiyanAKMessage: '请输入 API KEY', + addyiyanSK: '一言 Secret KEY', + yiyanSKMessage: '请输入 Secret KEY', }, message: { registered: '注册成功', diff --git a/web/src/pages/user-setting/setting-model/constant.ts b/web/src/pages/user-setting/setting-model/constant.ts index de44f45ae..443e9fb3e 100644 --- a/web/src/pages/user-setting/setting-model/constant.ts +++ b/web/src/pages/user-setting/setting-model/constant.ts @@ -34,6 +34,7 @@ export const IconMap = { Replicate: 'replicate', 'Tencent Hunyuan': 'hunyuan', 'XunFei Spark': 'spark', + BaiduYiyan: 'yiyan', }; export const BedrockRegionList = [ diff --git a/web/src/pages/user-setting/setting-model/hooks.ts b/web/src/pages/user-setting/setting-model/hooks.ts index dcf33bbd0..00665d12b 100644 --- a/web/src/pages/user-setting/setting-model/hooks.ts +++ b/web/src/pages/user-setting/setting-model/hooks.ts @@ -217,6 +217,33 @@ export const useSubmitSpark = () => { }; }; +export const useSubmityiyan = () => { + const { addLlm, loading } = useAddLlm(); + const { + visible: yiyanAddingVisible, + hideModal: hideyiyanAddingModal, + showModal: showyiyanAddingModal, + } = useSetModalState(); + + const onyiyanAddingOk = useCallback( + async (payload: IAddLlmRequestBody) => { + const ret = await addLlm(payload); + if (ret === 0) { + hideyiyanAddingModal(); + } + }, + [hideyiyanAddingModal, addLlm], + ); + + return { + yiyanAddingLoading: loading, + onyiyanAddingOk, + yiyanAddingVisible, + hideyiyanAddingModal, + showyiyanAddingModal, + }; +}; + export const useSubmitBedrock = () => { const { addLlm, loading } = useAddLlm(); const { diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx index 7e0500df0..1536a5a11 100644 --- a/web/src/pages/user-setting/setting-model/index.tsx +++ b/web/src/pages/user-setting/setting-model/index.tsx @@ -39,6 +39,7 @@ import { useSubmitSpark, useSubmitSystemModelSetting, useSubmitVolcEngine, + useSubmityiyan, } from './hooks'; import HunyuanModal from './hunyuan-modal'; import styles from './index.less'; @@ -46,6 +47,7 @@ import OllamaModal from './ollama-modal'; import SparkModal from './spark-modal'; import SystemModelSettingModal from './system-model-setting-modal'; import VolcEngineModal from './volcengine-modal'; +import YiyanModal from './yiyan-modal'; const LlmIcon = ({ name }: { name: string }) => { const icon = IconMap[name as keyof typeof IconMap]; @@ -95,7 +97,8 @@ const ModelCard = ({ item, clickApiKey }: IModelCardProps) => { {isLocalLlmFactory(item.name) || item.name === 'VolcEngine' || item.name === 'Tencent Hunyuan' || - item.name === 'XunFei Spark' + item.name === 'XunFei Spark' || + item.name === 'BaiduYiyan' ? t('addTheModel') : 'API-Key'} @@ -185,6 +188,14 @@ const UserSettingModel = () => { SparkAddingLoading, } = useSubmitSpark(); + const { + yiyanAddingVisible, + hideyiyanAddingModal, + showyiyanAddingModal, + onyiyanAddingOk, + yiyanAddingLoading, + } = useSubmityiyan(); + const { bedrockAddingLoading, onBedrockAddingOk, @@ -199,12 +210,14 @@ const UserSettingModel = () => { VolcEngine: showVolcAddingModal, 'Tencent Hunyuan': showHunyuanAddingModal, 'XunFei Spark': showSparkAddingModal, + BaiduYiyan: showyiyanAddingModal, }), [ showBedrockAddingModal, showVolcAddingModal, showHunyuanAddingModal, showSparkAddingModal, + showyiyanAddingModal, ], ); @@ -330,6 +343,13 @@ const UserSettingModel = () => { loading={SparkAddingLoading} llmFactory={'XunFei Spark'} > + & { llmFactory: string }) => { + const [form] = Form.useForm(); + + const { t } = useTranslate('setting'); + + const handleOk = async () => { + const values = await form.validateFields(); + const modelType = + values.model_type === 'chat' && values.vision + ? 'image2text' + : values.model_type; + + const data = { + ...omit(values, ['vision']), + model_type: modelType, + llm_factory: llmFactory, + }; + console.info(data); + + onOk?.(data); + }; + + return ( + +
+ + label={t('modelType')} + name="model_type" + initialValue={'chat'} + rules={[{ required: true, message: t('modelTypeMessage') }]} + > + + + + label={t('modelName')} + name="llm_name" + rules={[{ required: true, message: t('yiyanModelNameMessage') }]} + > + + + + label={t('addyiyanAK')} + name="yiyan_ak" + rules={[{ required: true, message: t('yiyanAKMessage') }]} + > + + + + label={t('addyiyanSK')} + name="yiyan_sk" + rules={[{ required: true, message: t('yiyanSKMessage') }]} + > + + + +
+ ); +}; + +export default YiyanModal;