fix bug of inserting cites (#76)

This commit is contained in:
KevinHuSh 2024-02-27 17:51:54 +08:00 committed by GitHub
parent 4568a4b2cb
commit 1567e881de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 31 additions and 16 deletions

View File

@ -208,9 +208,9 @@ def user_register(user_id, user):
for llm in LLMService.query(fid=LLM_FACTORY): for llm in LLMService.query(fid=LLM_FACTORY):
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY}) tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
if not UserService.save(**user):return if not UserService.insert(**user):return
TenantService.save(**tenant) TenantService.insert(**tenant)
UserTenantService.save(**usr_tenant) UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm) TenantLLMService.insert_many(tenant_llm)
return UserService.query(email=user["email"]) return UserService.query(email=user["email"])

View File

@ -58,16 +58,16 @@ def init_superuser():
if not UserService.save(**user_info): if not UserService.save(**user_info):
print("【ERROR】can't init admin.") print("【ERROR】can't init admin.")
return return
TenantService.save(**tenant) TenantService.insert(**tenant)
UserTenantService.save(**usr_tenant) UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm) TenantLLMService.insert_many(tenant_llm)
UserService.save(**user_info) print("【INFO】Super user initialized. user name: admin, password: admin. Changing the password after logining is strongly recomanded.")
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}) msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
if msg.find("ERROR: ") == 0: if msg.find("ERROR: ") == 0:
print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg) print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg)
embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"]) embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
v,c = embd_mdl.encode(["Hello!"]) v,c = embd_mdl.encode(["Hello!"])
if c == 0: if c == 0:
print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"])) print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"]))

View File

@ -18,7 +18,7 @@ from datetime import datetime
import peewee import peewee
from api.db.db_models import DB from api.db.db_models import DB
from api.utils import datetime_format from api.utils import datetime_format, current_timestamp, get_uuid
class CommonService: class CommonService:
@ -66,27 +66,42 @@ class CommonService:
sample_obj = cls.model(**kwargs).save(force_insert=True) sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj return sample_obj
@classmethod
@DB.connection_context()
def insert(cls, **kwargs):
if "id" not in kwargs:
kwargs["id"] = get_uuid()
kwargs["create_time"] = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now())
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def insert_many(cls, data_list, batch_size=100): def insert_many(cls, data_list, batch_size=100):
with DB.atomic(): with DB.atomic():
for d in data_list: d["create_time"] = datetime_format(datetime.now()) for d in data_list:
d["create_time"] = current_timestamp()
d["create_date"] = datetime_format(datetime.now())
for i in range(0, len(data_list), batch_size): for i in range(0, len(data_list), batch_size):
cls.model.insert_many(data_list[i:i + batch_size]).execute() cls.model.insert_many(data_list[i:i + batch_size]).execute()
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def update_many_by_id(cls, data_list): def update_many_by_id(cls, data_list):
cur = datetime_format(datetime.now())
with DB.atomic(): with DB.atomic():
for data in data_list: for data in data_list:
data["update_time"] = cur data["update_time"] = current_timestamp()
data["update_date"] = datetime_format(datetime.now())
cls.model.update(data).where(cls.model.id == data["id"]).execute() cls.model.update(data).where(cls.model.id == data["id"]).execute()
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def update_by_id(cls, pid, data): def update_by_id(cls, pid, data):
data["update_time"] = datetime_format(datetime.now()) data["update_time"] = current_timestamp()
data["update_date"] = datetime_format(datetime.now())
num = cls.model.update(data).where(cls.model.id == pid).execute() num = cls.model.update(data).where(cls.model.id == pid).execute()
return num return num

View File

@ -17,16 +17,16 @@ database:
name: 'rag_flow' name: 'rag_flow'
user: 'root' user: 'root'
passwd: 'infini_rag_flow' passwd: 'infini_rag_flow'
host: '123.60.95.134' host: '127.0.0.1'
port: 5455 port: 5455
max_connections: 100 max_connections: 100
stale_timeout: 30 stale_timeout: 30
minio: minio:
user: 'rag_flow' user: 'rag_flow'
passwd: 'infini_rag_flow' passwd: 'infini_rag_flow'
host: '123.60.95.134:9000' host: '127.0.0.1:9000'
es: es:
hosts: 'http://123.60.95.134:9200' hosts: 'http://127.0.0.1:9200'
user_default_llm: user_default_llm:
factory: '通义千问' factory: '通义千问'
chat_model: 'qwen-plus' chat_model: 'qwen-plus'

View File

@ -226,7 +226,7 @@ class Dealer:
continue continue
if i not in cites: if i not in cites:
continue continue
assert int(cites[i]) < len(chunk_v) for c in cites[i]: assert int(c) < len(chunk_v)
res += "##%s$$" % "$".join(cites[i]) res += "##%s$$" % "$".join(cites[i])
return res return res