Add Authorization checks (#2221)

### What problem does this PR solve?

Add Authorization checks
#2203

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Feiue <10215101452@stu.ecun.edu.cn>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
LiuHua 2024-09-04 10:36:15 +08:00 committed by GitHub
parent 4f05803690
commit 0164856343
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 79 additions and 25 deletions

View File

@ -18,6 +18,7 @@ from functools import partial
from flask import request, Response from flask import request, Response
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.settings import RetCode
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from agent.canvas import Canvas from agent.canvas import Canvas
@ -43,6 +44,10 @@ def canvas_list():
@login_required @login_required
def rm(): def rm():
for i in request.json["canvas_ids"]: for i in request.json["canvas_ids"]:
if not UserCanvasService.query(user_id=current_user.id,id=i):
return get_json_result(
data=False, retmsg=f'Only owner of canvas authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
UserCanvasService.delete_by_id(i) UserCanvasService.delete_by_id(i)
return get_json_result(data=True) return get_json_result(data=True)

View File

@ -13,16 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import json
from copy import deepcopy from copy import deepcopy
from db.services.user_service import UserTenantService
from flask import request, Response from flask import request, Response
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db import LLMType
from api.db.services.dialog_service import DialogService, ConversationService, chat from api.db.services.dialog_service import DialogService, ConversationService, chat
from api.db.services.llm_service import LLMBundle, TenantService from api.db.services.llm_service import LLMBundle, TenantService
from api.db import LLMType from api.settings import RetCode
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
import json from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
@manager.route('/set', methods=['POST']) @manager.route('/set', methods=['POST'])
@ -72,6 +76,14 @@ def get():
e, conv = ConversationService.get_by_id(conv_id) e, conv = ConversationService.get_by_id(conv_id)
if not e: if not e:
return get_data_error_result(retmsg="Conversation not found!") return get_data_error_result(retmsg="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of conversation authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
conv = conv.to_dict() conv = conv.to_dict()
return get_json_result(data=conv) return get_json_result(data=conv)
except Exception as e: except Exception as e:
@ -84,6 +96,17 @@ def rm():
conv_ids = request.json["conversation_ids"] conv_ids = request.json["conversation_ids"]
try: try:
for cid in conv_ids: for cid in conv_ids:
exist, conv = ConversationService.get_by_id(cid)
if not exist:
return get_data_error_result(retmsg="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of conversation authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
ConversationService.delete_by_id(cid) ConversationService.delete_by_id(cid)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
@ -95,6 +118,10 @@ def rm():
def list_convsersation(): def list_convsersation():
dialog_id = request.args["dialog_id"] dialog_id = request.args["dialog_id"]
try: try:
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
return get_json_result(
data=False, retmsg=f'Only owner of dialog authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
convs = ConversationService.query( convs = ConversationService.query(
dialog_id=dialog_id, dialog_id=dialog_id,
order_by=ConversationService.model.create_time, order_by=ConversationService.model.create_time,
@ -107,7 +134,7 @@ def list_convsersation():
@manager.route('/completion', methods=['POST']) @manager.route('/completion', methods=['POST'])
@login_required @login_required
#@validate_request("conversation_id", "messages") @validate_request("conversation_id", "messages")
def completion(): def completion():
req = request.json req = request.json
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [ # req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
@ -141,7 +168,8 @@ def completion():
nonlocal conv, message_id nonlocal conv, message_id
if not conv.reference: if not conv.reference:
conv.reference.append(ans["reference"]) conv.reference.append(ans["reference"])
else: conv.reference[-1] = ans["reference"] else:
conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"], conv.message[-1] = {"role": "assistant", "content": ans["answer"],
"id": message_id, "prompt": ans.get("prompt", "")} "id": message_id, "prompt": ans.get("prompt", "")}
ans["id"] = message_id ans["id"] = message_id
@ -194,6 +222,7 @@ def tts():
return get_data_error_result(retmsg="No default TTS model is set") return get_data_error_result(retmsg="No default TTS model is set")
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id) tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
def stream_audio(): def stream_audio():
try: try:
for chunk in tts_mdl.tts(text): for chunk in tts_mdl.tts(text):

View File

@ -19,7 +19,8 @@ from flask_login import login_required, current_user
from api.db.services.dialog_service import DialogService from api.db.services.dialog_service import DialogService
from api.db import StatusEnum from api.db import StatusEnum
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService, UserTenantService
from api.settings import RetCode
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
@ -164,9 +165,19 @@ def list_dialogs():
@validate_request("dialog_ids") @validate_request("dialog_ids")
def rm(): def rm():
req = request.json req = request.json
dialog_list=[]
tenants = UserTenantService.query(user_id=current_user.id)
try: try:
DialogService.update_many_by_id( for id in req["dialog_ids"]:
[{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]]) for tenant in tenants:
if DialogService.query(tenant_id=tenant.tenant_id, id=id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of dialog authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
DialogService.update_many_by_id(dialog_list)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -35,7 +35,7 @@ from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService, queue_tasks from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService, UserTenantService
from graphrag.mind_map_extractor import MindMapExtractor from graphrag.mind_map_extractor import MindMapExtractor
from rag.app import naive from rag.app import naive
from rag.nlp import search from rag.nlp import search
@ -189,6 +189,15 @@ def list_docs():
if not kb_id: if not kb_id:
return get_json_result( return get_json_result(
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR) data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if KnowledgebaseService.query(
tenant_id=tenant.tenant_id, id=kb_id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
keywords = request.args.get("keywords", "") keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 1)) page_number = int(request.args.get("page", 1))