mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 06:05:51 +08:00
Add rerank model type for LocalAI provider (#3952)
This commit is contained in:
parent
2c1c660c6e
commit
a588df4371
@ -15,6 +15,7 @@ help:
|
|||||||
supported_model_types:
|
supported_model_types:
|
||||||
- llm
|
- llm
|
||||||
- text-embedding
|
- text-embedding
|
||||||
|
- rerank
|
||||||
- speech2text
|
- speech2text
|
||||||
configurate_methods:
|
configurate_methods:
|
||||||
- customizable-model
|
- customizable-model
|
||||||
|
120
api/core/model_runtime/model_providers/localai/rerank/rerank.py
Normal file
120
api/core/model_runtime/model_providers/localai/rerank/rerank.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from json import dumps
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from requests import post
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||||
|
|
||||||
|
|
||||||
|
class LocalaiRerankModel(RerankModel):
|
||||||
|
"""
|
||||||
|
LocalAI rerank model API is compatible with Jina rerank model API. So just copy the JinaRerankModel class code here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None) -> RerankResult:
|
||||||
|
"""
|
||||||
|
Invoke rerank model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param query: search query
|
||||||
|
:param docs: docs for reranking
|
||||||
|
:param score_threshold: score threshold
|
||||||
|
:param top_n: top n documents to return
|
||||||
|
:param user: unique user id
|
||||||
|
:return: rerank result
|
||||||
|
"""
|
||||||
|
if len(docs) == 0:
|
||||||
|
return RerankResult(model=model, docs=[])
|
||||||
|
|
||||||
|
server_url = credentials['server_url']
|
||||||
|
model_name = model
|
||||||
|
|
||||||
|
if not server_url:
|
||||||
|
raise CredentialsValidateFailedError('server_url is required')
|
||||||
|
if not model_name:
|
||||||
|
raise CredentialsValidateFailedError('model_name is required')
|
||||||
|
|
||||||
|
url = server_url
|
||||||
|
headers = {
|
||||||
|
'Authorization': f"Bearer {credentials.get('api_key')}",
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": model_name,
|
||||||
|
"query": query,
|
||||||
|
"documents": docs,
|
||||||
|
"top_n": top_n
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = post(str(URL(url) / 'rerank'), headers=headers, data=dumps(data), timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
results = response.json()
|
||||||
|
|
||||||
|
rerank_documents = []
|
||||||
|
for result in results['results']:
|
||||||
|
rerank_document = RerankDocument(
|
||||||
|
index=result['index'],
|
||||||
|
text=result['document']['text'],
|
||||||
|
score=result['relevance_score'],
|
||||||
|
)
|
||||||
|
if score_threshold is None or result['relevance_score'] >= score_threshold:
|
||||||
|
rerank_documents.append(rerank_document)
|
||||||
|
|
||||||
|
return RerankResult(model=model, docs=rerank_documents)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise InvokeServerUnavailableError(str(e))
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
|
||||||
|
self._invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
query="What is the capital of the United States?",
|
||||||
|
docs=[
|
||||||
|
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||||
|
"Census, Carson City had a population of 55,274.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||||
|
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||||
|
],
|
||||||
|
score_threshold=0.8
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [httpx.ConnectError],
|
||||||
|
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||||
|
InvokeRateLimitError: [],
|
||||||
|
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||||
|
InvokeBadRequestError: [httpx.RequestError]
|
||||||
|
}
|
158
api/tests/integration_tests/model_runtime/localai/test_rerank.py
Normal file
158
api/tests/integration_tests/model_runtime/localai/test_rerank.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from api.core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
|
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials_for_chat_model():
|
||||||
|
model = LocalaiRerankModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model='bge-reranker-v2-m3',
|
||||||
|
credentials={
|
||||||
|
'server_url': 'hahahaha',
|
||||||
|
'completion_type': 'completion',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
model.validate_credentials(
|
||||||
|
model='bge-reranker-base',
|
||||||
|
credentials={
|
||||||
|
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||||
|
'completion_type': 'completion',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_invoke_rerank_model():
|
||||||
|
model = LocalaiRerankModel()
|
||||||
|
|
||||||
|
response = model.invoke(
|
||||||
|
model='bge-reranker-base',
|
||||||
|
credentials={
|
||||||
|
'server_url': os.environ.get('LOCALAI_SERVER_URL')
|
||||||
|
},
|
||||||
|
query='Organic skincare products for sensitive skin',
|
||||||
|
docs=[
|
||||||
|
"Eco-friendly kitchenware for modern homes",
|
||||||
|
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||||
|
"Organic cotton baby clothes for sensitive skin",
|
||||||
|
"Natural organic skincare range for sensitive skin",
|
||||||
|
"Tech gadgets for smart homes: 2024 edition",
|
||||||
|
"Sustainable gardening tools and compost solutions",
|
||||||
|
"Sensitive skin-friendly facial cleansers and toners",
|
||||||
|
"Organic food wraps and storage solutions",
|
||||||
|
"Yoga mats made from recycled materials"
|
||||||
|
],
|
||||||
|
top_n=3,
|
||||||
|
score_threshold=0.75,
|
||||||
|
user="abc-123"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, RerankResult)
|
||||||
|
assert len(response.docs) == 3
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from api.core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||||
|
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials_for_chat_model():
|
||||||
|
model = LocalaiRerankModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model='bge-reranker-v2-m3',
|
||||||
|
credentials={
|
||||||
|
'server_url': 'hahahaha',
|
||||||
|
'completion_type': 'completion',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
model.validate_credentials(
|
||||||
|
model='bge-reranker-base',
|
||||||
|
credentials={
|
||||||
|
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||||
|
'completion_type': 'completion',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_invoke_rerank_model():
|
||||||
|
model = LocalaiRerankModel()
|
||||||
|
|
||||||
|
response = model.invoke(
|
||||||
|
model='bge-reranker-base',
|
||||||
|
credentials={
|
||||||
|
'server_url': os.environ.get('LOCALAI_SERVER_URL')
|
||||||
|
},
|
||||||
|
query='Organic skincare products for sensitive skin',
|
||||||
|
docs=[
|
||||||
|
"Eco-friendly kitchenware for modern homes",
|
||||||
|
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||||
|
"Organic cotton baby clothes for sensitive skin",
|
||||||
|
"Natural organic skincare range for sensitive skin",
|
||||||
|
"Tech gadgets for smart homes: 2024 edition",
|
||||||
|
"Sustainable gardening tools and compost solutions",
|
||||||
|
"Sensitive skin-friendly facial cleansers and toners",
|
||||||
|
"Organic food wraps and storage solutions",
|
||||||
|
"Yoga mats made from recycled materials"
|
||||||
|
],
|
||||||
|
top_n=3,
|
||||||
|
score_threshold=0.75,
|
||||||
|
user="abc-123"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, RerankResult)
|
||||||
|
assert len(response.docs) == 3
|
||||||
|
|
||||||
|
def test__invoke():
|
||||||
|
model = LocalaiRerankModel()
|
||||||
|
|
||||||
|
# Test case 1: Empty docs
|
||||||
|
result = model._invoke(
|
||||||
|
model='bge-reranker-base',
|
||||||
|
credentials={
|
||||||
|
'server_url': 'https://example.com',
|
||||||
|
'api_key': '1234567890'
|
||||||
|
},
|
||||||
|
query='Organic skincare products for sensitive skin',
|
||||||
|
docs=[],
|
||||||
|
top_n=3,
|
||||||
|
score_threshold=0.75,
|
||||||
|
user="abc-123"
|
||||||
|
)
|
||||||
|
assert isinstance(result, RerankResult)
|
||||||
|
assert len(result.docs) == 0
|
||||||
|
|
||||||
|
# Test case 2: Valid invocation
|
||||||
|
result = model._invoke(
|
||||||
|
model='bge-reranker-base',
|
||||||
|
credentials={
|
||||||
|
'server_url': 'https://example.com',
|
||||||
|
'api_key': '1234567890'
|
||||||
|
},
|
||||||
|
query='Organic skincare products for sensitive skin',
|
||||||
|
docs=[
|
||||||
|
"Eco-friendly kitchenware for modern homes",
|
||||||
|
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||||
|
"Organic cotton baby clothes for sensitive skin",
|
||||||
|
"Natural organic skincare range for sensitive skin",
|
||||||
|
"Tech gadgets for smart homes: 2024 edition",
|
||||||
|
"Sustainable gardening tools and compost solutions",
|
||||||
|
"Sensitive skin-friendly facial cleansers and toners",
|
||||||
|
"Organic food wraps and storage solutions",
|
||||||
|
"Yoga mats made from recycled materials"
|
||||||
|
],
|
||||||
|
top_n=3,
|
||||||
|
score_threshold=0.75,
|
||||||
|
user="abc-123"
|
||||||
|
)
|
||||||
|
assert isinstance(result, RerankResult)
|
||||||
|
assert len(result.docs) == 3
|
||||||
|
assert all(isinstance(doc, RerankDocument) for doc in result.docs)
|
Loading…
x
Reference in New Issue
Block a user