add owner check for team work (#2892)

### What problem does this PR solve?

#2834

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu 2024-10-18 13:48:57 +08:00 committed by GitHub
parent 8fdfa0f669
commit c760f058df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 117 additions and 13 deletions

View File

@ -45,6 +45,9 @@ class ExeSQLParam(ComponentParamBase):
self.check_positive_integer(self.port, "IP Port")
self.check_empty(self.password, "Database password")
self.check_positive_integer(self.top_n, "Number of records")
if self.database == "rag_flow":
if self.host == "ragflow-mysql": raise ValueError("The host is not accessible.")
if self.password == "infini_rag_flow": raise ValueError("The host is not accessible.")
class ExeSQL(ComponentBase, ABC):

View File

@ -209,9 +209,17 @@ def list_docs():
@manager.route('/infos', methods=['POST'])
@login_required
def docinfos():
req = request.json
doc_ids = req["doc_ids"]
for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
docs = DocumentService.get_by_ids(doc_ids)
return get_json_result(data=list(docs.dicts()))
@ -242,11 +250,17 @@ def thumbnails():
def change_status():
req = request.json
if str(req["status"]) not in ["0", "1"]:
get_json_result(
return get_json_result(
data=False,
retmsg='"Status" must be either 0 or 1!',
retcode=RetCode.ARGUMENT_ERROR)
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR)
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
@ -285,6 +299,15 @@ def rm():
req = request.json
doc_ids = req["doc_id"]
if isinstance(doc_ids, str): doc_ids = [doc_ids]
for doc_id in doc_ids:
if not DocumentService.accessible4deletion(doc_id, current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, current_user.id)
@ -323,6 +346,13 @@ def rm():
@validate_request("doc_ids", "run")
def run():
req = request.json
for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
try:
for id in req["doc_ids"]:
info = {"run": str(req["run"]), "progress": 0}
@ -356,6 +386,12 @@ def run():
@validate_request("doc_id", "name")
def rename():
req = request.json
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
@ -416,6 +452,13 @@ def get(doc_id):
@validate_request("doc_id", "parser_id")
def change_parser():
req = request.json
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from elasticsearch_dsl import Q
from flask import request
from flask_login import login_required, current_user
@ -23,14 +22,12 @@ from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid, get_format_time
from api.db import StatusEnum, UserTenantRole, FileSource
from api.utils import get_uuid
from api.db import StatusEnum, FileSource
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import Knowledgebase, File
from api.settings import stat_logger, RetCode
from api.db.db_models import File
from api.settings import RetCode
from api.utils.api_utils import get_json_result
from rag.nlp import search
from rag.utils.es_conn import ELASTICSEARCH
@manager.route('/create', methods=['post'])
@ -65,6 +62,12 @@ def create():
def update():
req = request.json
req["name"] = req["name"].strip()
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
try:
if not KnowledgebaseService.query(
created_by=current_user.id, id=req["kb_id"]):
@ -139,6 +142,12 @@ def list_kbs():
@validate_request("kb_id")
def rm():
req = request.json
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
try:
kbs = KnowledgebaseService.query(
created_by=current_user.id, id=req["kb_id"])

View File

@ -38,7 +38,7 @@ from rag.utils.storage_factory import STORAGE_IMPL
from rag.nlp import search, rag_tokenizer
from api.db import FileType, TaskStatus, ParserType, LLMType
from api.db.db_models import DB, Knowledgebase, Tenant, Task
from api.db.db_models import DB, Knowledgebase, Tenant, Task, UserTenant
from api.db.db_models import Document
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -263,6 +263,33 @@ class DocumentService(CommonService):
return
return docs[0]["tenant_id"]
@classmethod
@DB.connection_context()
def accessible(cls, doc_id, user_id):
docs = cls.model.select(
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True
@classmethod
@DB.connection_context()
def accessible4deletion(cls, doc_id, user_id):
docs = cls.model.select(
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True
@classmethod
@DB.connection_context()
def get_embd_id(cls, doc_id):

View File

@ -14,7 +14,7 @@
# limitations under the License.
#
from api.db import StatusEnum, TenantPermission
from api.db.db_models import Knowledgebase, DB, Tenant, User
from api.db.db_models import Knowledgebase, DB, Tenant, User, UserTenant
from api.db.services.common_service import CommonService
@ -182,3 +182,25 @@ class KnowledgebaseService(CommonService):
kbs = kbs.paginate(page_number, items_per_page)
return list(kbs.dicts())
@classmethod
@DB.connection_context()
def accessible(cls, kb_id, user_id):
docs = cls.model.select(
cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True
@classmethod
@DB.connection_context()
def accessible4deletion(cls, kb_id, user_id):
docs = cls.model.select(
cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
import os
from datetime import date
from enum import IntEnum, Enum
from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import LoggerFactory, getLogger
@ -143,9 +144,8 @@ HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
SECRET_KEY = get_base_config(
RAG_FLOW_SERVICE_NAME,
{}).get(
"secret_key",
"infiniflow")
{}).get("secret_key", str(date.today()))
TOKEN_EXPIRE_IN = get_base_config(
RAG_FLOW_SERVICE_NAME, {}).get(
"token_expires_in", 3600)