Introduced beartype (#3460)

### What problem does this PR solve?

Introduced [beartype](https://github.com/beartype/beartype) for runtime
type-checking.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Zhichang Yu 2024-11-18 17:38:17 +08:00 committed by GitHub
parent 3824c1fec0
commit 4413683898
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 125 additions and 134 deletions

View File

@ -19,7 +19,6 @@ import builtins
import json
import os
from functools import partial
from typing import Tuple, Union
import pandas as pd
@ -417,7 +416,7 @@ class ComponentBase(ABC):
def _run(self, history, **kwargs):
raise NotImplementedError()
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
def output(self, allow_partial=True) -> tuple[str, pd.DataFrame | partial]:
o = getattr(self._param, self._param.output_var_name)
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
if not isinstance(o, list): o = [o]

View File

@ -17,7 +17,6 @@ import logging
import inspect
import os
import sys
import typing
import operator
from enum import Enum
from functools import wraps
@ -121,13 +120,13 @@ class SerializedField(LongTextField):
f"the serialized type {self._serialized_type} is not supported")
def is_continuous_field(cls: typing.Type) -> bool:
def is_continuous_field(cls: type) -> bool:
if cls in CONTINUOUS_FIELD_TYPE:
return True
for p in cls.__bases__:
if p in CONTINUOUS_FIELD_TYPE:
return True
elif p != Field and p != object:
elif p is not Field and p is not object:
if is_continuous_field(p):
return True
else:
@ -159,7 +158,7 @@ class BaseModel(Model):
def to_dict(self):
return self.__dict__['__data__']
def to_human_model_dict(self, only_primary_with: list = None):
def to_human_model_dict(self, only_primary_with: list | None = None):
model_dict = self.__dict__['__data__']
if not only_primary_with:

View File

@ -15,7 +15,6 @@
#
import operator
from functools import reduce
from typing import Dict, Type, Union
from playhouse.pool import PooledMySQLDatabase
@ -87,7 +86,7 @@ supported_operators = {
def query_dict2expression(
model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
model: type[DataBaseModel], query: dict[str, bool | int | str | list | tuple]):
expression = []
for field, value in query.items():
@ -105,8 +104,8 @@ def query_dict2expression(
return reduce(operator.iand, expression)
def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0,
query: dict = None, order_by: Union[str, list, tuple] = None):
def query_db(model: type[DataBaseModel], limit: int = 0, offset: int = 0,
query: dict = None, order_by: str | list | tuple | None = None):
data = model.select()
if query:
data = data.where(query_dict2expression(model, query))

View File

@ -14,6 +14,9 @@
# limitations under the License.
#
from beartype.claw import beartype_packages
beartype_packages(["agent", "api", "deepdoc", "plugins", "rag", "ragflow_sdk"]) # <-- raise exceptions in your code
import logging
from api.utils.log_utils import initRootLogger
initRootLogger("ragflow_server")

View File

@ -28,13 +28,12 @@ def get_project_base_directory():
)
return PROJECT_BASE
def initRootLogger(script_path: str, log_level: int = logging.INFO, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"):
def initRootLogger(logfile_basename: str, log_level: int = logging.INFO, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"):
logger = logging.getLogger()
if logger.hasHandlers():
return
script_name = os.path.basename(script_path)
log_path = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.path.splitext(script_name)[0]}.log"))
log_path = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{logfile_basename}.log"))
os.makedirs(os.path.dirname(log_path), exist_ok=True)
logger.setLevel(log_level)
@ -50,5 +49,5 @@ def initRootLogger(script_path: str, log_level: int = logging.INFO, log_format:
handler2.setFormatter(formatter)
logger.addHandler(handler2)
msg = f"{script_name} log path: {log_path}"
msg = f"{logfile_basename} log path: {log_path}"
logger.info(msg)

View File

@ -13,11 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import dotenv
import typing
import subprocess
def get_ragflow_version() -> typing.Optional[str]:
def get_ragflow_version() -> str:
return RAGFLOW_VERSION_INFO
@ -42,7 +40,7 @@ def get_closest_tag_and_count():
return closest_tag
else:
return f"{commit_id}({closest_tag}~{commits_count})"
except Exception as e:
except Exception:
return 'unknown'

View File

@ -3,12 +3,11 @@
# from https://github.com/langchain-ai/langchain/blob/master/libs/text-splitters/langchain_text_splitters/json.py
import json
from typing import Any, Dict, List, Optional
from typing import Any
from rag.nlp import find_codec
class RAGFlowJsonParser:
def __init__(
self, max_chunk_size: int = 2000, min_chunk_size: Optional[int] = None
self, max_chunk_size: int = 2000, min_chunk_size: int | None = None
):
super().__init__()
self.max_chunk_size = max_chunk_size * 2
@ -27,12 +26,12 @@ class RAGFlowJsonParser:
return sections
@staticmethod
def _json_size(data: Dict) -> int:
def _json_size(data: dict) -> int:
"""Calculate the size of the serialized JSON object."""
return len(json.dumps(data, ensure_ascii=False))
@staticmethod
def _set_nested_dict(d: Dict, path: List[str], value: Any) -> None:
def _set_nested_dict(d: dict, path: list[str], value: Any) -> None:
"""Set a value in a nested dictionary based on the given path."""
for key in path[:-1]:
d = d.setdefault(key, {})
@ -54,10 +53,10 @@ class RAGFlowJsonParser:
def _json_split(
self,
data: Dict[str, Any],
current_path: Optional[List[str]] = None,
chunks: Optional[List[Dict]] = None,
) -> List[Dict]:
data: dict[str, Any],
current_path: list[str] | None,
chunks: list[dict] | None,
) -> list[dict]:
"""
Split json into maximum size dictionaries while preserving structure.
"""
@ -87,9 +86,9 @@ class RAGFlowJsonParser:
def split_json(
self,
json_data: Dict[str, Any],
json_data: dict[str, Any],
convert_lists: bool = False,
) -> List[Dict]:
) -> list[dict]:
"""Splits JSON into a list of JSON chunks"""
if convert_lists:
@ -104,10 +103,10 @@ class RAGFlowJsonParser:
def split_text(
self,
json_data: Dict[str, Any],
json_data: dict[str, Any],
convert_lists: bool = False,
ensure_ascii: bool = True,
) -> List[str]:
) -> list[str]:
"""Splits JSON into a list of JSON formatted strings"""
chunks = self.split_json(json_data=json_data, convert_lists=convert_lists)

