mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 11:25:55 +08:00
fix: AnalyticdbVector retrieval scores (#8803)
This commit is contained in:
parent
d6b9587a97
commit
4c1063e1c5
@ -40,19 +40,8 @@ class AnalyticdbConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class AnalyticdbVector(BaseVector):
|
class AnalyticdbVector(BaseVector):
|
||||||
_instance = None
|
|
||||||
_init = False
|
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def __init__(self, collection_name: str, config: AnalyticdbConfig):
|
def __init__(self, collection_name: str, config: AnalyticdbConfig):
|
||||||
# collection_name must be updated every time
|
|
||||||
self._collection_name = collection_name.lower()
|
self._collection_name = collection_name.lower()
|
||||||
if AnalyticdbVector._init:
|
|
||||||
return
|
|
||||||
try:
|
try:
|
||||||
from alibabacloud_gpdb20160503.client import Client
|
from alibabacloud_gpdb20160503.client import Client
|
||||||
from alibabacloud_tea_openapi import models as open_api_models
|
from alibabacloud_tea_openapi import models as open_api_models
|
||||||
@ -62,7 +51,6 @@ class AnalyticdbVector(BaseVector):
|
|||||||
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
||||||
self._client = Client(self._client_config)
|
self._client = Client(self._client_config)
|
||||||
self._initialize()
|
self._initialize()
|
||||||
AnalyticdbVector._init = True
|
|
||||||
|
|
||||||
def _initialize(self) -> None:
|
def _initialize(self) -> None:
|
||||||
cache_key = f"vector_indexing_{self.config.instance_id}"
|
cache_key = f"vector_indexing_{self.config.instance_id}"
|
||||||
@ -257,11 +245,14 @@ class AnalyticdbVector(BaseVector):
|
|||||||
documents = []
|
documents = []
|
||||||
for match in response.body.matches.match:
|
for match in response.body.matches.match:
|
||||||
if match.score > score_threshold:
|
if match.score > score_threshold:
|
||||||
|
metadata = json.loads(match.metadata.get("metadata_"))
|
||||||
|
metadata["score"] = match.score
|
||||||
doc = Document(
|
doc = Document(
|
||||||
page_content=match.metadata.get("page_content"),
|
page_content=match.metadata.get("page_content"),
|
||||||
metadata=json.loads(match.metadata.get("metadata_")),
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
documents.append(doc)
|
documents.append(doc)
|
||||||
|
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
@ -286,12 +277,14 @@ class AnalyticdbVector(BaseVector):
|
|||||||
for match in response.body.matches.match:
|
for match in response.body.matches.match:
|
||||||
if match.score > score_threshold:
|
if match.score > score_threshold:
|
||||||
metadata = json.loads(match.metadata.get("metadata_"))
|
metadata = json.loads(match.metadata.get("metadata_"))
|
||||||
|
metadata["score"] = match.score
|
||||||
doc = Document(
|
doc = Document(
|
||||||
page_content=match.metadata.get("page_content"),
|
page_content=match.metadata.get("page_content"),
|
||||||
vector=match.metadata.get("vector"),
|
vector=match.metadata.get("vector"),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
documents.append(doc)
|
documents.append(doc)
|
||||||
|
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user