From 2b080b5cfcf55aeb3013615169f44c6eef2e876d Mon Sep 17 00:00:00 2001 From: liuzhenghua <1090179900@qq.com> Date: Fri, 28 Jun 2024 00:27:20 +0800 Subject: [PATCH] =?UTF-8?q?feature:=20Add=20presence=5Fpenalty=20and=20fre?= =?UTF-8?q?quency=5Fpenalty=20parameters=20to=20the=20=E2=80=A6=20(#5637)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: liuzhenghua-jk --- .../model_providers/xinference/llm/llm.py | 98 ++++++++++++++----- 1 file changed, 73 insertions(+), 25 deletions(-) diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 637e9b32e6..0ef63f8e23 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -39,6 +39,7 @@ from core.model_runtime.entities.message_entities import ( ) from core.model_runtime.entities.model_entities import ( AIModelEntity, + DefaultParameterName, FetchFrom, ModelFeature, ModelPropertyKey, @@ -67,7 +68,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + -> LLMResult | Generator: """ invoke LLM @@ -113,7 +114,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): elif 'generate' in extra_param.model_ability: credentials['completion_type'] = 'completion' else: - raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type') + raise ValueError( + f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type') if extra_param.support_function_call: credentials['support_function_call'] = True @@ -206,6 +208,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): :param tools: tools for tool calling :return: number of tokens """ + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -339,6 +342,45 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): zh_Hans='最大生成长度', en_US='Max Tokens' ) + ), + ParameterRule( + name=DefaultParameterName.PRESENCE_PENALTY, + use_template=DefaultParameterName.PRESENCE_PENALTY, + type=ParameterType.FLOAT, + label=I18nObject( + en_US='Presence Penalty', + zh_Hans='存在惩罚', + ), + required=False, + help=I18nObject( + en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they ' + 'appear in the text so far, increasing the model\'s likelihood to talk about new topics.', + zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚,从而增加模型谈论新话题的可能性。' + ), + default=0.0, + min=-2.0, + max=2.0, + precision=2 + ), + ParameterRule( + name=DefaultParameterName.FREQUENCY_PENALTY, + use_template=DefaultParameterName.FREQUENCY_PENALTY, + type=ParameterType.FLOAT, + label=I18nObject( + en_US='Frequency Penalty', + zh_Hans='频率惩罚', + ), + required=False, + help=I18nObject( + en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on their ' + 'existing frequency in the text so far, decreasing the model\'s likelihood to repeat the ' + 'same line verbatim.', + zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚,从而降低模型逐字重复相同内容的可能性。' + ), + default=0.0, + min=-2.0, + max=2.0, + precision=2 ) ] @@ -364,7 +406,6 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): else: raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') - features = [] support_function_call = credentials.get('support_function_call', False) @@ -395,9 +436,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): return entity def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ generate text from LLM @@ -429,6 +470,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): 'temperature': model_parameters.get('temperature', 1.0), 'top_p': model_parameters.get('top_p', 0.7), 'max_tokens': model_parameters.get('max_tokens', 512), + 'presence_penalty': model_parameters.get('presence_penalty', 0.0), + 'frequency_penalty': model_parameters.get('frequency_penalty', 0.0), } if stop: @@ -453,10 +496,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if stream: if tools and len(tools) > 0: raise InvokeBadRequestError('xinference tool calls does not support stream mode') - return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=resp) + return self._handle_chat_stream_response(model=model, credentials=credentials, + prompt_messages=prompt_messages, + tools=tools, resp=resp) + return self._handle_chat_generate_response(model=model, credentials=credentials, + prompt_messages=prompt_messages, + tools=tools, resp=resp) elif isinstance(xinference_model, RESTfulGenerateModelHandle): resp = client.completions.create( model=credentials['model_uid'], @@ -466,10 +511,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): **generate_config, ) if stream: - return self._handle_completion_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_completion_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=resp) + return self._handle_completion_stream_response(model=model, credentials=credentials, + prompt_messages=prompt_messages, + tools=tools, resp=resp) + return self._handle_completion_generate_response(model=model, credentials=credentials, + prompt_messages=prompt_messages, + tools=tools, resp=resp) else: raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported') @@ -523,8 +570,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): return tool_call def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: ChatCompletion) -> LLMResult: + tools: list[PromptMessageTool], + resp: ChatCompletion) -> LLMResult: """ handle normal chat generate response """ @@ -549,7 +596,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens) response = LLMResult( model=model, @@ -560,10 +608,10 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ) return response - + def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[ChatCompletionChunk]) -> Generator: + tools: list[PromptMessageTool], + resp: Iterator[ChatCompletionChunk]) -> Generator: """ handle stream chat generate response """ @@ -634,8 +682,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): full_response += delta.delta.content def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Completion) -> LLMResult: + tools: list[PromptMessageTool], + resp: Completion) -> LLMResult: """ handle normal completion generate response """ @@ -671,8 +719,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): return response def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[Completion]) -> Generator: + tools: list[PromptMessageTool], + resp: Iterator[Completion]) -> Generator: """ handle stream completion generate response """ @@ -764,4 +812,4 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): InvokeBadRequestError: [ ValueError ] - } \ No newline at end of file + }