diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 8f2697f21..b6407b52d 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -29,6 +29,7 @@ import time import traceback from concurrent.futures import ThreadPoolExecutor import threading +import uuid from werkzeug.serving import run_simple from api import settings @@ -47,17 +48,17 @@ from rag.utils.redis_conn import RedisDistributedLock stop_event = threading.Event() def update_progress(): - redis_lock = RedisDistributedLock("update_progress", timeout=60) + lock_value = str(uuid.uuid4()) + redis_lock = RedisDistributedLock("update_progress", lock_value=lock_value, timeout=60) + logging.info(f"update_progress lock_value: {lock_value}") while not stop_event.is_set(): try: - if not redis_lock.acquire(): - continue - DocumentService.update_progress() + if redis_lock.acquire(): + DocumentService.update_progress() + redis_lock.release() stop_event.wait(6) except Exception: logging.exception("update_progress exception") - finally: - redis_lock.release() def signal_handler(sig, frame): logging.info("Received interrupt signal, shutting down...") diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index a83febbed..f505b5064 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -48,10 +48,27 @@ class RedisMsg: @singleton class RedisDB: + lua_delete_if_equal = None + LUA_DELETE_IF_EQUAL_SCRIPT = """ + local current_value = redis.call('get', KEYS[1]) + if current_value and current_value == ARGV[1] then + redis.call('del', KEYS[1]) + return 1 + end + return 0 + """ + def __init__(self): self.REDIS = None self.config = settings.REDIS self.__open__() + self.register_scripts() + + def register_scripts(self) -> None: + cls = self.__class__ + client = self.REDIS + if cls.lua_delete_if_equal is None: + cls.lua_delete_if_equal = client.register_script(cls.LUA_DELETE_IF_EQUAL_SCRIPT) def __open__(self): try: @@ -277,6 +294,12 @@ class RedisDB: ) return None + def delete_if_equal(self, key: str, expected_value: str) -> bool: + """ + Do follwing atomically: + Delete a key if its value is equals to the given one, do nothing otherwise. + """ + return bool(self.lua_delete_if_equal(keys=[key], args=[expected_value], client=self.REDIS)) REDIS_CONN = RedisDB() @@ -292,7 +315,8 @@ class RedisDistributedLock: self.lock = Lock(REDIS_CONN.REDIS, lock_key, timeout=timeout, blocking_timeout=blocking_timeout) def acquire(self): - return self.lock.acquire() + REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value) + return self.lock.acquire(token=self.lock_value) def release(self): return self.lock.release()