diff --git a/api/db/init_data.py b/api/db/init_data.py index fe993033..aa2e3efc 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -148,7 +148,7 @@ def init_llm_factory(): pass break for kb_id in KnowledgebaseService.get_all_ids(): - KnowledgebaseService.update_by_id(kb_id, {"doc_num": DocumentService.get_kb_doc_count(kb_id)}) + KnowledgebaseService.update_document_number_in_init(kb_id=kb_id, doc_num=DocumentService.get_kb_doc_count(kb_id)) diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index 979cba60..278632f8 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -394,3 +394,30 @@ class KnowledgebaseService(CommonService): data["doc_num"] = cls.model.doc_num + 1 num = cls.model.update(data).where(cls.model.id == kb_id).execute() return num + + @classmethod + @DB.connection_context() + def update_document_number_in_init(cls, kb_id, doc_num): + """ + Only use this function when init system + """ + ok, kb = cls.get_by_id(kb_id) + if not ok: + return + kb.doc_num = doc_num + + dirty_fields = kb.dirty_fields + if cls.model._meta.combined.get("update_time") in dirty_fields: + dirty_fields.remove(cls.model._meta.combined["update_time"]) + + if cls.model._meta.combined.get("update_date") in dirty_fields: + dirty_fields.remove(cls.model._meta.combined["update_date"]) + + try: + kb.save(only=dirty_fields) + except ValueError as e: + if str(e) == "no data to save!": + pass # that's OK + else: + raise e +