diff --git a/agent/component/exesql.py b/agent/component/exesql.py index 919ec4e5f..eac305e8b 100644 --- a/agent/component/exesql.py +++ b/agent/component/exesql.py @@ -45,6 +45,9 @@ class ExeSQLParam(ComponentParamBase): self.check_positive_integer(self.port, "IP Port") self.check_empty(self.password, "Database password") self.check_positive_integer(self.top_n, "Number of records") + if self.database == "rag_flow": + if self.host == "ragflow-mysql": raise ValueError("The host is not accessible.") + if self.password == "infini_rag_flow": raise ValueError("The host is not accessible.") class ExeSQL(ComponentBase, ABC): diff --git a/api/apps/document_app.py b/api/apps/document_app.py index f056313b2..0e4af1447 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -209,9 +209,17 @@ def list_docs(): @manager.route('/infos', methods=['POST']) +@login_required def docinfos(): req = request.json doc_ids = req["doc_ids"] + for doc_id in doc_ids: + if not DocumentService.accessible(doc_id, current_user.id): + return get_json_result( + data=False, + retmsg='No authorization.', + retcode=RetCode.AUTHENTICATION_ERROR + ) docs = DocumentService.get_by_ids(doc_ids) return get_json_result(data=list(docs.dicts())) @@ -242,11 +250,17 @@ def thumbnails(): def change_status(): req = request.json if str(req["status"]) not in ["0", "1"]: - get_json_result( + return get_json_result( data=False, retmsg='"Status" must be either 0 or 1!', retcode=RetCode.ARGUMENT_ERROR) + if not DocumentService.accessible(req["doc_id"], current_user.id): + return get_json_result( + data=False, + retmsg='No authorization.', + retcode=RetCode.AUTHENTICATION_ERROR) + try: e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: @@ -285,6 +299,15 @@ def rm(): req = request.json doc_ids = req["doc_id"] if isinstance(doc_ids, str): doc_ids = [doc_ids] + + for doc_id in doc_ids: + if not DocumentService.accessible4deletion(doc_id, current_user.id): + return get_json_result( + data=False, + retmsg='No authorization.', + retcode=RetCode.AUTHENTICATION_ERROR + ) + root_folder = FileService.get_root_folder(current_user.id) pf_id = root_folder["id"] FileService.init_knowledgebase_docs(pf_id, current_user.id) @@ -323,6 +346,13 @@ def rm(): @validate_request("doc_ids", "run") def run(): req = request.json + for doc_id in req["doc_ids"]: + if not DocumentService.accessible(doc_id, current_user.id): + return get_json_result( + data=False, + retmsg='No authorization.', + retcode=RetCode.AUTHENTICATION_ERROR + ) try: for id in req["doc_ids"]: info = {"run": str(req["run"]), "progress": 0} @@ -356,6 +386,12 @@ def run(): @validate_request("doc_id", "name") def rename(): req = request.json + if not DocumentService.accessible(req["doc_id"], current_user.id): + return get_json_result( + data=False, + retmsg='No authorization.', + retcode=RetCode.AUTHENTICATION_ERROR + ) try: e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: @@ -416,6 +452,13 @@ def get(doc_id): @validate_request("doc_id", "parser_id") def change_parser(): req = request.json + + if not DocumentService.accessible(req["doc_id"], current_user.id): + return get_json_result( + data=False, + retmsg='No authorization.', + retcode=RetCode.AUTHENTICATION_ERROR + ) try: e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 7d7f86e2d..551e7867e 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from elasticsearch_dsl import Q from flask import request from flask_login import login_required, current_user @@ -23,14 +22,12 @@ from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.user_service import TenantService, UserTenantService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request -from api.utils import get_uuid, get_format_time -from api.db import StatusEnum, UserTenantRole, FileSource +from api.utils import get_uuid +from api.db import StatusEnum, FileSource from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.db_models import Knowledgebase, File -from api.settings import stat_logger, RetCode +from api.db.db_models import File +from api.settings import RetCode from api.utils.api_utils import get_json_result -from rag.nlp import search -from rag.utils.es_conn import ELASTICSEARCH @manager.route('/create', methods=['post']) @@ -65,6 +62,12 @@ def create(): def update(): req = request.json req["name"] = req["name"].strip() + if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id): + return get_json_result( + data=False, + retmsg='No authorization.', + retcode=RetCode.AUTHENTICATION_ERROR + ) try: if not KnowledgebaseService.query( created_by=current_user.id, id=req["kb_id"]): @@ -139,6 +142,12 @@ def list_kbs(): @validate_request("kb_id") def rm(): req = request.json + if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id): + return get_json_result( + data=False, + retmsg='No authorization.', + retcode=RetCode.AUTHENTICATION_ERROR + ) try: kbs = KnowledgebaseService.query( created_by=current_user.id, id=req["kb_id"]) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 40fd1188e..ed6220aed 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -38,7 +38,7 @@ from rag.utils.storage_factory import STORAGE_IMPL from rag.nlp import search, rag_tokenizer from api.db import FileType, TaskStatus, ParserType, LLMType -from api.db.db_models import DB, Knowledgebase, Tenant, Task +from api.db.db_models import DB, Knowledgebase, Tenant, Task, UserTenant from api.db.db_models import Document from api.db.services.common_service import CommonService from api.db.services.knowledgebase_service import KnowledgebaseService @@ -263,6 +263,33 @@ class DocumentService(CommonService): return return docs[0]["tenant_id"] + @classmethod + @DB.connection_context() + def accessible(cls, doc_id, user_id): + docs = cls.model.select( + cls.model.id).join( + Knowledgebase, on=( + Knowledgebase.id == cls.model.kb_id) + ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) + ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1) + docs = docs.dicts() + if not docs: + return False + return True + + @classmethod + @DB.connection_context() + def accessible4deletion(cls, doc_id, user_id): + docs = cls.model.select( + cls.model.id).join( + Knowledgebase, on=( + Knowledgebase.id == cls.model.kb_id) + ).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1) + docs = docs.dicts() + if not docs: + return False + return True + @classmethod @DB.connection_context() def get_embd_id(cls, doc_id): diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index e79887ac9..602d6097a 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -14,7 +14,7 @@ # limitations under the License. # from api.db import StatusEnum, TenantPermission -from api.db.db_models import Knowledgebase, DB, Tenant, User +from api.db.db_models import Knowledgebase, DB, Tenant, User, UserTenant from api.db.services.common_service import CommonService @@ -182,3 +182,25 @@ class KnowledgebaseService(CommonService): kbs = kbs.paginate(page_number, items_per_page) return list(kbs.dicts()) + + @classmethod + @DB.connection_context() + def accessible(cls, kb_id, user_id): + docs = cls.model.select( + cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) + ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) + docs = docs.dicts() + if not docs: + return False + return True + + @classmethod + @DB.connection_context() + def accessible4deletion(cls, kb_id, user_id): + docs = cls.model.select( + cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1) + docs = docs.dicts() + if not docs: + return False + return True + diff --git a/api/settings.py b/api/settings.py index 9faf7c169..f48a5fe7a 100644 --- a/api/settings.py +++ b/api/settings.py @@ -14,6 +14,7 @@ # limitations under the License. # import os +from datetime import date from enum import IntEnum, Enum from api.utils.file_utils import get_project_base_directory from api.utils.log_utils import LoggerFactory, getLogger @@ -143,9 +144,8 @@ HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") SECRET_KEY = get_base_config( RAG_FLOW_SERVICE_NAME, - {}).get( - "secret_key", - "infiniflow") + {}).get("secret_key", str(date.today())) + TOKEN_EXPIRE_IN = get_base_config( RAG_FLOW_SERVICE_NAME, {}).get( "token_expires_in", 3600)