### What problem does this PR solve?


### Type of change

- [x] Refactoring
This commit is contained in:
Kevin Hu 2025-04-15 10:20:33 +08:00 committed by GitHub
parent 7a34159737
commit 5af2d57086
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 78 additions and 15 deletions

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import binascii import binascii
from datetime import datetime
import logging import logging
import re import re
import time import time
@ -31,6 +32,7 @@ from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.llm_service import LLMBundle, TenantLLMService from api.db.services.llm_service import LLMBundle, TenantLLMService
from api.utils import current_timestamp, datetime_format
from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp.search import index_name from rag.nlp.search import index_name
@ -42,6 +44,39 @@ from rag.utils.tavily_conn import Tavily
class DialogService(CommonService): class DialogService(CommonService):
model = Dialog model = Dialog
@classmethod
def save(cls, **kwargs):
"""Save a new record to database.
This method creates a new record in the database with the provided field values,
forcing an insert operation rather than an update.
Args:
**kwargs: Record field values as keyword arguments.
Returns:
Model instance: The created record object.
"""
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj
@classmethod
def update_many_by_id(cls, data_list):
"""Update multiple records by their IDs.
This method updates multiple records in the database, identified by their IDs.
It automatically updates the update_time and update_date fields for each record.
Args:
data_list (list): List of dictionaries containing record data to update.
Each dictionary must include an 'id' field.
"""
with DB.atomic():
for data in data_list:
data["update_time"] = current_timestamp()
data["update_date"] = datetime_format(datetime.now())
cls.model.update(data).where(cls.model.id == data["id"]).execute()
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name): def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
@ -434,11 +469,11 @@ Please write the SQL, only SQL, without any other explanations or text.
Table name: {}; Table name: {};
Table of database fields are as follows: Table of database fields are as follows:
{} {}
Question are as follows: Question are as follows:
{} {}
Please write the SQL, only SQL, without any other explanations or text. Please write the SQL, only SQL, without any other explanations or text.
The SQL error you provided last time is as follows: The SQL error you provided last time is as follows:
{} {}
@ -461,7 +496,7 @@ Please write the SQL, only SQL, without any other explanations or text.
# compose Markdown table # compose Markdown table
columns = ( columns = (
"|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") "|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
) )
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "") line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
@ -557,4 +592,4 @@ def ask(question, kb_ids, tenant_id):
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}): for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
answer = ans answer = ans
yield {"answer": answer, "reference": {}} yield {"answer": answer, "reference": {}}
yield decorate_answer(answer) yield decorate_answer(answer)

View File

@ -30,10 +30,10 @@ from rag.settings import MINIO
class UserService(CommonService): class UserService(CommonService):
"""Service class for managing user-related database operations. """Service class for managing user-related database operations.
This class extends CommonService to provide specialized functionality for user management, This class extends CommonService to provide specialized functionality for user management,
including authentication, user creation, updates, and deletions. including authentication, user creation, updates, and deletions.
Attributes: Attributes:
model: The User model class for database operations. model: The User model class for database operations.
""" """
@ -43,10 +43,10 @@ class UserService(CommonService):
@DB.connection_context() @DB.connection_context()
def filter_by_id(cls, user_id): def filter_by_id(cls, user_id):
"""Retrieve a user by their ID. """Retrieve a user by their ID.
Args: Args:
user_id: The unique identifier of the user. user_id: The unique identifier of the user.
Returns: Returns:
User object if found, None otherwise. User object if found, None otherwise.
""" """
@ -60,11 +60,11 @@ class UserService(CommonService):
@DB.connection_context() @DB.connection_context()
def query_user(cls, email, password): def query_user(cls, email, password):
"""Authenticate a user with email and password. """Authenticate a user with email and password.
Args: Args:
email: User's email address. email: User's email address.
password: User's password in plain text. password: User's password in plain text.
Returns: Returns:
User object if authentication successful, None otherwise. User object if authentication successful, None otherwise.
""" """
@ -111,10 +111,10 @@ class UserService(CommonService):
class TenantService(CommonService): class TenantService(CommonService):
"""Service class for managing tenant-related database operations. """Service class for managing tenant-related database operations.
This class extends CommonService to provide functionality for tenant management, This class extends CommonService to provide functionality for tenant management,
including tenant information retrieval and credit management. including tenant information retrieval and credit management.
Attributes: Attributes:
model: The Tenant model class for database operations. model: The Tenant model class for database operations.
""" """
@ -170,15 +170,24 @@ class TenantService(CommonService):
class UserTenantService(CommonService): class UserTenantService(CommonService):
"""Service class for managing user-tenant relationship operations. """Service class for managing user-tenant relationship operations.
This class extends CommonService to handle the many-to-many relationship This class extends CommonService to handle the many-to-many relationship
between users and tenants, managing user roles and tenant memberships. between users and tenants, managing user roles and tenant memberships.
Attributes: Attributes:
model: The UserTenant model class for database operations. model: The UserTenant model class for database operations.
""" """
model = UserTenant model = UserTenant
@classmethod
@DB.connection_context()
def filter_by_id(cls, user_tenant_id):
try:
user_tenant = cls.model.select().where((cls.model.id == user_tenant_id) & (cls.model.status == StatusEnum.VALID.value)).get()
return user_tenant
except peewee.DoesNotExist:
return None
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def save(cls, **kwargs): def save(cls, **kwargs):
@ -191,6 +200,7 @@ class UserTenantService(CommonService):
@DB.connection_context() @DB.connection_context()
def get_by_tenant_id(cls, tenant_id): def get_by_tenant_id(cls, tenant_id):
fields = [ fields = [
cls.model.id,
cls.model.user_id, cls.model.user_id,
cls.model.status, cls.model.status,
cls.model.role, cls.model.role,
@ -222,3 +232,21 @@ class UserTenantService(CommonService):
return list(cls.model.select(*fields) return list(cls.model.select(*fields)
.join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value))) .join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
.where(cls.model.status == StatusEnum.VALID.value).dicts()) .where(cls.model.status == StatusEnum.VALID.value).dicts())
@classmethod
@DB.connection_context()
def get_num_members(cls, user_id: str):
cnt_members = cls.model.select(peewee.fn.COUNT(cls.model.id)).where(cls.model.tenant_id == user_id).scalar()
return cnt_members
@classmethod
@DB.connection_context()
def filter_by_tenant_and_user_id(cls, tenant_id, user_id):
try:
user_tenant = cls.model.select().where(
(cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value) &
(cls.model.user_id == user_id)
).first()
return user_tenant
except peewee.DoesNotExist:
return None

View File

@ -52,7 +52,7 @@ def chunks_format(reference):
def llm_id2llm_type(llm_id): def llm_id2llm_type(llm_id):
from api.db.services.llm_service import TenantLLMService from api.db.services.llm_service import TenantLLMService
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id) llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
llm_factories = settings.FACTORY_LLM_INFOS llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories: for llm_factory in llm_factories: