make task resumable (#2132)

### What problem does this PR solve?

### Type of change


- [x] Performance Improvement
This commit is contained in:
Kevin Hu 2024-08-28 14:06:27 +08:00 committed by GitHub
parent 074d4f5031
commit 5daed10136
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 16 deletions

View File

@ -217,7 +217,7 @@ def chat(dialog, messages, stream=True, **kwargs):
answer = ""
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
answer = ans
yield {"answer": answer, "reference": {}, "prompt": prompt}
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)
else:
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)

View File

@ -11,13 +11,13 @@ fi
function task_exe(){
while [ 1 -eq 1 ];do
$PY rag/svr/task_executor.py ;
$PY rag/svr/task_executor.py $1;
done
}
for ((i=0;i<WS;i++))
do
task_exe &
task_exe $i &
done
while [ 1 -eq 1 ];do

View File

@ -74,9 +74,12 @@ FACTORY = {
ParserType.KG.value: knowledge_graph
}
CONSUMEER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
PAYLOAD = None
def set_progress(task_id, from_page=0, to_page=-1,
prog=None, msg="Processing..."):
global PAYLOAD
if prog is not None and prog < 0:
msg = "[ERROR]" + msg
cancel = TaskService.do_cancel(task_id)
@ -97,22 +100,28 @@ def set_progress(task_id, from_page=0, to_page=-1,
close_connection()
if cancel:
sys.exit()
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
os._exit(0)
def collect():
global CONSUMEER_NAME, PAYLOAD
try:
payload = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", "rag_flow_svr_task_consumer")
if not payload:
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMEER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
if not PAYLOAD:
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMEER_NAME)
if not PAYLOAD:
time.sleep(1)
return pd.DataFrame()
except Exception as e:
cron_logger.error("Get task event from queue exception:" + str(e))
return pd.DataFrame()
msg = payload.get_message()
payload.ack()
if not msg: return pd.DataFrame()
msg = PAYLOAD.get_message()
if not msg:
return pd.DataFrame()
if TaskService.do_cancel(msg["id"]):
cron_logger.info("Task {} has been canceled.".format(msg["id"]))
@ -378,20 +387,21 @@ def main():
def report_status():
id = "0" if len(sys.argv) < 2 else sys.argv[1]
global CONSUMEER_NAME
while True:
try:
obj = REDIS_CONN.get("TASKEXE")
if not obj: obj = {}
else: obj = json.load(obj)
if id not in obj: obj[id] = []
obj[id].append(timer()*1000)
obj[id] = obj[id][-60:]
else: obj = json.loads(obj)
if CONSUMEER_NAME not in obj: obj[CONSUMEER_NAME] = []
obj[CONSUMEER_NAME].append(timer())
obj[CONSUMEER_NAME] = obj[CONSUMEER_NAME][-60:]
REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
except Exception as e:
print("[Exception]:", str(e))
time.sleep(60)
if __name__ == "__main__":
peewee_logger = logging.getLogger('peewee')
peewee_logger.propagate = False
@ -403,3 +413,6 @@ if __name__ == "__main__":
while True:
main()
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None

View File

@ -107,7 +107,7 @@ class RedisDB:
payload = {"message": json.dumps(message)}
pipeline = self.REDIS.pipeline()
pipeline.xadd(queue, payload)
pipeline.expire(queue, exp)
#pipeline.expire(queue, exp)
pipeline.execute()
return True
except Exception as e:
@ -143,8 +143,22 @@ class RedisDB:
if 'key' in str(e):
pass
else:
logging.warning("[EXCEPTION]consumer" + str(queue_name) + "||" + str(e))
logging.warning("[EXCEPTION]consumer: " + str(queue_name) + "||" + str(e))
return None
def get_unacked_for(self, consumer_name, queue_name, group_name):
try:
group_info = self.REDIS.xinfo_groups(queue_name)
if not any(e["name"] == group_name for e in group_info):
return
pendings = self.REDIS.xpending_range(queue_name, group_name, min=0, max=10000000000000, count=1, consumername=consumer_name)
if not pendings: return
msg_id = pendings[0]["message_id"]
msg = self.REDIS.xrange(queue_name, min=msg_id, count=1)
_, payload = msg[0]
return Payload(self.REDIS, queue_name, group_name, msg_id, payload)
except Exception as e:
logging.warning("[EXCEPTION]xpending_range" + consumer_name + "||" + str(e))
self.__open__()
REDIS_CONN = RedisDB()