mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 15:25:58 +08:00
Update Oracle db connection library and change connection pool to single connection (#18466)
This commit is contained in:
parent
30c051d485
commit
7b6523e54d
@ -2,12 +2,12 @@ import array
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import jieba.posseg as pseg # type: ignore
|
||||
import numpy
|
||||
import oracledb
|
||||
from oracledb.connection import Connection
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
@ -70,6 +70,7 @@ class OracleVector(BaseVector):
|
||||
super().__init__(collection_name)
|
||||
self.pool = self._create_connection_pool(config)
|
||||
self.table_name = f"embedding_{collection_name}"
|
||||
self.config = config
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ORACLE
|
||||
@ -107,16 +108,19 @@ class OracleVector(BaseVector):
|
||||
outconverter=self.numpy_converter_out,
|
||||
)
|
||||
|
||||
def _get_connection(self) -> Connection:
|
||||
connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn)
|
||||
return connection
|
||||
|
||||
def _create_connection_pool(self, config: OracleVectorConfig):
|
||||
pool_params = {
|
||||
"user": config.user,
|
||||
"password": config.password,
|
||||
"dsn": config.dsn,
|
||||
"min": 1,
|
||||
"max": 50,
|
||||
"max": 5,
|
||||
"increment": 1,
|
||||
}
|
||||
|
||||
if config.is_autonomous:
|
||||
pool_params.update(
|
||||
{
|
||||
@ -125,22 +129,8 @@ class OracleVector(BaseVector):
|
||||
"wallet_password": config.wallet_password,
|
||||
}
|
||||
)
|
||||
|
||||
return oracledb.create_pool(**pool_params)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
conn = self.pool.acquire()
|
||||
conn.inputtypehandler = self.input_type_handler
|
||||
conn.outputtypehandler = self.output_type_handler
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
@ -162,41 +152,68 @@ class OracleVector(BaseVector):
|
||||
numpy.array(embeddings[i]),
|
||||
)
|
||||
)
|
||||
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
|
||||
with self._get_cursor() as cur:
|
||||
cur.executemany(
|
||||
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
|
||||
)
|
||||
with self._get_connection() as conn:
|
||||
conn.inputtypehandler = self.input_type_handler
|
||||
conn.outputtypehandler = self.output_type_handler
|
||||
# with conn.cursor() as cur:
|
||||
# cur.executemany(
|
||||
# f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
|
||||
# )
|
||||
# conn.commit()
|
||||
for value in values:
|
||||
with conn.cursor() as cur:
|
||||
try:
|
||||
cur.execute(
|
||||
f"""INSERT INTO {self.table_name} (id, text, meta, embedding)
|
||||
VALUES (:1, :2, :3, :4)""",
|
||||
value,
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
conn.close()
|
||||
return pks
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
|
||||
return cur.fetchone() is not None
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
|
||||
return cur.fetchone() is not None
|
||||
conn.close()
|
||||
|
||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
docs = []
|
||||
for record in cur:
|
||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
docs = []
|
||||
for record in cur:
|
||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||
self.pool.release(connection=conn)
|
||||
conn.close()
|
||||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Search the nearest neighbors to a vector.
|
||||
|
||||
:param query_vector: The input vector to search for similar items.
|
||||
:param top_k: The number of nearest neighbors to return, default is 5.
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
@ -205,20 +222,25 @@ class OracleVector(BaseVector):
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
|
||||
f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
|
||||
[numpy.array(query_vector)],
|
||||
)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
with self._get_connection() as conn:
|
||||
conn.inputtypehandler = self.input_type_handler
|
||||
conn.outputtypehandler = self.output_type_handler
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
|
||||
AS distance FROM {self.table_name}
|
||||
{where_clause} ORDER BY distance fetch first {top_k} rows only""",
|
||||
[numpy.array(query_vector)],
|
||||
)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
conn.close()
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@ -228,7 +250,7 @@ class OracleVector(BaseVector):
|
||||
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
# just not implement fetch by score_threshold now, may be later
|
||||
# score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if len(query) > 0:
|
||||
# Check which language the query is in
|
||||
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
|
||||
@ -239,7 +261,7 @@ class OracleVector(BaseVector):
|
||||
words = pseg.cut(query)
|
||||
current_entity = ""
|
||||
for word, pos in words:
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
|
||||
current_entity += word
|
||||
else:
|
||||
if current_entity:
|
||||
@ -260,30 +282,35 @@ class OracleVector(BaseVector):
|
||||
for token in all_tokens:
|
||||
if token not in stop_words:
|
||||
entities.append(token)
|
||||
with self._get_cursor() as cur:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
cur.execute(
|
||||
f"select meta, text, embedding FROM {self.table_name}"
|
||||
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
|
||||
f"order by score(1) desc fetch first {top_k} rows only",
|
||||
[" ACCUM ".join(entities)],
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata, text, embedding = record
|
||||
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
cur.execute(
|
||||
f"""select meta, text, embedding FROM {self.table_name}
|
||||
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
|
||||
order by score(1) desc fetch first {top_k} rows only""",
|
||||
kk=" ACCUM ".join(entities),
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata, text, embedding = record
|
||||
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
|
||||
conn.close()
|
||||
return docs
|
||||
else:
|
||||
return [Document(page_content="", metadata={})]
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
@ -293,11 +320,14 @@ class OracleVector(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
class OracleVectorFactory(AbstractVectorFactory):
|
||||
|
@ -178,7 +178,7 @@ vdb = [
|
||||
"couchbase~=4.3.0",
|
||||
"elasticsearch==8.14.0",
|
||||
"opensearch-py==2.4.0",
|
||||
"oracledb~=2.2.1",
|
||||
"oracledb==3.0.0",
|
||||
"pgvecto-rs[sqlalchemy]~=0.2.1",
|
||||
"pgvector==0.2.5",
|
||||
"pymilvus~=2.5.0",
|
||||
|
26
api/uv.lock
generated
26
api/uv.lock
generated
@ -1471,7 +1471,7 @@ vdb = [
|
||||
{ name = "couchbase", specifier = "~=4.3.0" },
|
||||
{ name = "elasticsearch", specifier = "==8.14.0" },
|
||||
{ name = "opensearch-py", specifier = "==2.4.0" },
|
||||
{ name = "oracledb", specifier = "~=2.2.1" },
|
||||
{ name = "oracledb", specifier = "==3.0.0" },
|
||||
{ name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.1" },
|
||||
{ name = "pgvector", specifier = "==0.2.5" },
|
||||
{ name = "pymilvus", specifier = "~=2.5.0" },
|
||||
@ -3600,23 +3600,23 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "oracledb"
|
||||
version = "2.2.1"
|
||||
version = "3.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cryptography" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/36/fb/3fbacb351833dd794abb184303a5761c4bb33df9d770fd15d01ead2ff738/oracledb-2.2.1.tar.gz", hash = "sha256:8464c6f0295f3318daf6c2c72c83c2dcbc37e13f8fd44e3e39ff8665f442d6b6", size = 580818 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/bf/39/712f797b75705c21148fa1d98651f63c2e5cc6876e509a0a9e2f5b406572/oracledb-3.0.0.tar.gz", hash = "sha256:64dc86ee5c032febc556798b06e7b000ef6828bb0252084f6addacad3363db85", size = 840431 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/74/b7/a4238295944670fb8cc50a8cc082e0af5a0440bfb1c2bac2b18429c0a579/oracledb-2.2.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:fb6d9a4d7400398b22edb9431334f9add884dec9877fd9c4ae531e1ccc6ee1fd", size = 3551303 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/5f/98481d44976cd2b3086361f2d50026066b24090b0e6cd1f2a12c824e9717/oracledb-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07757c240afbb4f28112a6affc2c5e4e34b8a92e5bb9af81a40fba398da2b028", size = 12258455 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/54/06b2540286e2b63f60877d6f3c6c40747e216b6eeda0756260e194897076/oracledb-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63daec72f853c47179e98493e9b732909d96d495bdceb521c5973a3940d28142", size = 12317476 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/1a/67814439a4e24df83281a72cb0ba433d6b74e1bff52a9975b87a725bcba5/oracledb-2.2.1-cp311-cp311-win32.whl", hash = "sha256:fec5318d1e0ada7e4674574cb6c8d1665398e8b9c02982279107212f05df1660", size = 1369368 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/b8/b2a8f0607be17f58ec6689ad5fd15c2956f4996c64547325e96439570edf/oracledb-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:5134dccb5a11bc755abf02fd49be6dc8141dfcae4b650b55d40509323d00b5c2", size = 1655035 },
|
||||
{ url = "https://files.pythonhosted.org/packages/24/5b/2fff762243030f31a6b1561fc8eeb142e69ba6ebd3e7fbe4a2c82f0eb6f0/oracledb-2.2.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ac5716bc9a48247fdf563f5f4ec097f5c9f074a60fd130cdfe16699208ca29b5", size = 3583960 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/88/34117ae830e7338af7c0481f1c0fc6eda44d558e12f9203b45b491e53071/oracledb-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c150bddb882b7c73fb462aa2d698744da76c363e404570ed11d05b65811d96c3", size = 11749006 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/58/bac788f18c21f727955652fe238de2d24a12c2b455ed4db18a6d23ff781e/oracledb-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193e1888411bc21187ade4b16b76820bd1e8f216e25602f6cd0a97d45723c1dc", size = 11950663 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/e2/005f66ae919c6f7c73e06863256cf43aa844330e2dc61a5f9779ae44a801/oracledb-2.2.1-cp312-cp312-win32.whl", hash = "sha256:44a960f8bbb0711af222e0a9690e037b6a2a382e0559ae8eeb9cfafe26c7a3bc", size = 1324255 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/25/759eb2143134513382e66d874c4aacfd691dec3fef7141170cfa6c1b154f/oracledb-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:470136add32f0d0084225c793f12a52b61b52c3dc00c9cd388ec6a3db3a7643e", size = 1613047 },
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/bf/d872c4b3fc15cd3261fe0ea72b21d181700c92dbc050160e161654987062/oracledb-3.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:52daa9141c63dfa75c07d445e9bb7f69f43bfb3c5a173ecc48c798fe50288d26", size = 4312963 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b1/ea/01ee29e76a610a53bb34fdc1030f04b7669c3f80b25f661e07850fc6160e/oracledb-3.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af98941789df4c6aaaf4338f5b5f6b7f2c8c3fe6f8d6a9382f177f350868747a", size = 2661536 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/8e/ad380e34a46819224423b4773e58c350bc6269643c8969604097ced8c3bc/oracledb-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9812bb48865aaec35d73af54cd1746679f2a8a13cbd1412ab371aba2e39b3943", size = 2867461 },
|
||||
{ url = "https://files.pythonhosted.org/packages/96/09/ecc4384a27fd6e1e4de824ae9c160e4ad3aaebdaade5b4bdcf56a4d1ff63/oracledb-3.0.0-cp311-cp311-win32.whl", hash = "sha256:6c27fe0de64f2652e949eb05b3baa94df9b981a4a45fa7f8a991e1afb450c8e2", size = 1752046 },
|
||||
{ url = "https://files.pythonhosted.org/packages/62/e8/f34bde24050c6e55eeba46b23b2291f2dd7fd272fa8b322dcbe71be55778/oracledb-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:f922709672002f0b40997456f03a95f03e5712a86c61159951c5ce09334325e0", size = 2101210 },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/fc/24590c3a3d41e58494bd3c3b447a62835138e5f9b243d9f8da0cfb5da8dc/oracledb-3.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:acd0e747227dea01bebe627b07e958bf36588a337539f24db629dc3431d3f7eb", size = 4351993 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/b6/1f3b0b7bb94d53e8857d77b2e8dbdf6da091dd7e377523e24b79dac4fd71/oracledb-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f8b402f77c22af031cd0051aea2472ecd0635c1b452998f511aa08b7350c90a4", size = 2532640 },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/1a/1815f6c086ab49c00921cf155ff5eede5267fb29fcec37cb246339a5ce4d/oracledb-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:378a27782e9a37918bd07a5a1427a77cb6f777d0a5a8eac9c070d786f50120ef", size = 2765949 },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/8d/208900f8d372909792ee70b2daad3f7361181e55f2217c45ed9dff658b54/oracledb-3.0.0-cp312-cp312-win32.whl", hash = "sha256:54a28c2cb08316a527cd1467740a63771cc1c1164697c932aa834c0967dc4efc", size = 1709373 },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/5e/c21754f19c896102793c3afec2277e2180aa7d505e4d7fcca24b52d14e4f/oracledb-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:8289bad6d103ce42b140e40576cf0c81633e344d56e2d738b539341eacf65624", size = 2056452 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
Loading…
x
Reference in New Issue
Block a user