mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 01:39:00 +08:00
refactor add LLM (#2508)
### What problem does this PR solve? #2487 ### Type of change - [x] Refactoring
This commit is contained in:
parent
4f962d6bff
commit
5968f148bc
@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
||||
@ -126,55 +128,56 @@ def add_llm():
|
||||
req = request.json
|
||||
factory = req["llm_factory"]
|
||||
|
||||
def apikey_json(keys):
|
||||
nonlocal req
|
||||
return json.dumps({k: req.get(k, "") for k in keys})
|
||||
|
||||
if factory == "VolcEngine":
|
||||
# For VolcEngine, due to its special authentication method
|
||||
# Assemble ark_api_key endpoint_id into api_key
|
||||
llm_name = req["llm_name"]
|
||||
api_key = f'{{ "ark_api_key":"{req.get("ark_api_key", "")}", "ep_id":"{req.get("endpoint_id", "")}" }}'
|
||||
api_key = apikey_json(["ark_api_key", "endpoint_id"])
|
||||
|
||||
elif factory == "Tencent Hunyuan":
|
||||
api_key = '{' + f'"hunyuan_sid": "{req.get("hunyuan_sid", "")}", ' \
|
||||
f'"hunyuan_sk": "{req.get("hunyuan_sk", "")}"' + '}'
|
||||
req["api_key"] = api_key
|
||||
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
|
||||
return set_api_key()
|
||||
|
||||
elif factory == "Tencent Cloud":
|
||||
api_key = '{' + f'"tencent_cloud_sid": "{req.get("tencent_cloud_sid", "")}", ' \
|
||||
f'"tencent_cloud_sk": "{req.get("tencent_cloud_sk", "")}"' + '}'
|
||||
req["api_key"] = api_key
|
||||
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
|
||||
|
||||
elif factory == "Bedrock":
|
||||
# For Bedrock, due to its special authentication method
|
||||
# Assemble bedrock_ak, bedrock_sk, bedrock_region
|
||||
llm_name = req["llm_name"]
|
||||
api_key = '{' + f'"bedrock_ak": "{req.get("bedrock_ak", "")}", ' \
|
||||
f'"bedrock_sk": "{req.get("bedrock_sk", "")}", ' \
|
||||
f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}'
|
||||
api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"])
|
||||
|
||||
elif factory == "LocalAI":
|
||||
llm_name = req["llm_name"]+"___LocalAI"
|
||||
api_key = "xxxxxxxxxxxxxxx"
|
||||
|
||||
elif factory == "OpenAI-API-Compatible":
|
||||
llm_name = req["llm_name"]+"___OpenAI-API"
|
||||
api_key = req.get("api_key","xxxxxxxxxxxxxxx")
|
||||
|
||||
elif factory =="XunFei Spark":
|
||||
llm_name = req["llm_name"]
|
||||
api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx")
|
||||
api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx")
|
||||
|
||||
elif factory == "BaiduYiyan":
|
||||
llm_name = req["llm_name"]
|
||||
api_key = '{' + f'"yiyan_ak": "{req.get("yiyan_ak", "")}", ' \
|
||||
f'"yiyan_sk": "{req.get("yiyan_sk", "")}"' + '}'
|
||||
api_key = apikey_json(["yiyan_ak", "yiyan_sk"])
|
||||
|
||||
elif factory == "Fish Audio":
|
||||
llm_name = req["llm_name"]
|
||||
api_key = '{' + f'"fish_audio_ak": "{req.get("fish_audio_ak", "")}", ' \
|
||||
f'"fish_audio_refid": "{req.get("fish_audio_refid", "59cb5986671546eaa6ca8ae6f29f6d22")}"' + '}'
|
||||
api_key = apikey_json(["fish_audio_ak", "fish_audio_refid"])
|
||||
|
||||
elif factory == "Google Cloud":
|
||||
llm_name = req["llm_name"]
|
||||
api_key = (
|
||||
"{" + f'"google_project_id": "{req.get("google_project_id", "")}", '
|
||||
f'"google_region": "{req.get("google_region", "")}", '
|
||||
f'"google_service_account_key": "{req.get("google_service_account_key", "")}"'
|
||||
+ "}"
|
||||
)
|
||||
api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"])
|
||||
|
||||
else:
|
||||
llm_name = req["llm_name"]
|
||||
api_key = req.get("api_key","xxxxxxxxxxxxxxx")
|
||||
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")
|
||||
|
||||
llm = {
|
||||
"tenant_id": current_user.id,
|
||||
|
@ -458,7 +458,7 @@ class VolcEngineChat(Base):
|
||||
"""
|
||||
base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3'
|
||||
ark_api_key = json.loads(key).get('ark_api_key', '')
|
||||
model_name = json.loads(key).get('ep_id', '')
|
||||
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
|
||||
super().__init__(ark_api_key, model_name, base_url)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user