fix some tidb bugs (#4960)

This commit is contained in:
Jyong 2024-06-05 19:14:18 +08:00 committed by GitHub
parent 64c8093c1e
commit 02e4de5166
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -79,8 +79,6 @@ class TiDBVector(BaseVector):
return return
with Session(self._engine) as session: with Session(self._engine) as session:
session.begin() session.begin()
drop_statement = sql_text(f"""DROP TABLE IF EXISTS {self._collection_name}; """)
session.execute(drop_statement)
create_statement = sql_text(f""" create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS {self._collection_name} ( CREATE TABLE IF NOT EXISTS {self._collection_name} (
id CHAR(36) PRIMARY KEY, id CHAR(36) PRIMARY KEY,
@ -123,7 +121,7 @@ class TiDBVector(BaseVector):
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
result = self.get_ids_by_metadata_field('doc_id', id) result = self.get_ids_by_metadata_field('doc_id', id)
return len(result) > 0 return bool(result)
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
with Session(self._engine) as session: with Session(self._engine) as session:
@ -184,14 +182,14 @@ class TiDBVector(BaseVector):
docs = [] docs = []
if self._distance_func == 'l2': if self._distance_func == 'l2':
tidb_func = 'Vec_l2_distance' tidb_func = 'Vec_l2_distance'
elif self._distance_func == 'l2': elif self._distance_func == 'cosine':
tidb_func = 'Vec_Cosine_distance' tidb_func = 'Vec_Cosine_distance'
else: else:
tidb_func = 'Vec_Cosine_distance' tidb_func = 'Vec_Cosine_distance'
with Session(self._engine) as session: with Session(self._engine) as session:
select_statement = sql_text( select_statement = sql_text(
f"""SELECT meta, text FROM ( f"""SELECT meta, text, distance FROM (
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
FROM {self._collection_name} FROM {self._collection_name}
ORDER BY distance ORDER BY distance
@ -199,9 +197,11 @@ class TiDBVector(BaseVector):
) t WHERE distance < {distance};""" ) t WHERE distance < {distance};"""
) )
res = session.execute(select_statement) res = session.execute(select_statement)
results = [(row[0], row[1]) for row in res] results = [(row[0], row[1], row[2]) for row in res]
for meta, text in results: for meta, text, distance in results:
docs.append(Document(page_content=text, metadata=json.loads(meta))) metadata = json.loads(meta)
metadata['score'] = 1 - distance
docs.append(Document(page_content=text, metadata=metadata))
return docs return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: