mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-06-04 11:14:10 +08:00
enhance:speedup xinference embedding & rerank (#3587)
This commit is contained in:
parent
b4d2d635f7
commit
4365843c20
@ -47,17 +47,8 @@ class XinferenceRerankModel(RerankModel):
|
|||||||
if credentials['server_url'].endswith('/'):
|
if credentials['server_url'].endswith('/'):
|
||||||
credentials['server_url'] = credentials['server_url'][:-1]
|
credentials['server_url'] = credentials['server_url'][:-1]
|
||||||
|
|
||||||
# initialize client
|
handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={})
|
||||||
client = Client(
|
response = handle.rerank(
|
||||||
base_url=credentials['server_url']
|
|
||||||
)
|
|
||||||
|
|
||||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
|
||||||
|
|
||||||
if not isinstance(xinference_client, RESTfulRerankModelHandle):
|
|
||||||
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model')
|
|
||||||
|
|
||||||
response = xinference_client.rerank(
|
|
||||||
documents=docs,
|
documents=docs,
|
||||||
query=query,
|
query=query,
|
||||||
top_n=top_n,
|
top_n=top_n,
|
||||||
@ -98,6 +89,20 @@ class XinferenceRerankModel(RerankModel):
|
|||||||
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
|
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
|
||||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||||
|
|
||||||
|
if credentials['server_url'].endswith('/'):
|
||||||
|
credentials['server_url'] = credentials['server_url'][:-1]
|
||||||
|
|
||||||
|
# initialize client
|
||||||
|
client = Client(
|
||||||
|
base_url=credentials['server_url']
|
||||||
|
)
|
||||||
|
|
||||||
|
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||||
|
|
||||||
|
if not isinstance(xinference_client, RESTfulRerankModelHandle):
|
||||||
|
raise InvokeBadRequestError(
|
||||||
|
'please check model type, the model you want to invoke is not a rerank model')
|
||||||
|
|
||||||
self.invoke(
|
self.invoke(
|
||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
|
@ -47,17 +47,8 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
if server_url.endswith('/'):
|
if server_url.endswith('/'):
|
||||||
server_url = server_url[:-1]
|
server_url = server_url[:-1]
|
||||||
|
|
||||||
client = Client(base_url=server_url)
|
|
||||||
|
|
||||||
try:
|
|
||||||
handle = client.get_model(model_uid=model_uid)
|
|
||||||
except RuntimeError as e:
|
|
||||||
raise InvokeAuthorizationError(e)
|
|
||||||
|
|
||||||
if not isinstance(handle, RESTfulEmbeddingModelHandle):
|
|
||||||
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
|
||||||
embeddings = handle.create_embedding(input=texts)
|
embeddings = handle.create_embedding(input=texts)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise InvokeServerUnavailableError(e)
|
raise InvokeServerUnavailableError(e)
|
||||||
@ -122,6 +113,18 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
|
|
||||||
if extra_args.max_tokens:
|
if extra_args.max_tokens:
|
||||||
credentials['max_tokens'] = extra_args.max_tokens
|
credentials['max_tokens'] = extra_args.max_tokens
|
||||||
|
if server_url.endswith('/'):
|
||||||
|
server_url = server_url[:-1]
|
||||||
|
|
||||||
|
client = Client(base_url=server_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
handle = client.get_model(model_uid=model_uid)
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise InvokeAuthorizationError(e)
|
||||||
|
|
||||||
|
if not isinstance(handle, RESTfulEmbeddingModelHandle):
|
||||||
|
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
|
||||||
|
|
||||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||||
except InvokeAuthorizationError as e:
|
except InvokeAuthorizationError as e:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user