diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 04eb28c5d..b101edb17 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -21,6 +21,7 @@ from pathlib import Path from flask import Blueprint, Flask from werkzeug.wrappers.request import Request from flask_cors import CORS +from flasgger import Swagger from api.db import StatusEnum from api.db.db_models import close_connection @@ -34,27 +35,62 @@ from api.settings import API_VERSION, access_logger from api.utils.api_utils import server_error_response from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer -__all__ = ['app'] +__all__ = ["app"] -logger = logging.getLogger('flask.app') +logger = logging.getLogger("flask.app") for h in access_logger.handlers: logger.addHandler(h) Request.json = property(lambda self: self.get_json(force=True, silent=True)) app = Flask(__name__) -CORS(app, supports_credentials=True,max_age=2592000) + +# Add this at the beginning of your file to configure Swagger UI +swagger_config = { + "headers": [], + "specs": [ + { + "endpoint": "apispec", + "route": "/apispec.json", + "rule_filter": lambda rule: True, # Include all endpoints + "model_filter": lambda tag: True, # Include all models + } + ], + "static_url_path": "/flasgger_static", + "swagger_ui": True, + "specs_route": "/apidocs/", +} + +swagger = Swagger( + app, + config=swagger_config, + template={ + "swagger": "2.0", + "info": { + "title": "RAGFlow API", + "description": "", + "version": "1.0.0", + }, + "securityDefinitions": { + "ApiKeyAuth": {"type": "apiKey", "name": "Authorization", "in": "header"} + }, + }, +) + +CORS(app, supports_credentials=True, max_age=2592000) app.url_map.strict_slashes = False app.json_encoder = CustomJSONEncoder app.errorhandler(Exception)(server_error_response) ## convince for dev and debug -#app.config["LOGIN_DISABLED"] = True +# app.config["LOGIN_DISABLED"] = True app.config["SESSION_PERMANENT"] = False app.config["SESSION_TYPE"] = "filesystem" -app.config['MAX_CONTENT_LENGTH'] = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)) +app.config["MAX_CONTENT_LENGTH"] = int( + os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024) +) Session(app) login_manager = LoginManager() @@ -64,17 +100,23 @@ commands.register_commands(app) def search_pages_path(pages_dir): - app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')] - api_path_list = [path for path in pages_dir.glob('*sdk/*.py') if not path.name.startswith('.')] + app_path_list = [ + path for path in pages_dir.glob("*_app.py") if not path.name.startswith(".") + ] + api_path_list = [ + path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".") + ] app_path_list.extend(api_path_list) return app_path_list def register_page(page_path): - path = f'{page_path}' + path = f"{page_path}" - page_name = page_path.stem.rstrip('_app') - module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name,)) + page_name = page_path.stem.rstrip("_app") + module_name = ".".join( + page_path.parts[page_path.parts.index("api") : -1] + (page_name,) + ) spec = spec_from_file_location(module_name, page_path) page = module_from_spec(spec) @@ -82,8 +124,10 @@ def register_page(page_path): page.manager = Blueprint(page_name, module_name) sys.modules[module_name] = page spec.loader.exec_module(page) - page_name = getattr(page, 'page_name', page_name) - url_prefix = f'/api/{API_VERSION}' if "/sdk/" in path else f'/{API_VERSION}/{page_name}' + page_name = getattr(page, "page_name", page_name) + url_prefix = ( + f"/api/{API_VERSION}" if "/sdk/" in path else f"/{API_VERSION}/{page_name}" + ) app.register_blueprint(page.manager, url_prefix=url_prefix) return url_prefix @@ -91,14 +135,12 @@ def register_page(page_path): pages_dir = [ Path(__file__).parent, - Path(__file__).parent.parent / 'api' / 'apps', - Path(__file__).parent.parent / 'api' / 'apps' / 'sdk', + Path(__file__).parent.parent / "api" / "apps", + Path(__file__).parent.parent / "api" / "apps" / "sdk", ] client_urls_prefix = [ - register_page(path) - for dir in pages_dir - for path in search_pages_path(dir) + register_page(path) for dir in pages_dir for path in search_pages_path(dir) ] @@ -109,7 +151,9 @@ def load_user(web_request): if authorization: try: access_token = str(jwt.loads(authorization)) - user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value) + user = UserService.query( + access_token=access_token, status=StatusEnum.VALID.value + ) if user: return user[0] else: @@ -123,4 +167,4 @@ def load_user(web_request): @app.teardown_request def _db_close(exc): - close_connection() \ No newline at end of file + close_connection() diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index 2d0520ddf..c00343f38 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -21,16 +21,72 @@ from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import TenantLLMService,LLMService +from api.db.services.llm_service import TenantLLMService, LLMService from api.db.services.user_service import TenantService from api.settings import RetCode from api.utils import get_uuid -from api.utils.api_utils import get_result, token_required, get_error_data_result, valid,get_parser_config +from api.utils.api_utils import ( + get_result, + token_required, + get_error_data_result, + valid, + get_parser_config, +) -@manager.route('/datasets', methods=['POST']) +@manager.route("/datasets", methods=["POST"]) @token_required def create(tenant_id): + """ + Create a new dataset. + --- + tags: + - Datasets + security: + - ApiKeyAuth: [] + parameters: + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Dataset creation parameters. + required: true + schema: + type: object + required: + - name + properties: + name: + type: string + description: Name of the dataset. + permission: + type: string + enum: ['me', 'team'] + description: Dataset permission. + language: + type: string + enum: ['Chinese', 'English'] + description: Language of the dataset. + chunk_method: + type: string + enum: ["naive", "manual", "qa", "table", "paper", "book", "laws", + "presentation", "picture", "one", "knowledge_graph", "email"] + description: Chunking method. + parser_config: + type: object + description: Parser configuration. + responses: + 200: + description: Successful operation. + schema: + type: object + properties: + data: + type: object + """ req = request.json e, t = TenantService.get_by_id(tenant_id) permission = req.get("permission") @@ -38,49 +94,97 @@ def create(tenant_id): chunk_method = req.get("chunk_method") parser_config = req.get("parser_config") valid_permission = ["me", "team"] - valid_language =["Chinese", "English"] - valid_chunk_method = ["naive","manual","qa","table","paper","book","laws","presentation","picture","one","knowledge_graph","email"] - check_validation=valid(permission,valid_permission,language,valid_language,chunk_method,valid_chunk_method) + valid_language = ["Chinese", "English"] + valid_chunk_method = [ + "naive", + "manual", + "qa", + "table", + "paper", + "book", + "laws", + "presentation", + "picture", + "one", + "knowledge_graph", + "email", + ] + check_validation = valid( + permission, + valid_permission, + language, + valid_language, + chunk_method, + valid_chunk_method, + ) if check_validation: return check_validation - req["parser_config"]=get_parser_config(chunk_method,parser_config) + req["parser_config"] = get_parser_config(chunk_method, parser_config) if "tenant_id" in req: - return get_error_data_result( - retmsg="`tenant_id` must not be provided") + return get_error_data_result(retmsg="`tenant_id` must not be provided") if "chunk_count" in req or "document_count" in req: - return get_error_data_result(retmsg="`chunk_count` or `document_count` must not be provided") - if "name" not in req: return get_error_data_result( - retmsg="`name` is not empty!") - req['id'] = get_uuid() + retmsg="`chunk_count` or `document_count` must not be provided" + ) + if "name" not in req: + return get_error_data_result(retmsg="`name` is not empty!") + req["id"] = get_uuid() req["name"] = req["name"].strip() if req["name"] == "": + return get_error_data_result(retmsg="`name` is not empty string!") + if KnowledgebaseService.query( + name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value + ): return get_error_data_result( - retmsg="`name` is not empty string!") - if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): - return get_error_data_result( - retmsg="Duplicated dataset name in creating dataset.") - req["tenant_id"] = req['created_by'] = tenant_id + retmsg="Duplicated dataset name in creating dataset." + ) + req["tenant_id"] = req["created_by"] = tenant_id if not req.get("embedding_model"): - req['embedding_model'] = t.embd_id + req["embedding_model"] = t.embd_id else: - valid_embedding_models=["BAAI/bge-large-zh-v1.5","BAAI/bge-base-en-v1.5","BAAI/bge-large-en-v1.5","BAAI/bge-small-en-v1.5", - "BAAI/bge-small-zh-v1.5","jinaai/jina-embeddings-v2-base-en","jinaai/jina-embeddings-v2-small-en", - "nomic-ai/nomic-embed-text-v1.5","sentence-transformers/all-MiniLM-L6-v2","text-embedding-v2", - "text-embedding-v3","maidalun1020/bce-embedding-base_v1"] - embd_model=LLMService.query(llm_name=req["embedding_model"],model_type="embedding") + valid_embedding_models = [ + "BAAI/bge-large-zh-v1.5", + "BAAI/bge-base-en-v1.5", + "BAAI/bge-large-en-v1.5", + "BAAI/bge-small-en-v1.5", + "BAAI/bge-small-zh-v1.5", + "jinaai/jina-embeddings-v2-base-en", + "jinaai/jina-embeddings-v2-small-en", + "nomic-ai/nomic-embed-text-v1.5", + "sentence-transformers/all-MiniLM-L6-v2", + "text-embedding-v2", + "text-embedding-v3", + "maidalun1020/bce-embedding-base_v1", + ] + embd_model = LLMService.query( + llm_name=req["embedding_model"], model_type="embedding" + ) if not embd_model: - return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") + return get_error_data_result( + f"`embedding_model` {req.get('embedding_model')} doesn't exist" + ) if embd_model: - if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model")): - return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") + if req[ + "embedding_model" + ] not in valid_embedding_models and not TenantLLMService.query( + tenant_id=tenant_id, + model_type="embedding", + llm_name=req.get("embedding_model"), + ): + return get_error_data_result( + f"`embedding_model` {req.get('embedding_model')} doesn't exist" + ) key_mapping = { "chunk_num": "chunk_count", "doc_num": "document_count", "parser_id": "chunk_method", - "embd_id": "embedding_model" + "embd_id": "embedding_model", + } + mapped_keys = { + new_key: req[old_key] + for new_key, old_key in key_mapping.items() + if old_key in req } - mapped_keys = {new_key: req[old_key] for new_key, old_key in key_mapping.items() if old_key in req} req.update(mapped_keys) if not KnowledgebaseService.save(**req): return get_error_data_result(retmsg="Create dataset error.(Database error)") @@ -91,21 +195,53 @@ def create(tenant_id): renamed_data[new_key] = value return get_result(data=renamed_data) -@manager.route('/datasets', methods=['DELETE']) + +@manager.route("/datasets", methods=["DELETE"]) @token_required def delete(tenant_id): + """ + Delete datasets. + --- + tags: + - Datasets + security: + - ApiKeyAuth: [] + parameters: + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Dataset deletion parameters. + required: true + schema: + type: object + properties: + ids: + type: array + items: + type: string + description: List of dataset IDs to delete. + responses: + 200: + description: Successful operation. + schema: + type: object + """ req = request.json if not req: - ids=None + ids = None else: - ids=req.get("ids") + ids = req.get("ids") if not ids: id_list = [] - kbs=KnowledgebaseService.query(tenant_id=tenant_id) + kbs = KnowledgebaseService.query(tenant_id=tenant_id) for kb in kbs: id_list.append(kb.id) else: - id_list=ids + id_list = ids for id in id_list: kbs = KnowledgebaseService.query(id=id, tenant_id=tenant_id) if not kbs: @@ -113,19 +249,75 @@ def delete(tenant_id): for doc in DocumentService.query(kb_id=id): if not DocumentService.remove_document(doc, tenant_id): return get_error_data_result( - retmsg="Remove document error.(Database error)") + retmsg="Remove document error.(Database error)" + ) f2d = File2DocumentService.get_by_document_id(doc.id) - FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) + FileService.filter_delete( + [ + File.source_type == FileSource.KNOWLEDGEBASE, + File.id == f2d[0].file_id, + ] + ) File2DocumentService.delete_by_document_id(doc.id) if not KnowledgebaseService.delete_by_id(id): - return get_error_data_result( - retmsg="Delete dataset error.(Database error)") + return get_error_data_result(retmsg="Delete dataset error.(Database error)") return get_result(retcode=RetCode.SUCCESS) -@manager.route('/datasets/', methods=['PUT']) + +@manager.route("/datasets/", methods=["PUT"]) @token_required -def update(tenant_id,dataset_id): - if not KnowledgebaseService.query(id=dataset_id,tenant_id=tenant_id): +def update(tenant_id, dataset_id): + """ + Update a dataset. + --- + tags: + - Datasets + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset to update. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Dataset update parameters. + required: true + schema: + type: object + properties: + name: + type: string + description: New name of the dataset. + permission: + type: string + enum: ['me', 'team'] + description: Updated permission. + language: + type: string + enum: ['Chinese', 'English'] + description: Updated language. + chunk_method: + type: string + enum: ["naive", "manual", "qa", "table", "paper", "book", "laws", + "presentation", "picture", "one", "knowledge_graph", "email"] + description: Updated chunking method. + parser_config: + type: object + description: Updated parser configuration. + responses: + 200: + description: Successful operation. + schema: + type: object + """ + if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(retmsg="You don't own the dataset") req = request.json e, t = TenantService.get_by_id(tenant_id) @@ -138,91 +330,202 @@ def update(tenant_id,dataset_id): parser_config = req.get("parser_config") valid_permission = ["me", "team"] valid_language = ["Chinese", "English"] - valid_chunk_method = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", - "knowledge_graph", "email"] - check_validation = valid(permission, valid_permission, language, valid_language, chunk_method, valid_chunk_method) + valid_chunk_method = [ + "naive", + "manual", + "qa", + "table", + "paper", + "book", + "laws", + "presentation", + "picture", + "one", + "knowledge_graph", + "email", + ] + check_validation = valid( + permission, + valid_permission, + language, + valid_language, + chunk_method, + valid_chunk_method, + ) if check_validation: return check_validation if "tenant_id" in req: if req["tenant_id"] != tenant_id: - return get_error_data_result( - retmsg="Can't change `tenant_id`.") + return get_error_data_result(retmsg="Can't change `tenant_id`.") e, kb = KnowledgebaseService.get_by_id(dataset_id) if "parser_config" in req: - temp_dict=kb.parser_config + temp_dict = kb.parser_config temp_dict.update(req["parser_config"]) req["parser_config"] = temp_dict if "chunk_count" in req: if req["chunk_count"] != kb.chunk_num: - return get_error_data_result( - retmsg="Can't change `chunk_count`.") + return get_error_data_result(retmsg="Can't change `chunk_count`.") req.pop("chunk_count") if "document_count" in req: - if req['document_count'] != kb.doc_num: - return get_error_data_result( - retmsg="Can't change `document_count`.") + if req["document_count"] != kb.doc_num: + return get_error_data_result(retmsg="Can't change `document_count`.") req.pop("document_count") if "chunk_method" in req: - if kb.chunk_num != 0 and req['chunk_method'] != kb.parser_id: + if kb.chunk_num != 0 and req["chunk_method"] != kb.parser_id: return get_error_data_result( - retmsg="If `chunk_count` is not 0, `chunk_method` is not changeable.") - req['parser_id'] = req.pop('chunk_method') - if req['parser_id'] != kb.parser_id: + retmsg="If `chunk_count` is not 0, `chunk_method` is not changeable." + ) + req["parser_id"] = req.pop("chunk_method") + if req["parser_id"] != kb.parser_id: if not req.get("parser_config"): req["parser_config"] = get_parser_config(chunk_method, parser_config) if "embedding_model" in req: - if kb.chunk_num != 0 and req['embedding_model'] != kb.embd_id: + if kb.chunk_num != 0 and req["embedding_model"] != kb.embd_id: return get_error_data_result( - retmsg="If `chunk_count` is not 0, `embedding_model` is not changeable.") + retmsg="If `chunk_count` is not 0, `embedding_model` is not changeable." + ) if not req.get("embedding_model"): return get_error_data_result("`embedding_model` can't be empty") - valid_embedding_models=["BAAI/bge-large-zh-v1.5","BAAI/bge-base-en-v1.5","BAAI/bge-large-en-v1.5","BAAI/bge-small-en-v1.5", - "BAAI/bge-small-zh-v1.5","jinaai/jina-embeddings-v2-base-en","jinaai/jina-embeddings-v2-small-en", - "nomic-ai/nomic-embed-text-v1.5","sentence-transformers/all-MiniLM-L6-v2","text-embedding-v2", - "text-embedding-v3","maidalun1020/bce-embedding-base_v1"] - embd_model=LLMService.query(llm_name=req["embedding_model"],model_type="embedding") + valid_embedding_models = [ + "BAAI/bge-large-zh-v1.5", + "BAAI/bge-base-en-v1.5", + "BAAI/bge-large-en-v1.5", + "BAAI/bge-small-en-v1.5", + "BAAI/bge-small-zh-v1.5", + "jinaai/jina-embeddings-v2-base-en", + "jinaai/jina-embeddings-v2-small-en", + "nomic-ai/nomic-embed-text-v1.5", + "sentence-transformers/all-MiniLM-L6-v2", + "text-embedding-v2", + "text-embedding-v3", + "maidalun1020/bce-embedding-base_v1", + ] + embd_model = LLMService.query( + llm_name=req["embedding_model"], model_type="embedding" + ) if not embd_model: - return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") + return get_error_data_result( + f"`embedding_model` {req.get('embedding_model')} doesn't exist" + ) if embd_model: - if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model")): - return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") - req['embd_id'] = req.pop('embedding_model') + if req[ + "embedding_model" + ] not in valid_embedding_models and not TenantLLMService.query( + tenant_id=tenant_id, + model_type="embedding", + llm_name=req.get("embedding_model"), + ): + return get_error_data_result( + f"`embedding_model` {req.get('embedding_model')} doesn't exist" + ) + req["embd_id"] = req.pop("embedding_model") if "name" in req: req["name"] = req["name"].strip() - if req["name"].lower() != kb.name.lower() \ - and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, - status=StatusEnum.VALID.value)) > 0: + if ( + req["name"].lower() != kb.name.lower() + and len( + KnowledgebaseService.query( + name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value + ) + ) + > 0 + ): return get_error_data_result( - retmsg="Duplicated dataset name in updating dataset.") + retmsg="Duplicated dataset name in updating dataset." + ) if not KnowledgebaseService.update_by_id(kb.id, req): return get_error_data_result(retmsg="Update dataset error.(Database error)") return get_result(retcode=RetCode.SUCCESS) -@manager.route('/datasets', methods=['GET']) + +@manager.route("/datasets", methods=["GET"]) @token_required def list(tenant_id): + """ + List datasets. + --- + tags: + - Datasets + security: + - ApiKeyAuth: [] + parameters: + - in: query + name: id + type: string + required: false + description: Dataset ID to filter. + - in: query + name: name + type: string + required: false + description: Dataset name to filter. + - in: query + name: page + type: integer + required: false + default: 1 + description: Page number. + - in: query + name: page_size + type: integer + required: false + default: 1024 + description: Number of items per page. + - in: query + name: orderby + type: string + required: false + default: "create_time" + description: Field to order by. + - in: query + name: desc + type: boolean + required: false + default: true + description: Order in descending. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Successful operation. + schema: + type: array + items: + type: object + """ id = request.args.get("id") name = request.args.get("name") - kbs = KnowledgebaseService.query(id=id,name=name,status=1) + kbs = KnowledgebaseService.query(id=id, name=name, status=1) if not kbs: return get_error_data_result(retmsg="The dataset doesn't exist") page_number = int(request.args.get("page", 1)) items_per_page = int(request.args.get("page_size", 1024)) orderby = request.args.get("orderby", "create_time") - if request.args.get("desc") == "False" or request.args.get("desc") == "false" : + if request.args.get("desc") == "False" or request.args.get("desc") == "false": desc = False else: desc = True tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) kbs = KnowledgebaseService.get_list( - [m["tenant_id"] for m in tenants], tenant_id, page_number, items_per_page, orderby, desc, id, name) + [m["tenant_id"] for m in tenants], + tenant_id, + page_number, + items_per_page, + orderby, + desc, + id, + name, + ) renamed_list = [] for kb in kbs: key_mapping = { "chunk_num": "chunk_count", "doc_num": "document_count", "parser_id": "chunk_method", - "embd_id": "embedding_model" + "embd_id": "embedding_model", } renamed_data = {} for key, value in kb.items(): diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index ce616dfa5..a94f576bc 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -39,7 +39,7 @@ from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService from api.settings import RetCode, retrievaler -from api.utils.api_utils import construct_json_result,get_parser_config +from api.utils.api_utils import construct_json_result, get_parser_config from rag.nlp import search from rag.utils import rmSpace from rag.utils.es_conn import ELASTICSEARCH @@ -49,36 +49,93 @@ import os MAXIMUM_OF_UPLOADING_FILES = 256 - -@manager.route('/datasets//documents', methods=['POST']) +@manager.route("/datasets//documents", methods=["POST"]) @token_required def upload(dataset_id, tenant_id): - if 'file' not in request.files: + """ + Upload documents to a dataset. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: formData + name: file + type: file + required: true + description: Document files to upload. + responses: + 200: + description: Successfully uploaded documents. + schema: + type: object + properties: + data: + type: array + items: + type: object + properties: + id: + type: string + description: Document ID. + name: + type: string + description: Document name. + chunk_count: + type: integer + description: Number of chunks. + token_count: + type: integer + description: Number of tokens. + dataset_id: + type: string + description: ID of the dataset. + chunk_method: + type: string + description: Chunking method used. + run: + type: string + description: Processing status. + """ + if "file" not in request.files: return get_error_data_result( - retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR) - file_objs = request.files.getlist('file') + retmsg="No file part!", retcode=RetCode.ARGUMENT_ERROR + ) + file_objs = request.files.getlist("file") for file_obj in file_objs: - if file_obj.filename == '': + if file_obj.filename == "": return get_result( - retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) + retmsg="No file selected!", retcode=RetCode.ARGUMENT_ERROR + ) # total size total_size = 0 for file_obj in file_objs: file_obj.seek(0, os.SEEK_END) total_size += file_obj.tell() file_obj.seek(0) - MAX_TOTAL_FILE_SIZE=10*1024*1024 + MAX_TOTAL_FILE_SIZE = 10 * 1024 * 1024 if total_size > MAX_TOTAL_FILE_SIZE: return get_result( - retmsg=f'Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)', - retcode=RetCode.ARGUMENT_ERROR) + retmsg=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)", + retcode=RetCode.ARGUMENT_ERROR, + ) e, kb = KnowledgebaseService.get_by_id(dataset_id) if not e: raise LookupError(f"Can't find the dataset with ID {dataset_id}!") - err, files= FileService.upload_document(kb, file_objs, tenant_id) + err, files = FileService.upload_document(kb, file_objs, tenant_id) if err: - return get_result( - retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR) + return get_result(retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR) # rename key's name renamed_doc_list = [] for file in files: @@ -87,7 +144,7 @@ def upload(dataset_id, tenant_id): "chunk_num": "chunk_count", "kb_id": "dataset_id", "token_num": "token_count", - "parser_id": "chunk_method" + "parser_id": "chunk_method", } renamed_doc = {} for key, value in doc.items(): @@ -98,9 +155,54 @@ def upload(dataset_id, tenant_id): return get_result(data=renamed_doc_list) -@manager.route('/datasets//documents/', methods=['PUT']) +@manager.route("/datasets//documents/", methods=["PUT"]) @token_required def update_doc(tenant_id, dataset_id, document_id): + """ + Update a document within a dataset. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: path + name: document_id + type: string + required: true + description: ID of the document to update. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Document update parameters. + required: true + schema: + type: object + properties: + name: + type: string + description: New name of the document. + parser_config: + type: object + description: Parser configuration. + chunk_method: + type: string + description: Chunking method. + responses: + 200: + description: Document updated successfully. + schema: + type: object + """ req = request.json if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(retmsg="You don't own the dataset.") @@ -115,20 +217,25 @@ def update_doc(tenant_id, dataset_id, document_id): if req["token_count"] != doc.token_num: return get_error_data_result(retmsg="Can't change `token_count`.") if "progress" in req: - if req['progress'] != doc.progress: + if req["progress"] != doc.progress: return get_error_data_result(retmsg="Can't change `progress`.") if "name" in req and req["name"] != doc.name: - if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix: - return get_result(retmsg="The extension of file can't be changed", retcode=RetCode.ARGUMENT_ERROR) + if ( + pathlib.Path(req["name"].lower()).suffix + != pathlib.Path(doc.name.lower()).suffix + ): + return get_result( + retmsg="The extension of file can't be changed", + retcode=RetCode.ARGUMENT_ERROR, + ) for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): if d.name == req["name"]: return get_error_data_result( - retmsg="Duplicated document name in the same dataset.") - if not DocumentService.update_by_id( - document_id, {"name": req["name"]}): - return get_error_data_result( - retmsg="Database error (Document rename)!") + retmsg="Duplicated document name in the same dataset." + ) + if not DocumentService.update_by_id(document_id, {"name": req["name"]}): + return get_error_data_result(retmsg="Database error (Document rename)!") informs = File2DocumentService.get_by_document_id(document_id) if informs: @@ -137,77 +244,231 @@ def update_doc(tenant_id, dataset_id, document_id): if "parser_config" in req: DocumentService.update_parser_config(doc.id, req["parser_config"]) if "chunk_method" in req: - valid_chunk_method = {"naive","manual","qa","table","paper","book","laws","presentation","picture","one","knowledge_graph","email"} + valid_chunk_method = { + "naive", + "manual", + "qa", + "table", + "paper", + "book", + "laws", + "presentation", + "picture", + "one", + "knowledge_graph", + "email", + } if req.get("chunk_method") not in valid_chunk_method: - return get_error_data_result(f"`chunk_method` {req['chunk_method']} doesn't exist") + return get_error_data_result( + f"`chunk_method` {req['chunk_method']} doesn't exist" + ) if doc.parser_id.lower() == req["chunk_method"].lower(): - return get_result() + return get_result() - if doc.type == FileType.VISUAL or re.search( - r"\.(ppt|pptx|pages)$", doc.name): + if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name): return get_error_data_result(retmsg="Not supported yet!") - e = DocumentService.update_by_id(doc.id, - {"parser_id": req["chunk_method"], "progress": 0, "progress_msg": "", - "run": TaskStatus.UNSTART.value}) + e = DocumentService.update_by_id( + doc.id, + { + "parser_id": req["chunk_method"], + "progress": 0, + "progress_msg": "", + "run": TaskStatus.UNSTART.value, + }, + ) if not e: return get_error_data_result(retmsg="Document not found!") - req["parser_config"] = get_parser_config(req["chunk_method"], req.get("parser_config")) + req["parser_config"] = get_parser_config( + req["chunk_method"], req.get("parser_config") + ) DocumentService.update_parser_config(doc.id, req["parser_config"]) if doc.token_num > 0: - e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, - doc.process_duation * -1) + e = DocumentService.increment_chunk_num( + doc.id, + doc.kb_id, + doc.token_num * -1, + doc.chunk_num * -1, + doc.process_duation * -1, + ) if not e: return get_error_data_result(retmsg="Document not found!") ELASTICSEARCH.deleteByQuery( - Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) + Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id) + ) return get_result() -@manager.route('/datasets//documents/', methods=['GET']) +@manager.route("/datasets//documents/", methods=["GET"]) @token_required def download(tenant_id, dataset_id, document_id): + """ + Download a document from a dataset. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + produces: + - application/octet-stream + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: path + name: document_id + type: string + required: true + description: ID of the document to download. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Document file stream. + schema: + type: file + 400: + description: Error message. + schema: + type: object + """ if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): - return get_error_data_result(retmsg=f'You do not own the dataset {dataset_id}.') + return get_error_data_result(retmsg=f"You do not own the dataset {dataset_id}.") doc = DocumentService.query(kb_id=dataset_id, id=document_id) if not doc: - return get_error_data_result(retmsg=f'The dataset not own the document {document_id}.') + return get_error_data_result( + retmsg=f"The dataset not own the document {document_id}." + ) # The process of downloading - doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address + doc_id, doc_location = File2DocumentService.get_storage_address( + doc_id=document_id + ) # minio address file_stream = STORAGE_IMPL.get(doc_id, doc_location) if not file_stream: - return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR) + return construct_json_result( + message="This file is empty.", code=RetCode.DATA_ERROR + ) file = BytesIO(file_stream) # Use send_file with a proper filename and MIME type return send_file( file, as_attachment=True, download_name=doc[0].name, - mimetype='application/octet-stream' # Set a default MIME type + mimetype="application/octet-stream", # Set a default MIME type ) -@manager.route('/datasets//documents', methods=['GET']) +@manager.route("/datasets//documents", methods=["GET"]) @token_required def list_docs(dataset_id, tenant_id): + """ + List documents in a dataset. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: query + name: id + type: string + required: false + description: Filter by document ID. + - in: query + name: offset + type: integer + required: false + default: 1 + description: Page number. + - in: query + name: limit + type: integer + required: false + default: 1024 + description: Number of items per page. + - in: query + name: orderby + type: string + required: false + default: "create_time" + description: Field to order by. + - in: query + name: desc + type: boolean + required: false + default: true + description: Order in descending. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: List of documents. + schema: + type: object + properties: + total: + type: integer + description: Total number of documents. + docs: + type: array + items: + type: object + properties: + id: + type: string + description: Document ID. + name: + type: string + description: Document name. + chunk_count: + type: integer + description: Number of chunks. + token_count: + type: integer + description: Number of tokens. + dataset_id: + type: string + description: ID of the dataset. + chunk_method: + type: string + description: Chunking method used. + run: + type: string + description: Processing status. + """ if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}. ") id = request.args.get("id") name = request.args.get("name") - if not DocumentService.query(id=id,kb_id=dataset_id): + if not DocumentService.query(id=id, kb_id=dataset_id): return get_error_data_result(retmsg=f"You don't own the document {id}.") - if not DocumentService.query(name=name,kb_id=dataset_id): + if not DocumentService.query(name=name, kb_id=dataset_id): return get_error_data_result(retmsg=f"You don't own the document {name}.") offset = int(request.args.get("offset", 1)) - keywords = request.args.get("keywords","") + keywords = request.args.get("keywords", "") limit = int(request.args.get("limit", 1024)) orderby = request.args.get("orderby", "create_time") if request.args.get("desc") == "False": desc = False else: desc = True - docs, tol = DocumentService.get_list(dataset_id, offset, limit, orderby, desc, keywords, id,name) + docs, tol = DocumentService.get_list( + dataset_id, offset, limit, orderby, desc, keywords, id, name + ) # rename key's name renamed_doc_list = [] @@ -216,42 +477,80 @@ def list_docs(dataset_id, tenant_id): "chunk_num": "chunk_count", "kb_id": "dataset_id", "token_num": "token_count", - "parser_id": "chunk_method" + "parser_id": "chunk_method", } run_mapping = { - "0" :"UNSTART", - "1":"RUNNING", - "2":"CANCEL", - "3":"DONE", - "4":"FAIL" + "0": "UNSTART", + "1": "RUNNING", + "2": "CANCEL", + "3": "DONE", + "4": "FAIL", } renamed_doc = {} for key, value in doc.items(): + if key == "run": + renamed_doc["run"] = run_mapping.get(str(value)) new_key = key_mapping.get(key, key) renamed_doc[new_key] = value - if key =="run": - renamed_doc["run"]=run_mapping.get(value) + if key == "run": + renamed_doc["run"] = run_mapping.get(value) renamed_doc_list.append(renamed_doc) return get_result(data={"total": tol, "docs": renamed_doc_list}) -@manager.route('/datasets//documents', methods=['DELETE']) +@manager.route("/datasets//documents", methods=["DELETE"]) @token_required -def delete(tenant_id,dataset_id): +def delete(tenant_id, dataset_id): + """ + Delete documents from a dataset. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: body + name: body + description: Document deletion parameters. + required: true + schema: + type: object + properties: + ids: + type: array + items: + type: string + description: List of document IDs to delete. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Documents deleted successfully. + schema: + type: object + """ if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}. ") req = request.json if not req: - doc_ids=None + doc_ids = None else: - doc_ids=req.get("ids") + doc_ids = req.get("ids") if not doc_ids: doc_list = [] - docs=DocumentService.query(kb_id=dataset_id) + docs = DocumentService.query(kb_id=dataset_id) for doc in docs: doc_list.append(doc.id) else: - doc_list=doc_ids + doc_list = doc_ids root_folder = FileService.get_root_folder(tenant_id) pf_id = root_folder["id"] FileService.init_knowledgebase_docs(pf_id, tenant_id) @@ -269,10 +568,16 @@ def delete(tenant_id,dataset_id): if not DocumentService.remove_document(doc, tenant_id): return get_error_data_result( - retmsg="Database error (Document removal)!") + retmsg="Database error (Document removal)!" + ) f2d = File2DocumentService.get_by_document_id(doc_id) - FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) + FileService.filter_delete( + [ + File.source_type == FileSource.KNOWLEDGEBASE, + File.id == f2d[0].file_id, + ] + ) File2DocumentService.delete_by_document_id(doc_id) STORAGE_IMPL.rm(b, n) @@ -285,25 +590,66 @@ def delete(tenant_id,dataset_id): return get_result() -@manager.route('/datasets//chunks', methods=['POST']) +@manager.route("/datasets//chunks", methods=["POST"]) @token_required -def parse(tenant_id,dataset_id): +def parse(tenant_id, dataset_id): + """ + Start parsing documents into chunks. + --- + tags: + - Chunks + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: body + name: body + description: Parsing parameters. + required: true + schema: + type: object + properties: + document_ids: + type: array + items: + type: string + description: List of document IDs to parse. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Parsing started successfully. + schema: + type: object + """ if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.") req = request.json if not req.get("document_ids"): return get_error_data_result("`document_ids` is required") for id in req["document_ids"]: - doc = DocumentService.query(id=id,kb_id=dataset_id) + doc = DocumentService.query(id=id, kb_id=dataset_id) if not doc: return get_error_data_result(retmsg=f"You don't own the document {id}.") + if doc[0].progress != 0.0: + return get_error_data_result( + "Can't stop parsing document with progress at 0 or 100" + ) info = {"run": "1", "progress": 0} info["progress_msg"] = "" info["chunk_num"] = 0 info["token_num"] = 0 DocumentService.update_by_id(id, info) ELASTICSEARCH.deleteByQuery( - Q("match", doc_id=id), idxnm=search.index_name(tenant_id)) + Q("match", doc_id=id), idxnm=search.index_name(tenant_id) + ) TaskService.filter_delete([Task.doc_id == id]) e, doc = DocumentService.get_by_id(id) doc = doc.to_dict() @@ -312,9 +658,46 @@ def parse(tenant_id,dataset_id): queue_tasks(doc, bucket, name) return get_result() -@manager.route('/datasets//chunks', methods=['DELETE']) + +@manager.route("/datasets//chunks", methods=["DELETE"]) @token_required -def stop_parsing(tenant_id,dataset_id): +def stop_parsing(tenant_id, dataset_id): + """ + Stop parsing documents into chunks. + --- + tags: + - Chunks + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: body + name: body + description: Stop parsing parameters. + required: true + schema: + type: object + properties: + document_ids: + type: array + items: + type: string + description: List of document IDs to stop parsing. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Parsing stopped successfully. + schema: + type: object + """ if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.") req = request.json @@ -325,46 +708,125 @@ def stop_parsing(tenant_id,dataset_id): if not doc: return get_error_data_result(retmsg=f"You don't own the document {id}.") if int(doc[0].progress) == 1 or int(doc[0].progress) == 0: - return get_error_data_result("Can't stop parsing document with progress at 0 or 1") - info = {"run": "2", "progress": 0,"chunk_num":0} + return get_error_data_result( + "Can't stop parsing document with progress at 0 or 1" + ) + info = {"run": "2", "progress": 0, "chunk_num": 0} DocumentService.update_by_id(id, info) ELASTICSEARCH.deleteByQuery( - Q("match", doc_id=id), idxnm=search.index_name(tenant_id)) + Q("match", doc_id=id), idxnm=search.index_name(tenant_id) + ) return get_result() -@manager.route('/datasets//documents//chunks', methods=['GET']) +@manager.route("/datasets//documents//chunks", methods=["GET"]) @token_required -def list_chunks(tenant_id,dataset_id,document_id): +def list_chunks(tenant_id, dataset_id, document_id): + """ + List chunks of a document. + --- + tags: + - Chunks + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: path + name: document_id + type: string + required: true + description: ID of the document. + - in: query + name: offset + type: integer + required: false + default: 1 + description: Page number. + - in: query + name: limit + type: integer + required: false + default: 30 + description: Number of items per page. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: List of chunks. + schema: + type: object + properties: + total: + type: integer + description: Total number of chunks. + chunks: + type: array + items: + type: object + properties: + id: + type: string + description: Chunk ID. + content: + type: string + description: Chunk content. + document_id: + type: string + description: ID of the document. + important_keywords: + type: array + items: + type: string + description: Important keywords. + image_id: + type: string + description: Image ID associated with the chunk. + doc: + type: object + description: Document details. + """ if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.") - doc=DocumentService.query(id=document_id, kb_id=dataset_id) + doc = DocumentService.query(id=document_id, kb_id=dataset_id) if not doc: - return get_error_data_result(retmsg=f"You don't own the document {document_id}.") - doc=doc[0] + return get_error_data_result( + retmsg=f"You don't own the document {document_id}." + ) + doc = doc[0] req = request.args doc_id = document_id page = int(req.get("offset", 1)) size = int(req.get("limit", 30)) question = req.get("keywords", "") query = { - "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True + "doc_ids": [doc_id], + "page": page, + "size": size, + "question": question, + "sort": True, } sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True) key_mapping = { "chunk_num": "chunk_count", "kb_id": "dataset_id", "token_num": "token_count", - "parser_id": "chunk_method" + "parser_id": "chunk_method", } run_mapping = { "0": "UNSTART", "1": "RUNNING", "2": "CANCEL", "3": "DONE", - "4": "FAIL" + "4": "FAIL", } - doc=doc.to_dict() + doc = doc.to_dict() renamed_doc = {} for key, value in doc.items(): new_key = key_mapping.get(key, key) @@ -377,21 +839,30 @@ def list_chunks(tenant_id,dataset_id,document_id): for id in sres.ids: d = { "chunk_id": id, - "content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[ - id].get( - "content_with_weight", ""), + "content_with_weight": ( + rmSpace(sres.highlight[id]) + if question and id in sres.highlight + else sres.field[id].get("content_with_weight", "") + ), "doc_id": sres.field[id]["doc_id"], "docnm_kwd": sres.field[id]["docnm_kwd"], "important_kwd": sres.field[id].get("important_kwd", []), "img_id": sres.field[id].get("img_id", ""), "available_int": sres.field[id].get("available_int", 1), - "positions": sres.field[id].get("position_int", "").split("\t") + "positions": sres.field[id].get("position_int", "").split("\t"), } if len(d["positions"]) % 5 == 0: poss = [] for i in range(0, len(d["positions"]), 5): - poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]), - float(d["positions"][i + 3]), float(d["positions"][i + 4])]) + poss.append( + [ + float(d["positions"][i]), + float(d["positions"][i + 1]), + float(d["positions"][i + 2]), + float(d["positions"][i + 3]), + float(d["positions"][i + 4]), + ] + ) d["positions"] = poss origin_chunks.append(d) @@ -411,7 +882,7 @@ def list_chunks(tenant_id,dataset_id,document_id): "doc_id": "document_id", "important_kwd": "important_keywords", "img_id": "image_id", - "available_int":"available" + "available_int": "available", } renamed_chunk = {} for key, value in chunk.items(): @@ -425,31 +896,104 @@ def list_chunks(tenant_id,dataset_id,document_id): return get_result(data=res) - -@manager.route('/datasets//documents//chunks', methods=['POST']) +@manager.route( + "/datasets//documents//chunks", methods=["POST"] +) @token_required -def add_chunk(tenant_id,dataset_id,document_id): +def add_chunk(tenant_id, dataset_id, document_id): + """ + Add a chunk to a document. + --- + tags: + - Chunks + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: path + name: document_id + type: string + required: true + description: ID of the document. + - in: body + name: body + description: Chunk data. + required: true + schema: + type: object + properties: + content: + type: string + required: true + description: Content of the chunk. + important_keywords: + type: array + items: + type: string + description: Important keywords. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Chunk added successfully. + schema: + type: object + properties: + chunk: + type: object + properties: + id: + type: string + description: Chunk ID. + content: + type: string + description: Chunk content. + document_id: + type: string + description: ID of the document. + important_keywords: + type: array + items: + type: string + description: Important keywords. + """ if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.") doc = DocumentService.query(id=document_id, kb_id=dataset_id) if not doc: - return get_error_data_result(retmsg=f"You don't own the document {document_id}.") + return get_error_data_result( + retmsg=f"You don't own the document {document_id}." + ) doc = doc[0] req = request.json if not req.get("content"): return get_error_data_result(retmsg="`content` is required") if "important_keywords" in req: if type(req["important_keywords"]) != list: - return get_error_data_result("`important_keywords` is required to be a list") + return get_error_data_result( + "`important_keywords` is required to be a list" + ) md5 = hashlib.md5() md5.update((req["content"] + document_id).encode("utf-8")) chunk_id = md5.hexdigest() - d = {"id": chunk_id, "content_ltks": rag_tokenizer.tokenize(req["content"]), - "content_with_weight": req["content"]} + d = { + "id": chunk_id, + "content_ltks": rag_tokenizer.tokenize(req["content"]), + "content_with_weight": req["content"], + } d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["important_kwd"] = req.get("important_keywords", []) - d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_keywords", []))) + d["important_tks"] = rag_tokenizer.tokenize( + " ".join(req.get("important_keywords", [])) + ) d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_timestamp_flt"] = datetime.datetime.now().timestamp() d["kb_id"] = [doc.kb_id] @@ -457,17 +1001,17 @@ def add_chunk(tenant_id,dataset_id,document_id): d["doc_id"] = doc.id embd_id = DocumentService.get_embd_id(document_id) embd_mdl = TenantLLMService.model_instance( - tenant_id, LLMType.EMBEDDING.value, embd_id) - print(embd_mdl,flush=True) + tenant_id, LLMType.EMBEDDING.value, embd_id + ) + print(embd_mdl, flush=True) v, c = embd_mdl.encode([doc.name, req["content"]]) v = 0.1 * v[0] + 0.9 * v[1] d["q_%d_vec" % len(v)] = v.tolist() ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) - DocumentService.increment_chunk_num( - doc.id, doc.kb_id, c, 1, 0) + DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0) d["chunk_id"] = chunk_id - d["kb_id"]=doc.kb_id + d["kb_id"] = doc.kb_id # rename keys key_mapping = { "chunk_id": "id", @@ -477,7 +1021,7 @@ def add_chunk(tenant_id,dataset_id,document_id): "kb_id": "dataset_id", "create_timestamp_flt": "create_timestamp", "create_time": "create_time", - "document_keyword": "document" + "document_keyword": "document", } renamed_chunk = {} for key, value in d.items(): @@ -488,32 +1032,79 @@ def add_chunk(tenant_id,dataset_id,document_id): # return get_result(data={"chunk_id": chunk_id}) -@manager.route('datasets//documents//chunks', methods=['DELETE']) +@manager.route( + "datasets//documents//chunks", methods=["DELETE"] +) @token_required -def rm_chunk(tenant_id,dataset_id,document_id): +def rm_chunk(tenant_id, dataset_id, document_id): + """ + Remove chunks from a document. + --- + tags: + - Chunks + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: path + name: document_id + type: string + required: true + description: ID of the document. + - in: body + name: body + description: Chunk removal parameters. + required: true + schema: + type: object + properties: + chunk_ids: + type: array + items: + type: string + description: List of chunk IDs to remove. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Chunks removed successfully. + schema: + type: object + """ if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.") doc = DocumentService.query(id=document_id, kb_id=dataset_id) if not doc: - return get_error_data_result(retmsg=f"You don't own the document {document_id}.") + return get_error_data_result( + retmsg=f"You don't own the document {document_id}." + ) doc = doc[0] req = request.json - query = { - "doc_ids": [doc.id], "page": 1, "size": 1024, "question": "", "sort": True} + if not req.get("chunk_ids"): + return get_error_data_result("`chunk_ids` is required") + query = {"doc_ids": [doc.id], "page": 1, "size": 1024, "question": "", "sort": True} sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True) if not req: - chunk_ids=None + chunk_ids = None else: - chunk_ids=req.get("chunk_ids") + chunk_ids = req.get("chunk_ids") if not chunk_ids: - chunk_list=sres.ids + chunk_list = sres.ids else: - chunk_list=chunk_ids + chunk_list = chunk_ids for chunk_id in chunk_list: if chunk_id not in sres.ids: return get_error_data_result(f"Chunk {chunk_id} not found") if not ELASTICSEARCH.deleteByQuery( - Q("ids", values=chunk_list), search.index_name(tenant_id)): + Q("ids", values=chunk_list), search.index_name(tenant_id) + ): return get_error_data_result(retmsg="Index updating failure") deleted_chunk_ids = chunk_list chunk_number = len(deleted_chunk_ids) @@ -521,37 +1112,92 @@ def rm_chunk(tenant_id,dataset_id,document_id): return get_result() - -@manager.route('/datasets//documents//chunks/', methods=['PUT']) +@manager.route( + "/datasets//documents//chunks/", methods=["PUT"] +) @token_required -def update_chunk(tenant_id,dataset_id,document_id,chunk_id): +def update_chunk(tenant_id, dataset_id, document_id, chunk_id): + """ + Update a chunk within a document. + --- + tags: + - Chunks + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: path + name: document_id + type: string + required: true + description: ID of the document. + - in: path + name: chunk_id + type: string + required: true + description: ID of the chunk to update. + - in: body + name: body + description: Chunk update parameters. + required: true + schema: + type: object + properties: + content: + type: string + description: Updated content of the chunk. + important_keywords: + type: array + items: + type: string + description: Updated important keywords. + available: + type: boolean + description: Availability status of the chunk. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Chunk updated successfully. + schema: + type: object + """ try: - res = ELASTICSEARCH.get( - chunk_id, search.index_name( - tenant_id)) + res = ELASTICSEARCH.get(chunk_id, search.index_name(tenant_id)) except Exception as e: return get_error_data_result(f"Can't find this chunk {chunk_id}") if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.") doc = DocumentService.query(id=document_id, kb_id=dataset_id) if not doc: - return get_error_data_result(retmsg=f"You don't own the document {document_id}.") + return get_error_data_result( + retmsg=f"You don't own the document {document_id}." + ) doc = doc[0] query = { - "doc_ids": [document_id], "page": 1, "size": 1024, "question": "", "sort": True + "doc_ids": [document_id], + "page": 1, + "size": 1024, + "question": "", + "sort": True, } sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True) if chunk_id not in sres.ids: return get_error_data_result(f"You don't own the chunk {chunk_id}") req = request.json - content=res["_source"].get("content_with_weight") - d = { - "id": chunk_id, - "content_with_weight": req.get("content",content)} + content = res["_source"].get("content_with_weight") + d = {"id": chunk_id, "content_with_weight": req.get("content", content)} d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) if "important_keywords" in req: - if not isinstance(req["important_keywords"],list): + if not isinstance(req["important_keywords"], list): return get_error_data_result("`important_keywords` should be a list") d["important_kwd"] = req.get("important_keywords") d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_keywords"])) @@ -559,18 +1205,18 @@ def update_chunk(tenant_id,dataset_id,document_id,chunk_id): d["available_int"] = int(req["available"]) embd_id = DocumentService.get_embd_id(document_id) embd_mdl = TenantLLMService.model_instance( - tenant_id, LLMType.EMBEDDING.value, embd_id) + tenant_id, LLMType.EMBEDDING.value, embd_id + ) if doc.parser_id == ParserType.QA: - arr = [ - t for t in re.split( - r"[\n\t]", - d["content_with_weight"]) if len(t) > 1] + arr = [t for t in re.split(r"[\n\t]", d["content_with_weight"]) if len(t) > 1] if len(arr) != 2: return get_error_data_result( - retmsg="Q&A must be separated by TAB/ENTER key.") + retmsg="Q&A must be separated by TAB/ENTER key." + ) q, a = rmPrefix(arr[0]), rmPrefix(arr[1]) - d = beAdoc(d, arr[0], arr[1], not any( - [rag_tokenizer.is_chinese(t) for t in q + a])) + d = beAdoc( + d, arr[0], arr[1], not any([rag_tokenizer.is_chinese(t) for t in q + a]) + ) v, c = embd_mdl.encode([doc.name, d["content_with_weight"]]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] @@ -579,41 +1225,120 @@ def update_chunk(tenant_id,dataset_id,document_id,chunk_id): return get_result() - -@manager.route('/retrieval', methods=['POST']) +@manager.route("/retrieval", methods=["POST"]) @token_required def retrieval_test(tenant_id): + """ + Retrieve chunks based on a query. + --- + tags: + - Retrieval + security: + - ApiKeyAuth: [] + parameters: + - in: body + name: body + description: Retrieval parameters. + required: true + schema: + type: object + properties: + dataset_ids: + type: array + items: + type: string + required: true + description: List of dataset IDs to search in. + question: + type: string + required: true + description: Query string. + document_ids: + type: array + items: + type: string + description: List of document IDs to filter. + similarity_threshold: + type: number + format: float + description: Similarity threshold. + vector_similarity_weight: + type: number + format: float + description: Vector similarity weight. + top_k: + type: integer + description: Maximum number of chunks to return. + highlight: + type: boolean + description: Whether to highlight matched content. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Retrieval results. + schema: + type: object + properties: + chunks: + type: array + items: + type: object + properties: + id: + type: string + description: Chunk ID. + content: + type: string + description: Chunk content. + document_id: + type: string + description: ID of the document. + dataset_id: + type: string + description: ID of the dataset. + similarity: + type: number + format: float + description: Similarity score. + """ req = request.json if not req.get("dataset_ids"): return get_error_data_result("`dataset_ids` is required.") kb_ids = req["dataset_ids"] - if not isinstance(kb_ids,list): + if not isinstance(kb_ids, list): return get_error_data_result("`dataset_ids` should be a list") kbs = KnowledgebaseService.get_by_ids(kb_ids) for id in kb_ids: - if not KnowledgebaseService.query(id=id,tenant_id=tenant_id): + if not KnowledgebaseService.query(id=id, tenant_id=tenant_id): return get_error_data_result(f"You don't own the dataset {id}.") embd_nms = list(set([kb.embd_id for kb in kbs])) if len(embd_nms) != 1: return get_result( retmsg='Datasets use different embedding models."', - retcode=RetCode.AUTHENTICATION_ERROR) + retcode=RetCode.AUTHENTICATION_ERROR, + ) if "question" not in req: return get_error_data_result("`question` is required.") page = int(req.get("offset", 1)) size = int(req.get("limit", 1024)) question = req["question"] doc_ids = req.get("document_ids", []) - if not isinstance(doc_ids,list): + if not isinstance(doc_ids, list): return get_error_data_result("`documents` should be a list") - doc_ids_list=KnowledgebaseService.list_documents_by_ids(kb_ids) + doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids) for doc_id in doc_ids: if doc_id not in doc_ids_list: - return get_error_data_result(f"The datasets don't own the document {doc_id}") + return get_error_data_result( + f"The datasets don't own the document {doc_id}" + ) similarity_threshold = float(req.get("similarity_threshold", 0.2)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) top = int(req.get("top_k", 1024)) - if req.get("highlight")=="False" or req.get("highlight")=="false": + if req.get("highlight") == "False" or req.get("highlight") == "false": highlight = False else: highlight = True @@ -622,21 +1347,34 @@ def retrieval_test(tenant_id): if not e: return get_error_data_result(retmsg="Dataset not found!") embd_mdl = TenantLLMService.model_instance( - kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) + kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id + ) rerank_mdl = None if req.get("rerank_id"): rerank_mdl = TenantLLMService.model_instance( - kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) + kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"] + ) if req.get("keyword", False): chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler - ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size, - similarity_threshold, vector_similarity_weight, top, - doc_ids, rerank_mdl=rerank_mdl, highlight=highlight) + ranks = retr.retrieval( + question, + embd_mdl, + kb.tenant_id, + kb_ids, + page, + size, + similarity_threshold, + vector_similarity_weight, + top, + doc_ids, + rerank_mdl=rerank_mdl, + highlight=highlight, + ) for c in ranks["chunks"]: if "vector" in c: del c["vector"] @@ -649,7 +1387,7 @@ def retrieval_test(tenant_id): "content_with_weight": "content", "doc_id": "document_id", "important_kwd": "important_keywords", - "docnm_kwd": "document_keyword" + "docnm_kwd": "document_keyword", } rename_chunk = {} for key, value in chunk.items(): @@ -660,6 +1398,8 @@ def retrieval_test(tenant_id): return get_result(data=ranks) except Exception as e: if str(e).find("not_found") > 0: - return get_result(retmsg=f'No chunk found! Check the chunk status please!', - retcode=RetCode.DATA_ERROR) - return server_error_response(e) \ No newline at end of file + return get_result( + retmsg=f"No chunk found! Check the chunk status please!", + retcode=RetCode.DATA_ERROR, + ) + return server_error_response(e) diff --git a/api/apps/system_app.py b/api/apps/system_app.py index 28df3d688..67611a57f 100644 --- a/api/apps/system_app.py +++ b/api/apps/system_app.py @@ -24,8 +24,14 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.user_service import UserTenantService from api.settings import DATABASE_TYPE from api.utils import current_timestamp, datetime_format -from api.utils.api_utils import get_json_result, get_data_error_result, server_error_response, \ - generate_confirmation_token, request, validate_request +from api.utils.api_utils import ( + get_json_result, + get_data_error_result, + server_error_response, + generate_confirmation_token, + request, + validate_request, +) from api.versions import get_rag_version from rag.utils.es_conn import ELASTICSEARCH from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE @@ -34,44 +40,121 @@ from timeit import default_timer as timer from rag.utils.redis_conn import REDIS_CONN -@manager.route('/version', methods=['GET']) +@manager.route("/version", methods=["GET"]) @login_required def version(): + """ + Get the current version of the application. + --- + tags: + - System + security: + - ApiKeyAuth: [] + responses: + 200: + description: Version retrieved successfully. + schema: + type: object + properties: + version: + type: string + description: Version number. + """ return get_json_result(data=get_rag_version()) -@manager.route('/status', methods=['GET']) +@manager.route("/status", methods=["GET"]) @login_required def status(): + """ + Get the system status. + --- + tags: + - System + security: + - ApiKeyAuth: [] + responses: + 200: + description: System is operational. + schema: + type: object + properties: + es: + type: object + description: Elasticsearch status. + storage: + type: object + description: Storage status. + database: + type: object + description: Database status. + 503: + description: Service unavailable. + schema: + type: object + properties: + error: + type: string + description: Error message. + """ res = {} st = timer() try: res["es"] = ELASTICSEARCH.health() - res["es"]["elapsed"] = "{:.1f}".format((timer() - st)*1000.) + res["es"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0) except Exception as e: - res["es"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)} + res["es"] = { + "status": "red", + "elapsed": "{:.1f}".format((timer() - st) * 1000.0), + "error": str(e), + } st = timer() try: STORAGE_IMPL.health() - res["storage"] = {"storage": STORAGE_IMPL_TYPE.lower(), "status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)} + res["storage"] = { + "storage": STORAGE_IMPL_TYPE.lower(), + "status": "green", + "elapsed": "{:.1f}".format((timer() - st) * 1000.0), + } except Exception as e: - res["storage"] = {"storage": STORAGE_IMPL_TYPE.lower(), "status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)} + res["storage"] = { + "storage": STORAGE_IMPL_TYPE.lower(), + "status": "red", + "elapsed": "{:.1f}".format((timer() - st) * 1000.0), + "error": str(e), + } st = timer() try: KnowledgebaseService.get_by_id("x") - res["database"] = {"database": DATABASE_TYPE.lower(), "status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)} + res["database"] = { + "database": DATABASE_TYPE.lower(), + "status": "green", + "elapsed": "{:.1f}".format((timer() - st) * 1000.0), + } except Exception as e: - res["database"] = {"database": DATABASE_TYPE.lower(), "status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)} + res["database"] = { + "database": DATABASE_TYPE.lower(), + "status": "red", + "elapsed": "{:.1f}".format((timer() - st) * 1000.0), + "error": str(e), + } st = timer() try: if not REDIS_CONN.health(): raise Exception("Lost connection!") - res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)} + res["redis"] = { + "status": "green", + "elapsed": "{:.1f}".format((timer() - st) * 1000.0), + } except Exception as e: - res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)} + res["redis"] = { + "status": "red", + "elapsed": "{:.1f}".format((timer() - st) * 1000.0), + "error": str(e), + } try: v = REDIS_CONN.get("TASKEXE") @@ -84,10 +167,12 @@ def status(): if len(arr) == 1: obj[id] = [0] else: - obj[id] = [arr[i+1]-arr[i] for i in range(len(arr)-1)] + obj[id] = [arr[i + 1] - arr[i] for i in range(len(arr) - 1)] elapsed = max(obj[id]) - if elapsed > 50: color = "yellow" - if elapsed > 120: color = "red" + if elapsed > 50: + color = "yellow" + if elapsed > 120: + color = "red" res["task_executor"] = {"status": color, "elapsed": obj} except Exception as e: res["task_executor"] = {"status": "red", "error": str(e)} @@ -95,21 +180,46 @@ def status(): return get_json_result(data=res) -@manager.route('/new_token', methods=['POST']) +@manager.route("/new_token", methods=["POST"]) @login_required def new_token(): + """ + Generate a new API token. + --- + tags: + - API Tokens + security: + - ApiKeyAuth: [] + parameters: + - in: query + name: name + type: string + required: false + description: Name of the token. + responses: + 200: + description: Token generated successfully. + schema: + type: object + properties: + token: + type: string + description: The generated API token. + """ try: tenants = UserTenantService.query(user_id=current_user.id) if not tenants: return get_data_error_result(retmsg="Tenant not found!") tenant_id = tenants[0].tenant_id - obj = {"tenant_id": tenant_id, "token": generate_confirmation_token(tenant_id), - "create_time": current_timestamp(), - "create_date": datetime_format(datetime.now()), - "update_time": None, - "update_date": None - } + obj = { + "tenant_id": tenant_id, + "token": generate_confirmation_token(tenant_id), + "create_time": current_timestamp(), + "create_date": datetime_format(datetime.now()), + "update_time": None, + "update_date": None, + } if not APITokenService.save(**obj): return get_data_error_result(retmsg="Fail to new a dialog!") @@ -119,9 +229,37 @@ def new_token(): return server_error_response(e) -@manager.route('/token_list', methods=['GET']) +@manager.route("/token_list", methods=["GET"]) @login_required def token_list(): + """ + List all API tokens for the current user. + --- + tags: + - API Tokens + security: + - ApiKeyAuth: [] + responses: + 200: + description: List of API tokens. + schema: + type: object + properties: + tokens: + type: array + items: + type: object + properties: + token: + type: string + description: The API token. + name: + type: string + description: Name of the token. + create_time: + type: string + description: Token creation time. + """ try: tenants = UserTenantService.query(user_id=current_user.id) if not tenants: @@ -133,9 +271,33 @@ def token_list(): return server_error_response(e) -@manager.route('/token/', methods=['DELETE']) +@manager.route("/token/", methods=["DELETE"]) @login_required def rm(token): + """ + Remove an API token. + --- + tags: + - API Tokens + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: token + type: string + required: true + description: The API token to remove. + responses: + 200: + description: Token removed successfully. + schema: + type: object + properties: + success: + type: boolean + description: Deletion status. + """ APITokenService.filter_delete( - [APIToken.tenant_id == current_user.id, APIToken.token == token]) - return get_json_result(data=True) \ No newline at end of file + [APIToken.tenant_id == current_user.id, APIToken.token == token] + ) + return get_json_result(data=True) diff --git a/api/apps/user_app.py b/api/apps/user_app.py index bfb9b291e..a9d5c51ac 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -23,65 +23,141 @@ from flask_login import login_required, current_user, login_user, logout_user from api.db.db_models import TenantLLM from api.db.services.llm_service import TenantLLMService, LLMService -from api.utils.api_utils import server_error_response, validate_request, get_data_error_result -from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format +from api.utils.api_utils import ( + server_error_response, + validate_request, + get_data_error_result, +) +from api.utils import ( + get_uuid, + get_format_time, + decrypt, + download_img, + current_timestamp, + datetime_format, +) from api.db import UserTenantRole, LLMType, FileType -from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, \ - API_KEY, \ - LLM_FACTORY, LLM_BASE_URL, RERANK_MDL +from api.settings import ( + RetCode, + GITHUB_OAUTH, + FEISHU_OAUTH, + CHAT_MDL, + EMBEDDING_MDL, + ASR_MDL, + IMAGE2TEXT_MDL, + PARSERS, + API_KEY, + LLM_FACTORY, + LLM_BASE_URL, + RERANK_MDL, +) from api.db.services.user_service import UserService, TenantService, UserTenantService from api.db.services.file_service import FileService from api.settings import stat_logger from api.utils.api_utils import get_json_result, construct_response -@manager.route('/login', methods=['POST', 'GET']) +@manager.route("/login", methods=["POST", "GET"]) def login(): + """ + User login endpoint. + --- + tags: + - User + parameters: + - in: body + name: body + description: Login credentials. + required: true + schema: + type: object + properties: + email: + type: string + description: User email. + password: + type: string + description: User password. + responses: + 200: + description: Login successful. + schema: + type: object + 401: + description: Authentication failed. + schema: + type: object + """ if not request.json: - return get_json_result(data=False, - retcode=RetCode.AUTHENTICATION_ERROR, - retmsg='Unauthorized!') + return get_json_result( + data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg="Unauthorized!" + ) - email = request.json.get('email', "") + email = request.json.get("email", "") users = UserService.query(email=email) if not users: - return get_json_result(data=False, - retcode=RetCode.AUTHENTICATION_ERROR, - retmsg=f'Email: {email} is not registered!') + return get_json_result( + data=False, + retcode=RetCode.AUTHENTICATION_ERROR, + retmsg=f"Email: {email} is not registered!", + ) - password = request.json.get('password') + password = request.json.get("password") try: password = decrypt(password) except BaseException: - return get_json_result(data=False, - retcode=RetCode.SERVER_ERROR, - retmsg='Fail to crypt password') + return get_json_result( + data=False, retcode=RetCode.SERVER_ERROR, retmsg="Fail to crypt password" + ) user = UserService.query_user(email, password) if user: response_data = user.to_json() user.access_token = get_uuid() login_user(user) - user.update_time = current_timestamp(), - user.update_date = datetime_format(datetime.now()), + user.update_time = (current_timestamp(),) + user.update_date = (datetime_format(datetime.now()),) user.save() msg = "Welcome back!" return construct_response(data=response_data, auth=user.get_id(), retmsg=msg) else: - return get_json_result(data=False, - retcode=RetCode.AUTHENTICATION_ERROR, - retmsg='Email and password do not match!') + return get_json_result( + data=False, + retcode=RetCode.AUTHENTICATION_ERROR, + retmsg="Email and password do not match!", + ) -@manager.route('/github_callback', methods=['GET']) +@manager.route("/github_callback", methods=["GET"]) def github_callback(): + """ + GitHub OAuth callback endpoint. + --- + tags: + - OAuth + parameters: + - in: query + name: code + type: string + required: true + description: Authorization code from GitHub. + responses: + 200: + description: Authentication successful. + schema: + type: object + """ import requests - res = requests.post(GITHUB_OAUTH.get("url"), - data={ - "client_id": GITHUB_OAUTH.get("client_id"), - "client_secret": GITHUB_OAUTH.get("secret_key"), - "code": request.args.get('code')}, - headers={"Accept": "application/json"}) + + res = requests.post( + GITHUB_OAUTH.get("url"), + data={ + "client_id": GITHUB_OAUTH.get("client_id"), + "client_secret": GITHUB_OAUTH.get("secret_key"), + "code": request.args.get("code"), + }, + headers={"Accept": "application/json"}, + ) res = res.json() if "error" in res: return redirect("/?error=%s" % res["error_description"]) @@ -103,19 +179,22 @@ def github_callback(): except Exception as e: stat_logger.exception(e) avatar = "" - users = user_register(user_id, { - "access_token": session["access_token"], - "email": email_address, - "avatar": avatar, - "nickname": user_info["login"], - "login_channel": "github", - "last_login_time": get_format_time(), - "is_superuser": False, - }) + users = user_register( + user_id, + { + "access_token": session["access_token"], + "email": email_address, + "avatar": avatar, + "nickname": user_info["login"], + "login_channel": "github", + "last_login_time": get_format_time(), + "is_superuser": False, + }, + ) if not users: - raise Exception(f'Fail to register {email_address}.') + raise Exception(f"Fail to register {email_address}.") if len(users) > 1: - raise Exception(f'Same email: {email_address} exists!') + raise Exception(f"Same email: {email_address} exists!") # Try to log in user = users[0] @@ -134,30 +213,56 @@ def github_callback(): return redirect("/?auth=%s" % user.get_id()) -@manager.route('/feishu_callback', methods=['GET']) +@manager.route("/feishu_callback", methods=["GET"]) def feishu_callback(): + """ + Feishu OAuth callback endpoint. + --- + tags: + - OAuth + parameters: + - in: query + name: code + type: string + required: true + description: Authorization code from Feishu. + responses: + 200: + description: Authentication successful. + schema: + type: object + """ import requests - app_access_token_res = requests.post(FEISHU_OAUTH.get("app_access_token_url"), - data=json.dumps({ - "app_id": FEISHU_OAUTH.get("app_id"), - "app_secret": FEISHU_OAUTH.get("app_secret") - }), - headers={"Content-Type": "application/json; charset=utf-8"}) + + app_access_token_res = requests.post( + FEISHU_OAUTH.get("app_access_token_url"), + data=json.dumps( + { + "app_id": FEISHU_OAUTH.get("app_id"), + "app_secret": FEISHU_OAUTH.get("app_secret"), + } + ), + headers={"Content-Type": "application/json; charset=utf-8"}, + ) app_access_token_res = app_access_token_res.json() - if app_access_token_res['code'] != 0: + if app_access_token_res["code"] != 0: return redirect("/?error=%s" % app_access_token_res) - res = requests.post(FEISHU_OAUTH.get("user_access_token_url"), - data=json.dumps({ - "grant_type": FEISHU_OAUTH.get("grant_type"), - "code": request.args.get('code') - }), - headers={ - "Content-Type": "application/json; charset=utf-8", - 'Authorization': f"Bearer {app_access_token_res['app_access_token']}" - }) + res = requests.post( + FEISHU_OAUTH.get("user_access_token_url"), + data=json.dumps( + { + "grant_type": FEISHU_OAUTH.get("grant_type"), + "code": request.args.get("code"), + } + ), + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {app_access_token_res['app_access_token']}", + }, + ) res = res.json() - if res['code'] != 0: + if res["code"] != 0: return redirect("/?error=%s" % res["message"]) if "contact:user.email:readonly" not in res["data"]["scope"].split(" "): @@ -176,19 +281,22 @@ def feishu_callback(): except Exception as e: stat_logger.exception(e) avatar = "" - users = user_register(user_id, { - "access_token": session["access_token"], - "email": email_address, - "avatar": avatar, - "nickname": user_info["en_name"], - "login_channel": "feishu", - "last_login_time": get_format_time(), - "is_superuser": False, - }) + users = user_register( + user_id, + { + "access_token": session["access_token"], + "email": email_address, + "avatar": avatar, + "nickname": user_info["en_name"], + "login_channel": "feishu", + "last_login_time": get_format_time(), + "is_superuser": False, + }, + ) if not users: - raise Exception(f'Fail to register {email_address}.') + raise Exception(f"Fail to register {email_address}.") if len(users) > 1: - raise Exception(f'Same email: {email_address} exists!') + raise Exception(f"Same email: {email_address} exists!") # Try to log in user = users[0] @@ -209,11 +317,14 @@ def feishu_callback(): def user_info_from_feishu(access_token): import requests - headers = {"Content-Type": "application/json; charset=utf-8", - 'Authorization': f"Bearer {access_token}"} + + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {access_token}", + } res = requests.get( - f"https://open.feishu.cn/open-apis/authen/v1/user_info", - headers=headers) + f"https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers + ) user_info = res.json()["data"] user_info["email"] = None if user_info.get("email") == "" else user_info["email"] return user_info @@ -221,24 +332,38 @@ def user_info_from_feishu(access_token): def user_info_from_github(access_token): import requests - headers = {"Accept": "application/json", - 'Authorization': f"token {access_token}"} + + headers = {"Accept": "application/json", "Authorization": f"token {access_token}"} res = requests.get( - f"https://api.github.com/user?access_token={access_token}", - headers=headers) + f"https://api.github.com/user?access_token={access_token}", headers=headers + ) user_info = res.json() email_info = requests.get( f"https://api.github.com/user/emails?access_token={access_token}", - headers=headers).json() + headers=headers, + ).json() user_info["email"] = next( - (email for email in email_info if email['primary'] == True), - None)["email"] + (email for email in email_info if email["primary"] == True), None + )["email"] return user_info -@manager.route("/logout", methods=['GET']) +@manager.route("/logout", methods=["GET"]) @login_required def log_out(): + """ + User logout endpoint. + --- + tags: + - User + security: + - ApiKeyAuth: [] + responses: + 200: + description: Logout successful. + schema: + type: object + """ current_user.access_token = "" current_user.save() logout_user() @@ -248,20 +373,62 @@ def log_out(): @manager.route("/setting", methods=["POST"]) @login_required def setting_user(): + """ + Update user settings. + --- + tags: + - User + security: + - ApiKeyAuth: [] + parameters: + - in: body + name: body + description: User settings to update. + required: true + schema: + type: object + properties: + nickname: + type: string + description: New nickname. + email: + type: string + description: New email. + responses: + 200: + description: Settings updated successfully. + schema: + type: object + """ update_dict = {} request_data = request.json if request_data.get("password"): new_password = request_data.get("new_password") if not check_password_hash( - current_user.password, decrypt(request_data["password"])): - return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!') + current_user.password, decrypt(request_data["password"]) + ): + return get_json_result( + data=False, + retcode=RetCode.AUTHENTICATION_ERROR, + retmsg="Password error!", + ) if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password)) for k in request_data.keys(): - if k in ["password", "new_password", "email", "status", "is_superuser", "login_channel", "is_anonymous", - "is_active", "is_authenticated", "last_login_time"]: + if k in [ + "password", + "new_password", + "email", + "status", + "is_superuser", + "login_channel", + "is_anonymous", + "is_active", + "is_authenticated", + "last_login_time", + ]: continue update_dict[k] = request_data[k] @@ -270,12 +437,37 @@ def setting_user(): return get_json_result(data=True) except Exception as e: stat_logger.exception(e) - return get_json_result(data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR) + return get_json_result( + data=False, retmsg="Update failure!", retcode=RetCode.EXCEPTION_ERROR + ) @manager.route("/info", methods=["GET"]) @login_required def user_profile(): + """ + Get user profile information. + --- + tags: + - User + security: + - ApiKeyAuth: [] + responses: + 200: + description: User profile retrieved successfully. + schema: + type: object + properties: + id: + type: string + description: User ID. + nickname: + type: string + description: User nickname. + email: + type: string + description: User email. + """ return get_json_result(data=current_user.to_dict()) @@ -310,13 +502,13 @@ def user_register(user_id, user): "asr_id": ASR_MDL, "parser_ids": PARSERS, "img2txt_id": IMAGE2TEXT_MDL, - "rerank_id": RERANK_MDL + "rerank_id": RERANK_MDL, } usr_tenant = { "tenant_id": user_id, "user_id": user_id, "invited_by": user_id, - "role": UserTenantRole.OWNER + "role": UserTenantRole.OWNER, } file_id = get_uuid() file = { @@ -331,13 +523,16 @@ def user_register(user_id, user): } tenant_llm = [] for llm in LLMService.query(fid=LLM_FACTORY): - tenant_llm.append({"tenant_id": user_id, - "llm_factory": LLM_FACTORY, - "llm_name": llm.llm_name, - "model_type": llm.model_type, - "api_key": API_KEY, - "api_base": LLM_BASE_URL - }) + tenant_llm.append( + { + "tenant_id": user_id, + "llm_factory": LLM_FACTORY, + "llm_name": llm.llm_name, + "model_type": llm.model_type, + "api_key": API_KEY, + "api_base": LLM_BASE_URL, + } + ) if not UserService.save(**user): return @@ -351,21 +546,52 @@ def user_register(user_id, user): @manager.route("/register", methods=["POST"]) @validate_request("nickname", "email", "password") def user_add(): + """ + Register a new user. + --- + tags: + - User + parameters: + - in: body + name: body + description: Registration details. + required: true + schema: + type: object + properties: + nickname: + type: string + description: User nickname. + email: + type: string + description: User email. + password: + type: string + description: User password. + responses: + 200: + description: Registration successful. + schema: + type: object + """ req = request.json email_address = req["email"] # Validate the email address if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,5}$", email_address): - return get_json_result(data=False, - retmsg=f'Invalid email address: {email_address}!', - retcode=RetCode.OPERATING_ERROR) + return get_json_result( + data=False, + retmsg=f"Invalid email address: {email_address}!", + retcode=RetCode.OPERATING_ERROR, + ) # Check if the email address is already used if UserService.query(email=email_address): return get_json_result( data=False, - retmsg=f'Email: {email_address} has already registered!', - retcode=RetCode.OPERATING_ERROR) + retmsg=f"Email: {email_address} has already registered!", + retcode=RetCode.OPERATING_ERROR, + ) # Construct user info data nickname = req["nickname"] @@ -383,25 +609,55 @@ def user_add(): try: users = user_register(user_id, user_dict) if not users: - raise Exception(f'Fail to register {email_address}.') + raise Exception(f"Fail to register {email_address}.") if len(users) > 1: - raise Exception(f'Same email: {email_address} exists!') + raise Exception(f"Same email: {email_address} exists!") user = users[0] login_user(user) - return construct_response(data=user.to_json(), - auth=user.get_id(), - retmsg=f"{nickname}, welcome aboard!") + return construct_response( + data=user.to_json(), + auth=user.get_id(), + retmsg=f"{nickname}, welcome aboard!", + ) except Exception as e: rollback_user_registration(user_id) stat_logger.exception(e) - return get_json_result(data=False, - retmsg=f'User registration failure, error: {str(e)}', - retcode=RetCode.EXCEPTION_ERROR) + return get_json_result( + data=False, + retmsg=f"User registration failure, error: {str(e)}", + retcode=RetCode.EXCEPTION_ERROR, + ) @manager.route("/tenant_info", methods=["GET"]) @login_required def tenant_info(): + """ + Get tenant information. + --- + tags: + - Tenant + security: + - ApiKeyAuth: [] + responses: + 200: + description: Tenant information retrieved successfully. + schema: + type: object + properties: + tenant_id: + type: string + description: Tenant ID. + name: + type: string + description: Tenant name. + llm_id: + type: string + description: LLM ID. + embd_id: + type: string + description: Embedding model ID. + """ try: tenants = TenantService.get_info_by(current_user.id) if not tenants: @@ -415,6 +671,42 @@ def tenant_info(): @login_required @validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id") def set_tenant_info(): + """ + Update tenant information. + --- + tags: + - Tenant + security: + - ApiKeyAuth: [] + parameters: + - in: body + name: body + description: Tenant information to update. + required: true + schema: + type: object + properties: + tenant_id: + type: string + description: Tenant ID. + llm_id: + type: string + description: LLM ID. + embd_id: + type: string + description: Embedding model ID. + asr_id: + type: string + description: ASR model ID. + img2txt_id: + type: string + description: Image to Text model ID. + responses: + 200: + description: Tenant information updated successfully. + schema: + type: object + """ req = request.json try: tid = req["tenant_id"] diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 297227796..186b00b11 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -27,7 +27,11 @@ from api.apps import app from api.db.runtime_config import RuntimeConfig from api.db.services.document_service import DocumentService from api.settings import ( - HOST, HTTP_PORT, access_logger, database_logger, stat_logger, + HOST, + HTTP_PORT, + access_logger, + database_logger, + stat_logger, ) from api import utils @@ -45,27 +49,33 @@ def update_progress(): stat_logger.error("update_progress exception:" + str(e)) -if __name__ == '__main__': - print(r""" +if __name__ == "__main__": + print( + r""" ____ ___ ______ ______ __ / __ \ / | / ____// ____// /____ _ __ / /_/ // /| | / / __ / /_ / // __ \| | /| / / / _, _// ___ |/ /_/ // __/ / // /_/ /| |/ |/ / /_/ |_|/_/ |_|\____//_/ /_/ \____/ |__/|__/ - """, flush=True) - stat_logger.info( - f'project base: {utils.file_utils.get_project_base_directory()}' + """, + flush=True, ) + stat_logger.info(f"project base: {utils.file_utils.get_project_base_directory()}") # init db init_web_db() init_web_data() # init runtime config import argparse + parser = argparse.ArgumentParser() - parser.add_argument('--version', default=False, help="rag flow version", action='store_true') - parser.add_argument('--debug', default=False, help="debug mode", action='store_true') + parser.add_argument( + "--version", default=False, help="rag flow version", action="store_true" + ) + parser.add_argument( + "--debug", default=False, help="debug mode", action="store_true" + ) args = parser.parse_args() if args.version: print(get_versions()) @@ -78,7 +88,7 @@ if __name__ == '__main__': RuntimeConfig.init_env() RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT) - peewee_logger = logging.getLogger('peewee') + peewee_logger = logging.getLogger("peewee") peewee_logger.propagate = False # rag_arch.common.log.ROpenHandler peewee_logger.addHandler(database_logger.handlers[0]) @@ -93,7 +103,14 @@ if __name__ == '__main__': werkzeug_logger = logging.getLogger("werkzeug") for h in access_logger.handlers: werkzeug_logger.addHandler(h) - run_simple(hostname=HOST, port=HTTP_PORT, application=app, threaded=True, use_reloader=RuntimeConfig.DEBUG, use_debugger=RuntimeConfig.DEBUG) + run_simple( + hostname=HOST, + port=HTTP_PORT, + application=app, + threaded=True, + use_reloader=RuntimeConfig.DEBUG, + use_debugger=RuntimeConfig.DEBUG, + ) except Exception: traceback.print_exc() - os.kill(os.getpid(), signal.SIGKILL) \ No newline at end of file + os.kill(os.getpid(), signal.SIGKILL) diff --git a/poetry.lock b/poetry.lock index affd929c7..6429c6650 100644 --- a/poetry.lock +++ b/poetry.lock @@ -435,6 +435,17 @@ files = [ {file = "Aspose.Slides-24.10.0-py3-none-win_amd64.whl", hash = "sha256:8980015fbc32c1e70e80444c70a642597511300ead6b352183bf74ba3da67f2d"}, ] +[[package]] +name = "async-timeout" +version = "4.0.3" +description = "Timeout context manager for asyncio programs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, + {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, +] + [[package]] name = "attrs" version = "24.2.0" @@ -1912,7 +1923,10 @@ files = [ huggingface-hub = ">=0.20,<1.0" loguru = ">=0.7.2,<0.8.0" mmh3 = ">=4.0,<5.0" -numpy = {version = ">=1.26,<2", markers = "python_version >= \"3.12\""} +numpy = [ + {version = ">=1.21,<2", markers = "python_version < \"3.12\""}, + {version = ">=1.26,<2", markers = "python_version >= \"3.12\""}, +] onnx = ">=1.15.0,<2.0.0" onnxruntime = ">=1.17.0,<2.0.0" pillow = ">=10.3.0,<11.0.0" @@ -2037,6 +2051,24 @@ sentence_transformers = "*" torch = ">=1.6.0" transformers = ">=4.33.0" +[[package]] +name = "flasgger" +version = "0.9.7.1" +description = "Extract swagger specs from your flask project" +optional = false +python-versions = "*" +files = [ + {file = "flasgger-0.9.7.1.tar.gz", hash = "sha256:ca098e10bfbb12f047acc6299cc70a33851943a746e550d86e65e60d4df245fb"}, +] + +[package.dependencies] +Flask = ">=0.10" +jsonschema = ">=3.0.1" +mistune = "*" +packaging = "*" +PyYAML = ">=3.0" +six = ">=1.10.0" + [[package]] name = "flask" version = "3.0.3" @@ -4381,6 +4413,17 @@ httpx = ">=0.25,<1" orjson = ">=3.9.10,<3.11" pydantic = ">=2.5.2,<3" +[[package]] +name = "mistune" +version = "3.0.2" +description = "A sane and fast Markdown parser with useful plugins and renderers" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mistune-3.0.2-py3-none-any.whl", hash = "sha256:71481854c30fdbc938963d3605b72501f5c10a9320ecd412c121c163a1c7d205"}, + {file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"}, +] + [[package]] name = "mkl" version = "2021.4.0" @@ -5149,7 +5192,10 @@ files = [ ] [package.dependencies] -numpy = {version = ">=1.26.0", markers = "python_version >= \"3.12\""} +numpy = [ + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, +] [[package]] name = "opencv-python-headless" @@ -5168,7 +5214,10 @@ files = [ ] [package.dependencies] -numpy = {version = ">=1.26.0", markers = "python_version >= \"3.12\""} +numpy = [ + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, +] [[package]] name = "openpyxl" @@ -5350,7 +5399,10 @@ files = [ ] [package.dependencies] -numpy = {version = ">=1.26.0", markers = "python_version >= \"3.12\""} +numpy = [ + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, +] python-dateutil = ">=2.8.2" pytz = ">=2020.1" tzdata = ">=2022.7" @@ -7009,6 +7061,24 @@ lxml = "*" [package.extras] test = ["timeout-decorator"] +[[package]] +name = "redis" +version = "5.0.3" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.7" +files = [ + {file = "redis-5.0.3-py3-none-any.whl", hash = "sha256:5da9b8fe9e1254293756c16c008e8620b3d15fcc6dde6babde9541850e72a32d"}, + {file = "redis-5.0.3.tar.gz", hash = "sha256:4973bae7444c0fbed64a06b87446f79361cb7e4ec1538c022d696ed7a5015580"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} + +[package.extras] +hiredis = ["hiredis (>=1.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] + [[package]] name = "referencing" version = "0.35.1" @@ -8468,6 +8538,7 @@ nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \" nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" +triton = {version = "2.3.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} typing-extensions = ">=4.8.0" [package.extras] @@ -8611,6 +8682,29 @@ files = [ trio = ">=0.11" wsproto = ">=0.14" +[[package]] +name = "triton" +version = "2.3.0" +description = "A language and compiler for custom Deep Learning operations" +optional = false +python-versions = "*" +files = [ + {file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"}, + {file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"}, + {file = "triton-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:218d742e67480d9581bafb73ed598416cc8a56f6316152e5562ee65e33de01c0"}, + {file = "triton-2.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:381ec6b3dac06922d3e4099cfc943ef032893b25415de295e82b1a82b0359d2c"}, + {file = "triton-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:038e06a09c06a164fef9c48de3af1e13a63dc1ba3c792871e61a8e79720ea440"}, + {file = "triton-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8f636e0341ac348899a47a057c3daea99ea7db31528a225a3ba4ded28ccc65"}, +] + +[package.dependencies] +filelock = "*" + +[package.extras] +build = ["cmake (>=3.20)", "lit"] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] +tutorials = ["matplotlib", "pandas", "tabulate", "torch"] + [[package]] name = "typer" version = "0.12.5" @@ -9446,5 +9540,5 @@ files = [ [metadata] lock-version = "2.0" -python-versions = ">=3.12,<3.13" -content-hash = "9c488418342dcd2a1ff625db0da677d086e309c9e4285b46c622f1099af4850f" +python-versions = ">=3.11,<3.13" +content-hash = "74a9b4afef47cc36d638b43fd918ece27d65259af1ca9e5b17f6b239774e8bf9" diff --git a/pyproject.toml b/pyproject.toml index ad667d2cc..3cb576492 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ readme = "README.md" package-mode = false [tool.poetry.dependencies] -python = ">=3.12,<3.13" +python = ">=3.11,<3.13" datrie = "0.8.2" akshare = "^1.14.81" azure-storage-blob = "12.22.0" @@ -114,6 +114,7 @@ graspologic = "^3.4.1" pymysql = "^1.1.1" mini-racer = "^0.12.4" pyicu = "^2.13.1" +flasgger = "^0.9.7.1" [tool.poetry.group.full]