mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 23:16:04 +08:00
feature: Add presence_penalty and frequency_penalty parameters to the … (#5637)
Co-authored-by: liuzhenghua-jk <liuzhenghua-jk@360shuke.com>
This commit is contained in:
parent
e8b8f6c6dd
commit
2b080b5cfc
@ -39,6 +39,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import (
|
from core.model_runtime.entities.model_entities import (
|
||||||
AIModelEntity,
|
AIModelEntity,
|
||||||
|
DefaultParameterName,
|
||||||
FetchFrom,
|
FetchFrom,
|
||||||
ModelFeature,
|
ModelFeature,
|
||||||
ModelPropertyKey,
|
ModelPropertyKey,
|
||||||
@ -67,7 +68,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||||
-> LLMResult | Generator:
|
-> LLMResult | Generator:
|
||||||
"""
|
"""
|
||||||
invoke LLM
|
invoke LLM
|
||||||
|
|
||||||
@ -113,7 +114,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
elif 'generate' in extra_param.model_ability:
|
elif 'generate' in extra_param.model_ability:
|
||||||
credentials['completion_type'] = 'completion'
|
credentials['completion_type'] = 'completion'
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type')
|
raise ValueError(
|
||||||
|
f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type')
|
||||||
|
|
||||||
if extra_param.support_function_call:
|
if extra_param.support_function_call:
|
||||||
credentials['support_function_call'] = True
|
credentials['support_function_call'] = True
|
||||||
@ -206,6 +208,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
:param tools: tools for tool calling
|
:param tools: tools for tool calling
|
||||||
:return: number of tokens
|
:return: number of tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def tokens(text: str):
|
def tokens(text: str):
|
||||||
return self._get_num_tokens_by_gpt2(text)
|
return self._get_num_tokens_by_gpt2(text)
|
||||||
|
|
||||||
@ -339,6 +342,45 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
zh_Hans='最大生成长度',
|
zh_Hans='最大生成长度',
|
||||||
en_US='Max Tokens'
|
en_US='Max Tokens'
|
||||||
)
|
)
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name=DefaultParameterName.PRESENCE_PENALTY,
|
||||||
|
use_template=DefaultParameterName.PRESENCE_PENALTY,
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
label=I18nObject(
|
||||||
|
en_US='Presence Penalty',
|
||||||
|
zh_Hans='存在惩罚',
|
||||||
|
),
|
||||||
|
required=False,
|
||||||
|
help=I18nObject(
|
||||||
|
en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they '
|
||||||
|
'appear in the text so far, increasing the model\'s likelihood to talk about new topics.',
|
||||||
|
zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚,从而增加模型谈论新话题的可能性。'
|
||||||
|
),
|
||||||
|
default=0.0,
|
||||||
|
min=-2.0,
|
||||||
|
max=2.0,
|
||||||
|
precision=2
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name=DefaultParameterName.FREQUENCY_PENALTY,
|
||||||
|
use_template=DefaultParameterName.FREQUENCY_PENALTY,
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
label=I18nObject(
|
||||||
|
en_US='Frequency Penalty',
|
||||||
|
zh_Hans='频率惩罚',
|
||||||
|
),
|
||||||
|
required=False,
|
||||||
|
help=I18nObject(
|
||||||
|
en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on their '
|
||||||
|
'existing frequency in the text so far, decreasing the model\'s likelihood to repeat the '
|
||||||
|
'same line verbatim.',
|
||||||
|
zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚,从而降低模型逐字重复相同内容的可能性。'
|
||||||
|
),
|
||||||
|
default=0.0,
|
||||||
|
min=-2.0,
|
||||||
|
max=2.0,
|
||||||
|
precision=2
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -364,7 +406,6 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
|
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
|
||||||
|
|
||||||
|
|
||||||
features = []
|
features = []
|
||||||
|
|
||||||
support_function_call = credentials.get('support_function_call', False)
|
support_function_call = credentials.get('support_function_call', False)
|
||||||
@ -395,9 +436,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
return entity
|
return entity
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter,
|
model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter,
|
||||||
tools: list[PromptMessageTool] | None = None,
|
tools: list[PromptMessageTool] | None = None,
|
||||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||||
-> LLMResult | Generator:
|
-> LLMResult | Generator:
|
||||||
"""
|
"""
|
||||||
generate text from LLM
|
generate text from LLM
|
||||||
@ -429,6 +470,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
'temperature': model_parameters.get('temperature', 1.0),
|
'temperature': model_parameters.get('temperature', 1.0),
|
||||||
'top_p': model_parameters.get('top_p', 0.7),
|
'top_p': model_parameters.get('top_p', 0.7),
|
||||||
'max_tokens': model_parameters.get('max_tokens', 512),
|
'max_tokens': model_parameters.get('max_tokens', 512),
|
||||||
|
'presence_penalty': model_parameters.get('presence_penalty', 0.0),
|
||||||
|
'frequency_penalty': model_parameters.get('frequency_penalty', 0.0),
|
||||||
}
|
}
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
@ -453,10 +496,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
if stream:
|
if stream:
|
||||||
if tools and len(tools) > 0:
|
if tools and len(tools) > 0:
|
||||||
raise InvokeBadRequestError('xinference tool calls does not support stream mode')
|
raise InvokeBadRequestError('xinference tool calls does not support stream mode')
|
||||||
return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
return self._handle_chat_stream_response(model=model, credentials=credentials,
|
||||||
tools=tools, resp=resp)
|
prompt_messages=prompt_messages,
|
||||||
return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
tools=tools, resp=resp)
|
||||||
tools=tools, resp=resp)
|
return self._handle_chat_generate_response(model=model, credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
tools=tools, resp=resp)
|
||||||
elif isinstance(xinference_model, RESTfulGenerateModelHandle):
|
elif isinstance(xinference_model, RESTfulGenerateModelHandle):
|
||||||
resp = client.completions.create(
|
resp = client.completions.create(
|
||||||
model=credentials['model_uid'],
|
model=credentials['model_uid'],
|
||||||
@ -466,10 +511,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
**generate_config,
|
**generate_config,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_completion_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
return self._handle_completion_stream_response(model=model, credentials=credentials,
|
||||||
tools=tools, resp=resp)
|
prompt_messages=prompt_messages,
|
||||||
return self._handle_completion_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
tools=tools, resp=resp)
|
||||||
tools=tools, resp=resp)
|
return self._handle_completion_generate_response(model=model, credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
tools=tools, resp=resp)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported')
|
raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported')
|
||||||
|
|
||||||
@ -523,8 +570,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
return tool_call
|
return tool_call
|
||||||
|
|
||||||
def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: list[PromptMessageTool],
|
tools: list[PromptMessageTool],
|
||||||
resp: ChatCompletion) -> LLMResult:
|
resp: ChatCompletion) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
handle normal chat generate response
|
handle normal chat generate response
|
||||||
"""
|
"""
|
||||||
@ -549,7 +596,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
||||||
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
|
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
|
||||||
|
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens)
|
||||||
|
|
||||||
response = LLMResult(
|
response = LLMResult(
|
||||||
model=model,
|
model=model,
|
||||||
@ -560,10 +608,10 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: list[PromptMessageTool],
|
tools: list[PromptMessageTool],
|
||||||
resp: Iterator[ChatCompletionChunk]) -> Generator:
|
resp: Iterator[ChatCompletionChunk]) -> Generator:
|
||||||
"""
|
"""
|
||||||
handle stream chat generate response
|
handle stream chat generate response
|
||||||
"""
|
"""
|
||||||
@ -634,8 +682,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
full_response += delta.delta.content
|
full_response += delta.delta.content
|
||||||
|
|
||||||
def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: list[PromptMessageTool],
|
tools: list[PromptMessageTool],
|
||||||
resp: Completion) -> LLMResult:
|
resp: Completion) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
handle normal completion generate response
|
handle normal completion generate response
|
||||||
"""
|
"""
|
||||||
@ -671,8 +719,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: list[PromptMessageTool],
|
tools: list[PromptMessageTool],
|
||||||
resp: Iterator[Completion]) -> Generator:
|
resp: Iterator[Completion]) -> Generator:
|
||||||
"""
|
"""
|
||||||
handle stream completion generate response
|
handle stream completion generate response
|
||||||
"""
|
"""
|
||||||
@ -764,4 +812,4 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
InvokeBadRequestError: [
|
InvokeBadRequestError: [
|
||||||
ValueError
|
ValueError
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user