From 9541d7e7bcb77e25ad0292eda27fa147cc2a2c03 Mon Sep 17 00:00:00 2001 From: Zhichang Yu Date: Fri, 22 Nov 2024 12:00:25 +0800 Subject: [PATCH] Added TRACE_MALLOC_DELTA and TRACE_MALLOC_FULL (#3555) ### What problem does this PR solve? Added TRACE_MALLOC_DELTA and TRACE_MALLOC_FULL to debug task_executor.py heap. Relates to #3518 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- rag/svr/task_executor.py | 33 ++++++++++++++++++++++++++++++++- rag/utils/es_conn.py | 2 ++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 3a8972a1d..2b3314632 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -22,7 +22,8 @@ import sys from api.utils.log_utils import initRootLogger CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] -initRootLogger(f"task_executor_{CONSUMER_NO}") +CONSUMER_NAME = "task_executor_" + CONSUMER_NO +initRootLogger(CONSUMER_NAME) for module in ["pdfminer"]: module_logger = logging.getLogger(module) module_logger.setLevel(logging.WARNING) @@ -44,6 +45,7 @@ from functools import partial from io import BytesIO from multiprocessing.context import TimeoutError from timeit import default_timer as timer +import tracemalloc import numpy as np @@ -490,14 +492,43 @@ def report_status(): logging.exception("report_status got exception") time.sleep(30) +def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool): + msg = "" + if dump_full: + stats2 = snapshot2.statistics('lineno') + msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n" + for stat in stats2[:10]: + msg += f"{stat}\n" + stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno') + msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id-1} to snapshot {snapshot_id}:\n" + for stat in stats1_vs_2[:10]: + msg += f"{stat}\n" + msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n" + for stat in stats1_vs_2[:3]: + msg += '\n'.join(stat.traceback.format()) + logging.info(msg) + def main(): settings.init_settings() background_thread = threading.Thread(target=report_status) background_thread.daemon = True background_thread.start() + TRACE_MALLOC_DELTA = int(os.environ.get('TRACE_MALLOC_DELTA', "0")) + TRACE_MALLOC_FULL = int(os.environ.get('TRACE_MALLOC_FULL', "0")) + if TRACE_MALLOC_DELTA > 0: + if TRACE_MALLOC_FULL < TRACE_MALLOC_DELTA: + TRACE_MALLOC_FULL = TRACE_MALLOC_DELTA + tracemalloc.start() + snapshot1 = tracemalloc.take_snapshot() while True: handle_task() + num_tasks = DONE_TASKS + FAILED_TASKS + if TRACE_MALLOC_DELTA> 0 and num_tasks > 0 and num_tasks % TRACE_MALLOC_DELTA == 0: + snapshot2 = tracemalloc.take_snapshot() + analyze_heap(snapshot1, snapshot2, int(num_tasks/TRACE_MALLOC_DELTA), num_tasks % TRACE_MALLOC_FULL == 0) + snapshot1 = snapshot2 + snapshot2 = None if __name__ == "__main__": main() diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index 651c7a6b0..6354defc1 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -237,6 +237,7 @@ class ESConnection(DocStoreConnection): res = [] for _ in range(ATTEMPT_TIME): try: + res = [] r = self.es.bulk(index=(indexName), operations=operations, refresh=False, timeout="60s") if re.search(r"False", str(r["errors"]), re.IGNORECASE): @@ -248,6 +249,7 @@ class ESConnection(DocStoreConnection): res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"])) return res except Exception as e: + res.append(str(e)) logging.warning("ESConnection.insert got exception: " + str(e)) res = [] if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):