Merge branch 'main' into feat/attachments

This commit is contained in:
StyleZhang 2024-07-30 10:06:40 +08:00
commit b322dda3f6
38 changed files with 582 additions and 186 deletions

View File

@ -64,4 +64,6 @@ class DifyConfig(
return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
SSRF_PROXY_HTTP_URL: str | None = None
SSRF_PROXY_HTTPS_URL: str | None = None
SSRF_PROXY_HTTPS_URL: str | None = None
MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.')

View File

@ -13,18 +13,10 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
position_file_name = os.path.join(folder_path, file_name)
if not position_file_name or not os.path.exists(position_file_name):
return {}
positions = load_yaml_file(position_file_name, ignore_error=True)
position_map = {}
index = 0
for _, name in enumerate(positions):
if name and isinstance(name, str):
position_map[name.strip()] = index
index += 1
return position_map
position_file_path = os.path.join(folder_path, file_name)
yaml_content = load_yaml_file(file_path=position_file_path, default_value=[])
positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()]
return {name: index for index, name in enumerate(positions)}
def sort_by_position_map(

View File

@ -162,7 +162,7 @@ class AIModel(ABC):
# traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths:
# read yaml data from yaml file
yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)
yaml_data = load_yaml_file(model_schema_yaml_path)
new_parameter_rules = []
for parameter_rule in yaml_data.get('parameter_rules', []):

View File

@ -44,7 +44,7 @@ class ModelProvider(ABC):
# read provider schema from yaml file
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
yaml_data = load_yaml_file(yaml_path, ignore_error=True)
yaml_data = load_yaml_file(yaml_path)
try:
# yaml_data to entity

View File

@ -23,7 +23,7 @@ parameter_rules:
type: int
default: 4096
min: 1
max: 4096
max: 8192
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
@ -57,6 +57,18 @@ parameter_rules:
help:
zh_Hans: 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,降低模型重复相同内容的可能性。
en_US: A number between -2.0 and 2.0. If the value is positive, new tokens are penalized based on their frequency of occurrence in existing text, reducing the likelihood that the model will repeat the same content.
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '1'
output: '2'

View File

@ -18,6 +18,7 @@ help:
en_US: https://console.cloud.tencent.com/cam/capi
supported_model_types:
- llm
- text-embedding
configurate_methods:
- predefined-model
provider_credential_schema:

View File

@ -0,0 +1,5 @@
model: hunyuan-embedding
model_type: text-embedding
model_properties:
context_size: 1024
max_chunks: 1

View File

@ -0,0 +1,173 @@
import json
import logging
import time
from typing import Optional
from tencentcloud.common import credential
from tencentcloud.common.exception import TencentCloudSDKException
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
logger = logging.getLogger(__name__)
class HunyuanTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Hunyuan text embedding model.
"""
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
if model != 'hunyuan-embedding':
raise ValueError('Invalid model name')
client = self._setup_hunyuan_client(credentials)
embeddings = []
token_usage = 0
for input in texts:
request = models.GetEmbeddingRequest()
params = {
"Input": input
}
request.from_json_string(json.dumps(params))
response = client.GetEmbedding(request)
usage = response.Usage.TotalTokens
embeddings.extend([data.Embedding for data in response.Data])
token_usage += usage
result = TextEmbeddingResult(
model=model,
embeddings=embeddings,
usage=self._calc_response_usage(
model=model,
credentials=credentials,
tokens=token_usage
)
)
return result
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate credentials
"""
try:
client = self._setup_hunyuan_client(credentials)
req = models.ChatCompletionsRequest()
params = {
"Model": model,
"Messages": [{
"Role": "user",
"Content": "hello"
}],
"TopP": 1,
"Temperature": 0,
"Stream": False
}
req.from_json_string(json.dumps(params))
client.ChatCompletions(req)
except Exception as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
def _setup_hunyuan_client(self, credentials):
secret_id = credentials['secret_id']
secret_key = credentials['secret_key']
cred = credential.Credential(secret_id, secret_key)
httpProfile = HttpProfile()
httpProfile.endpoint = "hunyuan.tencentcloudapi.com"
clientProfile = ClientProfile()
clientProfile.httpProfile = httpProfile
client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
return client
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeError: [TencentCloudSDKException],
}
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
# client = self._setup_hunyuan_client(credentials)
num_tokens = 0
for text in texts:
num_tokens += self._get_num_tokens_by_gpt2(text)
# use client.GetTokenCount to get num tokens
# request = models.GetTokenCountRequest()
# params = {
# "Prompt": text
# }
# request.from_json_string(json.dumps(params))
# response = client.GetTokenCount(request)
# num_tokens += response.TokenCount
return num_tokens