View File

@ -1059,7 +1059,7 @@ Deletes chat assistants by ID.
#### ids: `list[str]`
The IDs of the chat assistants to delete. Defaults to `None`. If it is ot specified, all chat assistants in the system will be deleted.
The IDs of the chat assistants to delete. Defaults to `None`. If it is empty or not specified, all chat assistants in the system will be deleted.
### Returns

View File

@ -9,8 +9,8 @@ import logging
import json
import re
import traceback
from typing import Callable
from dataclasses import dataclass
from typing import List, Callable
import networkx as nx
import pandas as pd
from graphrag import leiden
@ -26,8 +26,8 @@ from timeit import default_timer as timer
class CommunityReportsResult:
"""Community reports result class definition."""
output: List[str]
structured_output: List[dict]
output: list[str]
structured_output: list[dict]
class CommunityReportsExtractor:
@ -53,7 +53,7 @@ class CommunityReportsExtractor:
self._max_report_length = max_report_length or 1500
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
communities: dict[str, dict[str, List]] = leiden.run(graph, {})
communities: dict[str, dict[str, list]] = leiden.run(graph, {})
total = sum([len(comm.items()) for _, comm in communities.items()])
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
res_str = []

View File

@ -6,7 +6,6 @@ Reference:
"""
from typing import Any
import numpy as np
import networkx as nx
from graphrag.leiden import stable_largest_connected_component

View File

@ -9,8 +9,8 @@ import logging
import numbers
import re
import traceback
from typing import Any, Callable
from dataclasses import dataclass
from typing import Any, Mapping, Callable
import tiktoken
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str

View File

@ -18,7 +18,6 @@ import os
from concurrent.futures import ThreadPoolExecutor
import json
from functools import reduce
from typing import List
import networkx as nx
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
@ -53,7 +52,7 @@ def graph_merge(g1, g2):
return g
def build_knowledge_graph_chunks(tenant_id: str, chunks: List[str], callback, entity_types=DEFAULT_ENTITY_TYPES):
def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, entity_types=DEFAULT_ENTITY_TYPES):
_, tenant = TenantService.get_by_id(tenant_id)
llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id)
ext = GraphExtractor(llm_bdl)

View File

@ -6,8 +6,8 @@ Reference:
"""
import logging
from typing import Any, cast, List
import html
from typing import Any
from graspologic.partition import hierarchical_leiden
from graspologic.utils import largest_connected_component
@ -132,7 +132,7 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
return results_by_level
def add_community_info2graph(graph: nx.Graph, nodes: List[str], community_title):
def add_community_info2graph(graph: nx.Graph, nodes: list[str], community_title):
for n in nodes:
if "communities" not in graph.nodes[n]:
graph.nodes[n]["communities"] = []

View File

