diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index dc9b594d5a..4f5a3b1604 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -75,6 +75,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if extra_param.support_function_call: credentials['support_function_call'] = True + if extra_param.context_length: + credentials['context_length'] = extra_param.context_length + except RuntimeError as e: raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') except KeyError as e: @@ -296,6 +299,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') support_function_call = credentials.get('support_function_call', False) + context_length = credentials.get('context_length', 2048) entity = AIModelEntity( model=model, @@ -309,6 +313,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ] if support_function_call else [], model_properties={ ModelPropertyKey.MODE: completion_type, + ModelPropertyKey.CONTEXT_SIZE: context_length }, parameter_rules=rules ) diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 64612ca3fa..1b8d1fbfd7 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -14,15 +14,17 @@ class XinferenceModelExtraParameter(object): model_handle_type: str model_ability: List[str] max_tokens: int = 512 + context_length: int = 2048 support_function_call: bool = False def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str], - support_function_call: bool, max_tokens: int) -> None: + support_function_call: bool, max_tokens: int, context_length: int) -> None: self.model_format = model_format self.model_handle_type = model_handle_type self.model_ability = model_ability self.support_function_call = support_function_call self.max_tokens = max_tokens + self.context_length = context_length cache = {} cache_lock = Lock() @@ -57,7 +59,7 @@ class XinferenceHelper: url = path.join(server_url, 'v1/models', model_uid) - # this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 + # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 session = Session() session.mount('http://', HTTPAdapter(max_retries=3)) session.mount('https://', HTTPAdapter(max_retries=3)) @@ -88,11 +90,14 @@ class XinferenceHelper: support_function_call = 'tools' in model_ability max_tokens = response_json.get('max_tokens', 512) + + context_length = response_json.get('context_length', 2048) return XinferenceModelExtraParameter( model_format=model_format, model_handle_type=model_handle_type, model_ability=model_ability, support_function_call=support_function_call, - max_tokens=max_tokens + max_tokens=max_tokens, + context_length=context_length ) \ No newline at end of file