View File

@ -0,0 +1,30 @@
model: deepseek-ai/DeepSeek-Coder-V2-Instruct
label:
en_US: deepseek-ai/DeepSeek-Coder-V2-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
- name: max_tokens
use_template: max_tokens
type: int
default: 512
min: 1
max: 4096
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
- name: top_p
use_template: top_p
- name: frequency_penalty
use_template: frequency_penalty
pricing:
input: '1.33'
output: '1.33'
unit: '0.000001'
currency: RMB

View File

@ -1,11 +1,9 @@
model: deepseek-ai/deepseek-v2-chat
label:
en_US: deepseek-ai/deepseek-v2-chat
en_US: deepseek-ai/DeepSeek-V2-Chat
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 32768

View File

@ -1,11 +1,9 @@
model: zhipuai/glm4-9B-chat
label:
en_US: zhipuai/glm4-9B-chat
en_US: THUDM/glm-4-9b-chat
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 32768

View File

@ -1,11 +1,9 @@
model: alibaba/Qwen2-57B-A14B-Instruct
label:
en_US: alibaba/Qwen2-57B-A14B-Instruct
en_US: Qwen/Qwen2-57B-A14B-Instruct
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 32768

View File

@ -1,11 +1,9 @@
model: alibaba/Qwen2-72B-Instruct
label:
en_US: alibaba/Qwen2-72B-Instruct
en_US: Qwen/Qwen2-72B-Instruct
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 32768

View File

@ -1,11 +1,9 @@
model: alibaba/Qwen2-7B-Instruct
label:
en_US: alibaba/Qwen2-7B-Instruct
en_US: Qwen/Qwen2-7B-Instruct
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 32768

View File

@ -1,11 +1,9 @@
model: 01-ai/Yi-1.5-34B-Chat
label:
en_US: 01-ai/Yi-1.5-34B-Chat
en_US: 01-ai/Yi-1.5-34B-Chat-16K
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 16384

View File

@ -3,9 +3,7 @@ label:
en_US: 01-ai/Yi-1.5-6B-Chat
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 4096

View File

@ -1,11 +1,9 @@
model: 01-ai/Yi-1.5-9B-Chat
label:
en_US: 01-ai/Yi-1.5-9B-Chat
en_US: 01-ai/Yi-1.5-9B-Chat-16K
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 16384

View File

@ -19,7 +19,7 @@ class SiliconflowProvider(ModelProvider):
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='deepseek-ai/deepseek-v2-chat',
model='deepseek-ai/DeepSeek-V2-Chat',
credentials=credentials
)
except CredentialsValidateFailedError as ex:

View File

@ -21,8 +21,6 @@ class ModerationRule(BaseModel):
class OutputModeration(BaseModel):
DEFAULT_BUFFER_SIZE: int = 300
tenant_id: str
app_id: str
@ -77,10 +75,10 @@ class OutputModeration(BaseModel):
return final_output
def start_thread(self) -> threading.Thread:
buffer_size = int(dify_config.config.MODERATION_BUFFER_SIZE)
buffer_size = dify_config.MODERATION_BUFFER_SIZE
thread = threading.Thread(target=self.worker, kwargs={
'flask_app': current_app._get_current_object(),
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
'buffer_size': buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE
})
thread.start()

View File

