mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-23 06:30:00 +08:00
make task resumable (#2132)
### What problem does this PR solve? ### Type of change - [x] Performance Improvement
This commit is contained in:
parent
074d4f5031
commit
5daed10136
@ -217,7 +217,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
answer = ""
|
||||
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
||||
answer = ans
|
||||
yield {"answer": answer, "reference": {}, "prompt": prompt}
|
||||
yield {"answer": answer, "reference": {}}
|
||||
yield decorate_answer(answer)
|
||||
else:
|
||||
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
||||
|
@ -11,13 +11,13 @@ fi
|
||||
|
||||
function task_exe(){
|
||||
while [ 1 -eq 1 ];do
|
||||
$PY rag/svr/task_executor.py ;
|
||||
$PY rag/svr/task_executor.py $1;
|
||||
done
|
||||
}
|
||||
|
||||
for ((i=0;i<WS;i++))
|
||||
do
|
||||
task_exe &
|
||||
task_exe $i &
|
||||
done
|
||||
|
||||
while [ 1 -eq 1 ];do
|
||||
|
@ -74,9 +74,12 @@ FACTORY = {
|
||||
ParserType.KG.value: knowledge_graph
|
||||
}
|
||||
|
||||
CONSUMEER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
|
||||
PAYLOAD = None
|
||||
|
||||
def set_progress(task_id, from_page=0, to_page=-1,
|
||||
prog=None, msg="Processing..."):
|
||||
global PAYLOAD
|
||||
if prog is not None and prog < 0:
|
||||
msg = "[ERROR]" + msg
|
||||
cancel = TaskService.do_cancel(task_id)
|
||||
@ -97,22 +100,28 @@ def set_progress(task_id, from_page=0, to_page=-1,
|
||||
|
||||
close_connection()
|
||||
if cancel:
|
||||
sys.exit()
|
||||
if PAYLOAD:
|
||||
PAYLOAD.ack()
|
||||
PAYLOAD = None
|
||||
os._exit(0)
|
||||
|
||||
|
||||
def collect():
|
||||
global CONSUMEER_NAME, PAYLOAD
|
||||
try:
|
||||
payload = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", "rag_flow_svr_task_consumer")
|
||||
if not payload:
|
||||
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMEER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
|
||||
if not PAYLOAD:
|
||||
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMEER_NAME)
|
||||
if not PAYLOAD:
|
||||
time.sleep(1)
|
||||
return pd.DataFrame()
|
||||
except Exception as e:
|
||||
cron_logger.error("Get task event from queue exception:" + str(e))
|
||||
return pd.DataFrame()
|
||||
|
||||
msg = payload.get_message()
|
||||
payload.ack()
|
||||
if not msg: return pd.DataFrame()
|
||||
msg = PAYLOAD.get_message()
|
||||
if not msg:
|
||||
return pd.DataFrame()
|
||||
|
||||
if TaskService.do_cancel(msg["id"]):
|
||||
cron_logger.info("Task {} has been canceled.".format(msg["id"]))
|
||||
@ -378,20 +387,21 @@ def main():
|
||||
|
||||
|
||||
def report_status():
|
||||
id = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||
global CONSUMEER_NAME
|
||||
while True:
|
||||
try:
|
||||
obj = REDIS_CONN.get("TASKEXE")
|
||||
if not obj: obj = {}
|
||||
else: obj = json.load(obj)
|
||||
if id not in obj: obj[id] = []
|
||||
obj[id].append(timer()*1000)
|
||||
obj[id] = obj[id][-60:]
|
||||
else: obj = json.loads(obj)
|
||||
if CONSUMEER_NAME not in obj: obj[CONSUMEER_NAME] = []
|
||||
obj[CONSUMEER_NAME].append(timer())
|
||||
obj[CONSUMEER_NAME] = obj[CONSUMEER_NAME][-60:]
|
||||
REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
|
||||
except Exception as e:
|
||||
print("[Exception]:", str(e))
|
||||
time.sleep(60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
peewee_logger = logging.getLogger('peewee')
|
||||
peewee_logger.propagate = False
|
||||
@ -403,3 +413,6 @@ if __name__ == "__main__":
|
||||
|
||||
while True:
|
||||
main()
|
||||
if PAYLOAD:
|
||||
PAYLOAD.ack()
|
||||
PAYLOAD = None
|
||||
|
@ -107,7 +107,7 @@ class RedisDB:
|
||||
payload = {"message": json.dumps(message)}
|
||||
pipeline = self.REDIS.pipeline()
|
||||
pipeline.xadd(queue, payload)
|
||||
pipeline.expire(queue, exp)
|
||||
#pipeline.expire(queue, exp)
|
||||
pipeline.execute()
|
||||
return True
|
||||
except Exception as e:
|
||||
@ -143,8 +143,22 @@ class RedisDB:
|
||||
if 'key' in str(e):
|
||||
pass
|
||||
else:
|
||||
logging.warning("[EXCEPTION]consumer" + str(queue_name) + "||" + str(e))
|
||||
logging.warning("[EXCEPTION]consumer: " + str(queue_name) + "||" + str(e))
|
||||
return None
|
||||
|
||||
def get_unacked_for(self, consumer_name, queue_name, group_name):
|
||||
try:
|
||||
group_info = self.REDIS.xinfo_groups(queue_name)
|
||||
if not any(e["name"] == group_name for e in group_info):
|
||||
return
|
||||
pendings = self.REDIS.xpending_range(queue_name, group_name, min=0, max=10000000000000, count=1, consumername=consumer_name)
|
||||
if not pendings: return
|
||||
msg_id = pendings[0]["message_id"]
|
||||
msg = self.REDIS.xrange(queue_name, min=msg_id, count=1)
|
||||
_, payload = msg[0]
|
||||
return Payload(self.REDIS, queue_name, group_name, msg_id, payload)
|
||||
except Exception as e:
|
||||
logging.warning("[EXCEPTION]xpending_range" + consumer_name + "||" + str(e))
|
||||
self.__open__()
|
||||
|
||||
REDIS_CONN = RedisDB()
|
||||
|
Loading…
x
Reference in New Issue
Block a user