mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 06:45:58 +08:00
Add Authorization checks (#2221)
### What problem does this PR solve? Add Authorization checks #2203 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Feiue <10215101452@stu.ecun.edu.cn> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
parent
4f05803690
commit
0164856343
@ -18,6 +18,7 @@ from functools import partial
|
|||||||
from flask import request, Response
|
from flask import request, Response
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
||||||
|
from api.settings import RetCode
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
@ -43,6 +44,10 @@ def canvas_list():
|
|||||||
@login_required
|
@login_required
|
||||||
def rm():
|
def rm():
|
||||||
for i in request.json["canvas_ids"]:
|
for i in request.json["canvas_ids"]:
|
||||||
|
if not UserCanvasService.query(user_id=current_user.id,id=i):
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg=f'Only owner of canvas authorized for this operation.',
|
||||||
|
retcode=RetCode.OPERATING_ERROR)
|
||||||
UserCanvasService.delete_by_id(i)
|
UserCanvasService.delete_by_id(i)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
@ -13,16 +13,20 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import json
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from db.services.user_service import UserTenantService
|
||||||
from flask import request, Response
|
from flask import request, Response
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
|
from api.db import LLMType
|
||||||
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
||||||
from api.db.services.llm_service import LLMBundle, TenantService
|
from api.db.services.llm_service import LLMBundle, TenantService
|
||||||
from api.db import LLMType
|
from api.settings import RetCode
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
import json
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/set', methods=['POST'])
|
@manager.route('/set', methods=['POST'])
|
||||||
@ -72,6 +76,14 @@ def get():
|
|||||||
e, conv = ConversationService.get_by_id(conv_id)
|
e, conv = ConversationService.get_by_id(conv_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(retmsg="Conversation not found!")
|
return get_data_error_result(retmsg="Conversation not found!")
|
||||||
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
|
for tenant in tenants:
|
||||||
|
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg=f'Only owner of conversation authorized for this operation.',
|
||||||
|
retcode=RetCode.OPERATING_ERROR)
|
||||||
conv = conv.to_dict()
|
conv = conv.to_dict()
|
||||||
return get_json_result(data=conv)
|
return get_json_result(data=conv)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -84,6 +96,17 @@ def rm():
|
|||||||
conv_ids = request.json["conversation_ids"]
|
conv_ids = request.json["conversation_ids"]
|
||||||
try:
|
try:
|
||||||
for cid in conv_ids:
|
for cid in conv_ids:
|
||||||
|
exist, conv = ConversationService.get_by_id(cid)
|
||||||
|
if not exist:
|
||||||
|
return get_data_error_result(retmsg="Conversation not found!")
|
||||||
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
|
for tenant in tenants:
|
||||||
|
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg=f'Only owner of conversation authorized for this operation.',
|
||||||
|
retcode=RetCode.OPERATING_ERROR)
|
||||||
ConversationService.delete_by_id(cid)
|
ConversationService.delete_by_id(cid)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -95,6 +118,10 @@ def rm():
|
|||||||
def list_convsersation():
|
def list_convsersation():
|
||||||
dialog_id = request.args["dialog_id"]
|
dialog_id = request.args["dialog_id"]
|
||||||
try:
|
try:
|
||||||
|
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg=f'Only owner of dialog authorized for this operation.',
|
||||||
|
retcode=RetCode.OPERATING_ERROR)
|
||||||
convs = ConversationService.query(
|
convs = ConversationService.query(
|
||||||
dialog_id=dialog_id,
|
dialog_id=dialog_id,
|
||||||
order_by=ConversationService.model.create_time,
|
order_by=ConversationService.model.create_time,
|
||||||
@ -107,7 +134,7 @@ def list_convsersation():
|
|||||||
|
|
||||||
@manager.route('/completion', methods=['POST'])
|
@manager.route('/completion', methods=['POST'])
|
||||||
@login_required
|
@login_required
|
||||||
#@validate_request("conversation_id", "messages")
|
@validate_request("conversation_id", "messages")
|
||||||
def completion():
|
def completion():
|
||||||
req = request.json
|
req = request.json
|
||||||
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
|
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
|
||||||
@ -141,7 +168,8 @@ def completion():
|
|||||||
nonlocal conv, message_id
|
nonlocal conv, message_id
|
||||||
if not conv.reference:
|
if not conv.reference:
|
||||||
conv.reference.append(ans["reference"])
|
conv.reference.append(ans["reference"])
|
||||||
else: conv.reference[-1] = ans["reference"]
|
else:
|
||||||
|
conv.reference[-1] = ans["reference"]
|
||||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"],
|
conv.message[-1] = {"role": "assistant", "content": ans["answer"],
|
||||||
"id": message_id, "prompt": ans.get("prompt", "")}
|
"id": message_id, "prompt": ans.get("prompt", "")}
|
||||||
ans["id"] = message_id
|
ans["id"] = message_id
|
||||||
@ -194,6 +222,7 @@ def tts():
|
|||||||
return get_data_error_result(retmsg="No default TTS model is set")
|
return get_data_error_result(retmsg="No default TTS model is set")
|
||||||
|
|
||||||
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
|
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
|
||||||
|
|
||||||
def stream_audio():
|
def stream_audio():
|
||||||
try:
|
try:
|
||||||
for chunk in tts_mdl.tts(text):
|
for chunk in tts_mdl.tts(text):
|
||||||
|
@ -19,7 +19,8 @@ from flask_login import login_required, current_user
|
|||||||
from api.db.services.dialog_service import DialogService
|
from api.db.services.dialog_service import DialogService
|
||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
|
from api.settings import RetCode
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
@ -164,9 +165,19 @@ def list_dialogs():
|
|||||||
@validate_request("dialog_ids")
|
@validate_request("dialog_ids")
|
||||||
def rm():
|
def rm():
|
||||||
req = request.json
|
req = request.json
|
||||||
|
dialog_list=[]
|
||||||
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
try:
|
try:
|
||||||
DialogService.update_many_by_id(
|
for id in req["dialog_ids"]:
|
||||||
[{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
|
for tenant in tenants:
|
||||||
|
if DialogService.query(tenant_id=tenant.tenant_id, id=id):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg=f'Only owner of dialog authorized for this operation.',
|
||||||
|
retcode=RetCode.OPERATING_ERROR)
|
||||||
|
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
|
||||||
|
DialogService.update_many_by_id(dialog_list)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
@ -35,7 +35,7 @@ from api.db.services.file2document_service import File2DocumentService
|
|||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.task_service import TaskService, queue_tasks
|
from api.db.services.task_service import TaskService, queue_tasks
|
||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from graphrag.mind_map_extractor import MindMapExtractor
|
from graphrag.mind_map_extractor import MindMapExtractor
|
||||||
from rag.app import naive
|
from rag.app import naive
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
@ -189,6 +189,15 @@ def list_docs():
|
|||||||
if not kb_id:
|
if not kb_id:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
|
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
|
||||||
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
|
for tenant in tenants:
|
||||||
|
if KnowledgebaseService.query(
|
||||||
|
tenant_id=tenant.tenant_id, id=kb_id):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
||||||
|
retcode=RetCode.OPERATING_ERROR)
|
||||||
keywords = request.args.get("keywords", "")
|
keywords = request.args.get("keywords", "")
|
||||||
|
|
||||||
page_number = int(request.args.get("page", 1))
|
page_number = int(request.args.get("page", 1))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user