diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 9ef14777..77f36270 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -14,6 +14,7 @@ # limitations under the License. # import binascii +from datetime import datetime import logging import re import time @@ -31,6 +32,7 @@ from api.db.services.common_service import CommonService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.llm_service import LLMBundle, TenantLLMService +from api.utils import current_timestamp, datetime_format from rag.app.resume import forbidden_select_fields4resume from rag.app.tag import label_question from rag.nlp.search import index_name @@ -42,6 +44,39 @@ from rag.utils.tavily_conn import Tavily class DialogService(CommonService): model = Dialog + @classmethod + def save(cls, **kwargs): + """Save a new record to database. + + This method creates a new record in the database with the provided field values, + forcing an insert operation rather than an update. + + Args: + **kwargs: Record field values as keyword arguments. + + Returns: + Model instance: The created record object. + """ + sample_obj = cls.model(**kwargs).save(force_insert=True) + return sample_obj + + @classmethod + def update_many_by_id(cls, data_list): + """Update multiple records by their IDs. + + This method updates multiple records in the database, identified by their IDs. + It automatically updates the update_time and update_date fields for each record. + + Args: + data_list (list): List of dictionaries containing record data to update. + Each dictionary must include an 'id' field. + """ + with DB.atomic(): + for data in data_list: + data["update_time"] = current_timestamp() + data["update_date"] = datetime_format(datetime.now()) + cls.model.update(data).where(cls.model.id == data["id"]).execute() + @classmethod @DB.connection_context() def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name): @@ -434,11 +469,11 @@ Please write the SQL, only SQL, without any other explanations or text. Table name: {}; Table of database fields are as follows: {} - + Question are as follows: {} Please write the SQL, only SQL, without any other explanations or text. - + The SQL error you provided last time is as follows: {} @@ -461,7 +496,7 @@ Please write the SQL, only SQL, without any other explanations or text. # compose Markdown table columns = ( - "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") + "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") ) line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "") @@ -557,4 +592,4 @@ def ask(question, kb_ids, tenant_id): for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}): answer = ans yield {"answer": answer, "reference": {}} - yield decorate_answer(answer) + yield decorate_answer(answer) \ No newline at end of file diff --git a/api/db/services/user_service.py b/api/db/services/user_service.py index facab201..1edd46c1 100644 --- a/api/db/services/user_service.py +++ b/api/db/services/user_service.py @@ -30,10 +30,10 @@ from rag.settings import MINIO class UserService(CommonService): """Service class for managing user-related database operations. - + This class extends CommonService to provide specialized functionality for user management, including authentication, user creation, updates, and deletions. - + Attributes: model: The User model class for database operations. """ @@ -43,10 +43,10 @@ class UserService(CommonService): @DB.connection_context() def filter_by_id(cls, user_id): """Retrieve a user by their ID. - + Args: user_id: The unique identifier of the user. - + Returns: User object if found, None otherwise. """ @@ -60,11 +60,11 @@ class UserService(CommonService): @DB.connection_context() def query_user(cls, email, password): """Authenticate a user with email and password. - + Args: email: User's email address. password: User's password in plain text. - + Returns: User object if authentication successful, None otherwise. """ @@ -111,10 +111,10 @@ class UserService(CommonService): class TenantService(CommonService): """Service class for managing tenant-related database operations. - + This class extends CommonService to provide functionality for tenant management, including tenant information retrieval and credit management. - + Attributes: model: The Tenant model class for database operations. """ @@ -170,15 +170,24 @@ class TenantService(CommonService): class UserTenantService(CommonService): """Service class for managing user-tenant relationship operations. - + This class extends CommonService to handle the many-to-many relationship between users and tenants, managing user roles and tenant memberships. - + Attributes: model: The UserTenant model class for database operations. """ model = UserTenant + @classmethod + @DB.connection_context() + def filter_by_id(cls, user_tenant_id): + try: + user_tenant = cls.model.select().where((cls.model.id == user_tenant_id) & (cls.model.status == StatusEnum.VALID.value)).get() + return user_tenant + except peewee.DoesNotExist: + return None + @classmethod @DB.connection_context() def save(cls, **kwargs): @@ -191,6 +200,7 @@ class UserTenantService(CommonService): @DB.connection_context() def get_by_tenant_id(cls, tenant_id): fields = [ + cls.model.id, cls.model.user_id, cls.model.status, cls.model.role, @@ -222,3 +232,21 @@ class UserTenantService(CommonService): return list(cls.model.select(*fields) .join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value))) .where(cls.model.status == StatusEnum.VALID.value).dicts()) + + @classmethod + @DB.connection_context() + def get_num_members(cls, user_id: str): + cnt_members = cls.model.select(peewee.fn.COUNT(cls.model.id)).where(cls.model.tenant_id == user_id).scalar() + return cnt_members + + @classmethod + @DB.connection_context() + def filter_by_tenant_and_user_id(cls, tenant_id, user_id): + try: + user_tenant = cls.model.select().where( + (cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value) & + (cls.model.user_id == user_id) + ).first() + return user_tenant + except peewee.DoesNotExist: + return None \ No newline at end of file diff --git a/rag/prompts.py b/rag/prompts.py index 5af74f60..4b3dd866 100644 --- a/rag/prompts.py +++ b/rag/prompts.py @@ -52,7 +52,7 @@ def chunks_format(reference): def llm_id2llm_type(llm_id): from api.db.services.llm_service import TenantLLMService - llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id) + llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id) llm_factories = settings.FACTORY_LLM_INFOS for llm_factory in llm_factories: