feat: add gte rerank for tongyi (#9153)

This commit is contained in:
Fei He 2024-10-11 10:35:56 +08:00 committed by GitHub
parent cabdb4ef17
commit 5c76131d3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 182 additions and 0 deletions

View File

@ -0,0 +1 @@
- gte-rerank

View File

@ -0,0 +1,4 @@
model: gte-rerank
model_type: rerank
model_properties:
context_size: 4000

View File

@ -0,0 +1,136 @@
from typing import Optional
import dashscope
from dashscope.common.error import (
AuthenticationError,
InvalidParameter,
RequestFailure,
ServiceUnavailableError,
UnsupportedHTTPMethod,
UnsupportedModel,
)
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
class GTERerankModel(RerankModel):
"""
Model class for GTE rerank model.
"""
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=docs)
# initialize client
dashscope.api_key = credentials["dashscope_api_key"]
response = dashscope.TextReRank.call(
query=query,
documents=docs,
model=model,
top_n=top_n,
return_documents=True,
)
rerank_documents = []
for _, result in enumerate(response.output.results):
# format document
rerank_document = RerankDocument(
index=result.index,
score=result.relevance_score,
text=result["document"]["text"],
)
# score threshold check
if score_threshold is not None:
if result.relevance_score >= score_threshold:
rerank_documents.append(rerank_document)
else:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self.invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
except Exception as ex:
print(ex)
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
RequestFailure,
],
InvokeServerUnavailableError: [
ServiceUnavailableError,
],
InvokeRateLimitError: [],
InvokeAuthorizationError: [
AuthenticationError,
],
InvokeBadRequestError: [
InvalidParameter,
UnsupportedModel,
UnsupportedHTTPMethod,
],
}

View File

@ -18,6 +18,7 @@ supported_model_types:
- llm
- tts
- text-embedding
- rerank
configurate_methods:
- predefined-model
- customizable-model

View File

@ -0,0 +1,40 @@
import os
import dashscope
import pytest
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.tongyi.rerank.rerank import GTERerankModel
def test_validate_credentials():
model = GTERerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="get-rank", credentials={"dashscope_api_key": "invalid_key"})
model.validate_credentials(
model="get-rank", credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}
)
def test_invoke_model():
model = GTERerankModel()
result = model.invoke(
model=dashscope.TextReRank.Models.gte_rerank,
credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")},
query="什么是文本排序模型",
docs=[
"文本排序模型广泛用于搜索引擎和推荐系统中,它们根据文本相关性对候选文本进行排序",
"量子计算是计算科学的一个前沿领域",
"预训练语言模型的发展给文本排序模型带来了新的进展",
],
score_threshold=0.7,
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 1
assert result.docs[0].index == 0
assert result.docs[0].score >= 0.7