mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 01:15:56 +08:00
feat: support xinference's auth system (#7369)
This commit is contained in:
parent
bbb6fcc4f0
commit
acd72e3ab2
@ -85,7 +85,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
tools=tools, stop=stop, stream=stream, user=user,
|
tools=tools, stop=stop, stream=stream, user=user,
|
||||||
extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter(
|
extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter(
|
||||||
server_url=credentials['server_url'],
|
server_url=credentials['server_url'],
|
||||||
model_uid=credentials['model_uid']
|
model_uid=credentials['model_uid'],
|
||||||
|
api_key=credentials.get('api_key'),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -106,7 +107,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
||||||
server_url=credentials['server_url'],
|
server_url=credentials['server_url'],
|
||||||
model_uid=credentials['model_uid']
|
model_uid=credentials['model_uid'],
|
||||||
|
api_key=credentials.get('api_key')
|
||||||
)
|
)
|
||||||
if 'completion_type' not in credentials:
|
if 'completion_type' not in credentials:
|
||||||
if 'chat' in extra_param.model_ability:
|
if 'chat' in extra_param.model_ability:
|
||||||
@ -396,7 +398,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
else:
|
else:
|
||||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
||||||
server_url=credentials['server_url'],
|
server_url=credentials['server_url'],
|
||||||
model_uid=credentials['model_uid']
|
model_uid=credentials['model_uid'],
|
||||||
|
api_key=credentials.get('api_key')
|
||||||
)
|
)
|
||||||
|
|
||||||
if 'chat' in extra_args.model_ability:
|
if 'chat' in extra_args.model_ability:
|
||||||
@ -464,6 +467,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
xinference_client = Client(
|
xinference_client = Client(
|
||||||
base_url=credentials['server_url'],
|
base_url=credentials['server_url'],
|
||||||
|
api_key=credentials.get('api_key'),
|
||||||
)
|
)
|
||||||
|
|
||||||
xinference_model = xinference_client.get_model(credentials['model_uid'])
|
xinference_model = xinference_client.get_model(credentials['model_uid'])
|
||||||
|
@ -108,7 +108,8 @@ class XinferenceRerankModel(RerankModel):
|
|||||||
|
|
||||||
# initialize client
|
# initialize client
|
||||||
client = Client(
|
client = Client(
|
||||||
base_url=credentials['server_url']
|
base_url=credentials['server_url'],
|
||||||
|
api_key=credentials.get('api_key'),
|
||||||
)
|
)
|
||||||
|
|
||||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||||
|
@ -52,7 +52,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
|||||||
|
|
||||||
# initialize client
|
# initialize client
|
||||||
client = Client(
|
client = Client(
|
||||||
base_url=credentials['server_url']
|
base_url=credentials['server_url'],
|
||||||
|
api_key=credentials.get('api_key'),
|
||||||
)
|
)
|
||||||
|
|
||||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||||
|
@ -110,14 +110,22 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
|
|
||||||
server_url = credentials['server_url']
|
server_url = credentials['server_url']
|
||||||
model_uid = credentials['model_uid']
|
model_uid = credentials['model_uid']
|
||||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
|
api_key = credentials.get('api_key')
|
||||||
|
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
||||||
|
server_url=server_url,
|
||||||
|
model_uid=model_uid,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
if extra_args.max_tokens:
|
if extra_args.max_tokens:
|
||||||
credentials['max_tokens'] = extra_args.max_tokens
|
credentials['max_tokens'] = extra_args.max_tokens
|
||||||
if server_url.endswith('/'):
|
if server_url.endswith('/'):
|
||||||
server_url = server_url[:-1]
|
server_url = server_url[:-1]
|
||||||
|
|
||||||
client = Client(base_url=server_url)
|
client = Client(
|
||||||
|
base_url=server_url,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
handle = client.get_model(model_uid=model_uid)
|
handle = client.get_model(model_uid=model_uid)
|
||||||
|
@ -81,7 +81,8 @@ class XinferenceText2SpeechModel(TTSModel):
|
|||||||
|
|
||||||
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
||||||
server_url=credentials['server_url'],
|
server_url=credentials['server_url'],
|
||||||
model_uid=credentials['model_uid']
|
model_uid=credentials['model_uid'],
|
||||||
|
api_key=credentials.get('api_key'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if 'text-to-audio' not in extra_param.model_ability:
|
if 'text-to-audio' not in extra_param.model_ability:
|
||||||
@ -203,7 +204,11 @@ class XinferenceText2SpeechModel(TTSModel):
|
|||||||
credentials['server_url'] = credentials['server_url'][:-1]
|
credentials['server_url'] = credentials['server_url'][:-1]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={})
|
api_key = credentials.get('api_key')
|
||||||
|
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||||
|
handle = RESTfulAudioModelHandle(
|
||||||
|
credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers
|
||||||
|
)
|
||||||
|
|
||||||
model_support_voice = [x.get("value") for x in
|
model_support_voice = [x.get("value") for x in
|
||||||
self.get_tts_model_voices(model=model, credentials=credentials)]
|
self.get_tts_model_voices(model=model, credentials=credentials)]
|
||||||
|
@ -35,13 +35,13 @@ cache_lock = Lock()
|
|||||||
|
|
||||||
class XinferenceHelper:
|
class XinferenceHelper:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
|
def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
|
||||||
XinferenceHelper._clean_cache()
|
XinferenceHelper._clean_cache()
|
||||||
with cache_lock:
|
with cache_lock:
|
||||||
if model_uid not in cache:
|
if model_uid not in cache:
|
||||||
cache[model_uid] = {
|
cache[model_uid] = {
|
||||||
'expires': time() + 300,
|
'expires': time() + 300,
|
||||||
'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid)
|
'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key)
|
||||||
}
|
}
|
||||||
return cache[model_uid]['value']
|
return cache[model_uid]['value']
|
||||||
|
|
||||||
@ -56,7 +56,7 @@ class XinferenceHelper:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
|
def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
|
||||||
"""
|
"""
|
||||||
get xinference model extra parameter like model_format and model_handle_type
|
get xinference model extra parameter like model_format and model_handle_type
|
||||||
"""
|
"""
|
||||||
@ -70,9 +70,10 @@ class XinferenceHelper:
|
|||||||
session = Session()
|
session = Session()
|
||||||
session.mount('http://', HTTPAdapter(max_retries=3))
|
session.mount('http://', HTTPAdapter(max_retries=3))
|
||||||
session.mount('https://', HTTPAdapter(max_retries=3))
|
session.mount('https://', HTTPAdapter(max_retries=3))
|
||||||
|
headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = session.get(url, timeout=10)
|
response = session.get(url, headers=headers, timeout=10)
|
||||||
except (MissingSchema, ConnectionError, Timeout) as e:
|
except (MissingSchema, ConnectionError, Timeout) as e:
|
||||||
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
|
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user