diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index b22839db47..00b929f67a 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -79,8 +79,6 @@ class TiDBVector(BaseVector): return with Session(self._engine) as session: session.begin() - drop_statement = sql_text(f"""DROP TABLE IF EXISTS {self._collection_name}; """) - session.execute(drop_statement) create_statement = sql_text(f""" CREATE TABLE IF NOT EXISTS {self._collection_name} ( id CHAR(36) PRIMARY KEY, @@ -123,7 +121,7 @@ class TiDBVector(BaseVector): def text_exists(self, id: str) -> bool: result = self.get_ids_by_metadata_field('doc_id', id) - return len(result) > 0 + return bool(result) def delete_by_ids(self, ids: list[str]) -> None: with Session(self._engine) as session: @@ -184,14 +182,14 @@ class TiDBVector(BaseVector): docs = [] if self._distance_func == 'l2': tidb_func = 'Vec_l2_distance' - elif self._distance_func == 'l2': + elif self._distance_func == 'cosine': tidb_func = 'Vec_Cosine_distance' else: tidb_func = 'Vec_Cosine_distance' with Session(self._engine) as session: select_statement = sql_text( - f"""SELECT meta, text FROM ( + f"""SELECT meta, text, distance FROM ( SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance FROM {self._collection_name} ORDER BY distance @@ -199,9 +197,11 @@ class TiDBVector(BaseVector): ) t WHERE distance < {distance};""" ) res = session.execute(select_statement) - results = [(row[0], row[1]) for row in res] - for meta, text in results: - docs.append(Document(page_content=text, metadata=json.loads(meta))) + results = [(row[0], row[1], row[2]) for row in res] + for meta, text, distance in results: + metadata = json.loads(meta) + metadata['score'] = 1 - distance + docs.append(Document(page_content=text, metadata=metadata)) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: