From 86594851cbd16a5ce50528af3e1fbb334ee94652 Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Wed, 16 Oct 2024 16:00:21 +0800 Subject: [PATCH] refactor: update the default values of top-k parameter in vdb to be consistent (#9367) --- .../vdb/elasticsearch/elasticsearch_vector.py | 2 +- .../rag/datasource/vdb/myscale/myscale_vector.py | 2 +- api/core/rag/datasource/vdb/oracle/oraclevector.py | 10 +--------- .../rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py | 14 +------------- api/core/rag/datasource/vdb/pgvector/pgvector.py | 2 +- api/core/rag/datasource/vdb/relyt/relyt_vector.py | 2 +- .../rag/datasource/vdb/tidb_vector/tidb_vector.py | 2 +- .../rag/datasource/vdb/vikingdb/vikingdb_vector.py | 2 +- .../rag/datasource/vdb/weaviate/weaviate_vector.py | 2 +- 9 files changed, 9 insertions(+), 29 deletions(-) diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 66bc31a4bf..f420373d5b 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -112,7 +112,7 @@ class ElasticSearchVector(BaseVector): self._client.indices.delete(index=self._collection_name) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - top_k = kwargs.get("top_k", 10) + top_k = kwargs.get("top_k", 4) num_candidates = math.ceil(top_k * 1.5) knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates} diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 2320a69a30..b30aa7ca22 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -121,7 +121,7 @@ class MyScaleVector(BaseVector): return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs) def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: - top_k = kwargs.get("top_k", 5) + top_k = kwargs.get("top_k", 4) score_threshold = float(kwargs.get("score_threshold") or 0.0) where_str = ( f"WHERE dist < {1 - score_threshold}" diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 77ec45b4d3..84a4381cd1 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -168,14 +168,6 @@ class OracleVector(BaseVector): docs.append(Document(page_content=record[1], metadata=record[0])) return docs - # def get_ids_by_metadata_field(self, key: str, value: str): - # with self._get_cursor() as cur: - # cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" ) - # idss = [] - # for record in cur: - # idss.append(record[0]) - # return idss - def delete_by_ids(self, ids: list[str]) -> None: with self._get_cursor() as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) @@ -192,7 +184,7 @@ class OracleVector(BaseVector): :param top_k: The number of nearest neighbors to return, default is 5. :return: List of Documents that are nearest to the query vector. """ - top_k = kwargs.get("top_k", 5) + top_k = kwargs.get("top_k", 4) with self._get_cursor() as cur: cur.execute( f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}" diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index de2d65b223..a82a9b96dd 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -186,7 +186,7 @@ class PGVectoRS(BaseVector): query_vector, ).label("distance"), ) - .limit(kwargs.get("top_k", 2)) + .limit(kwargs.get("top_k", 4)) .order_by("distance") ) res = session.execute(stmt) @@ -205,18 +205,6 @@ class PGVectoRS(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - # with Session(self._client) as session: - # select_statement = sql_text( - # f"SELECT text, meta FROM {self._collection_name} WHERE to_tsvector(text) @@ '{query}'::tsquery" - # ) - # results = session.execute(select_statement).fetchall() - # if results: - # docs = [] - # for result in results: - # doc = Document(page_content=result[0], - # metadata=result[1]) - # docs.append(doc) - # return docs return [] diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 25a10a1e48..6f336d27e7 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -143,7 +143,7 @@ class PGVector(BaseVector): :param top_k: The number of nearest neighbors to return, default is 5. :return: List of Documents that are nearest to the query vector. """ - top_k = kwargs.get("top_k", 5) + top_k = kwargs.get("top_k", 4) with self._get_cursor() as cur: cur.execute( diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 254956970f..13a63784be 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -224,7 +224,7 @@ class RelytVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: results = self.similarity_search_with_score_by_vector( - k=int(kwargs.get("top_k")), embedding=query_vector, filter=kwargs.get("filter") + k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter") ) # Organize results. diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 20490d3215..7837c5a4aa 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -184,7 +184,7 @@ class TiDBVector(BaseVector): self._delete_by_ids(ids) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - top_k = kwargs.get("top_k", 5) + top_k = kwargs.get("top_k", 4) score_threshold = float(kwargs.get("score_threshold") or 0.0) filter = kwargs.get("filter") distance = 1 - score_threshold diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index 22d0e92586..5f60f10acb 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -173,7 +173,7 @@ class VikingDBVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: results = self._client.get_index(self._collection_name, self._index_name).search_by_vector( - query_vector, limit=kwargs.get("top_k", 50) + query_vector, limit=kwargs.get("top_k", 4) ) score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(results, score_threshold) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 6eee344b9b..4009efe7a7 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -235,7 +235,7 @@ class WeaviateVector(BaseVector): query_obj = query_obj.with_where(kwargs.get("where_filter")) query_obj = query_obj.with_additional(["vector"]) properties = ["text"] - result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 2)).do() + result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do() if "errors" in result: raise ValueError(f"Error during query: {result['errors']}") docs = []