added alternative storage client instantiation method, corrected filepaths, added missing type hinting

This commit is contained in:
kahghi 2025-01-19 22:42:29 +08:00
parent 1764de41f3
commit 49f31ddcd8
2 changed files with 23 additions and 24 deletions

View File

@ -590,6 +590,7 @@ S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None) S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None) GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None)
GCS_PROJECT_ID = os.environ.get("GCS_PROJECT_ID", None)
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_JSON", None) GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_JSON", None)
#################################### ####################################

View File

@ -3,7 +3,6 @@ import shutil
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import BinaryIO, Tuple from typing import BinaryIO, Tuple
from io import BytesIO
import boto3 import boto3
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
@ -14,6 +13,7 @@ from open_webui.config import (
S3_REGION_NAME, S3_REGION_NAME,
S3_SECRET_ACCESS_KEY, S3_SECRET_ACCESS_KEY,
GCS_BUCKET_NAME, GCS_BUCKET_NAME,
GCS_PROJECT_ID,
GOOGLE_APPLICATION_CREDENTIALS_JSON, GOOGLE_APPLICATION_CREDENTIALS_JSON,
STORAGE_PROVIDER, STORAGE_PROVIDER,
UPLOAD_DIR, UPLOAD_DIR,
@ -145,41 +145,41 @@ class S3StorageProvider(StorageProvider):
class GCSStorageProvider(StorageProvider): class GCSStorageProvider(StorageProvider):
def __init__(self): def __init__(self):
self.gcs_client = storage.Client.from_service_account_info(info=json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)) if GCS_PROJECT_ID:
self.bucket_name = self.gcs_client.bucket(GCS_BUCKET_NAME) self.gcs_client = storage.Client(project=GCS_PROJECT_ID)
if GOOGLE_APPLICATION_CREDENTIALS_JSON:
self.gcs_client = storage.Client.from_service_account_info(info=json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON))
self.bucket_name = GCS_BUCKET_NAME
self.bucket = self.gcs_client.bucket(GCS_BUCKET_NAME)
def upload_file(self, file: BinaryIO, filename: str): def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to GCS storage.""" """Handles uploading of the file to GCS storage."""
contents, _ = LocalStorageProvider.upload_file(file, filename) contents, file_path = LocalStorageProvider.upload_file(file, filename)
try: try:
# Get the blob (object in the bucket) blob = self.bucket.blob(filename)
blob = self.bucket_name.blob(filename) blob.upload_from_filename(file_path)
# Upload the file to the bucket return contents, "gs://" + self.bucket_name + "/" + filename
blob.upload_from_file(BytesIO(contents))
return contents, _
except GoogleCloudError as e: except GoogleCloudError as e:
raise RuntimeError(f"Error uploading file to GCS: {e}") raise RuntimeError(f"Error uploading file to GCS: {e}")
def get_file(self, file_path:str) -> str: def get_file(self, file_path:str) -> str:
"""Handles downloading of the file from GCS storage.""" """Handles downloading of the file from GCS storage."""
try: try:
local_file_path=file_path.removeprefix(UPLOAD_DIR + "/") filename = file_path.removeprefix("gs://").split("/")[1]
# Get the blob (object in the bucket) local_file_path = f"{UPLOAD_DIR}/{filename}"
blob = self.bucket_name.blob(local_file_path) blob = self.bucket.blob(filename)
# Download the file to a local destination blob.download_to_filename(local_file_path)
blob.download_to_filename(file_path)
return file_path return local_file_path
except NotFound as e: except NotFound as e:
raise RuntimeError(f"Error downloading file from GCS: {e}") raise RuntimeError(f"Error downloading file from GCS: {e}")
def delete_file(self, file_path:str) -> None: def delete_file(self, file_path:str) -> None:
"""Handles deletion of the file from GCS storage.""" """Handles deletion of the file from GCS storage."""
try: try:
local_file_path = file_path.removeprefix(UPLOAD_DIR + "/") filename = file_path.removeprefix("gs://").split("/")[1]
# Get the blob (object in the bucket) blob = self.bucket.blob(filename)
blob = self.bucket_name.blob(local_file_path)
# Delete the file
blob.delete() blob.delete()
except NotFound as e: except NotFound as e:
raise RuntimeError(f"Error deleting file from GCS: {e}") raise RuntimeError(f"Error deleting file from GCS: {e}")
@ -190,10 +190,8 @@ class GCSStorageProvider(StorageProvider):
def delete_all_files(self) -> None: def delete_all_files(self) -> None:
"""Handles deletion of all files from GCS storage.""" """Handles deletion of all files from GCS storage."""
try: try:
# List all objects in the bucket blobs = self.bucket.list_blobs()
blobs = self.bucket_name.list_blobs()
# Delete all files
for blob in blobs: for blob in blobs:
blob.delete() blob.delete()