diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index ef48d22963..612542dab1 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -239,7 +239,7 @@ class AnalyticdbVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = kwargs.get("score_threshold") or 0.0 request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -267,7 +267,7 @@ class AnalyticdbVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 41f749279d..610aa498ab 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -92,7 +92,7 @@ class ChromaVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: collection = self._client.get_or_create_collection(self._collection_name) results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) ids: list[str] = results["ids"][0] documents: list[str] = results["documents"][0] diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index c52ca5f488..f8300cc271 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -131,7 +131,7 @@ class ElasticSearchVector(BaseVector): docs = [] for doc, score in docs_and_scores: - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) if score > score_threshold: doc.metadata["score"] = score docs.append(doc) diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 7d78f54256..bdca59f869 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -141,7 +141,7 @@ class MilvusVector(BaseVector): for result in results[0]: metadata = result["entity"].get(Field.METADATA_KEY.value) metadata["score"] = result["distance"] - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) if result["distance"] > score_threshold: doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index fbd7864edd..1bd5bcd3e4 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -122,7 +122,7 @@ class MyScaleVector(BaseVector): def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) where_str = ( f"WHERE dist < {1 - score_threshold}" if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index a6f70c8733..8d2e0a86ab 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -170,7 +170,7 @@ class OpenSearchVector(BaseVector): metadata = {} metadata["score"] = hit["_score"] - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) if hit["_score"] > score_threshold: doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 1636582362..b974fa80a4 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -200,7 +200,7 @@ class OracleVector(BaseVector): [numpy.array(query_vector)], ) docs = [] - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) for record in cur: metadata, text, distance = record score = 1 - distance @@ -212,7 +212,7 @@ class OracleVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) # just not implement fetch by score_threshold now, may be later - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) if len(query) > 0: # Check which language the query is in zh_pattern = re.compile("[\u4e00-\u9fa5]+") diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 765436006b..de2d65b223 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -198,7 +198,7 @@ class PGVectoRS(BaseVector): metadata = record.meta score = 1 - dis metadata["score"] = score - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) if score > score_threshold: doc = Document(page_content=record.text, metadata=metadata) docs.append(doc) diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index c35464b464..79879d4f63 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -144,7 +144,7 @@ class PGVector(BaseVector): (json.dumps(query_vector),), ) docs = [] - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) for record in cur: metadata, text, distance = record score = 1 - distance diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 2bf417e993..f418e3ca05 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -333,13 +333,13 @@ class QdrantVector(BaseVector): limit=kwargs.get("top_k", 4), with_payload=True, with_vectors=True, - score_threshold=kwargs.get("score_threshold", 0.0), + score_threshold=float(kwargs.get("score_threshold") or 0.0), ) docs = [] for result in results: metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) if result.score > score_threshold: metadata["score"] = result.score doc = Document( diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 76f214ffa4..76da144965 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -230,7 +230,7 @@ class RelytVector(BaseVector): # Organize results. docs = [] for document, score in results: - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) if 1 - score > score_threshold: docs.append(document) return docs diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index b550743f39..faa373017b 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -153,7 +153,7 @@ class TencentVector(BaseVector): limit=kwargs.get("top_k", 4), timeout=self._client_config.timeout, ) - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(res, score_threshold) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index a2eff6ff04..20490d3215 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -185,7 +185,7 @@ class TiDBVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) filter = kwargs.get("filter") distance = 1 - score_threshold diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 64e92c5b82..6eee344b9b 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -205,7 +205,7 @@ class WeaviateVector(BaseVector): docs = [] for doc, score in docs_and_scores: - score_threshold = kwargs.get("score_threshold", 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) # check score threshold if score > score_threshold: doc.metadata["score"] = score