@ -298,34 +298,29 @@ class TraceTask:
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
def execute(self):
method_name, trace_info = self.preprocess()
return trace_info
return self.preprocess()
def preprocess(self):
if self.trace_type == TraceTaskName.CONVERSATION_TRACE:
return TraceTaskName.CONVERSATION_TRACE, self.conversation_trace(**self.kwargs)
if self.trace_type == TraceTaskName.WORKFLOW_TRACE:
return TraceTaskName.WORKFLOW_TRACE, self.workflow_trace(self.workflow_run, self.conversation_id)
elif self.trace_type == TraceTaskName.MESSAGE_TRACE:
return TraceTaskName.MESSAGE_TRACE, self.message_trace(self.message_id)
elif self.trace_type == TraceTaskName.MODERATION_TRACE:
return TraceTaskName.MODERATION_TRACE, self.moderation_trace(self.message_id, self.timer, **self.kwargs)
elif self.trace_type == TraceTaskName.SUGGESTED_QUESTION_TRACE:
return TraceTaskName.SUGGESTED_QUESTION_TRACE, self.suggested_question_trace(
preprocess_map = {
TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(self.workflow_run, self.conversation_id),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
self.message_id, self.timer, **self.kwargs
)
elif self.trace_type == TraceTaskName.DATASET_RETRIEVAL_TRACE:
return TraceTaskName.DATASET_RETRIEVAL_TRACE, self.dataset_retrieval_trace(
),
TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
self.message_id, self.timer, **self.kwargs
)
elif self.trace_type == TraceTaskName.TOOL_TRACE:
return TraceTaskName.TOOL_TRACE, self.tool_trace(self.message_id, self.timer, **self.kwargs)
elif self.trace_type == TraceTaskName.GENERATE_NAME_TRACE:
return TraceTaskName.GENERATE_NAME_TRACE, self.generate_name_trace(
),
TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
self.message_id, self.timer, **self.kwargs
),
TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(self.message_id, self.timer, **self.kwargs),
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
self.conversation_id, self.timer, **self.kwargs
)
else:
return '', {}
),
}
return preprocess_map.get(self.trace_type, lambda: None)()
# process methods for different trace types
def conversation_trace(self, **kwargs):

View File

@ -5,6 +5,7 @@ from uuid import uuid4
from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException, connections
from pymilvus.milvus_client import IndexParams
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
@ -250,11 +251,15 @@ class MilvusVector(BaseVector):
# Since primary field is auto-id, no need to track it
self._fields.remove(Field.PRIMARY_KEY.value)
# Create Index params for the collection
index_params_obj = IndexParams()
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
# Create the collection
collection_name = self._collection_name
self._client.create_collection_with_schema(collection_name=collection_name,
schema=schema, index_param=index_params,
consistency_level=self._consistency_level)
self._client.create_collection(collection_name=collection_name,
schema=schema, index_params=index_params_obj,
consistency_level=self._consistency_level)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:

View File

