diff --git a/api/extensions/storage/local_fs_storage.py b/api/extensions/storage/local_fs_storage.py index e458b3ce8a..b4146d7384 100644 --- a/api/extensions/storage/local_fs_storage.py +++ b/api/extensions/storage/local_fs_storage.py @@ -19,68 +19,48 @@ class LocalFsStorage(BaseStorage): folder = os.path.join(current_app.root_path, folder) self.folder = folder - def save(self, filename, data): + def _build_filepath(self, filename: str) -> str: + """Build the full file path based on the folder and filename.""" if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename + return self.folder + filename else: - filename = self.folder + "/" + filename + return self.folder + "/" + filename - folder = os.path.dirname(filename) + def save(self, filename, data): + filepath = self._build_filepath(filename) + folder = os.path.dirname(filepath) os.makedirs(folder, exist_ok=True) - - Path(os.path.join(os.getcwd(), filename)).write_bytes(data) + Path(os.path.join(os.getcwd(), filepath)).write_bytes(data) def load_once(self, filename: str) -> bytes: - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - if not os.path.exists(filename): + filepath = self._build_filepath(filename) + if not os.path.exists(filepath): raise FileNotFoundError("File not found") - - data = Path(filename).read_bytes() - return data + return Path(filepath).read_bytes() def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename + filepath = self._build_filepath(filename) - if not os.path.exists(filename): + def generate() -> Generator: + if not os.path.exists(filepath): raise FileNotFoundError("File not found") - - with open(filename, "rb") as f: + with open(filepath, "rb") as f: while chunk := f.read(4096): # Read in chunks of 4KB yield chunk return generate() def download(self, filename, target_filepath): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - if not os.path.exists(filename): + filepath = self._build_filepath(filename) + if not os.path.exists(filepath): raise FileNotFoundError("File not found") - - shutil.copyfile(filename, target_filepath) + shutil.copyfile(filepath, target_filepath) def exists(self, filename): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - return os.path.exists(filename) + filepath = self._build_filepath(filename) + return os.path.exists(filepath) def delete(self, filename): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - if os.path.exists(filename): - os.remove(filename) + filepath = self._build_filepath(filename) + if os.path.exists(filepath): + os.remove(filepath)