diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index d7a852b0b..0a72b3f64 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json + from flask import request from flask_login import login_required, current_user from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService @@ -126,55 +128,56 @@ def add_llm(): req = request.json factory = req["llm_factory"] + def apikey_json(keys): + nonlocal req + return json.dumps({k: req.get(k, "") for k in keys}) + if factory == "VolcEngine": # For VolcEngine, due to its special authentication method # Assemble ark_api_key endpoint_id into api_key llm_name = req["llm_name"] - api_key = f'{{ "ark_api_key":"{req.get("ark_api_key", "")}", "ep_id":"{req.get("endpoint_id", "")}" }}' + api_key = apikey_json(["ark_api_key", "endpoint_id"]) + elif factory == "Tencent Hunyuan": - api_key = '{' + f'"hunyuan_sid": "{req.get("hunyuan_sid", "")}", ' \ - f'"hunyuan_sk": "{req.get("hunyuan_sk", "")}"' + '}' - req["api_key"] = api_key + req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"]) return set_api_key() + elif factory == "Tencent Cloud": - api_key = '{' + f'"tencent_cloud_sid": "{req.get("tencent_cloud_sid", "")}", ' \ - f'"tencent_cloud_sk": "{req.get("tencent_cloud_sk", "")}"' + '}' - req["api_key"] = api_key + req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"]) + elif factory == "Bedrock": # For Bedrock, due to its special authentication method # Assemble bedrock_ak, bedrock_sk, bedrock_region llm_name = req["llm_name"] - api_key = '{' + f'"bedrock_ak": "{req.get("bedrock_ak", "")}", ' \ - f'"bedrock_sk": "{req.get("bedrock_sk", "")}", ' \ - f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}' + api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"]) + elif factory == "LocalAI": llm_name = req["llm_name"]+"___LocalAI" api_key = "xxxxxxxxxxxxxxx" + elif factory == "OpenAI-API-Compatible": llm_name = req["llm_name"]+"___OpenAI-API" api_key = req.get("api_key","xxxxxxxxxxxxxxx") + elif factory =="XunFei Spark": llm_name = req["llm_name"] - api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx") + api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx") + elif factory == "BaiduYiyan": llm_name = req["llm_name"] - api_key = '{' + f'"yiyan_ak": "{req.get("yiyan_ak", "")}", ' \ - f'"yiyan_sk": "{req.get("yiyan_sk", "")}"' + '}' + api_key = apikey_json(["yiyan_ak", "yiyan_sk"]) + elif factory == "Fish Audio": llm_name = req["llm_name"] - api_key = '{' + f'"fish_audio_ak": "{req.get("fish_audio_ak", "")}", ' \ - f'"fish_audio_refid": "{req.get("fish_audio_refid", "59cb5986671546eaa6ca8ae6f29f6d22")}"' + '}' + api_key = apikey_json(["fish_audio_ak", "fish_audio_refid"]) + elif factory == "Google Cloud": llm_name = req["llm_name"] - api_key = ( - "{" + f'"google_project_id": "{req.get("google_project_id", "")}", ' - f'"google_region": "{req.get("google_region", "")}", ' - f'"google_service_account_key": "{req.get("google_service_account_key", "")}"' - + "}" - ) + api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"]) + else: llm_name = req["llm_name"] - api_key = req.get("api_key","xxxxxxxxxxxxxxx") + api_key = req.get("api_key", "xxxxxxxxxxxxxxx") llm = { "tenant_id": current_user.id, diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 500b08b87..eb3b74cf3 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -458,7 +458,7 @@ class VolcEngineChat(Base): """ base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3' ark_api_key = json.loads(key).get('ark_api_key', '') - model_name = json.loads(key).get('ep_id', '') + model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '') super().__init__(ark_api_key, model_name, base_url)