@ -27,7 +27,7 @@ class BuiltinToolProviderController(ToolProviderController):
provider = self.__class__.__module__.split('.')[-1]
yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
try:
provider_yaml = load_yaml_file(yaml_path)
provider_yaml = load_yaml_file(yaml_path, ignore_error=False)
except Exception as e:
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}')
@ -58,7 +58,7 @@ class BuiltinToolProviderController(ToolProviderController):
for tool_file in tool_files:
# get tool name
tool_name = tool_file.split(".")[0]
tool = load_yaml_file(path.join(tool_path, tool_file))
tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)
# get tool class, import the module
assistant_tool_class = load_single_subclass_from_source(

View File

@ -1,35 +1,32 @@
import logging
import os
from typing import Any
import yaml
from yaml import YAMLError
logger = logging.getLogger(__name__)
def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict:
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
"""
Safe loading a YAML file to a dict
Safe loading a YAML file
:param file_path: the path of the YAML file
:param ignore_error:
if True, return empty dict if error occurs and the error will be logged in warning level
if True, return default_value if error occurs and the error will be logged in debug level
if False, raise error if error occurs
:return: a dict of the YAML content
:param default_value: the value returned when errors ignored
:return: an object of the YAML content
"""
try:
if not file_path or not os.path.exists(file_path):
raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found')
with open(file_path, encoding='utf-8') as file:
with open(file_path, encoding='utf-8') as yaml_file:
try:
return yaml.safe_load(file)
yaml_content = yaml.safe_load(yaml_file)
return yaml_content if yaml_content else default_value
except Exception as e:
raise YAMLError(f'Failed to load YAML file {file_path}: {e}')
except FileNotFoundError as e:
logger.debug(f'Failed to load YAML file {file_path}: {e}')
return {}
except Exception as e:
if ignore_error:
logger.warning(f'Failed to load YAML file {file_path}: {e}')
return {}
logger.debug(f'Failed to load YAML file {file_path}: {e}')
return default_value
else:
raise e

111
api/poetry.lock generated
View File

@ -448,63 +448,6 @@ doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphin
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
trio = ["trio (>=0.23)"]
[[package]]
name = "argon2-cffi"
version = "23.1.0"
description = "Argon2 for Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "argon2_cffi-23.1.0-py3-none-any.whl", hash = "sha256:c670642b78ba29641818ab2e68bd4e6a78ba53b7eff7b4c3815ae16abf91c7ea"},
{file = "argon2_cffi-23.1.0.tar.gz", hash = "sha256:879c3e79a2729ce768ebb7d36d4609e3a78a4ca2ec3a9f12286ca057e3d0db08"},
]
[package.dependencies]
argon2-cffi-bindings = "*"
[package.extras]
dev = ["argon2-cffi[tests,typing]", "tox (>4)"]
docs = ["furo", "myst-parser", "sphinx", "sphinx-copybutton", "sphinx-notfound-page"]
tests = ["hypothesis", "pytest"]
typing = ["mypy"]
[[package]]
name = "argon2-cffi-bindings"
version = "21.2.0"
description = "Low-level CFFI bindings for Argon2"
optional = false
python-versions = ">=3.6"
files = [
{file = "argon2-cffi-bindings-21.2.0.tar.gz", hash = "sha256:bb89ceffa6c791807d1305ceb77dbfacc5aa499891d2c55661c6459651fc39e3"},
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ccb949252cb2ab3a08c02024acb77cfb179492d5701c7cbdbfd776124d4d2367"},
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9524464572e12979364b7d600abf96181d3541da11e23ddf565a32e70bd4dc0d"},
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b746dba803a79238e925d9046a63aa26bf86ab2a2fe74ce6b009a1c3f5c8f2ae"},
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:58ed19212051f49a523abb1dbe954337dc82d947fb6e5a0da60f7c8471a8476c"},
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:bd46088725ef7f58b5a1ef7ca06647ebaf0eb4baff7d1d0d177c6cc8744abd86"},
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_i686.whl", hash = "sha256:8cd69c07dd875537a824deec19f978e0f2078fdda07fd5c42ac29668dda5f40f"},
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:f1152ac548bd5b8bcecfb0b0371f082037e47128653df2e8ba6e914d384f3c3e"},
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-win32.whl", hash = "sha256:603ca0aba86b1349b147cab91ae970c63118a0f30444d4bc80355937c950c082"},
{file = "argon2_cffi_bindings-21.2.0-cp36-abi3-win_amd64.whl", hash = "sha256:b2ef1c30440dbbcba7a5dc3e319408b59676e2e039e2ae11a8775ecf482b192f"},
{file = "argon2_cffi_bindings-21.2.0-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e415e3f62c8d124ee16018e491a009937f8cf7ebf5eb430ffc5de21b900dad93"},
{file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3e385d1c39c520c08b53d63300c3ecc28622f076f4c2b0e6d7e796e9f6502194"},
{file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c3e3cc67fdb7d82c4718f19b4e7a87123caf8a93fde7e23cf66ac0337d3cb3f"},
{file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a22ad9800121b71099d0fb0a65323810a15f2e292f2ba450810a7316e128ee5"},
{file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f9f8b450ed0547e3d473fdc8612083fd08dd2120d6ac8f73828df9b7d45bb351"},
{file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:93f9bf70084f97245ba10ee36575f0c3f1e7d7724d67d8e5b08e61787c320ed7"},
{file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3b9ef65804859d335dc6b31582cad2c5166f0c3e7975f324d9ffaa34ee7e6583"},
{file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4966ef5848d820776f5f562a7d45fdd70c2f330c961d0d745b784034bd9f48d"},
{file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20ef543a89dee4db46a1a6e206cd015360e5a75822f76df533845c3cbaf72670"},
{file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ed2937d286e2ad0cc79a7087d3c272832865f779430e0cc2b4f3718d3159b0cb"},
{file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5e00316dabdaea0b2dd82d141cc66889ced0cdcbfa599e8b471cf22c620c329a"},
]
[package.dependencies]
cffi = ">=1.0.1"
[package.extras]
dev = ["cogapp", "pre-commit", "pytest", "wheel"]
tests = ["pytest"]
[[package]]
name = "arxiv"
version = "2.1.0"
@ -4616,22 +4559,20 @@ files = [
]
[[package]]
name = "minio"
version = "7.2.7"
description = "MinIO Python SDK for Amazon S3 Compatible Cloud Storage"
name = "milvus-lite"
version = "2.4.8"
description = "A lightweight version of Milvus wrapped with Python."
optional = false
python-versions = "*"
python-versions = ">=3.7"
files = [
{file = "minio-7.2.7-py3-none-any.whl", hash = "sha256:59d1f255d852fe7104018db75b3bebbd987e538690e680f7c5de835e422de837"},
{file = "minio-7.2.7.tar.gz", hash = "sha256:473d5d53d79f340f3cd632054d0c82d2f93177ce1af2eac34a235bea55708d98"},
{file = "milvus_lite-2.4.8-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:b7e90b34b214884cd44cdc112ab243d4cb197b775498355e2437b6cafea025fe"},
{file = "milvus_lite-2.4.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:519dfc62709d8f642d98a1c5b1dcde7080d107e6e312d677fef5a3412a40ac08"},
{file = "milvus_lite-2.4.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b21f36d24cbb0e920b4faad607019bb28c1b2c88b4d04680ac8c7697a4ae8a4d"},
{file = "milvus_lite-2.4.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:08332a2b9abfe7c4e1d7926068937e46f8fb81f2707928b7bc02c9dc99cebe41"},
]
[package.dependencies]
argon2-cffi = "*"
certifi = "*"
pycryptodome = "*"
typing-extensions = "*"
urllib3 = "*"
tqdm = "*"
[[package]]
name = "mmh3"
@ -6078,6 +6019,19 @@ files = [
{file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"},
{file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"},
{file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"},
{file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"},
{file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"},
{file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"},
{file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"},
{file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"},
]
[package.dependencies]
@ -6374,24 +6328,29 @@ tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
[[package]]
name = "pymilvus"
version = "2.3.1"
version = "2.4.4"
description = "Python Sdk for Milvus"
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "pymilvus-2.3.1-py3-none-any.whl", hash = "sha256:ce65e1de8700f33bd9aade20f013291629702e25b05726773208f1f0b22548ff"},
{file = "pymilvus-2.3.1.tar.gz", hash = "sha256:d460f6204d7deb2cff93716bd65670c1b440694b77701fb0ab0ead791aa582c6"},
{file = "pymilvus-2.4.4-py3-none-any.whl", hash = "sha256:073b76bc36f6f4e70f0f0a0023a53324f0ba8ef9a60883f87cd30a44b6c6f2b5"},
{file = "pymilvus-2.4.4.tar.gz", hash = "sha256:50c53eb103e034fbffe936fe942751ea3dbd2452e18cf79acc52360ed4987fb7"},
]
[package.dependencies]
environs = "<=9.5.0"
grpcio = ">=1.49.1,<=1.58.0"
minio = "*"
grpcio = ">=1.49.1,<=1.63.0"
milvus-lite = {version = ">=2.4.0,<2.5.0", markers = "sys_platform != \"win32\""}
pandas = ">=1.2.4"
protobuf = ">=3.20.0"
requests = "*"
setuptools = ">=67"
ujson = ">=2.0.0"
[package.extras]
bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "requests"]
dev = ["black", "grpcio (==1.62.2)", "grpcio-testing (==1.62.2)", "grpcio-tools (==1.62.2)", "pytest (>=5.3.4)", "pytest-cov (>=2.8.1)", "pytest-timeout (>=1.3.4)", "ruff (>0.4.0)"]
model = ["milvus-model (>=0.1.0)"]
[[package]]
name = "pymysql"
version = "1.1.1"
@ -9543,4 +9502,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "9619ddabdd67710981c13dcfa3ddae0a48497c9f694afc81b820e882440c1265"
content-hash = "a8b61d74d9322302b7447b6f8728ad606abc160202a8a122a05a8ef3cec7055b"

View File

@ -206,7 +206,7 @@ chromadb = "0.5.1"
oracledb = "~2.2.1"
pgvecto-rs = "0.1.4"
pgvector = "0.2.5"
pymilvus = "2.3.1"
pymilvus = "~2.4.4"
pymysql = "1.1.1"
tcvectordb = "1.3.2"
tidb-vector = "0.0.9"
@ -216,18 +216,6 @@ alibabacloud_gpdb20160503 = "~3.8.0"
alibabacloud_tea_openapi = "~0.3.9"
clickhouse-connect = "~0.7.16"
############################################################
# Transparent dependencies required by main dependencies
# for pinning versions
############################################################
[tool.poetry.group.transparent.dependencies]
kaleido = "0.2.1"
lxml = "5.1.0"
sympy = "1.12"
tenacity = "~8.3.0"
xlrd = "~2.0.1"
############################################################
# Dev dependencies for running tests
############################################################

View File

@ -0,0 +1,104 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.hunyuan.text_embedding.text_embedding import HunyuanTextEmbeddingModel
def test_validate_credentials():
model = HunyuanTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='hunyuan-embedding',
credentials={
'secret_id': 'invalid_key',
'secret_key': 'invalid_key'
}
)
model.validate_credentials(
model='hunyuan-embedding',
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
}
)
def test_invoke_model():
model = HunyuanTextEmbeddingModel()
result = model.invoke(
model='hunyuan-embedding',
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 6
def test_get_num_tokens():
model = HunyuanTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='hunyuan-embedding',
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 2
def test_max_chunks():
model = HunyuanTextEmbeddingModel()
result = model.invoke(
model='hunyuan-embedding',
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
},
texts=[
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
]
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 22

View File

@ -0,0 +1,106 @@
import os
from collections.abc import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.siliconflow.llm.llm import SiliconflowLargeLanguageModel
def test_validate_credentials():
model = SiliconflowLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='deepseek-ai/DeepSeek-V2-Chat',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(
model='deepseek-ai/DeepSeek-V2-Chat',
credentials={
'api_key': os.environ.get('API_KEY')
}
)
def test_invoke_model():
model = SiliconflowLargeLanguageModel()
response = model.invoke(
model='deepseek-ai/DeepSeek-V2-Chat',
credentials={
'api_key': os.environ.get('API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 10
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_model():
model = SiliconflowLargeLanguageModel()
response = model.invoke(
model='deepseek-ai/DeepSeek-V2-Chat',
credentials={
'api_key': os.environ.get('API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 100,
'seed': 1234
},
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = SiliconflowLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='deepseek-ai/DeepSeek-V2-Chat',
credentials={
'api_key': os.environ.get('API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
]
)
assert num_tokens == 12

View File

@ -0,0 +1,21 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.siliconflow.siliconflow import SiliconflowProvider
def test_validate_provider_credentials():
provider = SiliconflowProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('API_KEY')
}
)

View File

@ -21,6 +21,20 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
return str(tmp_path)
@pytest.fixture
def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
tmp_path.joinpath("example_positions_all_commented.yaml").write_text(dedent(
"""\
# - commented1
# - commented2
-
-
"""))
return str(tmp_path)
def test_position_helper(prepare_example_positions_yaml):
position_map = get_position_map(
folder_path=prepare_example_positions_yaml,
@ -32,3 +46,10 @@ def test_position_helper(prepare_example_positions_yaml):
'third': 2,
'forth': 3,
}
def test_position_helper_with_all_commented(prepare_empty_commented_positions_yaml):
position_map = get_position_map(
folder_path=prepare_empty_commented_positions_yaml,
file_name='example_positions_all_commented.yaml')
assert position_map == {}

View File

@ -53,6 +53,9 @@ def test_load_yaml_non_existing_file():
assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
assert load_yaml_file(file_path='') == {}
with pytest.raises(FileNotFoundError):
load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False)
def test_load_valid_yaml_file(prepare_example_yaml_file):
yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
@ -68,7 +71,7 @@ def test_load_valid_yaml_file(prepare_example_yaml_file):
def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
# yaml syntax error
with pytest.raises(YAMLError):
load_yaml_file(file_path=prepare_invalid_yaml_file)
load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=False)
# ignore error
assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {}
assert load_yaml_file(file_path=prepare_invalid_yaml_file) == {}

View File

@ -38,7 +38,7 @@ services:
milvus-standalone:
container_name: milvus-standalone
image: milvusdb/milvus:v2.3.1
image: milvusdb/milvus:v2.4.6
command: ["milvus", "run", "standalone"]
environment:
ETCD_ENDPOINTS: etcd:2379

View File

@ -25,7 +25,7 @@
}
.action-btn-xs {
@apply p-0 w-5 h-5 rounded
@apply p-0 w-4 h-4 rounded
}
.action-btn.action-btn-active {

View File

@ -409,7 +409,7 @@ const translation = {
},
retrieveMultiWay: {
title: 'Multi-path retrieval',
description: 'Based on user intent, queries across all Knowledge, retrieves relevant text from multi-sources, and selects the best results matching the user query after reranking. Configuration of the Rerank model API is required.',
description: 'Based on user intent, queries across all Knowledge, retrieves relevant text from multi-sources, and selects the best results matching the user query after reranking. ',
},
rerankModelRequired: 'Rerank model is required',
params: 'Params',

View File

@ -33,7 +33,7 @@ const translation = {
},
hybrid_search: {
title: 'Hybrid Search',
description: 'Execute full-text search and vector searches simultaneously, re-rank to select the best match for the user\'s query. Configuration of the Rerank model APIs necessary.',
description: 'Execute full-text search and vector searches simultaneously, re-rank to select the best match for the user\'s query. Users can choose to set weights or configure to a Rerank model.',
recommend: 'Recommend',
},
invertedIndex: {
@ -67,7 +67,7 @@ const translation = {
semantic: 'Semantic',
keyword: 'Keyword',
},
nTo1RetrievalLegacy: 'According to product planning, N-to-1 retrieval will be officially deprecated in September. Until then you can still use it normally.',
nTo1RetrievalLegacy: 'According to the optimization and upgrade of the retrieval strategy, N-to-1 retrieval will be officially deprecated in September. Until then you can still use it normally.',
}
export default translation

View File

@ -404,7 +404,7 @@ const translation = {
},
retrieveMultiWay: {
title: '多路召回',
description: '根据用户意图同时匹配所有知识库,从多路知识库查询相关文本片段,经过重排序步骤,从多路查询结果中选择匹配用户问题的最佳结果,需配置 Rerank 模型 API。',
description: '根据用户意图同时匹配所有知识库,从多路知识库查询相关文本片段,经过重排序步骤,从多路查询结果中选择匹配用户问题的最佳结果。',
},
rerankModelRequired: '请选择 Rerank 模型',
params: '参数设置',

View File

@ -33,7 +33,7 @@ const translation = {
},
hybrid_search: {
title: '混合检索',
description: '同时执行全文检索和向量检索,并应用重排序步骤,从两类查询结果中选择匹配用户问题的最佳结果,需配置 Rerank 模型 API',
description: '同时执行全文检索和向量检索,并应用重排序步骤,从两类查询结果中选择匹配用户问题的最佳结果,用户可以选择设置权重或配置重新排序模型。',
recommend: '推荐',
},
invertedIndex: {
@ -67,7 +67,7 @@ const translation = {
semantic: '语义',
keyword: '关键词',
},
nTo1RetrievalLegacy: '根据产品规划N 选 1 召回将于 9 月正式弃用。在那之前,您仍然可以正常使用它。',
nTo1RetrievalLegacy: '为了对检索策略进行优化和升级N 选 1 检索功能将于九月份正式被优化。在此之前,您仍然可以正常使用该功能。',
}
export default translation