mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-20 05:00:01 +08:00
Refa. (#7022)
### What problem does this PR solve? ### Type of change - [x] Refactoring
This commit is contained in:
parent
7a34159737
commit
5af2d57086
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import binascii
|
import binascii
|
||||||
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
@ -31,6 +32,7 @@ from api.db.services.common_service import CommonService
|
|||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.langfuse_service import TenantLangfuseService
|
from api.db.services.langfuse_service import TenantLangfuseService
|
||||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
||||||
|
from api.utils import current_timestamp, datetime_format
|
||||||
from rag.app.resume import forbidden_select_fields4resume
|
from rag.app.resume import forbidden_select_fields4resume
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp.search import index_name
|
from rag.nlp.search import index_name
|
||||||
@ -42,6 +44,39 @@ from rag.utils.tavily_conn import Tavily
|
|||||||
class DialogService(CommonService):
|
class DialogService(CommonService):
|
||||||
model = Dialog
|
model = Dialog
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save(cls, **kwargs):
|
||||||
|
"""Save a new record to database.
|
||||||
|
|
||||||
|
This method creates a new record in the database with the provided field values,
|
||||||
|
forcing an insert operation rather than an update.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Record field values as keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model instance: The created record object.
|
||||||
|
"""
|
||||||
|
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||||
|
return sample_obj
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_many_by_id(cls, data_list):
|
||||||
|
"""Update multiple records by their IDs.
|
||||||
|
|
||||||
|
This method updates multiple records in the database, identified by their IDs.
|
||||||
|
It automatically updates the update_time and update_date fields for each record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_list (list): List of dictionaries containing record data to update.
|
||||||
|
Each dictionary must include an 'id' field.
|
||||||
|
"""
|
||||||
|
with DB.atomic():
|
||||||
|
for data in data_list:
|
||||||
|
data["update_time"] = current_timestamp()
|
||||||
|
data["update_date"] = datetime_format(datetime.now())
|
||||||
|
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
|
def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
|
||||||
@ -461,7 +496,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
|
|
||||||
# compose Markdown table
|
# compose Markdown table
|
||||||
columns = (
|
columns = (
|
||||||
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
||||||
)
|
)
|
||||||
|
|
||||||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||||
|
@ -179,6 +179,15 @@ class UserTenantService(CommonService):
|
|||||||
"""
|
"""
|
||||||
model = UserTenant
|
model = UserTenant
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def filter_by_id(cls, user_tenant_id):
|
||||||
|
try:
|
||||||
|
user_tenant = cls.model.select().where((cls.model.id == user_tenant_id) & (cls.model.status == StatusEnum.VALID.value)).get()
|
||||||
|
return user_tenant
|
||||||
|
except peewee.DoesNotExist:
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def save(cls, **kwargs):
|
def save(cls, **kwargs):
|
||||||
@ -191,6 +200,7 @@ class UserTenantService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_tenant_id(cls, tenant_id):
|
def get_by_tenant_id(cls, tenant_id):
|
||||||
fields = [
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
cls.model.user_id,
|
cls.model.user_id,
|
||||||
cls.model.status,
|
cls.model.status,
|
||||||
cls.model.role,
|
cls.model.role,
|
||||||
@ -222,3 +232,21 @@ class UserTenantService(CommonService):
|
|||||||
return list(cls.model.select(*fields)
|
return list(cls.model.select(*fields)
|
||||||
.join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
|
.join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
|
||||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_num_members(cls, user_id: str):
|
||||||
|
cnt_members = cls.model.select(peewee.fn.COUNT(cls.model.id)).where(cls.model.tenant_id == user_id).scalar()
|
||||||
|
return cnt_members
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def filter_by_tenant_and_user_id(cls, tenant_id, user_id):
|
||||||
|
try:
|
||||||
|
user_tenant = cls.model.select().where(
|
||||||
|
(cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value) &
|
||||||
|
(cls.model.user_id == user_id)
|
||||||
|
).first()
|
||||||
|
return user_tenant
|
||||||
|
except peewee.DoesNotExist:
|
||||||
|
return None
|
@ -52,7 +52,7 @@ def chunks_format(reference):
|
|||||||
def llm_id2llm_type(llm_id):
|
def llm_id2llm_type(llm_id):
|
||||||
from api.db.services.llm_service import TenantLLMService
|
from api.db.services.llm_service import TenantLLMService
|
||||||
|
|
||||||
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
|
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||||
|
|
||||||
llm_factories = settings.FACTORY_LLM_INFOS
|
llm_factories = settings.FACTORY_LLM_INFOS
|
||||||
for llm_factory in llm_factories:
|
for llm_factory in llm_factories:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user