mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 00:45:53 +08:00
fix some tidb bugs (#4960)
This commit is contained in:
parent
64c8093c1e
commit
02e4de5166
@ -79,8 +79,6 @@ class TiDBVector(BaseVector):
|
|||||||
return
|
return
|
||||||
with Session(self._engine) as session:
|
with Session(self._engine) as session:
|
||||||
session.begin()
|
session.begin()
|
||||||
drop_statement = sql_text(f"""DROP TABLE IF EXISTS {self._collection_name}; """)
|
|
||||||
session.execute(drop_statement)
|
|
||||||
create_statement = sql_text(f"""
|
create_statement = sql_text(f"""
|
||||||
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
||||||
id CHAR(36) PRIMARY KEY,
|
id CHAR(36) PRIMARY KEY,
|
||||||
@ -123,7 +121,7 @@ class TiDBVector(BaseVector):
|
|||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
result = self.get_ids_by_metadata_field('doc_id', id)
|
result = self.get_ids_by_metadata_field('doc_id', id)
|
||||||
return len(result) > 0
|
return bool(result)
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
with Session(self._engine) as session:
|
with Session(self._engine) as session:
|
||||||
@ -184,14 +182,14 @@ class TiDBVector(BaseVector):
|
|||||||
docs = []
|
docs = []
|
||||||
if self._distance_func == 'l2':
|
if self._distance_func == 'l2':
|
||||||
tidb_func = 'Vec_l2_distance'
|
tidb_func = 'Vec_l2_distance'
|
||||||
elif self._distance_func == 'l2':
|
elif self._distance_func == 'cosine':
|
||||||
tidb_func = 'Vec_Cosine_distance'
|
tidb_func = 'Vec_Cosine_distance'
|
||||||
else:
|
else:
|
||||||
tidb_func = 'Vec_Cosine_distance'
|
tidb_func = 'Vec_Cosine_distance'
|
||||||
|
|
||||||
with Session(self._engine) as session:
|
with Session(self._engine) as session:
|
||||||
select_statement = sql_text(
|
select_statement = sql_text(
|
||||||
f"""SELECT meta, text FROM (
|
f"""SELECT meta, text, distance FROM (
|
||||||
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
|
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
|
||||||
FROM {self._collection_name}
|
FROM {self._collection_name}
|
||||||
ORDER BY distance
|
ORDER BY distance
|
||||||
@ -199,9 +197,11 @@ class TiDBVector(BaseVector):
|
|||||||
) t WHERE distance < {distance};"""
|
) t WHERE distance < {distance};"""
|
||||||
)
|
)
|
||||||
res = session.execute(select_statement)
|
res = session.execute(select_statement)
|
||||||
results = [(row[0], row[1]) for row in res]
|
results = [(row[0], row[1], row[2]) for row in res]
|
||||||
for meta, text in results:
|
for meta, text, distance in results:
|
||||||
docs.append(Document(page_content=text, metadata=json.loads(meta)))
|
metadata = json.loads(meta)
|
||||||
|
metadata['score'] = 1 - distance
|
||||||
|
docs.append(Document(page_content=text, metadata=metadata))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user