@ -19,9 +19,9 @@ import collections
import os
import re
import traceback
from typing import Any
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any
from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements

View File

@ -15,7 +15,6 @@
#
import json
from copy import deepcopy
from typing import Dict
import pandas as pd
from rag.utils.doc_store_conn import OrderByExpr, FusionExpr
@ -25,7 +24,7 @@ from rag.nlp.search import Dealer
class KGSearch(Dealer):
def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
def merge_into_first(sres, title="") -> Dict[str, str]:
def merge_into_first(sres, title="") -> dict[str, str]:
if not sres:
return {}
content_with_weight = ""

View File

@ -7,8 +7,7 @@ Reference:
import html
import re
from collections.abc import Callable
from typing import Any
from typing import Any, Callable
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]

38
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
[[package]]
name = "accelerate"
@ -413,7 +413,7 @@ name = "aspose-slides"
version = "24.11.0"
description = "Aspose.Slides for Python via .NET is a presentation file formats processing library for working with Microsoft PowerPoint files without using Microsoft PowerPoint."
optional = false
python-versions = "<3.14,>=3.5"
python-versions = ">=3.5,<3.14"
files = [
{file = "Aspose.Slides-24.11.0-py3-none-macosx_10_14_x86_64.whl", hash = "sha256:b4819364497f9e075e00e63ee8fba8745dda4c910e199d5201e4abeebdcdec89"},
{file = "Aspose.Slides-24.11.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:bbeb5f0b14901f29f209beeac694a183f8d36c9475556ddeed3b2edb8107536a"},
@ -565,7 +565,7 @@ name = "bce-python-sdk"
version = "0.9.23"
description = "BCE SDK for python"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,<4,>=2.7"
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, <4"
files = [
{file = "bce_python_sdk-0.9.23-py3-none-any.whl", hash = "sha256:8debe21a040e00060f6044877d594765ed7b18bc765c6bf16b878bca864140a3"},
{file = "bce_python_sdk-0.9.23.tar.gz", hash = "sha256:19739fed5cd0725356fc5ffa2acbdd8fb23f2a81edb91db21a03174551d0cf41"},
@ -1706,7 +1706,7 @@ name = "deprecated"
version = "1.2.15"
description = "Python @deprecated decorator to deprecate old python classes, functions or methods."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7"
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
{file = "Deprecated-1.2.15-py2.py3-none-any.whl", hash = "sha256:353bc4a8ac4bfc96800ddab349d89c25dec1079f65fd53acdcc1e0b975b21320"},
{file = "deprecated-1.2.15.tar.gz", hash = "sha256:683e561a90de76239796e6b6feac66b99030d2dd3fcf61ef996330f14bbb9b0d"},
@ -2023,7 +2023,7 @@ name = "fastembed"
version = "0.3.6"
description = "Fast, light, accurate library built for retrieval embedding generation"
optional = false
python-versions = "<3.13,>=3.8.0"
python-versions = ">=3.8.0,<3.13"
files = [
{file = "fastembed-0.3.6-py3-none-any.whl", hash = "sha256:2bf70edae28bb4ccd9e01617098c2075b0ba35b88025a3d22b0e1e85b2c488ce"},
{file = "fastembed-0.3.6.tar.gz", hash = "sha256:c93c8ec99b8c008c2d192d6297866b8d70ec7ac8f5696b34eb5ea91f85efd15f"},
@ -2940,7 +2940,7 @@ name = "graspologic"
version = "3.4.1"
description = "A set of Python modules for graph statistics"
optional = false
python-versions = "<3.13,>=3.9"
python-versions = ">=3.9,<3.13"
files = [
{file = "graspologic-3.4.1-py3-none-any.whl", hash = "sha256:c6563e087eda599bad1de831d4b7321c0daa7a82f4e85a7d7737ff67e07cdda2"},
{file = "graspologic-3.4.1.tar.gz", hash = "sha256:7561f0b852a2bccd351bff77e8db07d9892f9dfa35a420fdec01690e4fdc8075"},
@ -3625,7 +3625,7 @@ name = "infinity-emb"
version = "0.0.66"
description = "Infinity is a high-throughput, low-latency REST API for serving text-embeddings, reranking models and clip."
optional = false
python-versions = "<4,>=3.9"
python-versions = ">=3.9,<4"
files = [
{file = "infinity_emb-0.0.66-py3-none-any.whl", hash = "sha256:1dc6ed9fa48e6cbe83650a7583dbbb4bc393900c39c326bb0aff2ddc090ac018"},
{file = "infinity_emb-0.0.66.tar.gz", hash = "sha256:9c9a361ccebf8e8f626c1f685286518d03d0c35e7d14179ae7c2500b4fc68b98"},
@ -4070,7 +4070,7 @@ name = "litellm"
version = "1.48.0"
description = "Library to easily interface with LLM API providers"
optional = false
python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8"
python-versions = ">=3.8, !=2.7.*, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*, !=3.7.*"
files = [
{file = "litellm-1.48.0-py3-none-any.whl", hash = "sha256:7765e8a92069778f5fc66aacfabd0e2f8ec8d74fb117f5e475567d89b0d376b9"},
{file = "litellm-1.48.0.tar.gz", hash = "sha256:31a9b8a25a9daf44c24ddc08bf74298da920f2c5cea44135e5061278d0aa6fc9"},
@ -6197,7 +6197,7 @@ name = "psutil"
version = "6.1.0"
description = "Cross-platform lib for process and system monitoring in Python."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
files = [
{file = "psutil-6.1.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:ff34df86226c0227c52f38b919213157588a678d049688eded74c76c8ba4a5d0"},
{file = "psutil-6.1.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:c0e0c00aa18ca2d3b2b991643b799a15fc8f0563d2ebb6040f64ce8dc027b942"},
@ -6219,8 +6219,8 @@ files = [
]
[package.extras]
dev = ["black", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest-cov", "requests", "rstcheck", "ruff", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "virtualenv", "wheel"]
test = ["pytest", "pytest-xdist", "setuptools"]
dev = ["black", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest-cov", "requests", "rstcheck", "ruff", "sphinx", "sphinx-rtd-theme", "toml-sort", "twine", "virtualenv", "wheel"]
test = ["enum34", "futures", "ipaddress", "mock (==1.0.1)", "pytest (==4.6.11)", "pytest-xdist", "setuptools", "unittest2"]
[[package]]
name = "psycopg2-binary"
@ -7690,40 +7690,30 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:a606ef75a60ecf3d924613892cc603b154178ee25abb3055db5062da811fd969"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd5415dded15c3822597455bc02bcd66e81ef8b7a48cb71a33628fc9fdde39df"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f66efbc1caa63c088dead1c4170d148eabc9b80d95fb75b6c92ac0aad2437d76"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22353049ba4181685023b25b5b51a574bce33e7f51c759371a7422dcae5402a6"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:932205970b9f9991b34f55136be327501903f7c66830e9760a8ffb15b07f05cd"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win32.whl", hash = "sha256:3eac5a91891ceb88138c113f9db04f3cebdae277f5d44eaa3651a4f573e6a5da"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win_amd64.whl", hash = "sha256:ab007f2f5a87bd08ab1499bdf96f3d5c6ad4dcfa364884cb4549aa0154b13a28"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4a6679521a58256a90b0d89e03992c15144c5f3858f40d7c18886023d7943db6"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d84318609196d6bd6da0edfa25cedfbabd8dbde5140a0a23af29ad4b8f91fb1e"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb43a269eb827806502c7c8efb7ae7e9e9d0573257a46e8e952f4d4caba4f31e"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:811ea1594b8a0fb466172c384267a4e5e367298af6b228931f273b111f17ef52"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cf12567a7b565cbf65d438dec6cfbe2917d3c1bdddfce84a9930b7d35ea59642"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7dd5adc8b930b12c8fc5b99e2d535a09889941aa0d0bd06f4749e9a9397c71d2"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win32.whl", hash = "sha256:bd0a08f0bab19093c54e18a14a10b4322e1eacc5217056f3c063bd2f59853ce4"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win_amd64.whl", hash = "sha256:a274fb2cb086c7a3dea4322ec27f4cb5cc4b6298adb583ab0e211a4682f241eb"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:20b0f8dc160ba83b6dcc0e256846e1a02d044e13f7ea74a3d1d56ede4e48c632"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:943f32bc9dedb3abff9879edc134901df92cfce2c3d5c9348f172f62eb2d771d"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c3829bb364fdb8e0332c9931ecf57d9be3519241323c5274bd82f709cebc0c"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:749c16fcc4a2b09f28843cda5a193e0283e47454b63ec4b81eaa2242f50e4ccd"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bf165fef1f223beae7333275156ab2022cffe255dcc51c27f066b4370da81e31"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:32621c177bbf782ca5a18ba4d7af0f1082a3f6e517ac2a18b3974d4edf349680"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win32.whl", hash = "sha256:e8c4ebfcfd57177b572e2040777b8abc537cdef58a2120e830124946aa9b42c5"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win_amd64.whl", hash = "sha256:0467c5965282c62203273b838ae77c0d29d7638c8a4e3a1c8bdd3602c10904e4"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:4c8c5d82f50bb53986a5e02d1b3092b03622c02c2eb78e29bec33fd9593bae1a"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:e7e3736715fbf53e9be2a79eb4db68e4ed857017344d697e8b9749444ae57475"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b7e75b4965e1d4690e93021adfcecccbca7d61c7bddd8e22406ef2ff20d74ef"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96777d473c05ee3e5e3c3e999f5d23c6f4ec5b0c38c098b3a5229085f74236c6"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:3bc2a80e6420ca8b7d3590791e2dfc709c88ab9152c00eeb511c9875ce5778bf"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e188d2699864c11c36cdfdada94d781fd5d6b0071cd9c427bceb08ad3d7c70e1"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win32.whl", hash = "sha256:6442cb36270b3afb1b4951f060eccca1ce49f3d087ca1ca4563a6eb479cb3de6"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win_amd64.whl", hash = "sha256:e5b8daf27af0b90da7bb903a876477a9e6d7270be6146906b276605997c7e9a3"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fc4b630cd3fa2cf7fce38afa91d7cfe844a9f75d7f0f36393fa98815e911d987"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:bc5f1e1c28e966d61d2519f2a3d451ba989f9ea0f2307de7bc45baa526de9e45"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a0e060aace4c24dcaf71023bbd7d42674e3b230f7e7b97317baf1e953e5b519"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2f1c3765db32be59d18ab3953f43ab62a761327aafc1594a2a1fbe038b8b8a7"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d85252669dc32f98ebcd5d36768f5d4faeaeaa2d655ac0473be490ecdae3c285"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e143ada795c341b56de9418c58d028989093ee611aa27ffb9b7f609c00d813ed"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win32.whl", hash = "sha256:beffaed67936fbbeffd10966a4eb53c402fafd3d6833770516bf7314bc6ffa12"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win_amd64.whl", hash = "sha256:040ae85536960525ea62868b642bdb0c2cc6021c9f9d507810c0c604e66f5a7b"},
{file = "ruamel.yaml.clib-0.2.12.tar.gz", hash = "sha256:6c8fbb13ec503f99a91901ab46e0b07ae7941cd527393187039aec586fdfd36f"},
@ -7734,7 +7724,7 @@ name = "s3transfer"
version = "0.10.3"
description = "An Amazon S3 Transfer Manager"
optional = false
python-versions = ">=3.8"
python-versions = ">= 3.8"
files = [
{file = "s3transfer-0.10.3-py3-none-any.whl", hash = "sha256:263ed587a5803c6c708d3ce44dc4dfedaab4c1a32e8329bab818933d79ddcf5d"},
{file = "s3transfer-0.10.3.tar.gz", hash = "sha256:4f50ed74ab84d474ce614475e0b8d5047ff080810aac5d01ea25231cfc944b0c"},
@ -8196,7 +8186,7 @@ name = "smart-open"
version = "7.0.5"
description = "Utils for streaming large files (S3, HDFS, GCS, Azure Blob Storage, gzip, bz2...)"
optional = false
python-versions = "<4.0,>=3.7"
python-versions = ">=3.7,<4.0"
files = [
{file = "smart_open-7.0.5-py3-none-any.whl", hash = "sha256:8523ed805c12dff3eaa50e9c903a6cb0ae78800626631c5fe7ea073439847b89"},
{file = "smart_open-7.0.5.tar.gz", hash = "sha256:d3672003b1dbc85e2013e4983b88eb9a5ccfd389b0d4e5015f39a9ee5620ec18"},
@ -9967,4 +9957,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.11,<3.13"
content-hash = "045015d496b77682204fcdb5b2576ebf8f7efcf8340e1e9839de246d9d39c08a"
content-hash = "dcf6c6a1d7fc52f982ef717bc48b13837e693b53a4fc5d9c06cefd253227259f"

View File

@ -17,6 +17,7 @@ azure-storage-file-datalake = "12.16.0"
anthropic = "=0.34.1"
arxiv = "2.1.3"
aspose-slides = { version = "^24.9.0", markers = "platform_machine == 'x86_64'" }
beartype = "^0.18.5"
bio = "1.7.1"
boto3 = "1.34.140"
botocore = "1.34.140"

View File

@ -15,7 +15,6 @@
#
import logging
import re
from typing import Optional
import threading
import requests
from huggingface_hub import snapshot_download
@ -242,10 +241,10 @@ class FastEmbed(Base):
def __init__(
self,
key: Optional[str] = None,
key: str | None = None,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
cache_dir: str | None = None,
threads: int | None = None,
**kwargs,
):
if not settings.LIGHTEN and not FastEmbed._model:

View File

@ -17,7 +17,6 @@
import logging
import re
import json
from typing import List, Optional, Dict, Union
from dataclasses import dataclass
from rag.utils import rmSpace
@ -37,13 +36,13 @@ class Dealer:
@dataclass
class SearchResult:
total: int
ids: List[str]
query_vector: List[float] = None
field: Optional[Dict] = None
highlight: Optional[Dict] = None
aggregation: Union[List, Dict, None] = None
keywords: Optional[List[str]] = None
group_docs: List[List] = None
ids: list[str]
query_vector: list[float] | None = None
field: dict | None = None
highlight: dict | None = None
aggregation: list | dict | None = None
keywords: list[str] | None = None
group_docs: list[list] | None = None
def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
qv, _ = emb_mdl.encode_queries(txt)

View File

@ -17,7 +17,6 @@ import logging
import re
from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait
from threading import Lock
from typing import Tuple
import umap
import numpy as np
from sklearn.mixture import GaussianMixture
@ -45,7 +44,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
optimal_clusters = n_clusters[np.argmin(bics)]
return optimal_clusters
def __call__(self, chunks: Tuple[str, np.ndarray], random_state, callback=None):
def __call__(self, chunks: tuple[str, np.ndarray], random_state, callback=None):
layers = [(0, len(chunks))]
start, end = 0, len(chunks)
if len(chunks) <= 1: return

View File

@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from beartype.claw import beartype_packages
beartype_packages(["agent", "api", "deepdoc", "plugins", "rag", "ragflow_sdk"]) # <-- raise exceptions in your code
import logging
import sys
from api.utils.log_utils import initRootLogger

View File

@ -1,19 +1,17 @@
from abc import ABC, abstractmethod
from typing import Optional, Union
from dataclasses import dataclass
import numpy as np
import polars as pl
from typing import List, Dict
DEFAULT_MATCH_VECTOR_TOPN = 10
DEFAULT_MATCH_SPARSE_TOPN = 10
VEC = Union[list, np.ndarray]
VEC = list | np.ndarray
@dataclass
class SparseVector:
indices: list[int]
values: Union[list[float], list[int], None] = None
values: list[float] | list[int] | None = None
def __post_init__(self):
assert (self.values is None) or (len(self.indices) == len(self.values))
@ -82,7 +80,7 @@ class MatchSparseExpr(ABC):
sparse_data: SparseVector | dict,
distance_type: str,
topn: int,
opt_params: Optional[dict] = None,
opt_params: dict | None = None,
):
self.vector_column_name = vector_column_name
self.sparse_data = sparse_data
@ -98,7 +96,7 @@ class MatchTensorExpr(ABC):
query_data: VEC,
query_data_type: str,
topn: int,
extra_option: Optional[dict] = None,
extra_option: dict | None = None,
):
self.column_name = column_name
self.query_data = query_data
@ -108,16 +106,13 @@ class MatchTensorExpr(ABC):
class FusionExpr(ABC):
def __init__(self, method: str, topn: int, fusion_params: Optional[dict] = None):
def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
self.method = method
self.topn = topn
self.fusion_params = fusion_params
MatchExpr = Union[
MatchTextExpr, MatchDenseExpr, MatchSparseExpr, MatchTensorExpr, FusionExpr
]
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
class OrderByExpr(ABC):
def __init__(self):
@ -229,11 +224,11 @@ class DocStoreConnection(ABC):
raise NotImplementedError("Not implemented")
@abstractmethod
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
@abstractmethod
def getHighlight(self, res, keywords: List[str], fieldnm: str):
def getHighlight(self, res, keywords: list[str], fieldnm: str):
raise NotImplementedError("Not implemented")
@abstractmethod

View File

@ -3,7 +3,6 @@ import re
import json
import time
import os
from typing import List, Dict
import copy
from elasticsearch import Elasticsearch
@ -363,7 +362,7 @@ class ESConnection(DocStoreConnection):
rr.append(d["_source"])
return rr
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
@ -382,7 +381,7 @@ class ESConnection(DocStoreConnection):
res_fields[d["id"]] = m
return res_fields
def getHighlight(self, res, keywords: List[str], fieldnm: str):
def getHighlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")

View File

@ -3,7 +3,6 @@ import os
import re
import json
import time
from typing import List, Dict
import infinity
from infinity.common import ConflictType, InfinityException
from infinity.index import IndexInfo, IndexType
@ -384,7 +383,7 @@ class InfinityConnection(DocStoreConnection):
def getChunkIds(self, res):
return list(res["id"])
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
def getFields(self, res, fields: list[str]) -> list[str, dict]:
res_fields = {}
if not fields:
return {}
@ -412,7 +411,7 @@ class InfinityConnection(DocStoreConnection):
res_fields[id] = m
return res_fields
def getHighlight(self, res, keywords: List[str], fieldnm: str):
def getHighlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
num_rows = len(res)
column_id = res["id"]

32
sdk/python/poetry.lock generated
View File

@ -1,5 +1,23 @@
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
[[package]]
name = "beartype"
version = "0.18.5"
description = "Unbearably fast runtime type checking in pure Python."
optional = false
python-versions = ">=3.8.0"
files = [
{file = "beartype-0.18.5-py3-none-any.whl", hash = "sha256:5301a14f2a9a5540fe47ec6d34d758e9cd8331d36c4760fc7a5499ab86310089"},
{file = "beartype-0.18.5.tar.gz", hash = "sha256:264ddc2f1da9ec94ff639141fbe33d22e12a9f75aa863b83b7046ffff1381927"},
]
[package.extras]
all = ["typing-extensions (>=3.10.0.0)"]
dev = ["autoapi (>=0.9.0)", "coverage (>=5.5)", "equinox", "mypy (>=0.800)", "numpy", "pandera", "pydata-sphinx-theme (<=0.7.2)", "pytest (>=4.0.0)", "sphinx", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)", "tox (>=3.20.1)", "typing-extensions (>=3.10.0.0)"]
doc-rtd = ["autoapi (>=0.9.0)", "pydata-sphinx-theme (<=0.7.2)", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)"]
test-tox = ["equinox", "mypy (>=0.800)", "numpy", "pandera", "pytest (>=4.0.0)", "sphinx", "typing-extensions (>=3.10.0.0)"]
test-tox-coverage = ["coverage (>=5.5)"]
[[package]]
name = "certifi"
version = "2024.8.30"
@ -177,13 +195,13 @@ files = [
[[package]]
name = "packaging"
version = "24.1"
version = "24.2"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"},
{file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"},
{file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"},
{file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"},
]
[[package]]
@ -246,13 +264,13 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "tomli"
version = "2.0.2"
version = "2.1.0"
description = "A lil' TOML parser"
optional = false
python-versions = ">=3.8"
files = [
{file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"},
{file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"},
{file = "tomli-2.1.0-py3-none-any.whl", hash = "sha256:a5c57c3d1c56f5ccdf89f6523458f60ef716e210fc47c4cfb188c5ba473e0391"},
{file = "tomli-2.1.0.tar.gz", hash = "sha256:3f646cae2aec94e17d04973e4249548320197cfabdf130015d023de4b74d8ab8"},
]
[[package]]
@ -275,4 +293,4 @@ zstd = ["zstandard (>=0.18.0)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "202bfd3e121f1d57a2f9c9d91cd7a50eacf2362cd1995c9f6347bcb100cf9336"
content-hash = "19565d31d822b0573f505662c664d735194134a505f43bbd1657c033f87bb82d"

View File

@ -10,6 +10,7 @@ package-mode = true
[tool.poetry.dependencies]
python = "^3.10"
requests = "^2.30.0"
beartype = "^0.18.5"
pytest = "^8.0.0"

View File

@ -1,3 +1,6 @@
from beartype.claw import beartype_this_package
beartype_this_package() # <-- raise exceptions in your code
import importlib.metadata
__version__ = importlib.metadata.version("ragflow_sdk")

View File

@ -1,4 +1,3 @@
from typing import List
from .base import Base
from .session import Session
@ -58,7 +57,7 @@ class Chat(Base):
raise Exception(res["message"])
def list_sessions(self,page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True,
id: str = None, name: str = None) -> List[Session]:
id: str = None, name: str = None) -> list[Session]:
res = self.get(f'/chats/{self.id}/sessions',{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name} )
res = res.json()
if res.get("code") == 0:
@ -68,7 +67,7 @@ class Chat(Base):
return result_list
raise Exception(res["message"])
def delete_sessions(self,ids:List[str]=None):
def delete_sessions(self,ids: list[str] | None = None):
res = self.rm(f"/chats/{self.id}/sessions", {"ids": ids})
res = res.json()
if res.get("code") != 0:

View File

@ -1,5 +1,3 @@
from typing import List
from .document import Document
from .base import Base
@ -35,7 +33,7 @@ class DataSet(Base):
if res.get("code") != 0:
raise Exception(res["message"])
def upload_documents(self,document_list: List[dict]):
def upload_documents(self,document_list: list[dict]):
url = f"/datasets/{self.id}/documents"
files = [("file",(ele["displayed_name"],ele["blob"])) for ele in document_list]
res = self.post(path=url,json=None,files=files)
@ -48,7 +46,7 @@ class DataSet(Base):
return doc_list
raise Exception(res.get("message"))
def list_documents(self, id: str = None, keywords: str = None, page: int =1, page_size: int = 30, orderby: str = "create_time", desc: bool = True):
def list_documents(self, id: str | None = None, keywords: str | None = None, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True):
res = self.get(f"/datasets/{self.id}/documents",params={"id": id,"keywords": keywords,"page": page,"page_size": page_size,"orderby": orderby,"desc": desc})
res = res.json()
documents = []
@ -58,7 +56,7 @@ class DataSet(Base):
return documents
raise Exception(res["message"])
def delete_documents(self,ids: List[str] = None):
def delete_documents(self,ids: list[str] | None = None):
res = self.rm(f"/datasets/{self.id}/documents",{"ids":ids})
res = res.json()
if res.get("code") != 0:

View File

@ -1,7 +1,6 @@
import json
from .base import Base
from .chunk import Chunk
from typing import List
class Document(Base):
@ -63,14 +62,14 @@ class Document(Base):
raise Exception(res.get("message"))
def add_chunk(self, content: str,important_keywords:List[str]=[]):
def add_chunk(self, content: str,important_keywords: list[str] = []):
res = self.post(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', {"content":content,"important_keywords":important_keywords})
res = res.json()
if res.get("code") == 0:
return Chunk(self.rag,res["data"].get("chunk"))
raise Exception(res.get("message"))
def delete_chunks(self,ids:List[str] = None):
def delete_chunks(self,ids:list[str] | None = None):
res = self.rm(f"/datasets/{self.dataset_id}/documents/{self.id}/chunks",{"chunk_ids":ids})
res = res.json()
if res.get("code")!=0:

View File

@ -13,14 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import requests
from .modules.chat import Chat
from .modules.chunk import Chunk
from .modules.dataset import DataSet
from .modules.document import Document
class RAGFlow:
@ -64,7 +61,7 @@ class RAGFlow:
return DataSet(self, res["data"])
raise Exception(res["message"])
def delete_datasets(self, ids: List[str] = None):
def delete_datasets(self, ids: list[str] | None = None):
res = self.delete("/datasets",{"ids": ids})
res=res.json()
if res.get("code") != 0:
@ -77,8 +74,8 @@ class RAGFlow:
raise Exception("Dataset %s not found" % name)
def list_datasets(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True,
id: str = None, name: str = None) -> \
List[DataSet]:
id: str | None = None, name: str | None = None) -> \
list[DataSet]:
res = self.get("/datasets",
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
res = res.json()
@ -89,8 +86,8 @@ class RAGFlow:
return result_list
raise Exception(res["message"])
def create_chat(self, name: str, avatar: str = "", dataset_ids: List[str] = [],
llm: Chat.LLM = None, prompt: Chat.Prompt = None) -> Chat:
def create_chat(self, name: str, avatar: str = "", dataset_ids: list[str] = [],
llm: Chat.LLM | None = None, prompt: Chat.Prompt | None = None) -> Chat:
dataset_list = []
for id in dataset_ids:
dataset_list.append(id)
@ -135,7 +132,7 @@ class RAGFlow:
return Chat(self, res["data"])
raise Exception(res["message"])
def delete_chats(self,ids: List[str] = None) -> bool:
def delete_chats(self,ids: list[str] | None = None):
res = self.delete('/chats',
{"ids":ids})
res = res.json()
@ -143,7 +140,7 @@ class RAGFlow:
raise Exception(res["message"])
def list_chats(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True,
id: str = None, name: str = None) -> List[Chat]:
id: str | None = None, name: str | None = None) -> list[Chat]:
res = self.get("/chats",{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
res = res.json()
result_list = []
@ -154,7 +151,7 @@ class RAGFlow:
raise Exception(res["message"])
def retrieve(self, dataset_ids, document_ids=None, question="", page=1, page_size=30, similarity_threshold=0.2, vector_similarity_weight=0.3, top_k=1024, rerank_id:str=None, keyword:bool=False, ):
def retrieve(self, dataset_ids, document_ids=None, question="", page=1, page_size=30, similarity_threshold=0.2, vector_similarity_weight=0.3, top_k=1024, rerank_id: str | None = None, keyword:bool=False, ):
if document_ids is None:
document_ids = []
data_json ={
@ -170,7 +167,7 @@ class RAGFlow:
"documents": document_ids
}
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
res = self.post(f'/retrieval',json=data_json)
res = self.post('/retrieval',json=data_json)
res = res.json()
if res.get("code") ==0:
chunks=[]