diff --git a/README.md b/README.md index e83324ead..c8ad50037 100644 --- a/README.md +++ b/README.md @@ -170,7 +170,7 @@ docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/wa In the last part of the command, replace `open-webui` with your container name if it is different. -Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/migration/). +Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/tutorials/migration/). ### Using the Dev Branch 馃寵 diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index e0a40a1f5..9d62f32d2 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -27,7 +27,6 @@ from fastapi.responses import FileResponse, StreamingResponse from pydantic import BaseModel from starlette.background import BackgroundTask - from open_webui.utils.payload import ( apply_model_params_to_body_openai, apply_model_system_prompt_to_body, @@ -47,7 +46,6 @@ app.add_middleware( allow_headers=["*"], ) - app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER @@ -407,20 +405,25 @@ async def generate_chat_completion( url = app.state.config.OPENAI_API_BASE_URLS[idx] key = app.state.config.OPENAI_API_KEYS[idx] + is_o1 = payload["model"].lower().startswith("o1-") # Change max_completion_tokens to max_tokens (Backward compatible) - if "api.openai.com" not in url and not payload["model"].lower().startswith("o1-"): + if "api.openai.com" not in url and not is_o1: if "max_completion_tokens" in payload: # Remove "max_completion_tokens" from the payload payload["max_tokens"] = payload["max_completion_tokens"] del payload["max_completion_tokens"] else: - if payload["model"].lower().startswith("o1-") and "max_tokens" in payload: + if is_o1 and "max_tokens" in payload: payload["max_completion_tokens"] = payload["max_tokens"] del payload["max_tokens"] if "max_tokens" in payload and "max_completion_tokens" in payload: del payload["max_tokens"] + # Fix: O1 does not support the "system" parameter, Modify "system" to "user" + if is_o1 and payload["messages"][0]["role"] == "system": + payload["messages"][0]["role"] = "user" + # Convert the modified body back to JSON payload = json.dumps(payload) diff --git a/backend/open_webui/apps/retrieval/loader/main.py b/backend/open_webui/apps/retrieval/loader/main.py new file mode 100644 index 000000000..f0e8f804e --- /dev/null +++ b/backend/open_webui/apps/retrieval/loader/main.py @@ -0,0 +1,190 @@ +import requests +import logging +import ftfy + +from langchain_community.document_loaders import ( + BSHTMLLoader, + CSVLoader, + Docx2txtLoader, + OutlookMessageLoader, + PyPDFLoader, + TextLoader, + UnstructuredEPubLoader, + UnstructuredExcelLoader, + UnstructuredMarkdownLoader, + UnstructuredPowerPointLoader, + UnstructuredRSTLoader, + UnstructuredXMLLoader, + YoutubeLoader, +) +from langchain_core.documents import Document +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + +known_source_ext = [ + "go", + "py", + "java", + "sh", + "bat", + "ps1", + "cmd", + "js", + "ts", + "css", + "cpp", + "hpp", + "h", + "c", + "cs", + "sql", + "log", + "ini", + "pl", + "pm", + "r", + "dart", + "dockerfile", + "env", + "php", + "hs", + "hsc", + "lua", + "nginxconf", + "conf", + "m", + "mm", + "plsql", + "perl", + "rb", + "rs", + "db2", + "scala", + "bash", + "swift", + "vue", + "svelte", + "msg", + "ex", + "exs", + "erl", + "tsx", + "jsx", + "hs", + "lhs", +] + + +class TikaLoader: + def __init__(self, url, file_path, mime_type=None): + self.url = url + self.file_path = file_path + self.mime_type = mime_type + + def load(self) -> list[Document]: + with open(self.file_path, "rb") as f: + data = f.read() + + if self.mime_type is not None: + headers = {"Content-Type": self.mime_type} + else: + headers = {} + + endpoint = self.url + if not endpoint.endswith("/"): + endpoint += "/" + endpoint += "tika/text" + + r = requests.put(endpoint, data=data, headers=headers) + + if r.ok: + raw_metadata = r.json() + text = raw_metadata.get("X-TIKA:content", "") + + if "Content-Type" in raw_metadata: + headers["Content-Type"] = raw_metadata["Content-Type"] + + log.info("Tika extracted text: %s", text) + + return [Document(page_content=text, metadata=headers)] + else: + raise Exception(f"Error calling Tika: {r.reason}") + + +class Loader: + def __init__(self, engine: str = "", **kwargs): + self.engine = engine + self.kwargs = kwargs + + def load( + self, filename: str, file_content_type: str, file_path: str + ) -> list[Document]: + loader = self._get_loader(filename, file_content_type, file_path) + docs = loader.load() + + return [ + Document( + page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata + ) + for doc in docs + ] + + def _get_loader(self, filename: str, file_content_type: str, file_path: str): + file_ext = filename.split(".")[-1].lower() + + if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"): + if file_ext in known_source_ext or ( + file_content_type and file_content_type.find("text/") >= 0 + ): + loader = TextLoader(file_path, autodetect_encoding=True) + else: + loader = TikaLoader( + url=self.kwargs.get("TIKA_SERVER_URL"), + file_path=file_path, + mime_type=file_content_type, + ) + else: + if file_ext == "pdf": + loader = PyPDFLoader( + file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES") + ) + elif file_ext == "csv": + loader = CSVLoader(file_path) + elif file_ext == "rst": + loader = UnstructuredRSTLoader(file_path, mode="elements") + elif file_ext == "xml": + loader = UnstructuredXMLLoader(file_path) + elif file_ext in ["htm", "html"]: + loader = BSHTMLLoader(file_path, open_encoding="unicode_escape") + elif file_ext == "md": + loader = UnstructuredMarkdownLoader(file_path) + elif file_content_type == "application/epub+zip": + loader = UnstructuredEPubLoader(file_path) + elif ( + file_content_type + == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + or file_ext == "docx" + ): + loader = Docx2txtLoader(file_path) + elif file_content_type in [ + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ] or file_ext in ["xls", "xlsx"]: + loader = UnstructuredExcelLoader(file_path) + elif file_content_type in [ + "application/vnd.ms-powerpoint", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ] or file_ext in ["ppt", "pptx"]: + loader = UnstructuredPowerPointLoader(file_path) + elif file_ext == "msg": + loader = OutlookMessageLoader(file_path) + elif file_ext in known_source_ext or ( + file_content_type and file_content_type.find("text/") >= 0 + ): + loader = TextLoader(file_path, autodetect_encoding=True) + else: + loader = TextLoader(file_path, autodetect_encoding=True) + + return loader diff --git a/backend/open_webui/apps/rag/main.py b/backend/open_webui/apps/retrieval/main.py similarity index 69% rename from backend/open_webui/apps/rag/main.py rename to backend/open_webui/apps/retrieval/main.py index 7b476c056..a3e828978 100644 --- a/backend/open_webui/apps/rag/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -3,35 +3,40 @@ import logging import mimetypes import os import shutil -import socket -import urllib.parse + import uuid from datetime import datetime from pathlib import Path from typing import Iterator, Optional, Sequence, Union - -import numpy as np -import torch -import requests -import validators - from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel -from open_webui.apps.rag.search.main import SearchResult -from open_webui.apps.rag.search.brave import search_brave -from open_webui.apps.rag.search.duckduckgo import search_duckduckgo -from open_webui.apps.rag.search.google_pse import search_google_pse -from open_webui.apps.rag.search.jina_search import search_jina -from open_webui.apps.rag.search.searchapi import search_searchapi -from open_webui.apps.rag.search.searxng import search_searxng -from open_webui.apps.rag.search.serper import search_serper -from open_webui.apps.rag.search.serply import search_serply -from open_webui.apps.rag.search.serpstack import search_serpstack -from open_webui.apps.rag.search.tavily import search_tavily -from open_webui.apps.rag.utils import ( +from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT + +# Information retrieval models +from open_webui.apps.retrieval.model.colbert import ColBERT + +# Document loaders +from open_webui.apps.retrieval.loader.main import Loader + +# Web search engines +from open_webui.apps.retrieval.web.main import SearchResult +from open_webui.apps.retrieval.web.utils import get_web_loader +from open_webui.apps.retrieval.web.brave import search_brave +from open_webui.apps.retrieval.web.duckduckgo import search_duckduckgo +from open_webui.apps.retrieval.web.google_pse import search_google_pse +from open_webui.apps.retrieval.web.jina_search import search_jina +from open_webui.apps.retrieval.web.searchapi import search_searchapi +from open_webui.apps.retrieval.web.searxng import search_searxng +from open_webui.apps.retrieval.web.serper import search_serper +from open_webui.apps.retrieval.web.serply import search_serply +from open_webui.apps.retrieval.web.serpstack import search_serpstack +from open_webui.apps.retrieval.web.tavily import search_tavily + + +from open_webui.apps.retrieval.utils import ( get_embedding_function, get_model_path, query_collection, @@ -39,6 +44,7 @@ from open_webui.apps.rag.utils import ( query_doc, query_doc_with_hybrid_search, ) + from open_webui.apps.webui.models.documents import DocumentForm, Documents from open_webui.apps.webui.models.files import Files from open_webui.config import ( @@ -98,28 +104,13 @@ from open_webui.utils.misc import ( sanitize_filename, ) from open_webui.utils.utils import get_admin_user, get_verified_user -from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import ( - BSHTMLLoader, - CSVLoader, - Docx2txtLoader, - OutlookMessageLoader, - PyPDFLoader, - TextLoader, - UnstructuredEPubLoader, - UnstructuredExcelLoader, - UnstructuredMarkdownLoader, - UnstructuredPowerPointLoader, - UnstructuredRSTLoader, - UnstructuredXMLLoader, - WebBaseLoader, YoutubeLoader, ) from langchain_core.documents import Document -from colbert.infra import ColBERTConfig -from colbert.modeling.checkpoint import Checkpoint + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -200,83 +191,6 @@ def update_reranking_model( ): if reranking_model: if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]): - - class ColBERT: - def __init__(self, name) -> None: - print("ColBERT: Loading model", name) - self.device = "cuda" if torch.cuda.is_available() else "cpu" - - if DOCKER: - # This is a workaround for the issue with the docker container - # where the torch extension is not loaded properly - # and the following error is thrown: - # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory - - lock_file = "/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock" - if os.path.exists(lock_file): - os.remove(lock_file) - - self.ckpt = Checkpoint( - name, - colbert_config=ColBERTConfig(model_name=name), - ).to(self.device) - pass - - def calculate_similarity_scores( - self, query_embeddings, document_embeddings - ): - - query_embeddings = query_embeddings.to(self.device) - document_embeddings = document_embeddings.to(self.device) - - # Validate dimensions to ensure compatibility - if query_embeddings.dim() != 3: - raise ValueError( - f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}." - ) - if document_embeddings.dim() != 3: - raise ValueError( - f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}." - ) - if query_embeddings.size(0) not in [1, document_embeddings.size(0)]: - raise ValueError( - "There should be either one query or queries equal to the number of documents." - ) - - # Transpose the query embeddings to align for matrix multiplication - transposed_query_embeddings = query_embeddings.permute(0, 2, 1) - # Compute similarity scores using batch matrix multiplication - computed_scores = torch.matmul( - document_embeddings, transposed_query_embeddings - ) - # Apply max pooling to extract the highest semantic similarity across each document's sequence - maximum_scores = torch.max(computed_scores, dim=1).values - - # Sum up the maximum scores across features to get the overall document relevance scores - final_scores = maximum_scores.sum(dim=1) - - normalized_scores = torch.softmax(final_scores, dim=0) - - return normalized_scores.detach().cpu().numpy().astype(np.float32) - - def predict(self, sentences): - - query = sentences[0][0] - docs = [i[1] for i in sentences] - - # Embedding the documents - embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0] - # Embedding the queries - embedded_queries = self.ckpt.queryFromText([query], bsize=32) - embedded_query = embedded_queries[0] - - # Calculate retrieval scores for the query against all documents - scores = self.calculate_similarity_scores( - embedded_query.unsqueeze(0), embedded_docs - ) - - return scores - try: app.state.sentence_transformer_rf = ColBERT( get_model_path(reranking_model, auto_update) @@ -332,10 +246,10 @@ app.add_middleware( class CollectionNameForm(BaseModel): - collection_name: Optional[str] = "test" + collection_name: Optional[str] = None -class UrlForm(CollectionNameForm): +class ProcessUrlForm(CollectionNameForm): url: str @@ -707,103 +621,266 @@ async def update_query_settings( } -class QueryDocForm(BaseModel): - collection_name: str - query: str - k: Optional[int] = None - r: Optional[float] = None - hybrid: Optional[bool] = None +#################################### +# +# Document process and retrieval +# +#################################### -@app.post("/query/doc") -def query_doc_handler( - form_data: QueryDocForm, +def save_docs_to_vector_db( + docs, + collection_name, + metadata: Optional[dict] = None, + overwrite: bool = False, + split: bool = True, +) -> bool: + log.info(f"save_docs_to_vector_db {docs} {collection_name}") + + if split: + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=app.state.config.CHUNK_SIZE, + chunk_overlap=app.state.config.CHUNK_OVERLAP, + add_start_index=True, + ) + docs = text_splitter.split_documents(docs) + + if len(docs) == 0: + raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) + + texts = [doc.page_content for doc in docs] + metadatas = [{**doc.metadata, **(metadata if metadata else {})} for doc in docs] + + # ChromaDB does not like datetime formats + # for meta-data so convert them to string. + for metadata in metadatas: + for key, value in metadata.items(): + if isinstance(value, datetime): + metadata[key] = str(value) + + try: + if overwrite: + if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): + log.info(f"deleting existing collection {collection_name}") + VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) + + if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): + log.info(f"collection {collection_name} already exists") + return True + else: + embedding_function = get_embedding_function( + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, + app.state.sentence_transformer_ef, + app.state.config.OPENAI_API_KEY, + app.state.config.OPENAI_API_BASE_URL, + app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, + ) + + embeddings = embedding_function( + list(map(lambda x: x.replace("\n", " "), texts)) + ) + + VECTOR_DB_CLIENT.insert( + collection_name=collection_name, + items=[ + { + "id": str(uuid.uuid4()), + "text": text, + "vector": embeddings[idx], + "metadata": metadatas[idx], + } + for idx, text in enumerate(texts) + ], + ) + + return True + except Exception as e: + log.exception(e) + return False + + +class ProcessFileForm(BaseModel): + file_id: str + collection_name: Optional[str] = None + + +@app.post("/process/file") +def process_file( + form_data: ProcessFileForm, user=Depends(get_verified_user), ): try: - if app.state.config.ENABLE_RAG_HYBRID_SEARCH: - return query_doc_with_hybrid_search( - collection_name=form_data.collection_name, - query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, - reranking_function=app.state.sentence_transformer_rf, - r=( - form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD - ), + file = Files.get_file_by_id(form_data.file_id) + file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}") + + collection_name = form_data.collection_name + if collection_name is None: + with open(file_path, "rb") as f: + collection_name = calculate_sha256(f)[:63] + + loader = Loader( + engine=app.state.config.CONTENT_EXTRACTION_ENGINE, + TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, + PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, + ) + docs = loader.load(file.filename, file.meta.get("content_type"), file_path) + raw_text_content = " ".join([doc.page_content for doc in docs]) + + Files.update_files_metadata_by_id( + form_data.file_id, + { + "content": { + "text": raw_text_content, + } + }, + ) + + try: + result = save_docs_to_vector_db( + docs, + collection_name, + { + "file_id": form_data.file_id, + "name": file.meta.get("name", file.filename), + }, ) - else: - return query_doc( - collection_name=form_data.collection_name, - query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, + + if result: + return { + "status": True, + "collection_name": collection_name, + "filename": file.meta.get("name", file.filename), + } + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=e, ) except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) + if "No pandoc was found" in str(e): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) -class QueryCollectionsForm(BaseModel): - collection_names: list[str] - query: str - k: Optional[int] = None - r: Optional[float] = None - hybrid: Optional[bool] = None +class ProcessTextForm(BaseModel): + name: str + content: str + collection_name: Optional[str] = None -@app.post("/query/collection") -def query_collection_handler( - form_data: QueryCollectionsForm, +@app.post("/process/text") +def process_text( + form_data: ProcessTextForm, user=Depends(get_verified_user), ): - try: - if app.state.config.ENABLE_RAG_HYBRID_SEARCH: - return query_collection_with_hybrid_search( - collection_names=form_data.collection_names, - query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, - reranking_function=app.state.sentence_transformer_rf, - r=( - form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD - ), - ) - else: - return query_collection( - collection_names=form_data.collection_names, - query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, - ) + collection_name = form_data.collection_name + if collection_name is None: + collection_name = calculate_sha256_string(form_data.content) - except Exception as e: - log.exception(e) + docs = [ + Document( + page_content=form_data.content, + metadata={"name": form_data.name, "created_by": user.id}, + ) + ] + result = save_docs_to_vector_db(docs, collection_name) + + if result: + return {"status": True, "collection_name": collection_name} + else: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(), ) -@app.post("/youtube") -def store_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)): +@app.get("/process/dir") +def process_docs_dir(user=Depends(get_admin_user)): + for path in Path(DOCS_DIR).rglob("./**/*"): + try: + if path.is_file() and not path.name.startswith("."): + tags = extract_folders_after_data_docs(path) + filename = path.name + file_content_type = mimetypes.guess_type(path) + + with open(path, "rb") as f: + collection_name = calculate_sha256(f)[:63] + + loader = Loader( + engine=app.state.config.CONTENT_EXTRACTION_ENGINE, + TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, + PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, + ) + docs = loader.load(filename, file_content_type[0], str(path)) + + try: + result = save_docs_to_vector_db(docs, collection_name) + + if result: + sanitized_filename = sanitize_filename(filename) + doc = Documents.get_doc_by_name(sanitized_filename) + + if doc is None: + doc = Documents.insert_new_doc( + user.id, + DocumentForm( + **{ + "name": sanitized_filename, + "title": filename, + "collection_name": collection_name, + "filename": filename, + "content": ( + json.dumps( + { + "tags": list( + map( + lambda name: {"name": name}, + tags, + ) + ) + } + ) + if len(tags) + else "{}" + ), + } + ), + ) + except Exception as e: + log.exception(e) + pass + + except Exception as e: + log.exception(e) + + return True + + +@app.post("/process/youtube") +def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)): try: + collection_name = form_data.collection_name + if not collection_name: + collection_name = calculate_sha256_string(form_data.url)[:63] + loader = YoutubeLoader.from_youtube_url( form_data.url, add_video_info=True, language=app.state.config.YOUTUBE_LOADER_LANGUAGE, translation=app.state.YOUTUBE_LOADER_TRANSLATION, ) - data = loader.load() + docs = loader.load() + save_docs_to_vector_db(docs, collection_name, overwrite=True) - collection_name = form_data.collection_name - if collection_name == "": - collection_name = calculate_sha256_string(form_data.url)[:63] - - store_data_in_vector_db(data, collection_name, overwrite=True) return { "status": True, "collection_name": collection_name, @@ -817,21 +894,21 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)): ) -@app.post("/web") -def store_web(form_data: UrlForm, user=Depends(get_verified_user)): - # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" +@app.post("/process/web") +def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)): try: + collection_name = form_data.collection_name + if not collection_name: + collection_name = calculate_sha256_string(form_data.url)[:63] + loader = get_web_loader( form_data.url, verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) - data = loader.load() + docs = loader.load() + save_docs_to_vector_db(docs, collection_name, overwrite=True) - collection_name = form_data.collection_name - if collection_name == "": - collection_name = calculate_sha256_string(form_data.url)[:63] - - store_data_in_vector_db(data, collection_name, overwrite=True) return { "status": True, "collection_name": collection_name, @@ -845,53 +922,6 @@ def store_web(form_data: UrlForm, user=Depends(get_verified_user)): ) -def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True): - # Check if the URL is valid - if not validate_url(url): - raise ValueError(ERROR_MESSAGES.INVALID_URL) - return SafeWebBaseLoader( - url, - verify_ssl=verify_ssl, - requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS, - continue_on_failure=True, - ) - - -def validate_url(url: Union[str, Sequence[str]]): - if isinstance(url, str): - if isinstance(validators.url(url), validators.ValidationError): - raise ValueError(ERROR_MESSAGES.INVALID_URL) - if not ENABLE_RAG_LOCAL_WEB_FETCH: - # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses - parsed_url = urllib.parse.urlparse(url) - # Get IPv4 and IPv6 addresses - ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname) - # Check if any of the resolved addresses are private - # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader - for ip in ipv4_addresses: - if validators.ipv4(ip, private=True): - raise ValueError(ERROR_MESSAGES.INVALID_URL) - for ip in ipv6_addresses: - if validators.ipv6(ip, private=True): - raise ValueError(ERROR_MESSAGES.INVALID_URL) - return True - elif isinstance(url, Sequence): - return all(validate_url(u) for u in url) - else: - return False - - -def resolve_hostname(hostname): - # Get address information - addr_info = socket.getaddrinfo(hostname, None) - - # Extract IP addresses from address information - ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET] - ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6] - - return ipv4_addresses, ipv6_addresses - - def search_web(engine: str, query: str) -> list[SearchResult]: """Search the web using a search engine and return the results as a list of SearchResult objects. Will look for a search engine API key in environment variables in the following order: @@ -1007,8 +1037,8 @@ def search_web(engine: str, query: str) -> list[SearchResult]: raise Exception("No search engine API key found in environment variables") -@app.post("/web/search") -def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)): +@app.post("/process/web/search") +def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)): try: logging.info( f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" @@ -1026,15 +1056,16 @@ def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)): ) try: - urls = [result.link for result in web_results] - loader = get_web_loader(urls) - data = loader.load() - collection_name = form_data.collection_name if collection_name == "": collection_name = calculate_sha256_string(form_data.query)[:63] - store_data_in_vector_db(data, collection_name, overwrite=True) + urls = [result.link for result in web_results] + + loader = get_web_loader(urls) + docs = loader.load() + save_docs_to_vector_db(docs, collection_name, overwrite=True) + return { "status": True, "collection_name": collection_name, @@ -1048,449 +1079,92 @@ def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)): ) -def store_data_in_vector_db( - data, collection_name, metadata: Optional[dict] = None, overwrite: bool = False -) -> bool: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, - add_start_index=True, - ) - - docs = text_splitter.split_documents(data) - - if len(docs) > 0: - log.info(f"store_data_in_vector_db {docs}") - return store_docs_in_vector_db(docs, collection_name, metadata, overwrite), None - else: - raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) +class QueryDocForm(BaseModel): + collection_name: str + query: str + k: Optional[int] = None + r: Optional[float] = None + hybrid: Optional[bool] = None -def store_text_in_vector_db( - text, metadata, collection_name, overwrite: bool = False -) -> bool: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, - add_start_index=True, - ) - docs = text_splitter.create_documents([text], metadatas=[metadata]) - return store_docs_in_vector_db(docs, collection_name, overwrite=overwrite) - - -def store_docs_in_vector_db( - docs, collection_name, metadata: Optional[dict] = None, overwrite: bool = False -) -> bool: - log.info(f"store_docs_in_vector_db {docs} {collection_name}") - - texts = [doc.page_content for doc in docs] - metadatas = [{**doc.metadata, **(metadata if metadata else {})} for doc in docs] - - # ChromaDB does not like datetime formats - # for meta-data so convert them to string. - for metadata in metadatas: - for key, value in metadata.items(): - if isinstance(value, datetime): - metadata[key] = str(value) - - try: - if overwrite: - if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): - log.info(f"deleting existing collection {collection_name}") - VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) - - if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): - log.info(f"collection {collection_name} already exists") - return True - else: - embedding_function = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, - app.state.config.OPENAI_API_KEY, - app.state.config.OPENAI_API_BASE_URL, - app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, - ) - - embedding_texts = embedding_function( - list(map(lambda x: x.replace("\n", " "), texts)) - ) - - VECTOR_DB_CLIENT.insert( - collection_name=collection_name, - items=[ - { - "id": str(uuid.uuid4()), - "text": text, - "vector": embedding_texts[idx], - "metadata": metadatas[idx], - } - for idx, text in enumerate(texts) - ], - ) - - return True - except Exception as e: - log.exception(e) - return False - - -class TikaLoader: - def __init__(self, file_path, mime_type=None): - self.file_path = file_path - self.mime_type = mime_type - - def load(self) -> list[Document]: - with open(self.file_path, "rb") as f: - data = f.read() - - if self.mime_type is not None: - headers = {"Content-Type": self.mime_type} - else: - headers = {} - - endpoint = app.state.config.TIKA_SERVER_URL - if not endpoint.endswith("/"): - endpoint += "/" - endpoint += "tika/text" - - r = requests.put(endpoint, data=data, headers=headers) - - if r.ok: - raw_metadata = r.json() - text = raw_metadata.get("X-TIKA:content", "") - - if "Content-Type" in raw_metadata: - headers["Content-Type"] = raw_metadata["Content-Type"] - - log.info("Tika extracted text: %s", text) - - return [Document(page_content=text, metadata=headers)] - else: - raise Exception(f"Error calling Tika: {r.reason}") - - -def get_loader(filename: str, file_content_type: str, file_path: str): - file_ext = filename.split(".")[-1].lower() - known_type = True - - known_source_ext = [ - "go", - "py", - "java", - "sh", - "bat", - "ps1", - "cmd", - "js", - "ts", - "css", - "cpp", - "hpp", - "h", - "c", - "cs", - "sql", - "log", - "ini", - "pl", - "pm", - "r", - "dart", - "dockerfile", - "env", - "php", - "hs", - "hsc", - "lua", - "nginxconf", - "conf", - "m", - "mm", - "plsql", - "perl", - "rb", - "rs", - "db2", - "scala", - "bash", - "swift", - "vue", - "svelte", - "msg", - "ex", - "exs", - "erl", - "tsx", - "jsx", - "hs", - "lhs", - ] - - if ( - app.state.config.CONTENT_EXTRACTION_ENGINE == "tika" - and app.state.config.TIKA_SERVER_URL - ): - if file_ext in known_source_ext or ( - file_content_type and file_content_type.find("text/") >= 0 - ): - loader = TextLoader(file_path, autodetect_encoding=True) - else: - loader = TikaLoader(file_path, file_content_type) - else: - if file_ext == "pdf": - loader = PyPDFLoader( - file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES - ) - elif file_ext == "csv": - loader = CSVLoader(file_path) - elif file_ext == "rst": - loader = UnstructuredRSTLoader(file_path, mode="elements") - elif file_ext == "xml": - loader = UnstructuredXMLLoader(file_path) - elif file_ext in ["htm", "html"]: - loader = BSHTMLLoader(file_path, open_encoding="unicode_escape") - elif file_ext == "md": - loader = UnstructuredMarkdownLoader(file_path) - elif file_content_type == "application/epub+zip": - loader = UnstructuredEPubLoader(file_path) - elif ( - file_content_type - == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - or file_ext == "docx" - ): - loader = Docx2txtLoader(file_path) - elif file_content_type in [ - "application/vnd.ms-excel", - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - ] or file_ext in ["xls", "xlsx"]: - loader = UnstructuredExcelLoader(file_path) - elif file_content_type in [ - "application/vnd.ms-powerpoint", - "application/vnd.openxmlformats-officedocument.presentationml.presentation", - ] or file_ext in ["ppt", "pptx"]: - loader = UnstructuredPowerPointLoader(file_path) - elif file_ext == "msg": - loader = OutlookMessageLoader(file_path) - elif file_ext in known_source_ext or ( - file_content_type and file_content_type.find("text/") >= 0 - ): - loader = TextLoader(file_path, autodetect_encoding=True) - else: - loader = TextLoader(file_path, autodetect_encoding=True) - known_type = False - - return loader, known_type - - -@app.post("/doc") -def store_doc( - collection_name: Optional[str] = Form(None), - file: UploadFile = File(...), - user=Depends(get_verified_user), -): - # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" - - log.info(f"file.content_type: {file.content_type}") - try: - unsanitized_filename = file.filename - filename = os.path.basename(unsanitized_filename) - - file_path = f"{UPLOAD_DIR}/{filename}" - - contents = file.file.read() - with open(file_path, "wb") as f: - f.write(contents) - f.close() - - f = open(file_path, "rb") - if collection_name is None: - collection_name = calculate_sha256(f)[:63] - f.close() - - loader, known_type = get_loader(filename, file.content_type, file_path) - data = loader.load() - - try: - result = store_data_in_vector_db(data, collection_name) - - if result: - return { - "status": True, - "collection_name": collection_name, - "filename": filename, - "known_type": known_type, - } - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=e, - ) - except Exception as e: - log.exception(e) - if "No pandoc was found" in str(e): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, - ) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - -class ProcessDocForm(BaseModel): - file_id: str - collection_name: Optional[str] = None - - -@app.post("/process/doc") -def process_doc( - form_data: ProcessDocForm, +@app.post("/query/doc") +def query_doc_handler( + form_data: QueryDocForm, user=Depends(get_verified_user), ): try: - file = Files.get_file_by_id(form_data.file_id) - file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}") - - f = open(file_path, "rb") - - collection_name = form_data.collection_name - if collection_name is None: - collection_name = calculate_sha256(f)[:63] - f.close() - - loader, known_type = get_loader( - file.filename, file.meta.get("content_type"), file_path - ) - data = loader.load() - - try: - result = store_data_in_vector_db( - data, - collection_name, - { - "file_id": form_data.file_id, - "name": file.meta.get("name", file.filename), - }, + if app.state.config.ENABLE_RAG_HYBRID_SEARCH: + return query_doc_with_hybrid_search( + collection_name=form_data.collection_name, + query=form_data.query, + embedding_function=app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else app.state.config.TOP_K, + reranking_function=app.state.sentence_transformer_rf, + r=( + form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD + ), ) - - if result: - return { - "status": True, - "collection_name": collection_name, - "known_type": known_type, - "filename": file.meta.get("name", file.filename), - } - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=e, + else: + return query_doc( + collection_name=form_data.collection_name, + query=form_data.query, + embedding_function=app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else app.state.config.TOP_K, ) except Exception as e: log.exception(e) - if "No pandoc was found" in str(e): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, - ) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - -class TextRAGForm(BaseModel): - name: str - content: str - collection_name: Optional[str] = None - - -@app.post("/text") -def store_text( - form_data: TextRAGForm, - user=Depends(get_verified_user), -): - collection_name = form_data.collection_name - if collection_name is None: - collection_name = calculate_sha256_string(form_data.content) - - result = store_text_in_vector_db( - form_data.content, - metadata={"name": form_data.name, "created_by": user.id}, - collection_name=collection_name, - ) - - if result: - return {"status": True, "collection_name": collection_name} - else: raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=ERROR_MESSAGES.DEFAULT(), + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), ) -@app.get("/scan") -def scan_docs_dir(user=Depends(get_admin_user)): - for path in Path(DOCS_DIR).rglob("./**/*"): - try: - if path.is_file() and not path.name.startswith("."): - tags = extract_folders_after_data_docs(path) - filename = path.name - file_content_type = mimetypes.guess_type(path) +class QueryCollectionsForm(BaseModel): + collection_names: list[str] + query: str + k: Optional[int] = None + r: Optional[float] = None + hybrid: Optional[bool] = None - f = open(path, "rb") - collection_name = calculate_sha256(f)[:63] - f.close() - loader, known_type = get_loader( - filename, file_content_type[0], str(path) - ) - data = loader.load() +@app.post("/query/collection") +def query_collection_handler( + form_data: QueryCollectionsForm, + user=Depends(get_verified_user), +): + try: + if app.state.config.ENABLE_RAG_HYBRID_SEARCH: + return query_collection_with_hybrid_search( + collection_names=form_data.collection_names, + query=form_data.query, + embedding_function=app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else app.state.config.TOP_K, + reranking_function=app.state.sentence_transformer_rf, + r=( + form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD + ), + ) + else: + return query_collection( + collection_names=form_data.collection_names, + query=form_data.query, + embedding_function=app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else app.state.config.TOP_K, + ) - try: - result = store_data_in_vector_db(data, collection_name) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) - if result: - sanitized_filename = sanitize_filename(filename) - doc = Documents.get_doc_by_name(sanitized_filename) - if doc is None: - doc = Documents.insert_new_doc( - user.id, - DocumentForm( - **{ - "name": sanitized_filename, - "title": filename, - "collection_name": collection_name, - "filename": filename, - "content": ( - json.dumps( - { - "tags": list( - map( - lambda name: {"name": name}, - tags, - ) - ) - } - ) - if len(tags) - else "{}" - ), - } - ), - ) - except Exception as e: - log.exception(e) - pass - - except Exception as e: - log.exception(e) - - return True +#################################### +# +# Vector DB operations +# +#################################### @app.post("/reset/db") @@ -1543,33 +1217,6 @@ def reset(user=Depends(get_admin_user)) -> bool: return True -class SafeWebBaseLoader(WebBaseLoader): - """WebBaseLoader with enhanced error handling for URLs.""" - - def lazy_load(self) -> Iterator[Document]: - """Lazy load text from the url(s) in web_path with error handling.""" - for path in self.web_paths: - try: - soup = self._scrape(path, bs_kwargs=self.bs_kwargs) - text = soup.get_text(**self.bs_get_text_kwargs) - - # Build metadata - metadata = {"source": path} - if title := soup.find("title"): - metadata["title"] = title.get_text() - if description := soup.find("meta", attrs={"name": "description"}): - metadata["description"] = description.get( - "content", "No description found." - ) - if html := soup.find("html"): - metadata["language"] = html.get("lang", "No language found.") - - yield Document(page_content=text, metadata=metadata) - except Exception as e: - # Log the error and continue with the next URL - log.error(f"Error loading {path}: {e}") - - if ENV == "dev": @app.get("/ef") diff --git a/backend/open_webui/apps/retrieval/model/colbert.py b/backend/open_webui/apps/retrieval/model/colbert.py new file mode 100644 index 000000000..ea3204cb8 --- /dev/null +++ b/backend/open_webui/apps/retrieval/model/colbert.py @@ -0,0 +1,81 @@ +import os +import torch +import numpy as np +from colbert.infra import ColBERTConfig +from colbert.modeling.checkpoint import Checkpoint + + +class ColBERT: + def __init__(self, name, **kwargs) -> None: + print("ColBERT: Loading model", name) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + DOCKER = kwargs.get("env") == "docker" + if DOCKER: + # This is a workaround for the issue with the docker container + # where the torch extension is not loaded properly + # and the following error is thrown: + # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory + + lock_file = ( + "/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock" + ) + if os.path.exists(lock_file): + os.remove(lock_file) + + self.ckpt = Checkpoint( + name, + colbert_config=ColBERTConfig(model_name=name), + ).to(self.device) + pass + + def calculate_similarity_scores(self, query_embeddings, document_embeddings): + + query_embeddings = query_embeddings.to(self.device) + document_embeddings = document_embeddings.to(self.device) + + # Validate dimensions to ensure compatibility + if query_embeddings.dim() != 3: + raise ValueError( + f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}." + ) + if document_embeddings.dim() != 3: + raise ValueError( + f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}." + ) + if query_embeddings.size(0) not in [1, document_embeddings.size(0)]: + raise ValueError( + "There should be either one query or queries equal to the number of documents." + ) + + # Transpose the query embeddings to align for matrix multiplication + transposed_query_embeddings = query_embeddings.permute(0, 2, 1) + # Compute similarity scores using batch matrix multiplication + computed_scores = torch.matmul(document_embeddings, transposed_query_embeddings) + # Apply max pooling to extract the highest semantic similarity across each document's sequence + maximum_scores = torch.max(computed_scores, dim=1).values + + # Sum up the maximum scores across features to get the overall document relevance scores + final_scores = maximum_scores.sum(dim=1) + + normalized_scores = torch.softmax(final_scores, dim=0) + + return normalized_scores.detach().cpu().numpy().astype(np.float32) + + def predict(self, sentences): + + query = sentences[0][0] + docs = [i[1] for i in sentences] + + # Embedding the documents + embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0] + # Embedding the queries + embedded_queries = self.ckpt.queryFromText([query], bsize=32) + embedded_query = embedded_queries[0] + + # Calculate retrieval scores for the query against all documents + scores = self.calculate_similarity_scores( + embedded_query.unsqueeze(0), embedded_docs + ) + + return scores diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/retrieval/utils.py similarity index 99% rename from backend/open_webui/apps/rag/utils.py rename to backend/open_webui/apps/retrieval/utils.py index f9443d380..1fa30e6a0 100644 --- a/backend/open_webui/apps/rag/utils.py +++ b/backend/open_webui/apps/retrieval/utils.py @@ -15,7 +15,7 @@ from open_webui.apps.ollama.main import ( GenerateEmbeddingsForm, generate_ollama_embeddings, ) -from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT +from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/apps/rag/vector/connector.py b/backend/open_webui/apps/retrieval/vector/connector.py similarity index 52% rename from backend/open_webui/apps/rag/vector/connector.py rename to backend/open_webui/apps/retrieval/vector/connector.py index 073becdbe..5b203271f 100644 --- a/backend/open_webui/apps/rag/vector/connector.py +++ b/backend/open_webui/apps/retrieval/vector/connector.py @@ -1,5 +1,5 @@ -from open_webui.apps.rag.vector.dbs.chroma import ChromaClient -from open_webui.apps.rag.vector.dbs.milvus import MilvusClient +from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient +from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient from open_webui.config import VECTOR_DB diff --git a/backend/open_webui/apps/rag/vector/dbs/chroma.py b/backend/open_webui/apps/retrieval/vector/dbs/chroma.py similarity index 98% rename from backend/open_webui/apps/rag/vector/dbs/chroma.py rename to backend/open_webui/apps/retrieval/vector/dbs/chroma.py index 5f9420108..fe065f868 100644 --- a/backend/open_webui/apps/rag/vector/dbs/chroma.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/chroma.py @@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches from typing import Optional -from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult +from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( CHROMA_DATA_PATH, CHROMA_HTTP_HOST, diff --git a/backend/open_webui/apps/rag/vector/dbs/milvus.py b/backend/open_webui/apps/retrieval/vector/dbs/milvus.py similarity index 96% rename from backend/open_webui/apps/rag/vector/dbs/milvus.py rename to backend/open_webui/apps/retrieval/vector/dbs/milvus.py index 33ec6035a..77300acf2 100644 --- a/backend/open_webui/apps/rag/vector/dbs/milvus.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/milvus.py @@ -4,7 +4,7 @@ import json from typing import Optional -from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult +from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( MILVUS_URI, ) @@ -98,7 +98,10 @@ class MilvusClient: index_params = self.client.prepare_index_params() index_params.add_index( - field_name="vector", index_type="HNSW", metric_type="COSINE", params={} + field_name="vector", + index_type="HNSW", + metric_type="COSINE", + params={"M": 16, "efConstruction": 100}, ) self.client.create_collection( diff --git a/backend/open_webui/apps/rag/vector/main.py b/backend/open_webui/apps/retrieval/vector/main.py similarity index 100% rename from backend/open_webui/apps/rag/vector/main.py rename to backend/open_webui/apps/retrieval/vector/main.py diff --git a/backend/open_webui/apps/rag/search/brave.py b/backend/open_webui/apps/retrieval/web/brave.py similarity index 93% rename from backend/open_webui/apps/rag/search/brave.py rename to backend/open_webui/apps/retrieval/web/brave.py index 2eb256b4b..f988b3b08 100644 --- a/backend/open_webui/apps/rag/search/brave.py +++ b/backend/open_webui/apps/retrieval/web/brave.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.rag.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/rag/search/duckduckgo.py b/backend/open_webui/apps/retrieval/web/duckduckgo.py similarity index 95% rename from backend/open_webui/apps/rag/search/duckduckgo.py rename to backend/open_webui/apps/retrieval/web/duckduckgo.py index a8a580aca..11e512296 100644 --- a/backend/open_webui/apps/rag/search/duckduckgo.py +++ b/backend/open_webui/apps/retrieval/web/duckduckgo.py @@ -1,7 +1,7 @@ import logging from typing import Optional -from open_webui.apps.rag.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from duckduckgo_search import DDGS from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/apps/rag/search/google_pse.py b/backend/open_webui/apps/retrieval/web/google_pse.py similarity index 94% rename from backend/open_webui/apps/rag/search/google_pse.py rename to backend/open_webui/apps/retrieval/web/google_pse.py index a7f75a6c6..61b919583 100644 --- a/backend/open_webui/apps/rag/search/google_pse.py +++ b/backend/open_webui/apps/retrieval/web/google_pse.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.rag.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/rag/search/jina_search.py b/backend/open_webui/apps/retrieval/web/jina_search.py similarity index 94% rename from backend/open_webui/apps/rag/search/jina_search.py rename to backend/open_webui/apps/retrieval/web/jina_search.py index 41cde679d..487bbc948 100644 --- a/backend/open_webui/apps/rag/search/jina_search.py +++ b/backend/open_webui/apps/retrieval/web/jina_search.py @@ -1,7 +1,7 @@ import logging import requests -from open_webui.apps.rag.search.main import SearchResult +from open_webui.apps.retrieval.web.main import SearchResult from open_webui.env import SRC_LOG_LEVELS from yarl import URL diff --git a/backend/open_webui/apps/rag/search/main.py b/backend/open_webui/apps/retrieval/web/main.py similarity index 100% rename from backend/open_webui/apps/rag/search/main.py rename to backend/open_webui/apps/retrieval/web/main.py diff --git a/backend/open_webui/apps/rag/search/searchapi.py b/backend/open_webui/apps/retrieval/web/searchapi.py similarity index 93% rename from backend/open_webui/apps/rag/search/searchapi.py rename to backend/open_webui/apps/retrieval/web/searchapi.py index 9ec9a0747..412dc6b69 100644 --- a/backend/open_webui/apps/rag/search/searchapi.py +++ b/backend/open_webui/apps/retrieval/web/searchapi.py @@ -3,7 +3,7 @@ from typing import Optional from urllib.parse import urlencode import requests -from open_webui.apps.rag.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/rag/search/searxng.py b/backend/open_webui/apps/retrieval/web/searxng.py similarity index 97% rename from backend/open_webui/apps/rag/search/searxng.py rename to backend/open_webui/apps/retrieval/web/searxng.py index 26c534aa3..cb1eaf91d 100644 --- a/backend/open_webui/apps/rag/search/searxng.py +++ b/backend/open_webui/apps/retrieval/web/searxng.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.rag.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/rag/search/serper.py b/backend/open_webui/apps/retrieval/web/serper.py similarity index 93% rename from backend/open_webui/apps/rag/search/serper.py rename to backend/open_webui/apps/retrieval/web/serper.py index ed7cc2c5f..436fa167e 100644 --- a/backend/open_webui/apps/rag/search/serper.py +++ b/backend/open_webui/apps/retrieval/web/serper.py @@ -3,7 +3,7 @@ import logging from typing import Optional import requests -from open_webui.apps.rag.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/rag/search/serply.py b/backend/open_webui/apps/retrieval/web/serply.py similarity index 95% rename from backend/open_webui/apps/rag/search/serply.py rename to backend/open_webui/apps/retrieval/web/serply.py index 260e9b30e..1c2521c47 100644 --- a/backend/open_webui/apps/rag/search/serply.py +++ b/backend/open_webui/apps/retrieval/web/serply.py @@ -3,7 +3,7 @@ from typing import Optional from urllib.parse import urlencode import requests -from open_webui.apps.rag.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/rag/search/serpstack.py b/backend/open_webui/apps/retrieval/web/serpstack.py similarity index 94% rename from backend/open_webui/apps/rag/search/serpstack.py rename to backend/open_webui/apps/retrieval/web/serpstack.py index 962c1a5b3..b655934de 100644 --- a/backend/open_webui/apps/rag/search/serpstack.py +++ b/backend/open_webui/apps/retrieval/web/serpstack.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.rag.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/rag/search/tavily.py b/backend/open_webui/apps/retrieval/web/tavily.py similarity index 94% rename from backend/open_webui/apps/rag/search/tavily.py rename to backend/open_webui/apps/retrieval/web/tavily.py index a619d29ed..03b0be75a 100644 --- a/backend/open_webui/apps/rag/search/tavily.py +++ b/backend/open_webui/apps/retrieval/web/tavily.py @@ -1,7 +1,7 @@ import logging import requests -from open_webui.apps.rag.search.main import SearchResult +from open_webui.apps.retrieval.web.main import SearchResult from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/rag/search/testdata/brave.json b/backend/open_webui/apps/retrieval/web/testdata/brave.json similarity index 100% rename from backend/open_webui/apps/rag/search/testdata/brave.json rename to backend/open_webui/apps/retrieval/web/testdata/brave.json diff --git a/backend/open_webui/apps/rag/search/testdata/google_pse.json b/backend/open_webui/apps/retrieval/web/testdata/google_pse.json similarity index 100% rename from backend/open_webui/apps/rag/search/testdata/google_pse.json rename to backend/open_webui/apps/retrieval/web/testdata/google_pse.json diff --git a/backend/open_webui/apps/rag/search/testdata/searchapi.json b/backend/open_webui/apps/retrieval/web/testdata/searchapi.json similarity index 100% rename from backend/open_webui/apps/rag/search/testdata/searchapi.json rename to backend/open_webui/apps/retrieval/web/testdata/searchapi.json diff --git a/backend/open_webui/apps/rag/search/testdata/searxng.json b/backend/open_webui/apps/retrieval/web/testdata/searxng.json similarity index 100% rename from backend/open_webui/apps/rag/search/testdata/searxng.json rename to backend/open_webui/apps/retrieval/web/testdata/searxng.json diff --git a/backend/open_webui/apps/rag/search/testdata/serper.json b/backend/open_webui/apps/retrieval/web/testdata/serper.json similarity index 100% rename from backend/open_webui/apps/rag/search/testdata/serper.json rename to backend/open_webui/apps/retrieval/web/testdata/serper.json diff --git a/backend/open_webui/apps/rag/search/testdata/serply.json b/backend/open_webui/apps/retrieval/web/testdata/serply.json similarity index 100% rename from backend/open_webui/apps/rag/search/testdata/serply.json rename to backend/open_webui/apps/retrieval/web/testdata/serply.json diff --git a/backend/open_webui/apps/rag/search/testdata/serpstack.json b/backend/open_webui/apps/retrieval/web/testdata/serpstack.json similarity index 100% rename from backend/open_webui/apps/rag/search/testdata/serpstack.json rename to backend/open_webui/apps/retrieval/web/testdata/serpstack.json diff --git a/backend/open_webui/apps/retrieval/web/utils.py b/backend/open_webui/apps/retrieval/web/utils.py new file mode 100644 index 000000000..2df98b33c --- /dev/null +++ b/backend/open_webui/apps/retrieval/web/utils.py @@ -0,0 +1,97 @@ +import socket +import urllib.parse +import validators +from typing import Union, Sequence, Iterator + +from langchain_community.document_loaders import ( + WebBaseLoader, +) +from langchain_core.documents import Document + + +from open_webui.constants import ERROR_MESSAGES +from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH +from open_webui.env import SRC_LOG_LEVELS + +import logging + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def validate_url(url: Union[str, Sequence[str]]): + if isinstance(url, str): + if isinstance(validators.url(url), validators.ValidationError): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + if not ENABLE_RAG_LOCAL_WEB_FETCH: + # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses + parsed_url = urllib.parse.urlparse(url) + # Get IPv4 and IPv6 addresses + ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname) + # Check if any of the resolved addresses are private + # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader + for ip in ipv4_addresses: + if validators.ipv4(ip, private=True): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + for ip in ipv6_addresses: + if validators.ipv6(ip, private=True): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + return True + elif isinstance(url, Sequence): + return all(validate_url(u) for u in url) + else: + return False + + +def resolve_hostname(hostname): + # Get address information + addr_info = socket.getaddrinfo(hostname, None) + + # Extract IP addresses from address information + ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET] + ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6] + + return ipv4_addresses, ipv6_addresses + + +class SafeWebBaseLoader(WebBaseLoader): + """WebBaseLoader with enhanced error handling for URLs.""" + + def lazy_load(self) -> Iterator[Document]: + """Lazy load text from the url(s) in web_path with error handling.""" + for path in self.web_paths: + try: + soup = self._scrape(path, bs_kwargs=self.bs_kwargs) + text = soup.get_text(**self.bs_get_text_kwargs) + + # Build metadata + metadata = {"source": path} + if title := soup.find("title"): + metadata["title"] = title.get_text() + if description := soup.find("meta", attrs={"name": "description"}): + metadata["description"] = description.get( + "content", "No description found." + ) + if html := soup.find("html"): + metadata["language"] = html.get("lang", "No language found.") + + yield Document(page_content=text, metadata=metadata) + except Exception as e: + # Log the error and continue with the next URL + log.error(f"Error loading {path}: {e}") + + +def get_web_loader( + url: Union[str, Sequence[str]], + verify_ssl: bool = True, + requests_per_second: int = 2, +): + # Check if the URL is valid + if not validate_url(url): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + return SafeWebBaseLoader( + url, + verify_ssl=verify_ssl, + requests_per_second=requests_per_second, + continue_on_failure=True, + ) diff --git a/backend/open_webui/apps/webui/models/files.py b/backend/open_webui/apps/webui/models/files.py index 7fba74479..cf572ac78 100644 --- a/backend/open_webui/apps/webui/models/files.py +++ b/backend/open_webui/apps/webui/models/files.py @@ -97,6 +97,17 @@ class FilesTable: for file in db.query(File).filter_by(user_id=user_id).all() ] + def update_files_metadata_by_id(self, id: str, meta: dict) -> Optional[FileModel]: + with get_db() as db: + try: + file = db.query(File).filter_by(id=id).first() + file.meta = {**file.meta, **meta} + db.commit() + + return FileModel.model_validate(file) + except Exception: + return None + def delete_file_by_id(self, id: str) -> bool: with get_db() as db: try: diff --git a/backend/open_webui/apps/webui/routers/files.py b/backend/open_webui/apps/webui/routers/files.py index 1a326bcd8..f46a7992d 100644 --- a/backend/open_webui/apps/webui/routers/files.py +++ b/backend/open_webui/apps/webui/routers/files.py @@ -171,6 +171,19 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): ) +@router.get("/{id}/content/text") +async def get_file_text_content_by_id(id: str, user=Depends(get_verified_user)): + file = Files.get_file_by_id(id) + + if file and (file.user_id == user.id or user.role == "admin"): + return {"text": file.meta.get("content", {}).get("text", None)} + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + @router.get("/{id}/content/{file_name}", response_model=Optional[FileModel]) async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): file = Files.get_file_by_id(id) diff --git a/backend/open_webui/apps/webui/routers/memories.py b/backend/open_webui/apps/webui/routers/memories.py index d659833bc..ccf84a9d4 100644 --- a/backend/open_webui/apps/webui/routers/memories.py +++ b/backend/open_webui/apps/webui/routers/memories.py @@ -4,7 +4,7 @@ import logging from typing import Optional from open_webui.apps.webui.models.memories import Memories, MemoryModel -from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT +from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.utils import get_verified_user from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index f531a8728..2518599ca 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -921,7 +921,7 @@ CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") #################################### -# RAG +# Information Retrieval (RAG) #################################### # RAG Content Extraction diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index dadae0e04..40fac171f 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -16,37 +16,45 @@ from typing import Optional import aiohttp import requests - -from open_webui.apps.audio.main import app as audio_app -from open_webui.apps.images.main import app as images_app -from open_webui.apps.ollama.main import app as ollama_app from open_webui.apps.ollama.main import ( - GenerateChatCompletionForm, + app as ollama_app, + get_all_models as get_ollama_models, generate_chat_completion as generate_ollama_chat_completion, generate_openai_chat_completion as generate_ollama_openai_chat_completion, + GenerateChatCompletionForm, ) -from open_webui.apps.ollama.main import get_all_models as get_ollama_models -from open_webui.apps.openai.main import app as openai_app from open_webui.apps.openai.main import ( + app as openai_app, generate_chat_completion as generate_openai_chat_completion, + get_all_models as get_openai_models, ) -from open_webui.apps.openai.main import get_all_models as get_openai_models -from open_webui.apps.rag.main import app as rag_app -from open_webui.apps.rag.utils import get_rag_context, rag_template -from open_webui.apps.socket.main import app as socket_app, periodic_usage_pool_cleanup -from open_webui.apps.socket.main import get_event_call, get_event_emitter -from open_webui.apps.webui.internal.db import Session -from open_webui.apps.webui.main import app as webui_app + +from open_webui.apps.retrieval.main import app as retrieval_app +from open_webui.apps.retrieval.utils import get_rag_context, rag_template + +from open_webui.apps.socket.main import ( + app as socket_app, + periodic_usage_pool_cleanup, + get_event_call, + get_event_emitter, +) + from open_webui.apps.webui.main import ( + app as webui_app, generate_function_chat_completion, get_pipe_models, ) +from open_webui.apps.webui.internal.db import Session + from open_webui.apps.webui.models.auths import Auths from open_webui.apps.webui.models.functions import Functions from open_webui.apps.webui.models.models import Models from open_webui.apps.webui.models.users import UserModel, Users + from open_webui.apps.webui.utils import load_function_module_by_id +from open_webui.apps.audio.main import app as audio_app +from open_webui.apps.images.main import app as images_app from authlib.integrations.starlette_client import OAuth from authlib.oidc.core import UserInfo @@ -492,11 +500,11 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: contexts, citations = get_rag_context( files=files, messages=body["messages"], - embedding_function=rag_app.state.EMBEDDING_FUNCTION, - k=rag_app.state.config.TOP_K, - reranking_function=rag_app.state.sentence_transformer_rf, - r=rag_app.state.config.RELEVANCE_THRESHOLD, - hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH, + embedding_function=retrieval_app.state.EMBEDDING_FUNCTION, + k=retrieval_app.state.config.TOP_K, + reranking_function=retrieval_app.state.sentence_transformer_rf, + r=retrieval_app.state.config.RELEVANCE_THRESHOLD, + hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH, ) log.debug(f"rag_contexts: {contexts}, citations: {citations}") @@ -609,7 +617,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if prompt is None: raise Exception("No user message found") if ( - rag_app.state.config.RELEVANCE_THRESHOLD == 0 + retrieval_app.state.config.RELEVANCE_THRESHOLD == 0 and context_string.strip() == "" ): log.debug( @@ -621,14 +629,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if model["owned_by"] == "ollama": body["messages"] = prepend_to_first_user_message_content( rag_template( - rag_app.state.config.RAG_TEMPLATE, context_string, prompt + retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt ), body["messages"], ) else: body["messages"] = add_or_update_system_message( rag_template( - rag_app.state.config.RAG_TEMPLATE, context_string, prompt + retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt ), body["messages"], ) @@ -762,10 +770,22 @@ class PipelineMiddleware(BaseHTTPMiddleware): # Parse string to JSON data = json.loads(body_str) if body_str else {} - user = get_current_user( - request, - get_http_authorization_cred(request.headers["Authorization"]), - ) + try: + user = get_current_user( + request, + get_http_authorization_cred(request.headers["Authorization"]), + ) + except KeyError as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Not authenticated"}, + ) try: data = filter_pipeline(data, user) @@ -838,7 +858,7 @@ async def check_url(request: Request, call_next): async def update_embedding_function(request: Request, call_next): response = await call_next(request) if "/embedding/update" in request.url.path: - webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION + webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION return response @@ -866,11 +886,12 @@ app.mount("/openai", openai_app) app.mount("/images/api/v1", images_app) app.mount("/audio/api/v1", audio_app) -app.mount("/rag/api/v1", rag_app) +app.mount("/retrieval/api/v1", retrieval_app) app.mount("/api/v1", webui_app) -webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION + +webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION async def get_all_models(): @@ -2055,7 +2076,7 @@ async def get_app_config(request: Request): "enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM, **( { - "enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH, + "enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH, "enable_image_generation": images_app.state.config.ENABLED, "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, "enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING, @@ -2081,8 +2102,8 @@ async def get_app_config(request: Request): }, }, "file": { - "max_size": rag_app.state.config.FILE_MAX_SIZE, - "max_count": rag_app.state.config.FILE_MAX_COUNT, + "max_size": retrieval_app.state.config.FILE_MAX_SIZE, + "max_count": retrieval_app.state.config.FILE_MAX_COUNT, }, "permissions": {**webui_app.state.config.USER_PERMISSIONS}, } @@ -2154,7 +2175,8 @@ async def get_app_changelog(): @app.get("/api/version/updates") async def get_app_latest_release_version(): try: - async with aiohttp.ClientSession(trust_env=True) as session: + timeout = aiohttp.ClientTimeout(total=1) + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get( "https://api.github.com/repos/open-webui/open-webui/releases/latest" ) as response: @@ -2164,10 +2186,7 @@ async def get_app_latest_release_version(): return {"current": VERSION, "latest": latest_version[1:]} except aiohttp.ClientError: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED, - ) + return {"current": VERSION, "latest": VERSION} ############################ diff --git a/backend/requirements.txt b/backend/requirements.txt index 764e41d3d..a6933d20a 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -46,6 +46,8 @@ sentence-transformers==3.0.1 colbert-ai==0.2.21 einops==0.8.0 + +ftfy==6.2.3 pypdf==4.3.1 docx2txt==0.8 python-pptx==1.0.0 diff --git a/pyproject.toml b/pyproject.toml index d02281d52..1df284f80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,8 @@ dependencies = [ "colbert-ai==0.2.21", "einops==0.8.0", + + "ftfy==6.2.3", "pypdf==4.3.1", "docx2txt==0.8", "python-pptx==1.0.0", diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/retrieval/index.ts similarity index 88% rename from src/lib/apis/rag/index.ts rename to src/lib/apis/retrieval/index.ts index 3c0dba4b5..cf86e951c 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/retrieval/index.ts @@ -170,284 +170,6 @@ export const updateQuerySettings = async (token: string, settings: QuerySettings return res; }; -export const processDocToVectorDB = async (token: string, file_id: string) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/process/doc`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - file_id: file_id - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => { - const data = new FormData(); - data.append('file', file); - data.append('collection_name', collection_name); - - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/doc`, { - method: 'POST', - headers: { - Accept: 'application/json', - authorization: `Bearer ${token}` - }, - body: data - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const uploadWebToVectorDB = async (token: string, collection_name: string, url: string) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/web`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - url: url, - collection_name: collection_name - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const uploadYoutubeTranscriptionToVectorDB = async (token: string, url: string) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/youtube`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - url: url - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const queryDoc = async ( - token: string, - collection_name: string, - query: string, - k: number | null = null -) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/query/doc`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - collection_name: collection_name, - query: query, - k: k - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const queryCollection = async ( - token: string, - collection_names: string, - query: string, - k: number | null = null -) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - collection_names: collection_names, - query: query, - k: k - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const scanDocs = async (token: string) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/scan`, { - method: 'GET', - headers: { - Accept: 'application/json', - authorization: `Bearer ${token}` - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const resetUploadDir = async (token: string) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/reset/uploads`, { - method: 'POST', - headers: { - Accept: 'application/json', - authorization: `Bearer ${token}` - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const resetVectorDB = async (token: string) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/reset/db`, { - method: 'POST', - headers: { - Accept: 'application/json', - authorization: `Bearer ${token}` - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - export const getEmbeddingConfig = async (token: string) => { let error = null; @@ -578,14 +300,140 @@ export const updateRerankingConfig = async (token: string, payload: RerankingMod return res; }; -export const runWebSearch = async ( +export interface SearchDocument { + status: boolean; + collection_name: string; + filenames: string[]; +} + +export const processFile = async (token: string, file_id: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/process/file`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + file_id: file_id + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const processDocsDir = async (token: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/process/dir`, { + method: 'GET', + headers: { + Accept: 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const processYoutubeVideo = async (token: string, url: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/process/youtube`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + url: url + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const processWeb = async (token: string, collection_name: string, url: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/process/web`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + url: url, + collection_name: collection_name + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const processWebSearch = async ( token: string, query: string, collection_name?: string ): Promise => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/web/search`, { + const res = await fetch(`${RAG_API_BASE_URL}/process/web/search`, { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -613,8 +461,128 @@ export const runWebSearch = async ( return res; }; -export interface SearchDocument { - status: boolean; - collection_name: string; - filenames: string[]; -} +export const queryDoc = async ( + token: string, + collection_name: string, + query: string, + k: number | null = null +) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/query/doc`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + collection_name: collection_name, + query: query, + k: k + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const queryCollection = async ( + token: string, + collection_names: string, + query: string, + k: number | null = null +) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + collection_names: collection_names, + query: query, + k: k + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const resetUploadDir = async (token: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/reset/uploads`, { + method: 'POST', + headers: { + Accept: 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const resetVectorDB = async (token: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/reset/db`, { + method: 'POST', + headers: { + Accept: 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index e06edce9d..c10b60aa0 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -7,7 +7,7 @@ import { deleteAllFiles, deleteFileById } from '$lib/apis/files'; import { getQuerySettings, - scanDocs, + processDocsDir, updateQuerySettings, resetVectorDB, getEmbeddingConfig, @@ -17,7 +17,7 @@ resetUploadDir, getRAGConfig, updateRAGConfig - } from '$lib/apis/rag'; + } from '$lib/apis/retrieval'; import ResetUploadDirConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; import ResetVectorDBConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; @@ -63,7 +63,7 @@ const scanHandler = async () => { scanDirLoading = true; - const res = await scanDocs(localStorage.token); + const res = await processDocsDir(localStorage.token); scanDirLoading = false; if (res) { diff --git a/src/lib/components/admin/Settings/WebSearch.svelte b/src/lib/components/admin/Settings/WebSearch.svelte index 15eba096b..0a0c2eb16 100644 --- a/src/lib/components/admin/Settings/WebSearch.svelte +++ b/src/lib/components/admin/Settings/WebSearch.svelte @@ -1,5 +1,5 @@
@@ -37,17 +35,7 @@ class="h-14 {className} flex items-center space-x-3 {colorClassName} rounded-xl border border-gray-100 dark:border-gray-800 text-left" type="button" on:click={async () => { - if (clickHandler === null) { - if (url) { - if (type === 'file') { - window.open(`${url}/content`, '_blank').focus(); - } else { - window.open(`${url}`, '_blank').focus(); - } - } - } else { - clickHandler(); - } + dispatch('click'); }} >
diff --git a/src/lib/components/documents/AddDocModal.svelte b/src/lib/components/documents/AddDocModal.svelte index 10164be97..8c4d478f7 100644 --- a/src/lib/components/documents/AddDocModal.svelte +++ b/src/lib/components/documents/AddDocModal.svelte @@ -3,16 +3,13 @@ import dayjs from 'dayjs'; import { onMount, getContext } from 'svelte'; - import { createNewDoc, getDocs, tagDocByName, updateDocByName } from '$lib/apis/documents'; + import { getDocs } from '$lib/apis/documents'; import Modal from '../common/Modal.svelte'; import { documents } from '$lib/stores'; - import TagInput from '../common/Tags/TagInput.svelte'; - import Tags from '../common/Tags.svelte'; - import { addTagById } from '$lib/apis/chats'; - import { uploadDocToVectorDB } from '$lib/apis/rag'; - import { transformFileName } from '$lib/utils'; import { SUPPORTED_FILE_EXTENSIONS, SUPPORTED_FILE_TYPE } from '$lib/constants'; + import Tags from '../common/Tags.svelte'; + const i18n = getContext('i18n'); export let show = false; diff --git a/src/lib/components/workspace/Documents.svelte b/src/lib/components/workspace/Documents.svelte index 38f46f745..0fa50278c 100644 --- a/src/lib/components/workspace/Documents.svelte +++ b/src/lib/components/workspace/Documents.svelte @@ -8,7 +8,7 @@ import { createNewDoc, deleteDocByName, getDocs } from '$lib/apis/documents'; import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants'; - import { processDocToVectorDB, uploadDocToVectorDB } from '$lib/apis/rag'; + import { processFile } from '$lib/apis/retrieval'; import { blobToFile, transformFileName } from '$lib/utils'; import Checkbox from '$lib/components/common/Checkbox.svelte'; @@ -74,7 +74,7 @@ return null; }); - const res = await processDocToVectorDB(localStorage.token, uploadedFile.id).catch((error) => { + const res = await processFile(localStorage.token, uploadedFile.id).catch((error) => { toast.error(error); return null; }); diff --git a/src/lib/constants.ts b/src/lib/constants.ts index ad7b5c29e..8820c0d99 100644 --- a/src/lib/constants.ts +++ b/src/lib/constants.ts @@ -11,7 +11,7 @@ export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama`; export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai`; export const AUDIO_API_BASE_URL = `${WEBUI_BASE_URL}/audio/api/v1`; export const IMAGES_API_BASE_URL = `${WEBUI_BASE_URL}/images/api/v1`; -export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/rag/api/v1`; +export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/retrieval/api/v1`; export const WEBUI_VERSION = APP_VERSION; export const WEBUI_BUILD_HASH = APP_BUILD_HASH; diff --git a/src/lib/i18n/locales/ca-ES/translation.json b/src/lib/i18n/locales/ca-ES/translation.json index 25907c0f2..bd6179558 100644 --- a/src/lib/i18n/locales/ca-ES/translation.json +++ b/src/lib/i18n/locales/ca-ES/translation.json @@ -9,7 +9,7 @@ "{{user}}'s Chats": "Els xats de {{user}}", "{{webUIName}} Backend Required": "El Backend de {{webUIName}} 茅s necessari", "*Prompt node ID(s) are required for image generation": "*Els identificadors de nodes d'indicacions s贸n necessaris per a la generaci贸 d'imatges", - "A new version (v{{LATEST_VERSION}}) is now available.": "", + "A new version (v{{LATEST_VERSION}}) is now available.": "Hi ha una nova versi贸 disponible (v{{LATEST_VERSION}}).", "A task model is used when performing tasks such as generating titles for chats and web search queries": "Un model de tasca s'utilitza quan es realitzen tasques com ara generar t铆tols per a xats i consultes de cerca per a la web", "a user": "un usuari", "About": "Sobre", @@ -466,7 +466,7 @@ "Oops! Looks like the URL is invalid. Please double-check and try again.": "Ui! Sembla que l'URL no 茅s v脿lida. Si us plau, revisa-la i torna-ho a provar.", "Oops! There was an error in the previous response. Please try again or contact admin.": "Ui! Hi ha hagut un error en la resposta anterior. Torna a provar-ho o contacta amb un administrador", "Oops! You're using an unsupported method (frontend only). Please serve the WebUI from the backend.": "Ui! Est脿s utilitzant un m猫tode no suportat (nom茅s frontend). Si us plau, serveix la WebUI des del backend.", - "Open file": "", + "Open file": "Obrir arxiu", "Open new chat": "Obre un xat nou", "Open WebUI version (v{{OPEN_WEBUI_VERSION}}) is lower than required version (v{{REQUIRED_VERSION}})": "La versi贸 d'Open WebUI (v{{OPEN_WEBUI_VERSION}}) 茅s inferior a la versi贸 requerida (v{{REQUIRED_VERSION}})", "OpenAI": "OpenAI", @@ -478,7 +478,7 @@ "Other": "Altres", "Output format": "Format de sortida", "Overview": "Vista general", - "page": "", + "page": "p脿gina", "Password": "Contrasenya", "PDF document (.pdf)": "Document PDF (.pdf)", "PDF Extract Images (OCR)": "Extreu imatges del PDF (OCR)", @@ -497,7 +497,7 @@ "Plain text (.txt)": "Text pla (.txt)", "Playground": "Zona de jocs", "Please carefully review the following warnings:": "Si us plau, revisa els seg眉ents avisos amb cura:", - "Please select a reason": "", + "Please select a reason": "Si us plau, selecciona una ra贸", "Positive attitude": "Actitud positiva", "Previous 30 days": "30 dies anteriors", "Previous 7 days": "7 dies anteriors", @@ -704,7 +704,7 @@ "Unpin": "Alliberar", "Update": "Actualitzar", "Update and Copy Link": "Actualitzar i copiar l'enlla莽", - "Update for the latest features and improvements.": "", + "Update for the latest features and improvements.": "Actualitza per a les darreres caracter铆stiques i millores.", "Update password": "Actualitzar la contrasenya", "Updated at": "Actualitzat", "Upload": "Pujar", diff --git a/src/lib/utils/rag/index.ts b/src/lib/utils/rag/index.ts index ba1f29f88..6523bb7df 100644 --- a/src/lib/utils/rag/index.ts +++ b/src/lib/utils/rag/index.ts @@ -1,4 +1,4 @@ -import { getRAGTemplate } from '$lib/apis/rag'; +import { getRAGTemplate } from '$lib/apis/retrieval'; export const RAGTemplate = async (token: string, context: string, query: string) => { let template = await getRAGTemplate(token).catch(() => { diff --git a/src/routes/(app)/+layout.svelte b/src/routes/(app)/+layout.svelte index ad2f085cf..83a53dffd 100644 --- a/src/routes/(app)/+layout.svelte +++ b/src/routes/(app)/+layout.svelte @@ -206,10 +206,10 @@ const now = new Date(); if (now - dismissedUpdateToast > 24 * 60 * 60 * 1000) { - await checkForVersionUpdates(); + checkForVersionUpdates(); } } else { - await checkForVersionUpdates(); + checkForVersionUpdates(); } } await tick();