diff --git a/agent/component/base.py b/agent/component/base.py index 9af0b546e..5a53c2736 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -19,7 +19,6 @@ import builtins import json import os from functools import partial -from typing import Tuple, Union import pandas as pd @@ -417,7 +416,7 @@ class ComponentBase(ABC): def _run(self, history, **kwargs): raise NotImplementedError() - def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]: + def output(self, allow_partial=True) -> tuple[str, pd.DataFrame | partial]: o = getattr(self._param, self._param.output_var_name) if not isinstance(o, partial) and not isinstance(o, pd.DataFrame): if not isinstance(o, list): o = [o] diff --git a/api/db/db_models.py b/api/db/db_models.py index be3aafb1f..b6975afb4 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -17,7 +17,6 @@ import logging import inspect import os import sys -import typing import operator from enum import Enum from functools import wraps @@ -121,13 +120,13 @@ class SerializedField(LongTextField): f"the serialized type {self._serialized_type} is not supported") -def is_continuous_field(cls: typing.Type) -> bool: +def is_continuous_field(cls: type) -> bool: if cls in CONTINUOUS_FIELD_TYPE: return True for p in cls.__bases__: if p in CONTINUOUS_FIELD_TYPE: return True - elif p != Field and p != object: + elif p is not Field and p is not object: if is_continuous_field(p): return True else: @@ -159,7 +158,7 @@ class BaseModel(Model): def to_dict(self): return self.__dict__['__data__'] - def to_human_model_dict(self, only_primary_with: list = None): + def to_human_model_dict(self, only_primary_with: list | None = None): model_dict = self.__dict__['__data__'] if not only_primary_with: diff --git a/api/db/db_utils.py b/api/db/db_utils.py index 47449e79c..e597e3358 100644 --- a/api/db/db_utils.py +++ b/api/db/db_utils.py @@ -15,7 +15,6 @@ # import operator from functools import reduce -from typing import Dict, Type, Union from playhouse.pool import PooledMySQLDatabase @@ -87,7 +86,7 @@ supported_operators = { def query_dict2expression( - model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]): + model: type[DataBaseModel], query: dict[str, bool | int | str | list | tuple]): expression = [] for field, value in query.items(): @@ -105,8 +104,8 @@ def query_dict2expression( return reduce(operator.iand, expression) -def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0, - query: dict = None, order_by: Union[str, list, tuple] = None): +def query_db(model: type[DataBaseModel], limit: int = 0, offset: int = 0, + query: dict = None, order_by: str | list | tuple | None = None): data = model.select() if query: data = data.where(query_dict2expression(model, query)) diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 09dd9fd33..db063da2c 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -14,6 +14,9 @@ # limitations under the License. # +from beartype.claw import beartype_packages +beartype_packages(["agent", "api", "deepdoc", "plugins", "rag", "ragflow_sdk"]) # <-- raise exceptions in your code + import logging from api.utils.log_utils import initRootLogger initRootLogger("ragflow_server") diff --git a/api/utils/log_utils.py b/api/utils/log_utils.py index 4b296d036..b8ed722e2 100644 --- a/api/utils/log_utils.py +++ b/api/utils/log_utils.py @@ -28,13 +28,12 @@ def get_project_base_directory(): ) return PROJECT_BASE -def initRootLogger(script_path: str, log_level: int = logging.INFO, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"): +def initRootLogger(logfile_basename: str, log_level: int = logging.INFO, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"): logger = logging.getLogger() if logger.hasHandlers(): return - script_name = os.path.basename(script_path) - log_path = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.path.splitext(script_name)[0]}.log")) + log_path = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{logfile_basename}.log")) os.makedirs(os.path.dirname(log_path), exist_ok=True) logger.setLevel(log_level) @@ -50,5 +49,5 @@ def initRootLogger(script_path: str, log_level: int = logging.INFO, log_format: handler2.setFormatter(formatter) logger.addHandler(handler2) - msg = f"{script_name} log path: {log_path}" + msg = f"{logfile_basename} log path: {log_path}" logger.info(msg) \ No newline at end of file diff --git a/api/versions.py b/api/versions.py index 13f120f50..a52873483 100644 --- a/api/versions.py +++ b/api/versions.py @@ -13,11 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import dotenv -import typing import subprocess -def get_ragflow_version() -> typing.Optional[str]: +def get_ragflow_version() -> str: return RAGFLOW_VERSION_INFO @@ -42,7 +40,7 @@ def get_closest_tag_and_count(): return closest_tag else: return f"{commit_id}({closest_tag}~{commits_count})" - except Exception as e: + except Exception: return 'unknown' diff --git a/deepdoc/parser/json_parser.py b/deepdoc/parser/json_parser.py index 52b54d7b0..2c05614bf 100644 --- a/deepdoc/parser/json_parser.py +++ b/deepdoc/parser/json_parser.py @@ -3,12 +3,11 @@ # from https://github.com/langchain-ai/langchain/blob/master/libs/text-splitters/langchain_text_splitters/json.py import json -from typing import Any, Dict, List, Optional +from typing import Any from rag.nlp import find_codec - class RAGFlowJsonParser: def __init__( - self, max_chunk_size: int = 2000, min_chunk_size: Optional[int] = None + self, max_chunk_size: int = 2000, min_chunk_size: int | None = None ): super().__init__() self.max_chunk_size = max_chunk_size * 2 @@ -27,12 +26,12 @@ class RAGFlowJsonParser: return sections @staticmethod - def _json_size(data: Dict) -> int: + def _json_size(data: dict) -> int: """Calculate the size of the serialized JSON object.""" return len(json.dumps(data, ensure_ascii=False)) @staticmethod - def _set_nested_dict(d: Dict, path: List[str], value: Any) -> None: + def _set_nested_dict(d: dict, path: list[str], value: Any) -> None: """Set a value in a nested dictionary based on the given path.""" for key in path[:-1]: d = d.setdefault(key, {}) @@ -54,10 +53,10 @@ class RAGFlowJsonParser: def _json_split( self, - data: Dict[str, Any], - current_path: Optional[List[str]] = None, - chunks: Optional[List[Dict]] = None, - ) -> List[Dict]: + data: dict[str, Any], + current_path: list[str] | None, + chunks: list[dict] | None, + ) -> list[dict]: """ Split json into maximum size dictionaries while preserving structure. """ @@ -87,9 +86,9 @@ class RAGFlowJsonParser: def split_json( self, - json_data: Dict[str, Any], + json_data: dict[str, Any], convert_lists: bool = False, - ) -> List[Dict]: + ) -> list[dict]: """Splits JSON into a list of JSON chunks""" if convert_lists: @@ -104,10 +103,10 @@ class RAGFlowJsonParser: def split_text( self, - json_data: Dict[str, Any], + json_data: dict[str, Any], convert_lists: bool = False, ensure_ascii: bool = True, - ) -> List[str]: + ) -> list[str]: """Splits JSON into a list of JSON formatted strings""" chunks = self.split_json(json_data=json_data, convert_lists=convert_lists) diff --git a/docs/references/python_api_reference.md b/docs/references/python_api_reference.md index 093823abc..ff3e1a1c4 100644 --- a/docs/references/python_api_reference.md +++ b/docs/references/python_api_reference.md @@ -1059,7 +1059,7 @@ Deletes chat assistants by ID. #### ids: `list[str]` -The IDs of the chat assistants to delete. Defaults to `None`. If it is ot specified, all chat assistants in the system will be deleted. +The IDs of the chat assistants to delete. Defaults to `None`. If it is empty or not specified, all chat assistants in the system will be deleted. ### Returns diff --git a/graphrag/community_reports_extractor.py b/graphrag/community_reports_extractor.py index 213162f29..25f7b170b 100644 --- a/graphrag/community_reports_extractor.py +++ b/graphrag/community_reports_extractor.py @@ -9,8 +9,8 @@ import logging import json import re import traceback +from typing import Callable from dataclasses import dataclass -from typing import List, Callable import networkx as nx import pandas as pd from graphrag import leiden @@ -26,8 +26,8 @@ from timeit import default_timer as timer class CommunityReportsResult: """Community reports result class definition.""" - output: List[str] - structured_output: List[dict] + output: list[str] + structured_output: list[dict] class CommunityReportsExtractor: @@ -53,7 +53,7 @@ class CommunityReportsExtractor: self._max_report_length = max_report_length or 1500 def __call__(self, graph: nx.Graph, callback: Callable | None = None): - communities: dict[str, dict[str, List]] = leiden.run(graph, {}) + communities: dict[str, dict[str, list]] = leiden.run(graph, {}) total = sum([len(comm.items()) for _, comm in communities.items()]) relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)]) res_str = [] diff --git a/graphrag/entity_embedding.py b/graphrag/entity_embedding.py index ca9932b65..892d7db39 100644 --- a/graphrag/entity_embedding.py +++ b/graphrag/entity_embedding.py @@ -6,7 +6,6 @@ Reference: """ from typing import Any - import numpy as np import networkx as nx from graphrag.leiden import stable_largest_connected_component diff --git a/graphrag/graph_extractor.py b/graphrag/graph_extractor.py index e56a24780..2a9132cc6 100644 --- a/graphrag/graph_extractor.py +++ b/graphrag/graph_extractor.py @@ -9,8 +9,8 @@ import logging import numbers import re import traceback +from typing import Any, Callable from dataclasses import dataclass -from typing import Any, Mapping, Callable import tiktoken from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str diff --git a/graphrag/index.py b/graphrag/index.py index 26891803f..b0129a985 100644 --- a/graphrag/index.py +++ b/graphrag/index.py @@ -18,7 +18,6 @@ import os from concurrent.futures import ThreadPoolExecutor import json from functools import reduce -from typing import List import networkx as nx from api.db import LLMType from api.db.services.llm_service import LLMBundle @@ -53,7 +52,7 @@ def graph_merge(g1, g2): return g -def build_knowledge_graph_chunks(tenant_id: str, chunks: List[str], callback, entity_types=DEFAULT_ENTITY_TYPES): +def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, entity_types=DEFAULT_ENTITY_TYPES): _, tenant = TenantService.get_by_id(tenant_id) llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id) ext = GraphExtractor(llm_bdl) diff --git a/graphrag/leiden.py b/graphrag/leiden.py index 09440597e..9b03d628f 100644 --- a/graphrag/leiden.py +++ b/graphrag/leiden.py @@ -6,8 +6,8 @@ Reference: """ import logging -from typing import Any, cast, List import html +from typing import Any from graspologic.partition import hierarchical_leiden from graspologic.utils import largest_connected_component @@ -132,7 +132,7 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]: return results_by_level -def add_community_info2graph(graph: nx.Graph, nodes: List[str], community_title): +def add_community_info2graph(graph: nx.Graph, nodes: list[str], community_title): for n in nodes: if "communities" not in graph.nodes[n]: graph.nodes[n]["communities"] = [] diff --git a/graphrag/mind_map_extractor.py b/graphrag/mind_map_extractor.py index 1f4b924af..a472f8946 100644 --- a/graphrag/mind_map_extractor.py +++ b/graphrag/mind_map_extractor.py @@ -19,9 +19,9 @@ import collections import os import re import traceback +from typing import Any from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import Any from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT from graphrag.utils import ErrorHandlerFn, perform_variable_replacements diff --git a/graphrag/search.py b/graphrag/search.py index f32326d77..24735c4b5 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -15,7 +15,6 @@ # import json from copy import deepcopy -from typing import Dict import pandas as pd from rag.utils.doc_store_conn import OrderByExpr, FusionExpr @@ -25,7 +24,7 @@ from rag.nlp.search import Dealer class KGSearch(Dealer): def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False): - def merge_into_first(sres, title="") -> Dict[str, str]: + def merge_into_first(sres, title="") -> dict[str, str]: if not sres: return {} content_with_weight = "" diff --git a/graphrag/utils.py b/graphrag/utils.py index d4bbfadf9..3a8c5253f 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -7,8 +7,7 @@ Reference: import html import re -from collections.abc import Callable -from typing import Any +from typing import Any, Callable ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] diff --git a/poetry.lock b/poetry.lock index 6dc604a7f..e6fbc15c4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "accelerate" @@ -413,7 +413,7 @@ name = "aspose-slides" version = "24.11.0" description = "Aspose.Slides for Python via .NET is a presentation file formats processing library for working with Microsoft PowerPoint files without using Microsoft PowerPoint." optional = false -python-versions = "<3.14,>=3.5" +python-versions = ">=3.5,<3.14" files = [ {file = "Aspose.Slides-24.11.0-py3-none-macosx_10_14_x86_64.whl", hash = "sha256:b4819364497f9e075e00e63ee8fba8745dda4c910e199d5201e4abeebdcdec89"}, {file = "Aspose.Slides-24.11.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:bbeb5f0b14901f29f209beeac694a183f8d36c9475556ddeed3b2edb8107536a"}, @@ -565,7 +565,7 @@ name = "bce-python-sdk" version = "0.9.23" description = "BCE SDK for python" optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,<4,>=2.7" +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, <4" files = [ {file = "bce_python_sdk-0.9.23-py3-none-any.whl", hash = "sha256:8debe21a040e00060f6044877d594765ed7b18bc765c6bf16b878bca864140a3"}, {file = "bce_python_sdk-0.9.23.tar.gz", hash = "sha256:19739fed5cd0725356fc5ffa2acbdd8fb23f2a81edb91db21a03174551d0cf41"}, @@ -1706,7 +1706,7 @@ name = "deprecated" version = "1.2.15" description = "Python @deprecated decorator to deprecate old python classes, functions or methods." optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ {file = "Deprecated-1.2.15-py2.py3-none-any.whl", hash = "sha256:353bc4a8ac4bfc96800ddab349d89c25dec1079f65fd53acdcc1e0b975b21320"}, {file = "deprecated-1.2.15.tar.gz", hash = "sha256:683e561a90de76239796e6b6feac66b99030d2dd3fcf61ef996330f14bbb9b0d"}, @@ -2023,7 +2023,7 @@ name = "fastembed" version = "0.3.6" description = "Fast, light, accurate library built for retrieval embedding generation" optional = false -python-versions = "<3.13,>=3.8.0" +python-versions = ">=3.8.0,<3.13" files = [ {file = "fastembed-0.3.6-py3-none-any.whl", hash = "sha256:2bf70edae28bb4ccd9e01617098c2075b0ba35b88025a3d22b0e1e85b2c488ce"}, {file = "fastembed-0.3.6.tar.gz", hash = "sha256:c93c8ec99b8c008c2d192d6297866b8d70ec7ac8f5696b34eb5ea91f85efd15f"}, @@ -2940,7 +2940,7 @@ name = "graspologic" version = "3.4.1" description = "A set of Python modules for graph statistics" optional = false -python-versions = "<3.13,>=3.9" +python-versions = ">=3.9,<3.13" files = [ {file = "graspologic-3.4.1-py3-none-any.whl", hash = "sha256:c6563e087eda599bad1de831d4b7321c0daa7a82f4e85a7d7737ff67e07cdda2"}, {file = "graspologic-3.4.1.tar.gz", hash = "sha256:7561f0b852a2bccd351bff77e8db07d9892f9dfa35a420fdec01690e4fdc8075"}, @@ -3625,7 +3625,7 @@ name = "infinity-emb" version = "0.0.66" description = "Infinity is a high-throughput, low-latency REST API for serving text-embeddings, reranking models and clip." optional = false -python-versions = "<4,>=3.9" +python-versions = ">=3.9,<4" files = [ {file = "infinity_emb-0.0.66-py3-none-any.whl", hash = "sha256:1dc6ed9fa48e6cbe83650a7583dbbb4bc393900c39c326bb0aff2ddc090ac018"}, {file = "infinity_emb-0.0.66.tar.gz", hash = "sha256:9c9a361ccebf8e8f626c1f685286518d03d0c35e7d14179ae7c2500b4fc68b98"}, @@ -4070,7 +4070,7 @@ name = "litellm" version = "1.48.0" description = "Library to easily interface with LLM API providers" optional = false -python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" +python-versions = ">=3.8, !=2.7.*, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*, !=3.7.*" files = [ {file = "litellm-1.48.0-py3-none-any.whl", hash = "sha256:7765e8a92069778f5fc66aacfabd0e2f8ec8d74fb117f5e475567d89b0d376b9"}, {file = "litellm-1.48.0.tar.gz", hash = "sha256:31a9b8a25a9daf44c24ddc08bf74298da920f2c5cea44135e5061278d0aa6fc9"}, @@ -6197,7 +6197,7 @@ name = "psutil" version = "6.1.0" description = "Cross-platform lib for process and system monitoring in Python." optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ {file = "psutil-6.1.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:ff34df86226c0227c52f38b919213157588a678d049688eded74c76c8ba4a5d0"}, {file = "psutil-6.1.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:c0e0c00aa18ca2d3b2b991643b799a15fc8f0563d2ebb6040f64ce8dc027b942"}, @@ -6219,8 +6219,8 @@ files = [ ] [package.extras] -dev = ["black", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest-cov", "requests", "rstcheck", "ruff", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "virtualenv", "wheel"] -test = ["pytest", "pytest-xdist", "setuptools"] +dev = ["black", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest-cov", "requests", "rstcheck", "ruff", "sphinx", "sphinx-rtd-theme", "toml-sort", "twine", "virtualenv", "wheel"] +test = ["enum34", "futures", "ipaddress", "mock (==1.0.1)", "pytest (==4.6.11)", "pytest-xdist", "setuptools", "unittest2"] [[package]] name = "psycopg2-binary" @@ -7690,40 +7690,30 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:a606ef75a60ecf3d924613892cc603b154178ee25abb3055db5062da811fd969"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd5415dded15c3822597455bc02bcd66e81ef8b7a48cb71a33628fc9fdde39df"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f66efbc1caa63c088dead1c4170d148eabc9b80d95fb75b6c92ac0aad2437d76"}, - {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22353049ba4181685023b25b5b51a574bce33e7f51c759371a7422dcae5402a6"}, - {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:932205970b9f9991b34f55136be327501903f7c66830e9760a8ffb15b07f05cd"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win32.whl", hash = "sha256:3eac5a91891ceb88138c113f9db04f3cebdae277f5d44eaa3651a4f573e6a5da"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win_amd64.whl", hash = "sha256:ab007f2f5a87bd08ab1499bdf96f3d5c6ad4dcfa364884cb4549aa0154b13a28"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4a6679521a58256a90b0d89e03992c15144c5f3858f40d7c18886023d7943db6"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d84318609196d6bd6da0edfa25cedfbabd8dbde5140a0a23af29ad4b8f91fb1e"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb43a269eb827806502c7c8efb7ae7e9e9d0573257a46e8e952f4d4caba4f31e"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:811ea1594b8a0fb466172c384267a4e5e367298af6b228931f273b111f17ef52"}, - {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cf12567a7b565cbf65d438dec6cfbe2917d3c1bdddfce84a9930b7d35ea59642"}, - {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7dd5adc8b930b12c8fc5b99e2d535a09889941aa0d0bd06f4749e9a9397c71d2"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win32.whl", hash = "sha256:bd0a08f0bab19093c54e18a14a10b4322e1eacc5217056f3c063bd2f59853ce4"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win_amd64.whl", hash = "sha256:a274fb2cb086c7a3dea4322ec27f4cb5cc4b6298adb583ab0e211a4682f241eb"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:20b0f8dc160ba83b6dcc0e256846e1a02d044e13f7ea74a3d1d56ede4e48c632"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:943f32bc9dedb3abff9879edc134901df92cfce2c3d5c9348f172f62eb2d771d"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c3829bb364fdb8e0332c9931ecf57d9be3519241323c5274bd82f709cebc0c"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:749c16fcc4a2b09f28843cda5a193e0283e47454b63ec4b81eaa2242f50e4ccd"}, - {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bf165fef1f223beae7333275156ab2022cffe255dcc51c27f066b4370da81e31"}, - {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:32621c177bbf782ca5a18ba4d7af0f1082a3f6e517ac2a18b3974d4edf349680"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win32.whl", hash = "sha256:e8c4ebfcfd57177b572e2040777b8abc537cdef58a2120e830124946aa9b42c5"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win_amd64.whl", hash = "sha256:0467c5965282c62203273b838ae77c0d29d7638c8a4e3a1c8bdd3602c10904e4"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:4c8c5d82f50bb53986a5e02d1b3092b03622c02c2eb78e29bec33fd9593bae1a"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:e7e3736715fbf53e9be2a79eb4db68e4ed857017344d697e8b9749444ae57475"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b7e75b4965e1d4690e93021adfcecccbca7d61c7bddd8e22406ef2ff20d74ef"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96777d473c05ee3e5e3c3e999f5d23c6f4ec5b0c38c098b3a5229085f74236c6"}, - {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:3bc2a80e6420ca8b7d3590791e2dfc709c88ab9152c00eeb511c9875ce5778bf"}, - {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e188d2699864c11c36cdfdada94d781fd5d6b0071cd9c427bceb08ad3d7c70e1"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win32.whl", hash = "sha256:6442cb36270b3afb1b4951f060eccca1ce49f3d087ca1ca4563a6eb479cb3de6"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win_amd64.whl", hash = "sha256:e5b8daf27af0b90da7bb903a876477a9e6d7270be6146906b276605997c7e9a3"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fc4b630cd3fa2cf7fce38afa91d7cfe844a9f75d7f0f36393fa98815e911d987"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:bc5f1e1c28e966d61d2519f2a3d451ba989f9ea0f2307de7bc45baa526de9e45"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a0e060aace4c24dcaf71023bbd7d42674e3b230f7e7b97317baf1e953e5b519"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2f1c3765db32be59d18ab3953f43ab62a761327aafc1594a2a1fbe038b8b8a7"}, - {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d85252669dc32f98ebcd5d36768f5d4faeaeaa2d655ac0473be490ecdae3c285"}, - {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e143ada795c341b56de9418c58d028989093ee611aa27ffb9b7f609c00d813ed"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win32.whl", hash = "sha256:beffaed67936fbbeffd10966a4eb53c402fafd3d6833770516bf7314bc6ffa12"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win_amd64.whl", hash = "sha256:040ae85536960525ea62868b642bdb0c2cc6021c9f9d507810c0c604e66f5a7b"}, {file = "ruamel.yaml.clib-0.2.12.tar.gz", hash = "sha256:6c8fbb13ec503f99a91901ab46e0b07ae7941cd527393187039aec586fdfd36f"}, @@ -7734,7 +7724,7 @@ name = "s3transfer" version = "0.10.3" description = "An Amazon S3 Transfer Manager" optional = false -python-versions = ">=3.8" +python-versions = ">= 3.8" files = [ {file = "s3transfer-0.10.3-py3-none-any.whl", hash = "sha256:263ed587a5803c6c708d3ce44dc4dfedaab4c1a32e8329bab818933d79ddcf5d"}, {file = "s3transfer-0.10.3.tar.gz", hash = "sha256:4f50ed74ab84d474ce614475e0b8d5047ff080810aac5d01ea25231cfc944b0c"}, @@ -8196,7 +8186,7 @@ name = "smart-open" version = "7.0.5" description = "Utils for streaming large files (S3, HDFS, GCS, Azure Blob Storage, gzip, bz2...)" optional = false -python-versions = "<4.0,>=3.7" +python-versions = ">=3.7,<4.0" files = [ {file = "smart_open-7.0.5-py3-none-any.whl", hash = "sha256:8523ed805c12dff3eaa50e9c903a6cb0ae78800626631c5fe7ea073439847b89"}, {file = "smart_open-7.0.5.tar.gz", hash = "sha256:d3672003b1dbc85e2013e4983b88eb9a5ccfd389b0d4e5015f39a9ee5620ec18"}, @@ -9967,4 +9957,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "045015d496b77682204fcdb5b2576ebf8f7efcf8340e1e9839de246d9d39c08a" +content-hash = "dcf6c6a1d7fc52f982ef717bc48b13837e693b53a4fc5d9c06cefd253227259f" diff --git a/pyproject.toml b/pyproject.toml index 7c6975a80..d8f1a7e22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ azure-storage-file-datalake = "12.16.0" anthropic = "=0.34.1" arxiv = "2.1.3" aspose-slides = { version = "^24.9.0", markers = "platform_machine == 'x86_64'" } +beartype = "^0.18.5" bio = "1.7.1" boto3 = "1.34.140" botocore = "1.34.140" diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index de65ed691..813aa04db 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -15,7 +15,6 @@ # import logging import re -from typing import Optional import threading import requests from huggingface_hub import snapshot_download @@ -242,10 +241,10 @@ class FastEmbed(Base): def __init__( self, - key: Optional[str] = None, + key: str | None = None, model_name: str = "BAAI/bge-small-en-v1.5", - cache_dir: Optional[str] = None, - threads: Optional[int] = None, + cache_dir: str | None = None, + threads: int | None = None, **kwargs, ): if not settings.LIGHTEN and not FastEmbed._model: diff --git a/rag/nlp/search.py b/rag/nlp/search.py index ffad85583..0aeee4ad9 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -17,7 +17,6 @@ import logging import re import json -from typing import List, Optional, Dict, Union from dataclasses import dataclass from rag.utils import rmSpace @@ -37,13 +36,13 @@ class Dealer: @dataclass class SearchResult: total: int - ids: List[str] - query_vector: List[float] = None - field: Optional[Dict] = None - highlight: Optional[Dict] = None - aggregation: Union[List, Dict, None] = None - keywords: Optional[List[str]] = None - group_docs: List[List] = None + ids: list[str] + query_vector: list[float] | None = None + field: dict | None = None + highlight: dict | None = None + aggregation: list | dict | None = None + keywords: list[str] | None = None + group_docs: list[list] | None = None def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1): qv, _ = emb_mdl.encode_queries(txt) diff --git a/rag/raptor.py b/rag/raptor.py index 405f9224e..5974e371d 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -17,7 +17,6 @@ import logging import re from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait from threading import Lock -from typing import Tuple import umap import numpy as np from sklearn.mixture import GaussianMixture @@ -45,7 +44,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: optimal_clusters = n_clusters[np.argmin(bics)] return optimal_clusters - def __call__(self, chunks: Tuple[str, np.ndarray], random_state, callback=None): + def __call__(self, chunks: tuple[str, np.ndarray], random_state, callback=None): layers = [(0, len(chunks))] start, end = 0, len(chunks) if len(chunks) <= 1: return diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index ccef8bfb1..8a38389d4 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from beartype.claw import beartype_packages +beartype_packages(["agent", "api", "deepdoc", "plugins", "rag", "ragflow_sdk"]) # <-- raise exceptions in your code + import logging import sys from api.utils.log_utils import initRootLogger diff --git a/rag/utils/doc_store_conn.py b/rag/utils/doc_store_conn.py index b550e5c1a..cc2ea673d 100644 --- a/rag/utils/doc_store_conn.py +++ b/rag/utils/doc_store_conn.py @@ -1,19 +1,17 @@ from abc import ABC, abstractmethod -from typing import Optional, Union from dataclasses import dataclass import numpy as np import polars as pl -from typing import List, Dict DEFAULT_MATCH_VECTOR_TOPN = 10 DEFAULT_MATCH_SPARSE_TOPN = 10 -VEC = Union[list, np.ndarray] +VEC = list | np.ndarray @dataclass class SparseVector: indices: list[int] - values: Union[list[float], list[int], None] = None + values: list[float] | list[int] | None = None def __post_init__(self): assert (self.values is None) or (len(self.indices) == len(self.values)) @@ -82,7 +80,7 @@ class MatchSparseExpr(ABC): sparse_data: SparseVector | dict, distance_type: str, topn: int, - opt_params: Optional[dict] = None, + opt_params: dict | None = None, ): self.vector_column_name = vector_column_name self.sparse_data = sparse_data @@ -98,7 +96,7 @@ class MatchTensorExpr(ABC): query_data: VEC, query_data_type: str, topn: int, - extra_option: Optional[dict] = None, + extra_option: dict | None = None, ): self.column_name = column_name self.query_data = query_data @@ -108,16 +106,13 @@ class MatchTensorExpr(ABC): class FusionExpr(ABC): - def __init__(self, method: str, topn: int, fusion_params: Optional[dict] = None): + def __init__(self, method: str, topn: int, fusion_params: dict | None = None): self.method = method self.topn = topn self.fusion_params = fusion_params -MatchExpr = Union[ - MatchTextExpr, MatchDenseExpr, MatchSparseExpr, MatchTensorExpr, FusionExpr -] - +MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr class OrderByExpr(ABC): def __init__(self): @@ -229,11 +224,11 @@ class DocStoreConnection(ABC): raise NotImplementedError("Not implemented") @abstractmethod - def getFields(self, res, fields: List[str]) -> Dict[str, dict]: + def getFields(self, res, fields: list[str]) -> dict[str, dict]: raise NotImplementedError("Not implemented") @abstractmethod - def getHighlight(self, res, keywords: List[str], fieldnm: str): + def getHighlight(self, res, keywords: list[str], fieldnm: str): raise NotImplementedError("Not implemented") @abstractmethod diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index b372ff67a..d8541eddc 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -3,7 +3,6 @@ import re import json import time import os -from typing import List, Dict import copy from elasticsearch import Elasticsearch @@ -363,7 +362,7 @@ class ESConnection(DocStoreConnection): rr.append(d["_source"]) return rr - def getFields(self, res, fields: List[str]) -> Dict[str, dict]: + def getFields(self, res, fields: list[str]) -> dict[str, dict]: res_fields = {} if not fields: return {} @@ -382,7 +381,7 @@ class ESConnection(DocStoreConnection): res_fields[d["id"]] = m return res_fields - def getHighlight(self, res, keywords: List[str], fieldnm: str): + def getHighlight(self, res, keywords: list[str], fieldnm: str): ans = {} for d in res["hits"]["hits"]: hlts = d.get("highlight") diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index f6ca3493b..ad5a83ab7 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -3,7 +3,6 @@ import os import re import json import time -from typing import List, Dict import infinity from infinity.common import ConflictType, InfinityException from infinity.index import IndexInfo, IndexType @@ -384,7 +383,7 @@ class InfinityConnection(DocStoreConnection): def getChunkIds(self, res): return list(res["id"]) - def getFields(self, res, fields: List[str]) -> Dict[str, dict]: + def getFields(self, res, fields: list[str]) -> list[str, dict]: res_fields = {} if not fields: return {} @@ -412,7 +411,7 @@ class InfinityConnection(DocStoreConnection): res_fields[id] = m return res_fields - def getHighlight(self, res, keywords: List[str], fieldnm: str): + def getHighlight(self, res, keywords: list[str], fieldnm: str): ans = {} num_rows = len(res) column_id = res["id"] diff --git a/sdk/python/poetry.lock b/sdk/python/poetry.lock index 4ca41e400..76bc4b974 100644 --- a/sdk/python/poetry.lock +++ b/sdk/python/poetry.lock @@ -1,5 +1,23 @@ # This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +[[package]] +name = "beartype" +version = "0.18.5" +description = "Unbearably fast runtime type checking in pure Python." +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "beartype-0.18.5-py3-none-any.whl", hash = "sha256:5301a14f2a9a5540fe47ec6d34d758e9cd8331d36c4760fc7a5499ab86310089"}, + {file = "beartype-0.18.5.tar.gz", hash = "sha256:264ddc2f1da9ec94ff639141fbe33d22e12a9f75aa863b83b7046ffff1381927"}, +] + +[package.extras] +all = ["typing-extensions (>=3.10.0.0)"] +dev = ["autoapi (>=0.9.0)", "coverage (>=5.5)", "equinox", "mypy (>=0.800)", "numpy", "pandera", "pydata-sphinx-theme (<=0.7.2)", "pytest (>=4.0.0)", "sphinx", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)", "tox (>=3.20.1)", "typing-extensions (>=3.10.0.0)"] +doc-rtd = ["autoapi (>=0.9.0)", "pydata-sphinx-theme (<=0.7.2)", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)"] +test-tox = ["equinox", "mypy (>=0.800)", "numpy", "pandera", "pytest (>=4.0.0)", "sphinx", "typing-extensions (>=3.10.0.0)"] +test-tox-coverage = ["coverage (>=5.5)"] + [[package]] name = "certifi" version = "2024.8.30" @@ -177,13 +195,13 @@ files = [ [[package]] name = "packaging" -version = "24.1" +version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, - {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, + {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, + {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, ] [[package]] @@ -246,13 +264,13 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "tomli" -version = "2.0.2" +version = "2.1.0" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" files = [ - {file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"}, - {file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"}, + {file = "tomli-2.1.0-py3-none-any.whl", hash = "sha256:a5c57c3d1c56f5ccdf89f6523458f60ef716e210fc47c4cfb188c5ba473e0391"}, + {file = "tomli-2.1.0.tar.gz", hash = "sha256:3f646cae2aec94e17d04973e4249548320197cfabdf130015d023de4b74d8ab8"}, ] [[package]] @@ -275,4 +293,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "202bfd3e121f1d57a2f9c9d91cd7a50eacf2362cd1995c9f6347bcb100cf9336" +content-hash = "19565d31d822b0573f505662c664d735194134a505f43bbd1657c033f87bb82d" diff --git a/sdk/python/pyproject.toml b/sdk/python/pyproject.toml index d32e68866..8724eb198 100644 --- a/sdk/python/pyproject.toml +++ b/sdk/python/pyproject.toml @@ -10,6 +10,7 @@ package-mode = true [tool.poetry.dependencies] python = "^3.10" requests = "^2.30.0" +beartype = "^0.18.5" pytest = "^8.0.0" diff --git a/sdk/python/ragflow_sdk/__init__.py b/sdk/python/ragflow_sdk/__init__.py index f2b0eb278..a99e97c6b 100644 --- a/sdk/python/ragflow_sdk/__init__.py +++ b/sdk/python/ragflow_sdk/__init__.py @@ -1,3 +1,6 @@ +from beartype.claw import beartype_this_package +beartype_this_package() # <-- raise exceptions in your code + import importlib.metadata __version__ = importlib.metadata.version("ragflow_sdk") diff --git a/sdk/python/ragflow_sdk/modules/chat.py b/sdk/python/ragflow_sdk/modules/chat.py index 9c91abe8a..7663dfdc1 100644 --- a/sdk/python/ragflow_sdk/modules/chat.py +++ b/sdk/python/ragflow_sdk/modules/chat.py @@ -1,4 +1,3 @@ -from typing import List from .base import Base from .session import Session @@ -58,7 +57,7 @@ class Chat(Base): raise Exception(res["message"]) def list_sessions(self,page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, - id: str = None, name: str = None) -> List[Session]: + id: str = None, name: str = None) -> list[Session]: res = self.get(f'/chats/{self.id}/sessions',{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name} ) res = res.json() if res.get("code") == 0: @@ -68,7 +67,7 @@ class Chat(Base): return result_list raise Exception(res["message"]) - def delete_sessions(self,ids:List[str]=None): + def delete_sessions(self,ids: list[str] | None = None): res = self.rm(f"/chats/{self.id}/sessions", {"ids": ids}) res = res.json() if res.get("code") != 0: diff --git a/sdk/python/ragflow_sdk/modules/dataset.py b/sdk/python/ragflow_sdk/modules/dataset.py index 506aa9179..63d95b12c 100644 --- a/sdk/python/ragflow_sdk/modules/dataset.py +++ b/sdk/python/ragflow_sdk/modules/dataset.py @@ -1,5 +1,3 @@ -from typing import List - from .document import Document from .base import Base @@ -35,7 +33,7 @@ class DataSet(Base): if res.get("code") != 0: raise Exception(res["message"]) - def upload_documents(self,document_list: List[dict]): + def upload_documents(self,document_list: list[dict]): url = f"/datasets/{self.id}/documents" files = [("file",(ele["displayed_name"],ele["blob"])) for ele in document_list] res = self.post(path=url,json=None,files=files) @@ -48,7 +46,7 @@ class DataSet(Base): return doc_list raise Exception(res.get("message")) - def list_documents(self, id: str = None, keywords: str = None, page: int =1, page_size: int = 30, orderby: str = "create_time", desc: bool = True): + def list_documents(self, id: str | None = None, keywords: str | None = None, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True): res = self.get(f"/datasets/{self.id}/documents",params={"id": id,"keywords": keywords,"page": page,"page_size": page_size,"orderby": orderby,"desc": desc}) res = res.json() documents = [] @@ -58,7 +56,7 @@ class DataSet(Base): return documents raise Exception(res["message"]) - def delete_documents(self,ids: List[str] = None): + def delete_documents(self,ids: list[str] | None = None): res = self.rm(f"/datasets/{self.id}/documents",{"ids":ids}) res = res.json() if res.get("code") != 0: diff --git a/sdk/python/ragflow_sdk/modules/document.py b/sdk/python/ragflow_sdk/modules/document.py index 052f9ccf1..62728636e 100644 --- a/sdk/python/ragflow_sdk/modules/document.py +++ b/sdk/python/ragflow_sdk/modules/document.py @@ -1,7 +1,6 @@ import json from .base import Base from .chunk import Chunk -from typing import List class Document(Base): @@ -63,14 +62,14 @@ class Document(Base): raise Exception(res.get("message")) - def add_chunk(self, content: str,important_keywords:List[str]=[]): + def add_chunk(self, content: str,important_keywords: list[str] = []): res = self.post(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', {"content":content,"important_keywords":important_keywords}) res = res.json() if res.get("code") == 0: return Chunk(self.rag,res["data"].get("chunk")) raise Exception(res.get("message")) - def delete_chunks(self,ids:List[str] = None): + def delete_chunks(self,ids:list[str] | None = None): res = self.rm(f"/datasets/{self.dataset_id}/documents/{self.id}/chunks",{"chunk_ids":ids}) res = res.json() if res.get("code")!=0: diff --git a/sdk/python/ragflow_sdk/ragflow.py b/sdk/python/ragflow_sdk/ragflow.py index 95450af4a..1ee1fdc7f 100644 --- a/sdk/python/ragflow_sdk/ragflow.py +++ b/sdk/python/ragflow_sdk/ragflow.py @@ -13,14 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - import requests from .modules.chat import Chat from .modules.chunk import Chunk from .modules.dataset import DataSet -from .modules.document import Document class RAGFlow: @@ -64,7 +61,7 @@ class RAGFlow: return DataSet(self, res["data"]) raise Exception(res["message"]) - def delete_datasets(self, ids: List[str] = None): + def delete_datasets(self, ids: list[str] | None = None): res = self.delete("/datasets",{"ids": ids}) res=res.json() if res.get("code") != 0: @@ -77,8 +74,8 @@ class RAGFlow: raise Exception("Dataset %s not found" % name) def list_datasets(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, - id: str = None, name: str = None) -> \ - List[DataSet]: + id: str | None = None, name: str | None = None) -> \ + list[DataSet]: res = self.get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) res = res.json() @@ -89,8 +86,8 @@ class RAGFlow: return result_list raise Exception(res["message"]) - def create_chat(self, name: str, avatar: str = "", dataset_ids: List[str] = [], - llm: Chat.LLM = None, prompt: Chat.Prompt = None) -> Chat: + def create_chat(self, name: str, avatar: str = "", dataset_ids: list[str] = [], + llm: Chat.LLM | None = None, prompt: Chat.Prompt | None = None) -> Chat: dataset_list = [] for id in dataset_ids: dataset_list.append(id) @@ -135,7 +132,7 @@ class RAGFlow: return Chat(self, res["data"]) raise Exception(res["message"]) - def delete_chats(self,ids: List[str] = None) -> bool: + def delete_chats(self,ids: list[str] | None = None): res = self.delete('/chats', {"ids":ids}) res = res.json() @@ -143,7 +140,7 @@ class RAGFlow: raise Exception(res["message"]) def list_chats(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, - id: str = None, name: str = None) -> List[Chat]: + id: str | None = None, name: str | None = None) -> list[Chat]: res = self.get("/chats",{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) res = res.json() result_list = [] @@ -154,7 +151,7 @@ class RAGFlow: raise Exception(res["message"]) - def retrieve(self, dataset_ids, document_ids=None, question="", page=1, page_size=30, similarity_threshold=0.2, vector_similarity_weight=0.3, top_k=1024, rerank_id:str=None, keyword:bool=False, ): + def retrieve(self, dataset_ids, document_ids=None, question="", page=1, page_size=30, similarity_threshold=0.2, vector_similarity_weight=0.3, top_k=1024, rerank_id: str | None = None, keyword:bool=False, ): if document_ids is None: document_ids = [] data_json ={ @@ -170,7 +167,7 @@ class RAGFlow: "documents": document_ids } # Send a POST request to the backend service (using requests library as an example, actual implementation may vary) - res = self.post(f'/retrieval',json=data_json) + res = self.post('/retrieval',json=data_json) res = res.json() if res.get("code") ==0: chunks=[]