diff --git a/.github/workflows/format-build-frontend.yaml b/.github/workflows/format-build-frontend.yaml index 121266bf6..eec1305e4 100644 --- a/.github/workflows/format-build-frontend.yaml +++ b/.github/workflows/format-build-frontend.yaml @@ -29,6 +29,9 @@ jobs: - name: Format Frontend run: npm run format + - name: Run i18next + run: npm run i18n:parse + - name: Check for Changes After Format run: git diff --exit-code diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index dfceaacc1..32c331654 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -53,3 +53,134 @@ jobs: name: compose-logs path: compose-logs.txt if-no-files-found: ignore + + migration_test: + name: Run Migration Tests + runs-on: ubuntu-latest + services: + postgres: + image: postgres + env: + POSTGRES_PASSWORD: postgres + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 +# mysql: +# image: mysql +# env: +# MYSQL_ROOT_PASSWORD: mysql +# MYSQL_DATABASE: mysql +# options: >- +# --health-cmd "mysqladmin ping -h localhost" +# --health-interval 10s +# --health-timeout 5s +# --health-retries 5 +# ports: +# - 3306:3306 + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Set up uv + uses: yezz123/setup-uv@v4 + with: + uv-venv: venv + + - name: Activate virtualenv + run: | + . venv/bin/activate + echo PATH=$PATH >> $GITHUB_ENV + + - name: Install dependencies + run: | + uv pip install -r backend/requirements.txt + + - name: Test backend with SQLite + id: sqlite + env: + WEBUI_SECRET_KEY: secret-key + GLOBAL_LOG_LEVEL: debug + run: | + cd backend + uvicorn main:app --port "8080" --forwarded-allow-ips '*' & + UVICORN_PID=$! + # Wait up to 20 seconds for the server to start + for i in {1..20}; do + curl -s http://localhost:8080/api/config > /dev/null && break + sleep 1 + if [ $i -eq 20 ]; then + echo "Server failed to start" + kill -9 $UVICORN_PID + exit 1 + fi + done + # Check that the server is still running after 5 seconds + sleep 5 + if ! kill -0 $UVICORN_PID; then + echo "Server has stopped" + exit 1 + fi + + + - name: Test backend with Postgres + if: success() || steps.sqlite.conclusion == 'failure' + env: + WEBUI_SECRET_KEY: secret-key + GLOBAL_LOG_LEVEL: debug + DATABASE_URL: postgresql://postgres:postgres@localhost:5432/postgres + run: | + cd backend + uvicorn main:app --port "8081" --forwarded-allow-ips '*' & + UVICORN_PID=$! + # Wait up to 20 seconds for the server to start + for i in {1..20}; do + curl -s http://localhost:8081/api/config > /dev/null && break + sleep 1 + if [ $i -eq 20 ]; then + echo "Server failed to start" + kill -9 $UVICORN_PID + exit 1 + fi + done + # Check that the server is still running after 5 seconds + sleep 5 + if ! kill -0 $UVICORN_PID; then + echo "Server has stopped" + exit 1 + fi + +# - name: Test backend with MySQL +# if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure' +# env: +# WEBUI_SECRET_KEY: secret-key +# GLOBAL_LOG_LEVEL: debug +# DATABASE_URL: mysql://root:mysql@localhost:3306/mysql +# run: | +# cd backend +# uvicorn main:app --port "8083" --forwarded-allow-ips '*' & +# UVICORN_PID=$! +# # Wait up to 20 seconds for the server to start +# for i in {1..20}; do +# curl -s http://localhost:8083/api/config > /dev/null && break +# sleep 1 +# if [ $i -eq 20 ]; then +# echo "Server failed to start" +# kill -9 $UVICORN_PID +# exit 1 +# fi +# done +# # Check that the server is still running after 5 seconds +# sleep 5 +# if ! kill -0 $UVICORN_PID; then +# echo "Server has stopped" +# exit 1 +# fi diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index d9e378303..b5d1e68d6 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -171,6 +171,7 @@ async def fetch_url(url, key): def merge_models_lists(model_lists): + log.info(f"merge_models_lists {model_lists}") merged_list = [] for idx, models in enumerate(model_lists): @@ -199,14 +200,16 @@ async def get_all_models(): ] responses = await asyncio.gather(*tasks) + log.info(f"get_all_models:responses() {responses}") + models = { "data": merge_models_lists( list( map( lambda response: ( response["data"] - if response and "data" in response - else None + if (response and "data" in response) + else (response if isinstance(response, list) else None) ), responses, ) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index a33a29659..f147152b7 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -31,6 +31,11 @@ from langchain_community.document_loaders import ( ) from langchain.text_splitter import RecursiveCharacterTextSplitter +import validators +import urllib.parse +import socket + + from pydantic import BaseModel from typing import Optional import mimetypes @@ -84,6 +89,7 @@ from config import ( CHUNK_SIZE, CHUNK_OVERLAP, RAG_TEMPLATE, + ENABLE_LOCAL_WEB_FETCH, ) from constants import ERROR_MESSAGES @@ -391,16 +397,16 @@ def query_doc_handler( return query_doc_with_hybrid_search( collection_name=form_data.collection_name, query=form_data.query, - embeddings_function=app.state.EMBEDDING_FUNCTION, - reranking_function=app.state.sentence_transformer_rf, + embedding_function=app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else app.state.TOP_K, + reranking_function=app.state.sentence_transformer_rf, r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, ) else: return query_doc( collection_name=form_data.collection_name, query=form_data.query, - embeddings_function=app.state.EMBEDDING_FUNCTION, + embedding_function=app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else app.state.TOP_K, ) except Exception as e: @@ -429,16 +435,16 @@ def query_collection_handler( return query_collection_with_hybrid_search( collection_names=form_data.collection_names, query=form_data.query, - embeddings_function=app.state.EMBEDDING_FUNCTION, - reranking_function=app.state.sentence_transformer_rf, + embedding_function=app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else app.state.TOP_K, + reranking_function=app.state.sentence_transformer_rf, r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, ) else: return query_collection( collection_names=form_data.collection_names, query=form_data.query, - embeddings_function=app.state.EMBEDDING_FUNCTION, + embedding_function=app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else app.state.TOP_K, ) @@ -454,7 +460,7 @@ def query_collection_handler( def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" try: - loader = WebBaseLoader(form_data.url) + loader = get_web_loader(form_data.url) data = loader.load() collection_name = form_data.collection_name @@ -475,6 +481,37 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): ) +def get_web_loader(url: str): + # Check if the URL is valid + if isinstance(validators.url(url), validators.ValidationError): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + if not ENABLE_LOCAL_WEB_FETCH: + # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses + parsed_url = urllib.parse.urlparse(url) + # Get IPv4 and IPv6 addresses + ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname) + # Check if any of the resolved addresses are private + # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader + for ip in ipv4_addresses: + if validators.ipv4(ip, private=True): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + for ip in ipv6_addresses: + if validators.ipv6(ip, private=True): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + return WebBaseLoader(url) + + +def resolve_hostname(hostname): + # Get address information + addr_info = socket.getaddrinfo(hostname, None) + + # Extract IP addresses from address information + ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET] + ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6] + + return ipv4_addresses, ipv6_addresses + + def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: text_splitter = RecursiveCharacterTextSplitter( diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index eb9d5c84b..10f1f7bed 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -35,6 +35,7 @@ def query_doc( try: collection = CHROMA_CLIENT.get_collection(name=collection_name) query_embeddings = embedding_function(query) + result = collection.query( query_embeddings=[query_embeddings], n_results=k, @@ -76,9 +77,9 @@ def query_doc_with_hybrid_search( compressor = RerankCompressor( embedding_function=embedding_function, + top_n=k, reranking_function=reranking_function, r_score=r, - top_n=k, ) compression_retriever = ContextualCompressionRetriever( @@ -91,6 +92,7 @@ def query_doc_with_hybrid_search( "documents": [[d.page_content for d in result]], "metadatas": [[d.metadata for d in result]], } + log.info(f"query_doc_with_hybrid_search:result {result}") return result except Exception as e: @@ -167,7 +169,6 @@ def query_collection_with_hybrid_search( reranking_function, r: float, ): - results = [] for collection_name in collection_names: try: @@ -182,7 +183,6 @@ def query_collection_with_hybrid_search( results.append(result) except: pass - return merge_and_sort_query_results(results, k=k, reverse=True) @@ -443,13 +443,15 @@ class ChromaRetriever(BaseRetriever): metadatas = results["metadatas"][0] documents = results["documents"][0] - return [ - Document( - metadata=metadatas[idx], - page_content=documents[idx], + results = [] + for idx in range(len(ids)): + results.append( + Document( + metadata=metadatas[idx], + page_content=documents[idx], + ) ) - for idx in range(len(ids)) - ] + return results import operator @@ -465,9 +467,9 @@ from sentence_transformers import util class RerankCompressor(BaseDocumentCompressor): embedding_function: Any + top_n: int reranking_function: Any r_score: float - top_n: int class Config: extra = Extra.forbid @@ -479,7 +481,9 @@ class RerankCompressor(BaseDocumentCompressor): query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: - if self.reranking_function: + reranking = self.reranking_function is not None + + if reranking: scores = self.reranking_function.predict( [(query, doc.page_content) for doc in documents] ) @@ -496,9 +500,7 @@ class RerankCompressor(BaseDocumentCompressor): (d, s) for d, s in docs_with_scores if s >= self.r_score ] - reverse = self.reranking_function is not None - result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse) - + result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) final_results = [] for doc, doc_score in result[: self.top_n]: metadata = doc.metadata diff --git a/backend/config.py b/backend/config.py index f864062d9..09880b12e 100644 --- a/backend/config.py +++ b/backend/config.py @@ -168,7 +168,11 @@ except: STATIC_DIR = str(Path(os.getenv("STATIC_DIR", "./static")).resolve()) -shutil.copyfile(f"{FRONTEND_BUILD_DIR}/favicon.png", f"{STATIC_DIR}/favicon.png") +frontend_favicon = f"{FRONTEND_BUILD_DIR}/favicon.png" +if os.path.exists(frontend_favicon): + shutil.copyfile(frontend_favicon, f"{STATIC_DIR}/favicon.png") +else: + logging.warning(f"Frontend favicon not found at {frontend_favicon}") #################################### # CUSTOM_NAME @@ -516,6 +520,8 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE) RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY) +ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true" + #################################### # Transcribe #################################### diff --git a/backend/constants.py b/backend/constants.py index a26945756..3fdf506fa 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -71,3 +71,7 @@ class ERROR_MESSAGES(str, Enum): EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." DB_NOT_SQLITE = "This feature is only available when running with SQLite databases." + + INVALID_URL = ( + "Oops! The URL you provided is invalid. Please double-check and try again." + ) diff --git a/backend/dev.sh b/backend/dev.sh old mode 100644 new mode 100755 diff --git a/backend/main.py b/backend/main.py index 1b2772627..91cce711b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -318,11 +318,16 @@ async def get_manifest_json(): app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") -app.mount( - "/", - SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True), - name="spa-static-files", -) +if os.path.exists(FRONTEND_BUILD_DIR): + app.mount( + "/", + SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True), + name="spa-static-files", + ) +else: + log.warning( + f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only." + ) @app.on_event("shutdown") diff --git a/backend/requirements.txt b/backend/requirements.txt index 336cae17a..eb509c6ed 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -19,8 +19,8 @@ psycopg2-binary pymysql bcrypt -litellm==1.35.17 -litellm[proxy]==1.35.17 +litellm==1.35.28 +litellm[proxy]==1.35.28 boto3 @@ -43,6 +43,7 @@ pandas openpyxl pyxlsb xlrd +validators opencv-python-headless rapidocr-onnxruntime diff --git a/src/lib/apis/streaming/index.ts b/src/lib/apis/streaming/index.ts index 19e4d8fc7..aad42b2b6 100644 --- a/src/lib/apis/streaming/index.ts +++ b/src/lib/apis/streaming/index.ts @@ -73,7 +73,11 @@ async function* streamLargeDeltasAsRandomChunks( const chunkSize = Math.min(Math.floor(Math.random() * 3) + 1, content.length); const chunk = content.slice(0, chunkSize); yield { done: false, value: chunk }; - await sleep(5); + // Do not sleep if the tab is hidden + // Timers are throttled to 1s in hidden tabs + if (document?.visibilityState !== 'hidden') { + await sleep(5); + } content = content.slice(chunkSize); } } diff --git a/src/lib/components/admin/UserChatsModal.svelte b/src/lib/components/admin/UserChatsModal.svelte index 9b1cacb33..67fa367cd 100644 --- a/src/lib/components/admin/UserChatsModal.svelte +++ b/src/lib/components/admin/UserChatsModal.svelte @@ -70,7 +70,7 @@ >