diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 091cf9a58..f832bdf53 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -15,10 +15,8 @@ # import logging -import inspect from api.utils.log_utils import initRootLogger - -initRootLogger(inspect.getfile(inspect.currentframe())) +initRootLogger("ragflow_server") for module in ["pdfminer"]: module_logger = logging.getLogger(module) module_logger.setLevel(logging.WARNING) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index c8fe5ccfa..fefd7958c 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -14,9 +14,10 @@ # limitations under the License. # import logging -import inspect +import sys from api.utils.log_utils import initRootLogger -initRootLogger(inspect.getfile(inspect.currentframe())) +CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] +initRootLogger(f"task_executor_{CONSUMER_NO}") for module in ["pdfminer"]: module_logger = logging.getLogger(module) module_logger.setLevel(logging.WARNING) @@ -25,7 +26,7 @@ for module in ["peewee"]: module_logger.handlers.clear() module_logger.propagate = True -import datetime +from datetime import datetime import json import os import hashlib @@ -33,7 +34,7 @@ import copy import re import sys import time -from concurrent.futures import ThreadPoolExecutor +import threading from functools import partial from io import BytesIO from multiprocessing.context import TimeoutError @@ -78,9 +79,14 @@ FACTORY = { ParserType.KG.value: knowledge_graph } -CONSUMER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1]) +CONSUMER_NAME = "task_consumer_" + CONSUMER_NO PAYLOAD: Payload | None = None - +BOOT_AT = datetime.now().isoformat() +DONE_TASKS = 0 +RETRY_TASKS = 0 +PENDING_TASKS = 0 +HEAD_CREATED_AT = "" +HEAD_DETAIL = "" def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): global PAYLOAD @@ -199,8 +205,8 @@ def build(row): md5.update((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")) d["id"] = md5.hexdigest() - d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] - d["create_timestamp_flt"] = datetime.datetime.now().timestamp() + d["create_time"] = str(datetime.now()).replace("T", " ")[:19] + d["create_timestamp_flt"] = datetime.now().timestamp() if not d.get("image"): d["img_id"] = "" d["page_num_list"] = json.dumps([]) @@ -333,8 +339,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None): md5 = hashlib.md5() md5.update((content + str(d["doc_id"])).encode("utf-8")) d["id"] = md5.hexdigest() - d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] - d["create_timestamp_flt"] = datetime.datetime.now().timestamp() + d["create_time"] = str(datetime.now()).replace("T", " ")[:19] + d["create_timestamp_flt"] = datetime.now().timestamp() d[vctr_nm] = vctr.tolist() d["content_with_weight"] = content d["content_ltks"] = rag_tokenizer.tokenize(content) @@ -403,7 +409,7 @@ def main(): logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) if es_r: - callback(-1, f"Insert chunk error, detail info please check {LOG_FILE}. Please also check ES status!") + callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!") docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) logging.error('Insert chunk error: ' + str(es_r)) else: @@ -420,24 +426,44 @@ def main(): def report_status(): - global CONSUMER_NAME + global CONSUMER_NAME, BOOT_AT, DONE_TASKS, RETRY_TASKS, PENDING_TASKS, HEAD_CREATED_AT, HEAD_DETAIL + REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME) while True: try: - obj = REDIS_CONN.get("TASKEXE") - if not obj: obj = {} - else: obj = json.loads(obj) - if CONSUMER_NAME not in obj: obj[CONSUMER_NAME] = [] - obj[CONSUMER_NAME].append(timer()) - obj[CONSUMER_NAME] = obj[CONSUMER_NAME][-60:] - REDIS_CONN.set_obj("TASKEXE", obj, 60*2) + now = datetime.now() + PENDING_TASKS = REDIS_CONN.queue_length(SVR_QUEUE_NAME) + if PENDING_TASKS > 0: + head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME) + if head_info is not None: + seconds = int(head_info[0].split("-")[0])/1000 + HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat() + HEAD_DETAIL = head_info[1] + + heartbeat = json.dumps({ + "name": CONSUMER_NAME, + "now": now.isoformat(), + "boot_at": BOOT_AT, + "done": DONE_TASKS, + "retry": RETRY_TASKS, + "pending": PENDING_TASKS, + "head_created_at": HEAD_CREATED_AT, + "head_detail": HEAD_DETAIL, + }) + REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp()) + logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}") + + expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60*30) + if expired > 0: + REDIS_CONN.zpopmin(CONSUMER_NAME, expired) except Exception: logging.exception("report_status got exception") time.sleep(30) if __name__ == "__main__": - exe = ThreadPoolExecutor(max_workers=1) - exe.submit(report_status) + background_thread = threading.Thread(target=report_status) + background_thread.daemon = True + background_thread.start() while True: main() diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index 7529bee32..b77d4eba2 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -90,6 +90,69 @@ class RedisDB: self.__open__() return False + def sadd(self, key: str, member: str): + try: + self.REDIS.sadd(key, member) + return True + except Exception as e: + logging.warning("[EXCEPTION]sadd" + str(key) + "||" + str(e)) + self.__open__() + return False + + def srem(self, key: str, member: str): + try: + self.REDIS.srem(key, member) + return True + except Exception as e: + logging.warning("[EXCEPTION]srem" + str(key) + "||" + str(e)) + self.__open__() + return False + + def smembers(self, key: str): + try: + res = self.REDIS.smembers(key) + return res + except Exception as e: + logging.warning("[EXCEPTION]smembers" + str(key) + "||" + str(e)) + self.__open__() + return None + + def zadd(self, key: str, member: str, score: float): + try: + self.REDIS.zadd(key, {member: score}) + return True + except Exception as e: + logging.warning("[EXCEPTION]zadd" + str(key) + "||" + str(e)) + self.__open__() + return False + + def zcount(self, key: str, min: float, max: float): + try: + res = self.REDIS.zcount(key, min, max) + return res + except Exception as e: + logging.warning("[EXCEPTION]spopmin" + str(key) + "||" + str(e)) + self.__open__() + return 0 + + def zpopmin(self, key: str, count: int): + try: + res = self.REDIS.zpopmin(key, count) + return res + except Exception as e: + logging.warning("[EXCEPTION]spopmin" + str(key) + "||" + str(e)) + self.__open__() + return None + + def zrangebyscore(self, key: str, min: float, max: float): + try: + res = self.REDIS.zrangebyscore(key, min, max) + return res + except Exception as e: + logging.warning("[EXCEPTION]srangebyscore" + str(key) + "||" + str(e)) + self.__open__() + return None + def transaction(self, key, value, exp=3600): try: pipeline = self.REDIS.pipeline(transaction=True) @@ -162,4 +225,22 @@ class RedisDB: logging.exception("xpending_range: " + consumer_name + " got exception") self.__open__() + def queue_length(self, queue) -> int: + for _ in range(3): + try: + num = self.REDIS.xlen(queue) + return num + except Exception: + logging.exception("queue_length" + str(queue) + " got exception") + return 0 + + def queue_head(self, queue) -> int: + for _ in range(3): + try: + ent = self.REDIS.xrange(queue, count=1) + return ent[0] + except Exception: + logging.exception("queue_head" + str(queue) + " got exception") + return 0 + REDIS_CONN = RedisDB()