Rework task executor heartbeat (#3430)

### What problem does this PR solve?

Rework task executor heartbeat, and print in console.

### Type of change

- [ ] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):
This commit is contained in:
Zhichang Yu 2024-11-15 14:43:55 +08:00 committed by GitHub
parent 48e060aa53
commit a854bc22d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 129 additions and 24 deletions

View File

@ -15,10 +15,8 @@
# #
import logging import logging
import inspect
from api.utils.log_utils import initRootLogger from api.utils.log_utils import initRootLogger
initRootLogger("ragflow_server")
initRootLogger(inspect.getfile(inspect.currentframe()))
for module in ["pdfminer"]: for module in ["pdfminer"]:
module_logger = logging.getLogger(module) module_logger = logging.getLogger(module)
module_logger.setLevel(logging.WARNING) module_logger.setLevel(logging.WARNING)

View File

@ -14,9 +14,10 @@
# limitations under the License. # limitations under the License.
# #
import logging import logging
import inspect import sys
from api.utils.log_utils import initRootLogger from api.utils.log_utils import initRootLogger
initRootLogger(inspect.getfile(inspect.currentframe())) CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
initRootLogger(f"task_executor_{CONSUMER_NO}")
for module in ["pdfminer"]: for module in ["pdfminer"]:
module_logger = logging.getLogger(module) module_logger = logging.getLogger(module)
module_logger.setLevel(logging.WARNING) module_logger.setLevel(logging.WARNING)
@ -25,7 +26,7 @@ for module in ["peewee"]:
module_logger.handlers.clear() module_logger.handlers.clear()
module_logger.propagate = True module_logger.propagate = True
import datetime from datetime import datetime
import json import json
import os import os
import hashlib import hashlib
@ -33,7 +34,7 @@ import copy
import re import re
import sys import sys
import time import time
from concurrent.futures import ThreadPoolExecutor import threading
from functools import partial from functools import partial
from io import BytesIO from io import BytesIO
from multiprocessing.context import TimeoutError from multiprocessing.context import TimeoutError
@ -78,9 +79,14 @@ FACTORY = {
ParserType.KG.value: knowledge_graph ParserType.KG.value: knowledge_graph
} }
CONSUMER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1]) CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
PAYLOAD: Payload | None = None PAYLOAD: Payload | None = None
BOOT_AT = datetime.now().isoformat()
DONE_TASKS = 0
RETRY_TASKS = 0
PENDING_TASKS = 0
HEAD_CREATED_AT = ""
HEAD_DETAIL = ""
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
global PAYLOAD global PAYLOAD
@ -199,8 +205,8 @@ def build(row):
md5.update((ck["content_with_weight"] + md5.update((ck["content_with_weight"] +
str(d["doc_id"])).encode("utf-8")) str(d["doc_id"])).encode("utf-8"))
d["id"] = md5.hexdigest() d["id"] = md5.hexdigest()
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp() d["create_timestamp_flt"] = datetime.now().timestamp()
if not d.get("image"): if not d.get("image"):
d["img_id"] = "" d["img_id"] = ""
d["page_num_list"] = json.dumps([]) d["page_num_list"] = json.dumps([])
@ -333,8 +339,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
md5 = hashlib.md5() md5 = hashlib.md5()
md5.update((content + str(d["doc_id"])).encode("utf-8")) md5.update((content + str(d["doc_id"])).encode("utf-8"))
d["id"] = md5.hexdigest() d["id"] = md5.hexdigest()
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp() d["create_timestamp_flt"] = datetime.now().timestamp()
d[vctr_nm] = vctr.tolist() d[vctr_nm] = vctr.tolist()
d["content_with_weight"] = content d["content_with_weight"] = content
d["content_ltks"] = rag_tokenizer.tokenize(content) d["content_ltks"] = rag_tokenizer.tokenize(content)
@ -403,7 +409,7 @@ def main():
logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
if es_r: if es_r:
callback(-1, f"Insert chunk error, detail info please check {LOG_FILE}. Please also check ES status!") callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
logging.error('Insert chunk error: ' + str(es_r)) logging.error('Insert chunk error: ' + str(es_r))
else: else:
@ -420,24 +426,44 @@ def main():
def report_status(): def report_status():
global CONSUMER_NAME global CONSUMER_NAME, BOOT_AT, DONE_TASKS, RETRY_TASKS, PENDING_TASKS, HEAD_CREATED_AT, HEAD_DETAIL
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
while True: while True:
try: try:
obj = REDIS_CONN.get("TASKEXE") now = datetime.now()
if not obj: obj = {} PENDING_TASKS = REDIS_CONN.queue_length(SVR_QUEUE_NAME)
else: obj = json.loads(obj) if PENDING_TASKS > 0:
if CONSUMER_NAME not in obj: obj[CONSUMER_NAME] = [] head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME)
obj[CONSUMER_NAME].append(timer()) if head_info is not None:
obj[CONSUMER_NAME] = obj[CONSUMER_NAME][-60:] seconds = int(head_info[0].split("-")[0])/1000
REDIS_CONN.set_obj("TASKEXE", obj, 60*2) HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat()
HEAD_DETAIL = head_info[1]
heartbeat = json.dumps({
"name": CONSUMER_NAME,
"now": now.isoformat(),
"boot_at": BOOT_AT,
"done": DONE_TASKS,
"retry": RETRY_TASKS,
"pending": PENDING_TASKS,
"head_created_at": HEAD_CREATED_AT,
"head_detail": HEAD_DETAIL,
})
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60*30)
if expired > 0:
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
except Exception: except Exception:
logging.exception("report_status got exception") logging.exception("report_status got exception")
time.sleep(30) time.sleep(30)
if __name__ == "__main__": if __name__ == "__main__":
exe = ThreadPoolExecutor(max_workers=1) background_thread = threading.Thread(target=report_status)
exe.submit(report_status) background_thread.daemon = True
background_thread.start()
while True: while True:
main() main()

