From e0c0bdeb0ad76cf23978a7a2a2b70d2c93b01af0 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Fri, 18 Oct 2024 11:30:19 +0800 Subject: [PATCH] add team tag to kb (#2890) ### What problem does this PR solve? #2834 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/tenant_app.py | 4 +-- api/apps/user_app.py | 3 ++- api/db/services/knowledgebase_service.py | 32 ++++++++++++++++++------ 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/api/apps/tenant_app.py b/api/apps/tenant_app.py index 3d61992ac..e852b2ca4 100644 --- a/api/apps/tenant_app.py +++ b/api/apps/tenant_app.py @@ -60,7 +60,7 @@ def create(tenant_id): role=UserTenantRole.INVITE, status=StatusEnum.VALID.value) - usr = list(usrs.dicts())[0] + usr = usrs[0].to_dict() usr = {k: v for k, v in usr.items() if k in ["id", "avatar", "email", "nickname"]} return get_json_result(data=usr) @@ -88,7 +88,7 @@ def tenant_list(): return server_error_response(e) -@manager.route("/agree/", methods=["GET"]) +@manager.route("/agree/", methods=["PUT"]) @login_required def agree(tenant_id): try: diff --git a/api/apps/user_app.py b/api/apps/user_app.py index 4d7947c56..bfb9b291e 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -260,7 +260,8 @@ def setting_user(): update_dict["password"] = generate_password_hash(decrypt(new_password)) for k in request_data.keys(): - if k in ["password", "new_password"]: + if k in ["password", "new_password", "email", "status", "is_superuser", "login_channel", "is_anonymous", + "is_active", "is_authenticated", "last_login_time"]: continue update_dict[k] = request_data[k] diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index 9fe75d461..e79887ac9 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -14,7 +14,7 @@ # limitations under the License. # from api.db import StatusEnum, TenantPermission -from api.db.db_models import Knowledgebase, DB, Tenant +from api.db.db_models import Knowledgebase, DB, Tenant, User from api.db.services.common_service import CommonService @@ -25,10 +25,26 @@ class KnowledgebaseService(CommonService): @DB.connection_context() def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc): - kbs = cls.model.select().where( + fields = [ + cls.model.id, + cls.model.avatar, + cls.model.name, + cls.model.language, + cls.model.description, + cls.model.permission, + cls.model.doc_num, + cls.model.token_num, + cls.model.chunk_num, + cls.model.parser_id, + cls.model.embd_id, + User.nickname, + User.avatar.alias('tenant_avatar'), + cls.model.update_time + ] + kbs = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where( ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | ( - cls.model.tenant_id == user_id)) + cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value) ) if desc: @@ -63,14 +79,14 @@ class KnowledgebaseService(CommonService): if count == -1: return kbs[offset:] - return kbs[offset:offset+count] + return kbs[offset:offset + count] @classmethod @DB.connection_context() def get_detail(cls, kb_id): fields = [ cls.model.id, - #Tenant.embd_id, + # Tenant.embd_id, cls.model.embd_id, cls.model.avatar, cls.model.name, @@ -83,14 +99,14 @@ class KnowledgebaseService(CommonService): cls.model.parser_id, cls.model.parser_config] kbs = cls.model.select(*fields).join(Tenant, on=( - (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where( + (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where( (cls.model.id == kb_id), (cls.model.status == StatusEnum.VALID.value) ) if not kbs: return d = kbs[0].to_dict() - #d["embd_id"] = kbs[0].tenant.embd_id + # d["embd_id"] = kbs[0].tenant.embd_id return d @classmethod @@ -146,7 +162,7 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() def get_list(cls, joined_tenant_ids, user_id, - page_number, items_per_page, orderby, desc, id , name): + page_number, items_per_page, orderby, desc, id, name): kbs = cls.model.select() if id: kbs = kbs.where(cls.model.id == id)