mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-19 22:19:12 +08:00
Merge branch 'main' into feat/attachments
This commit is contained in:
commit
b322dda3f6
@ -64,4 +64,6 @@ class DifyConfig(
|
||||
return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
|
||||
|
||||
SSRF_PROXY_HTTP_URL: str | None = None
|
||||
SSRF_PROXY_HTTPS_URL: str | None = None
|
||||
SSRF_PROXY_HTTPS_URL: str | None = None
|
||||
|
||||
MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.')
|
||||
|
@ -13,18 +13,10 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
|
||||
:param file_name: the YAML file name, default to '_position.yaml'
|
||||
:return: a dict with name as key and index as value
|
||||
"""
|
||||
position_file_name = os.path.join(folder_path, file_name)
|
||||
if not position_file_name or not os.path.exists(position_file_name):
|
||||
return {}
|
||||
|
||||
positions = load_yaml_file(position_file_name, ignore_error=True)
|
||||
position_map = {}
|
||||
index = 0
|
||||
for _, name in enumerate(positions):
|
||||
if name and isinstance(name, str):
|
||||
position_map[name.strip()] = index
|
||||
index += 1
|
||||
return position_map
|
||||
position_file_path = os.path.join(folder_path, file_name)
|
||||
yaml_content = load_yaml_file(file_path=position_file_path, default_value=[])
|
||||
positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()]
|
||||
return {name: index for index, name in enumerate(positions)}
|
||||
|
||||
|
||||
def sort_by_position_map(
|
||||
|
@ -162,7 +162,7 @@ class AIModel(ABC):
|
||||
# traverse all model_schema_yaml_paths
|
||||
for model_schema_yaml_path in model_schema_yaml_paths:
|
||||
# read yaml data from yaml file
|
||||
yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)
|
||||
yaml_data = load_yaml_file(model_schema_yaml_path)
|
||||
|
||||
new_parameter_rules = []
|
||||
for parameter_rule in yaml_data.get('parameter_rules', []):
|
||||
|
@ -44,7 +44,7 @@ class ModelProvider(ABC):
|
||||
|
||||
# read provider schema from yaml file
|
||||
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
||||
yaml_data = load_yaml_file(yaml_path, ignore_error=True)
|
||||
yaml_data = load_yaml_file(yaml_path)
|
||||
|
||||
try:
|
||||
# yaml_data to entity
|
||||
|
@ -23,7 +23,7 @@ parameter_rules:
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
@ -57,6 +57,18 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,降低模型重复相同内容的可能性。
|
||||
en_US: A number between -2.0 and 2.0. If the value is positive, new tokens are penalized based on their frequency of occurrence in existing text, reducing the likelihood that the model will repeat the same content.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '1'
|
||||
output: '2'
|
||||
|
@ -18,6 +18,7 @@ help:
|
||||
en_US: https://console.cloud.tencent.com/cam/capi
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
|
@ -0,0 +1,5 @@
|
||||
model: hunyuan-embedding
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 1024
|
||||
max_chunks: 1
|
@ -0,0 +1,173 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.exception import TencentCloudSDKException
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class HunyuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Hunyuan text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
if model != 'hunyuan-embedding':
|
||||
raise ValueError('Invalid model name')
|
||||
|
||||
client = self._setup_hunyuan_client(credentials)
|
||||
|
||||
embeddings = []
|
||||
token_usage = 0
|
||||
|
||||
for input in texts:
|
||||
request = models.GetEmbeddingRequest()
|
||||
params = {
|
||||
"Input": input
|
||||
}
|
||||
request.from_json_string(json.dumps(params))
|
||||
response = client.GetEmbedding(request)
|
||||
usage = response.Usage.TotalTokens
|
||||
|
||||
embeddings.extend([data.Embedding for data in response.Data])
|
||||
token_usage += usage
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=embeddings,
|
||||
usage=self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=token_usage
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate credentials
|
||||
"""
|
||||
try:
|
||||
client = self._setup_hunyuan_client(credentials)
|
||||
|
||||
req = models.ChatCompletionsRequest()
|
||||
params = {
|
||||
"Model": model,
|
||||
"Messages": [{
|
||||
"Role": "user",
|
||||
"Content": "hello"
|
||||
}],
|
||||
"TopP": 1,
|
||||
"Temperature": 0,
|
||||
"Stream": False
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
client.ChatCompletions(req)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
||||
|
||||
def _setup_hunyuan_client(self, credentials):
|
||||
secret_id = credentials['secret_id']
|
||||
secret_key = credentials['secret_key']
|
||||
cred = credential.Credential(secret_id, secret_key)
|
||||
httpProfile = HttpProfile()
|
||||
httpProfile.endpoint = "hunyuan.tencentcloudapi.com"
|
||||
clientProfile = ClientProfile()
|
||||
clientProfile.httpProfile = httpProfile
|
||||
client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
|
||||
return client
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeError: [TencentCloudSDKException],
|
||||
}
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
# client = self._setup_hunyuan_client(credentials)
|
||||
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
# use client.GetTokenCount to get num tokens
|
||||
# request = models.GetTokenCountRequest()
|
||||
# params = {
|
||||
# "Prompt": text
|
||||
# }
|
||||
# request.from_json_string(json.dumps(params))
|
||||
# response = client.GetTokenCount(request)
|
||||
# num_tokens += response.TokenCount
|
||||
|
||||
return num_tokens
|
@ -0,0 +1,30 @@
|
||||
model: deepseek-ai/DeepSeek-Coder-V2-Instruct
|
||||
label:
|
||||
en_US: deepseek-ai/DeepSeek-Coder-V2-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '1.33'
|
||||
output: '1.33'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
@ -1,11 +1,9 @@
|
||||
model: deepseek-ai/deepseek-v2-chat
|
||||
label:
|
||||
en_US: deepseek-ai/deepseek-v2-chat
|
||||
en_US: deepseek-ai/DeepSeek-V2-Chat
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
|
@ -1,11 +1,9 @@
|
||||
model: zhipuai/glm4-9B-chat
|
||||
label:
|
||||
en_US: zhipuai/glm4-9B-chat
|
||||
en_US: THUDM/glm-4-9b-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
|
@ -1,11 +1,9 @@
|
||||
model: alibaba/Qwen2-57B-A14B-Instruct
|
||||
label:
|
||||
en_US: alibaba/Qwen2-57B-A14B-Instruct
|
||||
en_US: Qwen/Qwen2-57B-A14B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
|
@ -1,11 +1,9 @@
|
||||
model: alibaba/Qwen2-72B-Instruct
|
||||
label:
|
||||
en_US: alibaba/Qwen2-72B-Instruct
|
||||
en_US: Qwen/Qwen2-72B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
|
@ -1,11 +1,9 @@
|
||||
model: alibaba/Qwen2-7B-Instruct
|
||||
label:
|
||||
en_US: alibaba/Qwen2-7B-Instruct
|
||||
en_US: Qwen/Qwen2-7B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
|
@ -1,11 +1,9 @@
|
||||
model: 01-ai/Yi-1.5-34B-Chat
|
||||
label:
|
||||
en_US: 01-ai/Yi-1.5-34B-Chat
|
||||
en_US: 01-ai/Yi-1.5-34B-Chat-16K
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16384
|
||||
|
@ -3,9 +3,7 @@ label:
|
||||
en_US: 01-ai/Yi-1.5-6B-Chat
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
|
@ -1,11 +1,9 @@
|
||||
model: 01-ai/Yi-1.5-9B-Chat
|
||||
label:
|
||||
en_US: 01-ai/Yi-1.5-9B-Chat
|
||||
en_US: 01-ai/Yi-1.5-9B-Chat-16K
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16384
|
||||
|
@ -19,7 +19,7 @@ class SiliconflowProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='deepseek-ai/deepseek-v2-chat',
|
||||
model='deepseek-ai/DeepSeek-V2-Chat',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
@ -21,8 +21,6 @@ class ModerationRule(BaseModel):
|
||||
|
||||
|
||||
class OutputModeration(BaseModel):
|
||||
DEFAULT_BUFFER_SIZE: int = 300
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
|
||||
@ -77,10 +75,10 @@ class OutputModeration(BaseModel):
|
||||
return final_output
|
||||
|
||||
def start_thread(self) -> threading.Thread:
|
||||
buffer_size = int(dify_config.config.MODERATION_BUFFER_SIZE)
|
||||
buffer_size = dify_config.MODERATION_BUFFER_SIZE
|
||||
thread = threading.Thread(target=self.worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
|
||||
'buffer_size': buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE
|
||||
})
|
||||
|
||||
thread.start()
|
||||
|
@ -298,34 +298,29 @@ class TraceTask:
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def execute(self):
|
||||
method_name, trace_info = self.preprocess()
|
||||
return trace_info
|
||||
return self.preprocess()
|
||||
|
||||
def preprocess(self):
|
||||
if self.trace_type == TraceTaskName.CONVERSATION_TRACE:
|
||||
return TraceTaskName.CONVERSATION_TRACE, self.conversation_trace(**self.kwargs)
|
||||
if self.trace_type == TraceTaskName.WORKFLOW_TRACE:
|
||||
return TraceTaskName.WORKFLOW_TRACE, self.workflow_trace(self.workflow_run, self.conversation_id)
|
||||
elif self.trace_type == TraceTaskName.MESSAGE_TRACE:
|
||||
return TraceTaskName.MESSAGE_TRACE, self.message_trace(self.message_id)
|
||||
elif self.trace_type == TraceTaskName.MODERATION_TRACE:
|
||||
return TraceTaskName.MODERATION_TRACE, self.moderation_trace(self.message_id, self.timer, **self.kwargs)
|
||||
elif self.trace_type == TraceTaskName.SUGGESTED_QUESTION_TRACE:
|
||||
return TraceTaskName.SUGGESTED_QUESTION_TRACE, self.suggested_question_trace(
|
||||
preprocess_map = {
|
||||
TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
|
||||
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(self.workflow_run, self.conversation_id),
|
||||
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id),
|
||||
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
|
||||
self.message_id, self.timer, **self.kwargs
|
||||
)
|
||||
elif self.trace_type == TraceTaskName.DATASET_RETRIEVAL_TRACE:
|
||||
return TraceTaskName.DATASET_RETRIEVAL_TRACE, self.dataset_retrieval_trace(
|
||||
),
|
||||
TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
|
||||
self.message_id, self.timer, **self.kwargs
|
||||
)
|
||||
elif self.trace_type == TraceTaskName.TOOL_TRACE:
|
||||
return TraceTaskName.TOOL_TRACE, self.tool_trace(self.message_id, self.timer, **self.kwargs)
|
||||
elif self.trace_type == TraceTaskName.GENERATE_NAME_TRACE:
|
||||
return TraceTaskName.GENERATE_NAME_TRACE, self.generate_name_trace(
|
||||
),
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
|
||||
self.message_id, self.timer, **self.kwargs
|
||||
),
|
||||
TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(self.message_id, self.timer, **self.kwargs),
|
||||
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
|
||||
self.conversation_id, self.timer, **self.kwargs
|
||||
)
|
||||
else:
|
||||
return '', {}
|
||||
),
|
||||
}
|
||||
|
||||
return preprocess_map.get(self.trace_type, lambda: None)()
|
||||
|
||||
# process methods for different trace types
|
||||
def conversation_trace(self, **kwargs):
|
||||
|
@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymilvus import MilvusClient, MilvusException, connections
|
||||
from pymilvus.milvus_client import IndexParams
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
@ -250,11 +251,15 @@ class MilvusVector(BaseVector):
|
||||
# Since primary field is auto-id, no need to track it
|
||||
self._fields.remove(Field.PRIMARY_KEY.value)
|
||||
|
||||
# Create Index params for the collection
|
||||
index_params_obj = IndexParams()
|
||||
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
|
||||
|
||||
# Create the collection
|
||||
collection_name = self._collection_name
|
||||
self._client.create_collection_with_schema(collection_name=collection_name,
|
||||
schema=schema, index_param=index_params,
|
||||
consistency_level=self._consistency_level)
|
||||
self._client.create_collection(collection_name=collection_name,
|
||||
schema=schema, index_params=index_params_obj,
|
||||
consistency_level=self._consistency_level)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _init_client(self, config) -> MilvusClient:
|
||||
|
@ -27,7 +27,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
provider = self.__class__.__module__.split('.')[-1]
|
||||
yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
|
||||
try:
|
||||
provider_yaml = load_yaml_file(yaml_path)
|
||||
provider_yaml = load_yaml_file(yaml_path, ignore_error=False)
|
||||
except Exception as e:
|
||||
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}')
|
||||
|
||||
@ -58,7 +58,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
for tool_file in tool_files:
|
||||
# get tool name
|
||||
tool_name = tool_file.split(".")[0]
|
||||
tool = load_yaml_file(path.join(tool_path, tool_file))
|
||||
tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)
|
||||
|
||||
# get tool class, import the module
|
||||
assistant_tool_class = load_single_subclass_from_source(
|
||||
|
@ -1,35 +1,32 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict:
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
|
||||
"""
|
||||
Safe loading a YAML file to a dict
|
||||
Safe loading a YAML file
|
||||
:param file_path: the path of the YAML file
|
||||
:param ignore_error:
|
||||
if True, return empty dict if error occurs and the error will be logged in warning level
|
||||
if True, return default_value if error occurs and the error will be logged in debug level
|
||||
if False, raise error if error occurs
|
||||
:return: a dict of the YAML content
|
||||
:param default_value: the value returned when errors ignored
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
try:
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found')
|
||||
|
||||
with open(file_path, encoding='utf-8') as file:
|
||||
with open(file_path, encoding='utf-8') as yaml_file:
|
||||
try:
|
||||
return yaml.safe_load(file)
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content if yaml_content else default_value
|
||||
except Exception as e:
|
||||
raise YAMLError(f'Failed to load YAML file {file_path}: {e}')
|
||||
except FileNotFoundError as e:
|
||||
logger.debug(f'Failed to load YAML file {file_path}: {e}')
|
||||
return {}
|
||||
except Exception as e:
|
||||
if ignore_error:
|
||||
logger.warning(f'Failed to load YAML file {file_path}: {e}')
|
||||
return {}
|
||||
logger.debug(f'Failed to load YAML file {file_path}: {e}')
|
||||
return default_value
|
||||
else:
|
||||
raise e
|
||||
|
111
api/poetry.lock
generated
111
api/poetry.lock
generated
@ -448,63 +448,6 @@ doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphin
|
||||
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
|
||||
trio = ["trio (>=0.23)"]
|
||||
|
||||
[[package]]
|
||||
name = "argon2-cffi"
|
||||
version = "23.1.0"
|
||||
description = "Argon2 for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "argon2_cffi-23.1.0-py3-none-any.whl", hash = "sha256:c670642b78ba29641818ab2e68bd4e6a78ba53b7eff7b4c3815ae16abf91c7ea"},
|
||||
{file = "argon2_cffi-23.1.0.tar.gz", hash = "sha256:879c3e79a2729ce768ebb7d36d4609e3a78a4ca2ec3a9f12286ca057e3d0db08"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
argon2-cffi-bindings = "*"
|
||||
|
||||
[package.extras]
|
||||
dev = ["argon2-cffi[tests,typing]", "tox (>4)"]
|
||||
docs = ["furo", "myst-parser", "sphinx", "sphinx-copybutton", "sphinx-notfound-page"]
|
||||
tests = ["hypothesis", "pytest"]
|
||||
typing = ["mypy"]
|
||||
|
||||
[[package]]
|
||||
name = "argon2-cffi-bindings"
|
||||
version = "21.2.0"
|
||||
description = "Low-level CFFI bindings for Argon2"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "argon2-cffi-bindings-21.2.0.tar.gz", hash = "sha256:bb89ceffa6c791807d1305ceb77dbfacc5aa499891d2c55661c6459651fc39e3"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ccb949252cb2ab3a08c02024acb77cfb179492d5701c7cbdbfd776124d4d2367"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9524464572e12979364b7d600abf96181d3541da11e23ddf565a32e70bd4dc0d"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b746dba803a79238e925d9046a63aa26bf86ab2a2fe74ce6b009a1c3f5c8f2ae"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:58ed19212051f49a523abb1dbe954337dc82d947fb6e5a0da60f7c8471a8476c"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:bd46088725ef7f58b5a1ef7ca06647ebaf0eb4baff7d1d0d177c6cc8744abd86"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_i686.whl", hash = "sha256:8cd69c07dd875537a824deec19f978e0f2078fdda07fd5c42ac29668dda5f40f"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:f1152ac548bd5b8bcecfb0b0371f082037e47128653df2e8ba6e914d384f3c3e"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-win32.whl", hash = "sha256:603ca0aba86b1349b147cab91ae970c63118a0f30444d4bc80355937c950c082"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-win_amd64.whl", hash = "sha256:b2ef1c30440dbbcba7a5dc3e319408b59676e2e039e2ae11a8775ecf482b192f"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e415e3f62c8d124ee16018e491a009937f8cf7ebf5eb430ffc5de21b900dad93"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3e385d1c39c520c08b53d63300c3ecc28622f076f4c2b0e6d7e796e9f6502194"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c3e3cc67fdb7d82c4718f19b4e7a87123caf8a93fde7e23cf66ac0337d3cb3f"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a22ad9800121b71099d0fb0a65323810a15f2e292f2ba450810a7316e128ee5"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f9f8b450ed0547e3d473fdc8612083fd08dd2120d6ac8f73828df9b7d45bb351"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:93f9bf70084f97245ba10ee36575f0c3f1e7d7724d67d8e5b08e61787c320ed7"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3b9ef65804859d335dc6b31582cad2c5166f0c3e7975f324d9ffaa34ee7e6583"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4966ef5848d820776f5f562a7d45fdd70c2f330c961d0d745b784034bd9f48d"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20ef543a89dee4db46a1a6e206cd015360e5a75822f76df533845c3cbaf72670"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ed2937d286e2ad0cc79a7087d3c272832865f779430e0cc2b4f3718d3159b0cb"},
|
||||
{file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5e00316dabdaea0b2dd82d141cc66889ced0cdcbfa599e8b471cf22c620c329a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cffi = ">=1.0.1"
|
||||
|
||||
[package.extras]
|
||||
dev = ["cogapp", "pre-commit", "pytest", "wheel"]
|
||||
tests = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "arxiv"
|
||||
version = "2.1.0"
|
||||
@ -4616,22 +4559,20 @@ files = [
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "minio"
|
||||
version = "7.2.7"
|
||||
description = "MinIO Python SDK for Amazon S3 Compatible Cloud Storage"
|
||||
name = "milvus-lite"
|
||||
version = "2.4.8"
|
||||
description = "A lightweight version of Milvus wrapped with Python."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "minio-7.2.7-py3-none-any.whl", hash = "sha256:59d1f255d852fe7104018db75b3bebbd987e538690e680f7c5de835e422de837"},
|
||||
{file = "minio-7.2.7.tar.gz", hash = "sha256:473d5d53d79f340f3cd632054d0c82d2f93177ce1af2eac34a235bea55708d98"},
|
||||
{file = "milvus_lite-2.4.8-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:b7e90b34b214884cd44cdc112ab243d4cb197b775498355e2437b6cafea025fe"},
|
||||
{file = "milvus_lite-2.4.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:519dfc62709d8f642d98a1c5b1dcde7080d107e6e312d677fef5a3412a40ac08"},
|
||||
{file = "milvus_lite-2.4.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b21f36d24cbb0e920b4faad607019bb28c1b2c88b4d04680ac8c7697a4ae8a4d"},
|
||||
{file = "milvus_lite-2.4.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:08332a2b9abfe7c4e1d7926068937e46f8fb81f2707928b7bc02c9dc99cebe41"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
argon2-cffi = "*"
|
||||
certifi = "*"
|
||||
pycryptodome = "*"
|
||||
typing-extensions = "*"
|
||||
urllib3 = "*"
|
||||
tqdm = "*"
|
||||
|
||||
[[package]]
|
||||
name = "mmh3"
|
||||
@ -6078,6 +6019,19 @@ files = [
|
||||
{file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"},
|
||||
{file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"},
|
||||
{file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"},
|
||||
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"},
|
||||
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"},
|
||||
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"},
|
||||
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"},
|
||||
{file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"},
|
||||
{file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"},
|
||||
{file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"},
|
||||
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"},
|
||||
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"},
|
||||
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"},
|
||||
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"},
|
||||
{file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"},
|
||||
{file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -6374,24 +6328,29 @@ tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pymilvus"
|
||||
version = "2.3.1"
|
||||
version = "2.4.4"
|
||||
description = "Python Sdk for Milvus"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pymilvus-2.3.1-py3-none-any.whl", hash = "sha256:ce65e1de8700f33bd9aade20f013291629702e25b05726773208f1f0b22548ff"},
|
||||
{file = "pymilvus-2.3.1.tar.gz", hash = "sha256:d460f6204d7deb2cff93716bd65670c1b440694b77701fb0ab0ead791aa582c6"},
|
||||
{file = "pymilvus-2.4.4-py3-none-any.whl", hash = "sha256:073b76bc36f6f4e70f0f0a0023a53324f0ba8ef9a60883f87cd30a44b6c6f2b5"},
|
||||
{file = "pymilvus-2.4.4.tar.gz", hash = "sha256:50c53eb103e034fbffe936fe942751ea3dbd2452e18cf79acc52360ed4987fb7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
environs = "<=9.5.0"
|
||||
grpcio = ">=1.49.1,<=1.58.0"
|
||||
minio = "*"
|
||||
grpcio = ">=1.49.1,<=1.63.0"
|
||||
milvus-lite = {version = ">=2.4.0,<2.5.0", markers = "sys_platform != \"win32\""}
|
||||
pandas = ">=1.2.4"
|
||||
protobuf = ">=3.20.0"
|
||||
requests = "*"
|
||||
setuptools = ">=67"
|
||||
ujson = ">=2.0.0"
|
||||
|
||||
[package.extras]
|
||||
bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "requests"]
|
||||
dev = ["black", "grpcio (==1.62.2)", "grpcio-testing (==1.62.2)", "grpcio-tools (==1.62.2)", "pytest (>=5.3.4)", "pytest-cov (>=2.8.1)", "pytest-timeout (>=1.3.4)", "ruff (>0.4.0)"]
|
||||
model = ["milvus-model (>=0.1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pymysql"
|
||||
version = "1.1.1"
|
||||
@ -9543,4 +9502,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "9619ddabdd67710981c13dcfa3ddae0a48497c9f694afc81b820e882440c1265"
|
||||
content-hash = "a8b61d74d9322302b7447b6f8728ad606abc160202a8a122a05a8ef3cec7055b"
|
||||
|
@ -206,7 +206,7 @@ chromadb = "0.5.1"
|
||||
oracledb = "~2.2.1"
|
||||
pgvecto-rs = "0.1.4"
|
||||
pgvector = "0.2.5"
|
||||
pymilvus = "2.3.1"
|
||||
pymilvus = "~2.4.4"
|
||||
pymysql = "1.1.1"
|
||||
tcvectordb = "1.3.2"
|
||||
tidb-vector = "0.0.9"
|
||||
@ -216,18 +216,6 @@ alibabacloud_gpdb20160503 = "~3.8.0"
|
||||
alibabacloud_tea_openapi = "~0.3.9"
|
||||
clickhouse-connect = "~0.7.16"
|
||||
|
||||
############################################################
|
||||
# Transparent dependencies required by main dependencies
|
||||
# for pinning versions
|
||||
############################################################
|
||||
|
||||
[tool.poetry.group.transparent.dependencies]
|
||||
kaleido = "0.2.1"
|
||||
lxml = "5.1.0"
|
||||
sympy = "1.12"
|
||||
tenacity = "~8.3.0"
|
||||
xlrd = "~2.0.1"
|
||||
|
||||
############################################################
|
||||
# Dev dependencies for running tests
|
||||
############################################################
|
||||
|
@ -0,0 +1,104 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.hunyuan.text_embedding.text_embedding import HunyuanTextEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = HunyuanTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='hunyuan-embedding',
|
||||
credentials={
|
||||
'secret_id': 'invalid_key',
|
||||
'secret_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='hunyuan-embedding',
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = HunyuanTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='hunyuan-embedding',
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 6
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = HunyuanTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='hunyuan-embedding',
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
||||
def test_max_chunks():
|
||||
model = HunyuanTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='hunyuan-embedding',
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 22
|
@ -0,0 +1,106 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.siliconflow.llm.llm import SiliconflowLargeLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = SiliconflowLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='deepseek-ai/DeepSeek-V2-Chat',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='deepseek-ai/DeepSeek-V2-Chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = SiliconflowLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='deepseek-ai/DeepSeek-V2-Chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = SiliconflowLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='deepseek-ai/DeepSeek-V2-Chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 100,
|
||||
'seed': 1234
|
||||
},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = SiliconflowLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='deepseek-ai/DeepSeek-V2-Chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 12
|
@ -0,0 +1,21 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.siliconflow.siliconflow import SiliconflowProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = SiliconflowProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('API_KEY')
|
||||
}
|
||||
)
|
@ -21,6 +21,20 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
tmp_path.joinpath("example_positions_all_commented.yaml").write_text(dedent(
|
||||
"""\
|
||||
# - commented1
|
||||
# - commented2
|
||||
-
|
||||
-
|
||||
|
||||
"""))
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
def test_position_helper(prepare_example_positions_yaml):
|
||||
position_map = get_position_map(
|
||||
folder_path=prepare_example_positions_yaml,
|
||||
@ -32,3 +46,10 @@ def test_position_helper(prepare_example_positions_yaml):
|
||||
'third': 2,
|
||||
'forth': 3,
|
||||
}
|
||||
|
||||
|
||||
def test_position_helper_with_all_commented(prepare_empty_commented_positions_yaml):
|
||||
position_map = get_position_map(
|
||||
folder_path=prepare_empty_commented_positions_yaml,
|
||||
file_name='example_positions_all_commented.yaml')
|
||||
assert position_map == {}
|
||||
|
@ -53,6 +53,9 @@ def test_load_yaml_non_existing_file():
|
||||
assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
|
||||
assert load_yaml_file(file_path='') == {}
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False)
|
||||
|
||||
|
||||
def test_load_valid_yaml_file(prepare_example_yaml_file):
|
||||
yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
|
||||
@ -68,7 +71,7 @@ def test_load_valid_yaml_file(prepare_example_yaml_file):
|
||||
def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
|
||||
# yaml syntax error
|
||||
with pytest.raises(YAMLError):
|
||||
load_yaml_file(file_path=prepare_invalid_yaml_file)
|
||||
load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=False)
|
||||
|
||||
# ignore error
|
||||
assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {}
|
||||
assert load_yaml_file(file_path=prepare_invalid_yaml_file) == {}
|
||||
|
@ -38,7 +38,7 @@ services:
|
||||
|
||||
milvus-standalone:
|
||||
container_name: milvus-standalone
|
||||
image: milvusdb/milvus:v2.3.1
|
||||
image: milvusdb/milvus:v2.4.6
|
||||
command: ["milvus", "run", "standalone"]
|
||||
environment:
|
||||
ETCD_ENDPOINTS: etcd:2379
|
||||
|
@ -25,7 +25,7 @@
|
||||
}
|
||||
|
||||
.action-btn-xs {
|
||||
@apply p-0 w-5 h-5 rounded
|
||||
@apply p-0 w-4 h-4 rounded
|
||||
}
|
||||
|
||||
.action-btn.action-btn-active {
|
||||
|
@ -409,7 +409,7 @@ const translation = {
|
||||
},
|
||||
retrieveMultiWay: {
|
||||
title: 'Multi-path retrieval',
|
||||
description: 'Based on user intent, queries across all Knowledge, retrieves relevant text from multi-sources, and selects the best results matching the user query after reranking. Configuration of the Rerank model API is required.',
|
||||
description: 'Based on user intent, queries across all Knowledge, retrieves relevant text from multi-sources, and selects the best results matching the user query after reranking. ',
|
||||
},
|
||||
rerankModelRequired: 'Rerank model is required',
|
||||
params: 'Params',
|
||||
|
@ -33,7 +33,7 @@ const translation = {
|
||||
},
|
||||
hybrid_search: {
|
||||
title: 'Hybrid Search',
|
||||
description: 'Execute full-text search and vector searches simultaneously, re-rank to select the best match for the user\'s query. Configuration of the Rerank model APIs necessary.',
|
||||
description: 'Execute full-text search and vector searches simultaneously, re-rank to select the best match for the user\'s query. Users can choose to set weights or configure to a Rerank model.',
|
||||
recommend: 'Recommend',
|
||||
},
|
||||
invertedIndex: {
|
||||
@ -67,7 +67,7 @@ const translation = {
|
||||
semantic: 'Semantic',
|
||||
keyword: 'Keyword',
|
||||
},
|
||||
nTo1RetrievalLegacy: 'According to product planning, N-to-1 retrieval will be officially deprecated in September. Until then you can still use it normally.',
|
||||
nTo1RetrievalLegacy: 'According to the optimization and upgrade of the retrieval strategy, N-to-1 retrieval will be officially deprecated in September. Until then you can still use it normally.',
|
||||
}
|
||||
|
||||
export default translation
|
||||
|
@ -404,7 +404,7 @@ const translation = {
|
||||
},
|
||||
retrieveMultiWay: {
|
||||
title: '多路召回',
|
||||
description: '根据用户意图同时匹配所有知识库,从多路知识库查询相关文本片段,经过重排序步骤,从多路查询结果中选择匹配用户问题的最佳结果,需配置 Rerank 模型 API。',
|
||||
description: '根据用户意图同时匹配所有知识库,从多路知识库查询相关文本片段,经过重排序步骤,从多路查询结果中选择匹配用户问题的最佳结果。',
|
||||
},
|
||||
rerankModelRequired: '请选择 Rerank 模型',
|
||||
params: '参数设置',
|
||||
|
@ -33,7 +33,7 @@ const translation = {
|
||||
},
|
||||
hybrid_search: {
|
||||
title: '混合检索',
|
||||
description: '同时执行全文检索和向量检索,并应用重排序步骤,从两类查询结果中选择匹配用户问题的最佳结果,需配置 Rerank 模型 API',
|
||||
description: '同时执行全文检索和向量检索,并应用重排序步骤,从两类查询结果中选择匹配用户问题的最佳结果,用户可以选择设置权重或配置重新排序模型。',
|
||||
recommend: '推荐',
|
||||
},
|
||||
invertedIndex: {
|
||||
@ -67,7 +67,7 @@ const translation = {
|
||||
semantic: '语义',
|
||||
keyword: '关键词',
|
||||
},
|
||||
nTo1RetrievalLegacy: '根据产品规划,N 选 1 召回将于 9 月正式弃用。在那之前,您仍然可以正常使用它。',
|
||||
nTo1RetrievalLegacy: '为了对检索策略进行优化和升级,N 选 1 检索功能将于九月份正式被优化。在此之前,您仍然可以正常使用该功能。',
|
||||
}
|
||||
|
||||
export default translation
|
||||
|
Loading…
x
Reference in New Issue
Block a user