View File

@ -90,6 +90,69 @@ class RedisDB:
self.__open__() self.__open__()
return False return False
def sadd(self, key: str, member: str):
try:
self.REDIS.sadd(key, member)
return True
except Exception as e:
logging.warning("[EXCEPTION]sadd" + str(key) + "||" + str(e))
self.__open__()
return False
def srem(self, key: str, member: str):
try:
self.REDIS.srem(key, member)
return True
except Exception as e:
logging.warning("[EXCEPTION]srem" + str(key) + "||" + str(e))
self.__open__()
return False
def smembers(self, key: str):
try:
res = self.REDIS.smembers(key)
return res
except Exception as e:
logging.warning("[EXCEPTION]smembers" + str(key) + "||" + str(e))
self.__open__()
return None
def zadd(self, key: str, member: str, score: float):
try:
self.REDIS.zadd(key, {member: score})
return True
except Exception as e:
logging.warning("[EXCEPTION]zadd" + str(key) + "||" + str(e))
self.__open__()
return False
def zcount(self, key: str, min: float, max: float):
try:
res = self.REDIS.zcount(key, min, max)
return res
except Exception as e:
logging.warning("[EXCEPTION]spopmin" + str(key) + "||" + str(e))
self.__open__()
return 0
def zpopmin(self, key: str, count: int):
try:
res = self.REDIS.zpopmin(key, count)
return res
except Exception as e:
logging.warning("[EXCEPTION]spopmin" + str(key) + "||" + str(e))
self.__open__()
return None
def zrangebyscore(self, key: str, min: float, max: float):
try:
res = self.REDIS.zrangebyscore(key, min, max)
return res
except Exception as e:
logging.warning("[EXCEPTION]srangebyscore" + str(key) + "||" + str(e))
self.__open__()
return None
def transaction(self, key, value, exp=3600): def transaction(self, key, value, exp=3600):
try: try:
pipeline = self.REDIS.pipeline(transaction=True) pipeline = self.REDIS.pipeline(transaction=True)
@ -162,4 +225,22 @@ class RedisDB:
logging.exception("xpending_range: " + consumer_name + " got exception") logging.exception("xpending_range: " + consumer_name + " got exception")
self.__open__() self.__open__()
def queue_length(self, queue) -> int:
for _ in range(3):
try:
num = self.REDIS.xlen(queue)
return num
except Exception:
logging.exception("queue_length" + str(queue) + " got exception")
return 0
def queue_head(self, queue) -> int:
for _ in range(3):
try:
ent = self.REDIS.xrange(queue, count=1)
return ent[0]
except Exception:
logging.exception("queue_head" + str(queue) + " got exception")
return 0
REDIS_CONN = RedisDB() REDIS_CONN = RedisDB()