diff --git a/CHANGELOG.md b/CHANGELOG.md index b508d76fc..eaf0f6213 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,26 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.5.3] - 2024-12-31 + +### Added + +- **💬 Channel Reactions with Built-In Emoji Picker**: Easily express yourself in channel threads and messages with reactions, featuring an intuitive built-in emoji picker for seamless selection. +- **🧵 Threads for Channels**: Organize discussions within channels by creating threads, improving clarity and fostering focused conversations. +- **🔄 Reset Button for SVG Pan/Zoom**: Added a handy reset button to SVG Pan/Zoom, allowing users to quickly return diagrams or visuals to their default state without hassle. +- **⚡ Realtime Chat Save Environment Variable**: Introduced the ENABLE_REALTIME_CHAT_SAVE environment variable. Choose between faster responses by disabling realtime chat saving or ensuring chunk-by-chunk data persistency for critical operations. +- **🌍 Translation Enhancements**: Updated and refined translations across multiple languages, providing a smoother experience for international users. +- **📚 Improved Documentation**: Expanded documentation on functions, including clearer guidance on function plugins and detailed instructions for migrating to v0.5. This ensures users can adapt and harness new updates more effectively. (https://docs.openwebui.com/features/plugin/) + +### Fixed + +- **🛠️ Ollama Parameters Respected**: Resolved an issue where input parameters for Ollama were being ignored, ensuring precise and consistent model behavior. +- **🔧 Function Plugin Outlet Hook Reliability**: Fixed a bug causing issues with 'event_emitter' and outlet hooks in filter function plugins, guaranteeing smoother operation within custom extensions. +- **🖋️ Weird Custom Status Descriptions**: Adjusted the formatting and functionality for custom user statuses, ensuring they display correctly and intuitively. +- **🔗 Restored API Functionality**: Fixed a critical issue where APIs were not operational for certain configurations, ensuring uninterrupted access. +- **⏳ Custom Pipe Function Completion**: Resolved an issue where chats using specific custom pipe function plugins weren’t finishing properly, restoring consistent chat workflows. +- **✅ General Stability Enhancements**: Implemented various under-the-hood improvements to boost overall reliability, ensuring smoother and more consistent performance across the WebUI. + ## [0.5.2] - 2024-12-26 ### Added diff --git a/README.md b/README.md index 4ac495a47..8c7e4cdf7 100644 --- a/README.md +++ b/README.md @@ -185,6 +185,14 @@ If you want to try out the latest bleeding-edge features and are okay with occas docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --add-host=host.docker.internal:host-gateway --restart always ghcr.io/open-webui/open-webui:dev ``` +### Offline Mode + +If you are running Open WebUI in an offline environment, you can set the `HF_HUB_OFFLINE` environment variable to `1` to prevent attempts to download models from the internet. + +```bash +export HF_HUB_OFFLINE=1 +``` + ## What's Next? 🌟 Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/). diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index a5f848b62..8aafa15b6 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -311,6 +311,11 @@ RESET_CONFIG_ON_START = ( os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" ) + +ENABLE_REALTIME_CHAT_SAVE = ( + os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "True").lower() == "true" +) + #################################### # REDIS #################################### @@ -392,3 +397,6 @@ else: #################################### OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true" + +if OFFLINE_MODE: + os.environ["HF_HUB_OFFLINE"] = "1" diff --git a/backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py b/backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py new file mode 100644 index 000000000..16fb0e85e --- /dev/null +++ b/backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py @@ -0,0 +1,70 @@ +"""Update message & channel tables + +Revision ID: 3781e22d8b01 +Revises: 7826ab40b532 +Create Date: 2024-12-30 03:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa + +revision = "3781e22d8b01" +down_revision = "7826ab40b532" +branch_labels = None +depends_on = None + + +def upgrade(): + # Add 'type' column to the 'channel' table + op.add_column( + "channel", + sa.Column( + "type", + sa.Text(), + nullable=True, + ), + ) + + # Add 'parent_id' column to the 'message' table for threads + op.add_column( + "message", + sa.Column("parent_id", sa.Text(), nullable=True), + ) + + op.create_table( + "message_reaction", + sa.Column( + "id", sa.Text(), nullable=False, primary_key=True, unique=True + ), # Unique reaction ID + sa.Column("user_id", sa.Text(), nullable=False), # User who reacted + sa.Column( + "message_id", sa.Text(), nullable=False + ), # Message that was reacted to + sa.Column( + "name", sa.Text(), nullable=False + ), # Reaction name (e.g. "thumbs_up") + sa.Column( + "created_at", sa.BigInteger(), nullable=True + ), # Timestamp of when the reaction was added + ) + + op.create_table( + "channel_member", + sa.Column( + "id", sa.Text(), nullable=False, primary_key=True, unique=True + ), # Record ID for the membership row + sa.Column("channel_id", sa.Text(), nullable=False), # Associated channel + sa.Column("user_id", sa.Text(), nullable=False), # Associated user + sa.Column( + "created_at", sa.BigInteger(), nullable=True + ), # Timestamp of when the user joined the channel + ) + + +def downgrade(): + # Revert 'type' column addition to the 'channel' table + op.drop_column("channel", "type") + op.drop_column("message", "parent_id") + op.drop_table("message_reaction") + op.drop_table("channel_member") diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index bc36146cf..92f238c3a 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -21,6 +21,7 @@ class Channel(Base): id = Column(Text, primary_key=True) user_id = Column(Text) + type = Column(Text, nullable=True) name = Column(Text) description = Column(Text, nullable=True) @@ -38,9 +39,11 @@ class ChannelModel(BaseModel): id: str user_id: str - description: Optional[str] = None + type: Optional[str] = None name: str + description: Optional[str] = None + data: Optional[dict] = None meta: Optional[dict] = None access_control: Optional[dict] = None @@ -64,12 +67,13 @@ class ChannelForm(BaseModel): class ChannelTable: def insert_new_channel( - self, form_data: ChannelForm, user_id: str + self, type: Optional[str], form_data: ChannelForm, user_id: str ) -> Optional[ChannelModel]: with get_db() as db: channel = ChannelModel( **{ **form_data.model_dump(), + "type": type, "name": form_data.name.lower(), "id": str(uuid.uuid4()), "user_id": user_id, diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 87bcdb8d3..18f802afe 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -212,6 +212,15 @@ class ChatTable: return chat.chat.get("history", {}).get("messages", {}) or {} + def get_message_by_id_and_message_id( + self, id: str, message_id: str + ) -> Optional[dict]: + chat = self.get_chat_by_id(id) + if chat is None: + return None + + return chat.chat.get("history", {}).get("messages", {}).get(message_id, {}) + def upsert_message_to_chat_by_id_and_message_id( self, id: str, message_id: str, message: dict ) -> Optional[ChatModel]: diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index 2a4322d0d..87da2f3fb 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -17,6 +17,25 @@ from sqlalchemy.sql import exists #################### +class MessageReaction(Base): + __tablename__ = "message_reaction" + id = Column(Text, primary_key=True) + user_id = Column(Text) + message_id = Column(Text) + name = Column(Text) + created_at = Column(BigInteger) + + +class MessageReactionModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: str + user_id: str + message_id: str + name: str + created_at: int # timestamp in epoch + + class Message(Base): __tablename__ = "message" id = Column(Text, primary_key=True) @@ -24,6 +43,8 @@ class Message(Base): user_id = Column(Text) channel_id = Column(Text, nullable=True) + parent_id = Column(Text, nullable=True) + content = Column(Text) data = Column(JSON, nullable=True) meta = Column(JSON, nullable=True) @@ -39,6 +60,8 @@ class MessageModel(BaseModel): user_id: str channel_id: Optional[str] = None + parent_id: Optional[str] = None + content: str data: Optional[dict] = None meta: Optional[dict] = None @@ -54,10 +77,23 @@ class MessageModel(BaseModel): class MessageForm(BaseModel): content: str + parent_id: Optional[str] = None data: Optional[dict] = None meta: Optional[dict] = None +class Reactions(BaseModel): + name: str + user_ids: list[str] + count: int + + +class MessageResponse(MessageModel): + latest_reply_at: Optional[int] + reply_count: int + reactions: list[Reactions] + + class MessageTable: def insert_new_message( self, form_data: MessageForm, channel_id: str, user_id: str @@ -71,6 +107,7 @@ class MessageTable: "id": id, "user_id": user_id, "channel_id": channel_id, + "parent_id": form_data.parent_id, "content": form_data.content, "data": form_data.data, "meta": form_data.meta, @@ -85,10 +122,40 @@ class MessageTable: db.refresh(result) return MessageModel.model_validate(result) if result else None - def get_message_by_id(self, id: str) -> Optional[MessageModel]: + def get_message_by_id(self, id: str) -> Optional[MessageResponse]: with get_db() as db: message = db.get(Message, id) - return MessageModel.model_validate(message) if message else None + if not message: + return None + + reactions = self.get_reactions_by_message_id(id) + replies = self.get_replies_by_message_id(id) + + return MessageResponse( + **{ + **MessageModel.model_validate(message).model_dump(), + "latest_reply_at": replies[0].created_at if replies else None, + "reply_count": len(replies), + "reactions": reactions, + } + ) + + def get_replies_by_message_id(self, id: str) -> list[MessageModel]: + with get_db() as db: + all_messages = ( + db.query(Message) + .filter_by(parent_id=id) + .order_by(Message.created_at.desc()) + .all() + ) + return [MessageModel.model_validate(message) for message in all_messages] + + def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: + with get_db() as db: + return [ + message.user_id + for message in db.query(Message).filter_by(parent_id=id).all() + ] def get_messages_by_channel_id( self, channel_id: str, skip: int = 0, limit: int = 50 @@ -96,7 +163,7 @@ class MessageTable: with get_db() as db: all_messages = ( db.query(Message) - .filter_by(channel_id=channel_id) + .filter_by(channel_id=channel_id, parent_id=None) .order_by(Message.created_at.desc()) .offset(skip) .limit(limit) @@ -104,19 +171,27 @@ class MessageTable: ) return [MessageModel.model_validate(message) for message in all_messages] - def get_messages_by_user_id( - self, user_id: str, skip: int = 0, limit: int = 50 + def get_messages_by_parent_id( + self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50 ) -> list[MessageModel]: with get_db() as db: + message = db.get(Message, parent_id) + + if not message: + return [] + all_messages = ( db.query(Message) - .filter_by(user_id=user_id) + .filter_by(channel_id=channel_id, parent_id=parent_id) .order_by(Message.created_at.desc()) .offset(skip) .limit(limit) .all() ) - return [MessageModel.model_validate(message) for message in all_messages] + + return [ + MessageModel.model_validate(message) for message in all_messages + ] + [MessageModel.model_validate(message)] def update_message_by_id( self, id: str, form_data: MessageForm @@ -131,9 +206,70 @@ class MessageTable: db.refresh(message) return MessageModel.model_validate(message) if message else None + def add_reaction_to_message( + self, id: str, user_id: str, name: str + ) -> Optional[MessageReactionModel]: + with get_db() as db: + reaction_id = str(uuid.uuid4()) + reaction = MessageReactionModel( + id=reaction_id, + user_id=user_id, + message_id=id, + name=name, + created_at=int(time.time_ns()), + ) + result = MessageReaction(**reaction.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + return MessageReactionModel.model_validate(result) if result else None + + def get_reactions_by_message_id(self, id: str) -> list[Reactions]: + with get_db() as db: + all_reactions = db.query(MessageReaction).filter_by(message_id=id).all() + + reactions = {} + for reaction in all_reactions: + if reaction.name not in reactions: + reactions[reaction.name] = { + "name": reaction.name, + "user_ids": [], + "count": 0, + } + reactions[reaction.name]["user_ids"].append(reaction.user_id) + reactions[reaction.name]["count"] += 1 + + return [Reactions(**reaction) for reaction in reactions.values()] + + def remove_reaction_by_id_and_user_id_and_name( + self, id: str, user_id: str, name: str + ) -> bool: + with get_db() as db: + db.query(MessageReaction).filter_by( + message_id=id, user_id=user_id, name=name + ).delete() + db.commit() + return True + + def delete_reactions_by_id(self, id: str) -> bool: + with get_db() as db: + db.query(MessageReaction).filter_by(message_id=id).delete() + db.commit() + return True + + def delete_replies_by_id(self, id: str) -> bool: + with get_db() as db: + db.query(Message).filter_by(parent_id=id).delete() + db.commit() + return True + def delete_message_by_id(self, id: str) -> bool: with get_db() as db: db.query(Message).filter_by(id=id).delete() + + # Delete all reactions to this message + db.query(MessageReaction).filter_by(message_id=id).delete() + db.commit() return True diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 17f1438da..c95367e6c 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -14,7 +14,7 @@ from langchain_core.documents import Document from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message -from open_webui.env import SRC_LOG_LEVELS +from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -375,6 +375,9 @@ def get_model_path(model: str, update_model: bool = False): local_files_only = not update_model + if OFFLINE_MODE: + local_files_only = True + snapshot_kwargs = { "cache_dir": cache_dir, "local_files_only": local_files_only, diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 292f33e78..da6a8d01f 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -11,7 +11,12 @@ from open_webui.socket.main import sio, get_user_ids_from_room from open_webui.models.users import Users, UserNameResponse from open_webui.models.channels import Channels, ChannelModel, ChannelForm -from open_webui.models.messages import Messages, MessageModel, MessageForm +from open_webui.models.messages import ( + Messages, + MessageModel, + MessageResponse, + MessageForm, +) from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT @@ -49,7 +54,7 @@ async def get_channels(user=Depends(get_verified_user)): @router.post("/create", response_model=Optional[ChannelModel]) async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)): try: - channel = Channels.insert_new_channel(form_data, user.id) + channel = Channels.insert_new_channel(None, form_data, user.id) return ChannelModel(**channel.model_dump()) except Exception as e: log.exception(e) @@ -134,11 +139,11 @@ async def delete_channel_by_id(id: str, user=Depends(get_admin_user)): ############################ -class MessageUserModel(MessageModel): +class MessageUserResponse(MessageResponse): user: UserNameResponse -@router.get("/{id}/messages", response_model=list[MessageUserModel]) +@router.get("/{id}/messages", response_model=list[MessageUserResponse]) async def get_channel_messages( id: str, skip: int = 0, limit: int = 50, user=Depends(get_verified_user) ): @@ -164,10 +169,16 @@ async def get_channel_messages( user = Users.get_user_by_id(message.user_id) users[message.user_id] = user + replies = Messages.get_replies_by_message_id(message.id) + latest_reply_at = replies[0].created_at if replies else None + messages.append( - MessageUserModel( + MessageUserResponse( **{ **message.model_dump(), + "reply_count": len(replies), + "latest_reply_at": latest_reply_at, + "reactions": Messages.get_reactions_by_message_id(message.id), "user": UserNameResponse(**users[message.user_id].model_dump()), } ) @@ -236,10 +247,17 @@ async def post_new_message( "message_id": message.id, "data": { "type": "message", - "data": { - **message.model_dump(), - "user": UserNameResponse(**user.model_dump()).model_dump(), - }, + "data": MessageUserResponse( + **{ + **message.model_dump(), + "reply_count": 0, + "latest_reply_at": None, + "reactions": Messages.get_reactions_by_message_id( + message.id + ), + "user": UserNameResponse(**user.model_dump()), + } + ).model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), @@ -251,6 +269,35 @@ async def post_new_message( to=f"channel:{channel.id}", ) + if message.parent_id: + # If this message is a reply, emit to the parent message as well + parent_message = Messages.get_message_by_id(message.parent_id) + + if parent_message: + await sio.emit( + "channel-events", + { + "channel_id": channel.id, + "message_id": parent_message.id, + "data": { + "type": "message:reply", + "data": MessageUserResponse( + **{ + **parent_message.model_dump(), + "user": UserNameResponse( + **Users.get_user_by_id( + parent_message.user_id + ).model_dump() + ), + } + ).model_dump(), + }, + "user": UserNameResponse(**user.model_dump()).model_dump(), + "channel": channel.model_dump(), + }, + to=f"channel:{channel.id}", + ) + active_user_ids = get_user_ids_from_room(f"channel:{channel.id}") background_tasks.add_task( @@ -269,6 +316,101 @@ async def post_new_message( ) +############################ +# GetChannelMessage +############################ + + +@router.get("/{id}/messages/{message_id}", response_model=Optional[MessageUserResponse]) +async def get_channel_message( + id: str, message_id: str, user=Depends(get_verified_user) +): + channel = Channels.get_channel_by_id(id) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if user.role != "admin" and not has_access( + user.id, type="read", access_control=channel.access_control + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + + message = Messages.get_message_by_id(message_id) + if not message: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if message.channel_id != id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + return MessageUserResponse( + **{ + **message.model_dump(), + "user": UserNameResponse( + **Users.get_user_by_id(message.user_id).model_dump() + ), + } + ) + + +############################ +# GetChannelThreadMessages +############################ + + +@router.get( + "/{id}/messages/{message_id}/thread", response_model=list[MessageUserResponse] +) +async def get_channel_thread_messages( + id: str, + message_id: str, + skip: int = 0, + limit: int = 50, + user=Depends(get_verified_user), +): + channel = Channels.get_channel_by_id(id) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if user.role != "admin" and not has_access( + user.id, type="read", access_control=channel.access_control + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + + message_list = Messages.get_messages_by_parent_id(id, message_id, skip, limit) + users = {} + + messages = [] + for message in message_list: + if message.user_id not in users: + user = Users.get_user_by_id(message.user_id) + users[message.user_id] = user + + messages.append( + MessageUserResponse( + **{ + **message.model_dump(), + "reply_count": 0, + "latest_reply_at": None, + "reactions": Messages.get_reactions_by_message_id(message.id), + "user": UserNameResponse(**users[message.user_id].model_dump()), + } + ) + ) + + return messages + + ############################ # UpdateMessageById ############################ @@ -306,6 +448,8 @@ async def update_message_by_id( try: message = Messages.update_message_by_id(message_id, form_data) + message = Messages.get_message_by_id(message_id) + if message: await sio.emit( "channel-events", @@ -314,10 +458,14 @@ async def update_message_by_id( "message_id": message.id, "data": { "type": "message:update", - "data": { - **message.model_dump(), - "user": UserNameResponse(**user.model_dump()).model_dump(), - }, + "data": MessageUserResponse( + **{ + **message.model_dump(), + "user": UserNameResponse( + **user.model_dump() + ).model_dump(), + } + ).model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), @@ -333,6 +481,145 @@ async def update_message_by_id( ) +############################ +# AddReactionToMessage +############################ + + +class ReactionForm(BaseModel): + name: str + + +@router.post("/{id}/messages/{message_id}/reactions/add", response_model=bool) +async def add_reaction_to_message( + id: str, message_id: str, form_data: ReactionForm, user=Depends(get_verified_user) +): + channel = Channels.get_channel_by_id(id) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if user.role != "admin" and not has_access( + user.id, type="read", access_control=channel.access_control + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + + message = Messages.get_message_by_id(message_id) + if not message: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if message.channel_id != id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + try: + Messages.add_reaction_to_message(message_id, user.id, form_data.name) + message = Messages.get_message_by_id(message_id) + + await sio.emit( + "channel-events", + { + "channel_id": channel.id, + "message_id": message.id, + "data": { + "type": "message:reaction:add", + "data": { + **message.model_dump(), + "user": UserNameResponse( + **Users.get_user_by_id(message.user_id).model_dump() + ).model_dump(), + "name": form_data.name, + }, + }, + "user": UserNameResponse(**user.model_dump()).model_dump(), + "channel": channel.model_dump(), + }, + to=f"channel:{channel.id}", + ) + + return True + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + +############################ +# RemoveReactionById +############################ + + +@router.post("/{id}/messages/{message_id}/reactions/remove", response_model=bool) +async def remove_reaction_by_id_and_user_id_and_name( + id: str, message_id: str, form_data: ReactionForm, user=Depends(get_verified_user) +): + channel = Channels.get_channel_by_id(id) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if user.role != "admin" and not has_access( + user.id, type="read", access_control=channel.access_control + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + + message = Messages.get_message_by_id(message_id) + if not message: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if message.channel_id != id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + try: + Messages.remove_reaction_by_id_and_user_id_and_name( + message_id, user.id, form_data.name + ) + + message = Messages.get_message_by_id(message_id) + + await sio.emit( + "channel-events", + { + "channel_id": channel.id, + "message_id": message.id, + "data": { + "type": "message:reaction:remove", + "data": { + **message.model_dump(), + "user": UserNameResponse( + **Users.get_user_by_id(message.user_id).model_dump() + ).model_dump(), + "name": form_data.name, + }, + }, + "user": UserNameResponse(**user.model_dump()).model_dump(), + "channel": channel.model_dump(), + }, + to=f"channel:{channel.id}", + ) + + return True + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # DeleteMessageById ############################ @@ -386,6 +673,35 @@ async def delete_message_by_id( to=f"channel:{channel.id}", ) + if message.parent_id: + # If this message is a reply, emit to the parent message as well + parent_message = Messages.get_message_by_id(message.parent_id) + + if parent_message: + await sio.emit( + "channel-events", + { + "channel_id": channel.id, + "message_id": parent_message.id, + "data": { + "type": "message:reply", + "data": MessageUserResponse( + **{ + **parent_message.model_dump(), + "user": UserNameResponse( + **Users.get_user_by_id( + parent_message.user_id + ).model_dump() + ), + } + ).model_dump(), + }, + "user": UserNameResponse(**user.model_dump()).model_dump(), + "channel": channel.model_dump(), + }, + to=f"channel:{channel.id}", + ) + return True except Exception as e: log.exception(e) diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 3b1ba2945..9e36d98b7 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -348,7 +348,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)): result = Files.delete_file_by_id(id) if result: try: - Storage.delete_file(file.filename) + Storage.delete_file(file.path) except Exception as e: log.exception(e) log.error(f"Error deleting files") diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 04ebcf507..ad67cc31f 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -419,13 +419,6 @@ def remove_file_from_knowledge_by_id( collection_name=knowledge.id, filter={"file_id": form_data.file_id} ) - result = VECTOR_DB_CLIENT.query( - collection_name=knowledge.id, - filter={"file_id": form_data.file_id}, - ) - - Files.delete_file_by_id(form_data.file_id) - if knowledge: data = knowledge.data or {} file_ids = data.get("file_ids", []) @@ -527,6 +520,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse]) def add_files_to_knowledge_batch( + request: Request, id: str, form_data: list[KnowledgeFileIdForm], user=Depends(get_verified_user), @@ -562,7 +556,9 @@ def add_files_to_knowledge_batch( # Process files try: result = process_files_batch( - BatchProcessFilesForm(files=files, collection_name=id) + request=request, + form_data=BatchProcessFilesForm(files=files, collection_name=id), + user=user, ) except Exception as e: log.error( diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index d6ff463a9..c791bde84 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -1458,6 +1458,7 @@ class BatchProcessFilesResponse(BaseModel): @router.post("/process/files/batch") def process_files_batch( + request: Request, form_data: BatchProcessFilesForm, user=Depends(get_verified_user), ) -> BatchProcessFilesResponse: @@ -1504,7 +1505,10 @@ def process_files_batch( if all_docs: try: save_docs_to_vector_db( - docs=all_docs, collection_name=collection_name, add=True + request=request, + docs=all_docs, + collection_name=collection_name, + add=True, ) # Update all files with collection name diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index c0bb3c662..2d12f5803 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -237,6 +237,7 @@ async def channel_events(sid, data): "channel-events", { "channel_id": data["channel_id"], + "message_id": data.get("message_id", None), "data": event_data, "user": UserNameResponse(**SESSION_POOL[sid]).model_dump(), }, @@ -292,6 +293,34 @@ def get_event_emitter(request_info): event_data.get("data", {}), ) + if "type" in event_data and event_data["type"] == "message": + message = Chats.get_message_by_id_and_message_id( + request_info["chat_id"], + request_info["message_id"], + ) + + content = message.get("content", "") + content += event_data.get("data", {}).get("content", "") + + Chats.upsert_message_to_chat_by_id_and_message_id( + request_info["chat_id"], + request_info["message_id"], + { + "content": content, + }, + ) + + if "type" in event_data and event_data["type"] == "replace": + content = event_data.get("data", {}).get("content", "") + + Chats.upsert_message_to_chat_by_id_and_message_id( + request_info["chat_id"], + request_info["message_id"], + { + "content": content, + }, + ) + return __event_emitter__ diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py index 76e4fc48f..ae3347682 100644 --- a/backend/open_webui/storage/provider.py +++ b/backend/open_webui/storage/provider.py @@ -147,8 +147,10 @@ class StorageProvider: return self._get_file_from_s3(file_path) return self._get_file_from_local(file_path) - def delete_file(self, filename: str) -> None: + def delete_file(self, file_path: str) -> None: """Deletes a file either from S3 or the local file system.""" + filename = file_path.split("/")[-1] + if self.storage_provider == "s3": self._delete_from_s3(filename) diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 1cfae4d54..7fce76419 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -157,6 +157,9 @@ async def generate_chat_completion( ) +chat_completion = generate_chat_completion + + async def chat_completed(request: Request, form_data: dict, user: Any): if not request.app.state.MODELS: await get_all_models(request) @@ -179,6 +182,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any): "chat_id": data["chat_id"], "message_id": data["id"], "session_id": data["session_id"], + "user_id": user.id, } ) @@ -187,6 +191,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any): "chat_id": data["chat_id"], "message_id": data["id"], "session_id": data["session_id"], + "user_id": user.id, } ) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 7d79e1d0b..0e32bf626 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -23,7 +23,7 @@ from open_webui.models.users import Users from open_webui.socket.main import ( get_event_call, get_event_emitter, - get_user_id_from_session_pool, + get_active_status_by_user_id, ) from open_webui.routers.tasks import ( generate_queries, @@ -65,6 +65,7 @@ from open_webui.env import ( SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL, + ENABLE_REALTIME_CHAT_SAVE, ) from open_webui.constants import TASKS @@ -560,7 +561,6 @@ def apply_params_to_form_data(form_data, model): if "frequency_penalty" in params: form_data["frequency_penalty"] = params["frequency_penalty"] - return form_data @@ -750,7 +750,7 @@ async def process_chat_response( ): async def background_tasks_handler(): message_map = Chats.get_messages_by_chat_id(metadata["chat_id"]) - message = message_map.get(metadata["message_id"]) + message = message_map.get(metadata["message_id"]) if message_map else None if message: messages = get_message_list(message_map, message.get("id")) @@ -896,7 +896,7 @@ async def process_chat_response( ) # Send a webhook notification if the user is not active - if get_user_id_from_session_pool(metadata["session_id"]) is None: + if get_active_status_by_user_id(user.id) is None: webhook_url = Users.get_user_webhook_url_by_id(user.id) if webhook_url: post_webhook( @@ -928,6 +928,11 @@ async def process_chat_response( # Handle as a background task async def post_response_handler(response, events): + message = Chats.get_message_by_id_and_message_id( + metadata["chat_id"], metadata["message_id"] + ) + content = message.get("content", "") if message else "" + try: for event in events: await event_emitter( @@ -946,9 +951,6 @@ async def process_chat_response( }, ) - assistant_message = get_last_assistant_message(form_data["messages"]) - content = assistant_message if assistant_message else "" - async for line in response.body_iterator: line = line.decode("utf-8") if isinstance(line, bytes) else line data = line @@ -977,7 +979,6 @@ async def process_chat_response( ) else: - value = ( data.get("choices", [])[0] .get("delta", {}) @@ -987,55 +988,85 @@ async def process_chat_response( if value: content = f"{content}{value}" - # Save message in the database - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { + if ENABLE_REALTIME_CHAT_SAVE: + # Save message in the database + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "content": content, + }, + ) + else: + data = { "content": content, - }, - ) + } + + await event_emitter( + { + "type": "chat:completion", + "data": data, + } + ) except Exception as e: done = "data: [DONE]" in line - title = Chats.get_chat_title_by_id(metadata["chat_id"]) if done: - data = {"done": True, "content": content, "title": title} - - # Send a webhook notification if the user is not active - if ( - get_user_id_from_session_pool(metadata["session_id"]) - is None - ): - webhook_url = Users.get_user_webhook_url_by_id(user.id) - if webhook_url: - post_webhook( - webhook_url, - f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", - { - "action": "chat", - "message": content, - "title": title, - "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", - }, - ) - + pass else: continue - await event_emitter( + title = Chats.get_chat_title_by_id(metadata["chat_id"]) + data = {"done": True, "content": content, "title": title} + + if not ENABLE_REALTIME_CHAT_SAVE: + # Save message in the database + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], { - "type": "chat:completion", - "data": data, - } + "content": content, + }, ) + # Send a webhook notification if the user is not active + if get_active_status_by_user_id(user.id) is None: + webhook_url = Users.get_user_webhook_url_by_id(user.id) + if webhook_url: + post_webhook( + webhook_url, + f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", + { + "action": "chat", + "message": content, + "title": title, + "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", + }, + ) + + await event_emitter( + { + "type": "chat:completion", + "data": data, + } + ) + await background_tasks_handler() except asyncio.CancelledError: print("Task was cancelled!") await event_emitter({"type": "task-cancelled"}) + if not ENABLE_REALTIME_CHAT_SAVE: + # Save message in the database + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "content": content, + }, + ) + if response.background is not None: await response.background() diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 0125a799c..fdc62f79f 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -160,6 +160,10 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict: # If there are advanced parameters in the payload, format them in Ollama's options field ollama_options = {} + if openai_payload.get("options"): + ollama_payload["options"] = openai_payload["options"] + ollama_options = openai_payload["options"] + # Handle parameters which map directly for param in ["temperature", "top_p", "seed"]: if param in openai_payload: diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index d429db8aa..d6f7b0ac6 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -29,7 +29,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) ( ( data.get("eval_count", 0) - / ((data.get("eval_duration", 0) / 1_000_000)) + / ((data.get("eval_duration", 0) / 10_000_000)) ) * 100 ), @@ -43,7 +43,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) ( ( data.get("prompt_eval_count", 0) - / ((data.get("prompt_eval_duration", 0) / 1_000_000)) + / ((data.get("prompt_eval_duration", 0) / 10_000_000)) ) * 100 ), diff --git a/package-lock.json b/package-lock.json index e953bc91d..1afeaecbe 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.5.2", + "version": "0.5.3", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.5.2", + "version": "0.5.3", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", @@ -16,6 +16,7 @@ "@mediapipe/tasks-vision": "^0.10.17", "@pyscript/core": "^0.4.32", "@sveltejs/adapter-node": "^2.0.0", + "@sveltejs/svelte-virtual-list": "^3.0.1", "@tiptap/core": "^2.10.0", "@tiptap/extension-code-block-lowlight": "^2.10.0", "@tiptap/extension-highlight": "^2.10.0", @@ -2291,6 +2292,12 @@ "vite": "^5.0.3 || ^6.0.0" } }, + "node_modules/@sveltejs/svelte-virtual-list": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/@sveltejs/svelte-virtual-list/-/svelte-virtual-list-3.0.1.tgz", + "integrity": "sha512-aF9TptS7NKKS7/TqpsxQBSDJ9Q0XBYzBehCeIC5DzdMEgrJZpIYao9LRLnyyo6SVodpapm2B7FE/Lj+FSA5/SQ==", + "license": "LIL" + }, "node_modules/@sveltejs/vite-plugin-svelte": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/@sveltejs/vite-plugin-svelte/-/vite-plugin-svelte-3.1.1.tgz", diff --git a/package.json b/package.json index ed988fb60..71fd418a7 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.5.2", + "version": "0.5.3", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", @@ -50,7 +50,6 @@ "type": "module", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", - "codemirror-lang-hcl": "^0.0.0-beta.2", "@codemirror/lang-python": "^6.1.6", "@codemirror/language-data": "^6.5.1", "@codemirror/theme-one-dark": "^6.1.2", @@ -58,6 +57,7 @@ "@mediapipe/tasks-vision": "^0.10.17", "@pyscript/core": "^0.4.32", "@sveltejs/adapter-node": "^2.0.0", + "@sveltejs/svelte-virtual-list": "^3.0.1", "@tiptap/core": "^2.10.0", "@tiptap/extension-code-block-lowlight": "^2.10.0", "@tiptap/extension-highlight": "^2.10.0", @@ -69,6 +69,7 @@ "async": "^3.2.5", "bits-ui": "^0.19.7", "codemirror": "^6.0.1", + "codemirror-lang-hcl": "^0.0.0-beta.2", "crc-32": "^1.2.2", "dayjs": "^1.11.10", "dompurify": "^3.1.6", diff --git a/src/lib/apis/channels/index.ts b/src/lib/apis/channels/index.ts index 607a8f900..f16b43505 100644 --- a/src/lib/apis/channels/index.ts +++ b/src/lib/apis/channels/index.ts @@ -1,4 +1,5 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; +import { t } from 'i18next'; type ChannelForm = { name: string; @@ -207,7 +208,48 @@ export const getChannelMessages = async ( return res; }; +export const getChannelThreadMessages = async ( + token: string = '', + channel_id: string, + message_id: string, + skip: number = 0, + limit: number = 50 +) => { + let error = null; + + const res = await fetch( + `${WEBUI_API_BASE_URL}/channels/${channel_id}/messages/${message_id}/thread?skip=${skip}&limit=${limit}`, + { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + type MessageForm = { + parent_id?: string; content: string; data?: object; meta?: object; @@ -285,6 +327,86 @@ export const updateMessage = async ( return res; }; +export const addReaction = async ( + token: string = '', + channel_id: string, + message_id: string, + name: string +) => { + let error = null; + + const res = await fetch( + `${WEBUI_API_BASE_URL}/channels/${channel_id}/messages/${message_id}/reactions/add`, + { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ name }) + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const removeReaction = async ( + token: string = '', + channel_id: string, + message_id: string, + name: string +) => { + let error = null; + + const res = await fetch( + `${WEBUI_API_BASE_URL}/channels/${channel_id}/messages/${message_id}/reactions/remove`, + { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ name }) + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const deleteMessage = async (token: string = '', channel_id: string, message_id: string) => { let error = null; diff --git a/src/lib/components/channel/Channel.svelte b/src/lib/components/channel/Channel.svelte index a27c88540..b205afcb3 100644 --- a/src/lib/components/channel/Channel.svelte +++ b/src/lib/components/channel/Channel.svelte @@ -1,14 +1,19 @@ @@ -166,40 +204,99 @@ : ''} w-full max-w-full flex flex-col" id="channel-container" > - + + + -
- {#if channel} -
{ - scrollEnd = Math.abs(messagesContainerElement.scrollTop) <= 50; - }} +
+ {#if channel} +
{ + scrollEnd = Math.abs(messagesContainerElement.scrollTop) <= 50; + }} + > + {#key id} + { + threadId = id; + }} + onLoad={async () => { + const newMessages = await getChannelMessages( + localStorage.token, + id, + messages.length + ); + + messages = [...messages, ...newMessages]; + + if (newMessages.length < 50) { + top = true; + return; + } + }} + /> + {/key} +
+ {/if} +
+ +
+ +
+ + + {#if !largeScreen} + {#if threadId !== null} + { + threadId = null; + }} + > +
+ { + threadId = null; + }} + /> +
+
+ {/if} + {:else if threadId !== null} + - {#key id} - + +
+ + + +
+ { - const newMessages = await getChannelMessages(localStorage.token, id, messages.length); - - messages = [...messages, ...newMessages]; - - if (newMessages.length < 50) { - top = true; - return; - } + onClose={() => { + threadId = null; }} /> - {/key} -
+
+
{/if} - - -
- -
+
diff --git a/src/lib/components/channel/MessageInput.svelte b/src/lib/components/channel/MessageInput.svelte index c4290b24e..c0605da8c 100644 --- a/src/lib/components/channel/MessageInput.svelte +++ b/src/lib/components/channel/MessageInput.svelte @@ -23,6 +23,8 @@ export let placeholder = $i18n.t('Send a Message'); export let transparentBackground = false; + export let id = null; + let draggedOver = false; let recording = false; @@ -37,7 +39,7 @@ export let onSubmit: Function; export let onChange: Function; export let scrollEnd = true; - export let scrollToBottom: Function; + export let scrollToBottom: Function = () => {}; const screenCaptureHandler = async () => { try { @@ -257,7 +259,7 @@ await tick(); - const chatInputElement = document.getElementById('chat-input'); + const chatInputElement = document.getElementById(`chat-input-${id}`); chatInputElement?.focus(); }; @@ -267,7 +269,7 @@ onMount(async () => { window.setTimeout(() => { - const chatInput = document.getElementById('chat-input'); + const chatInput = document.getElementById(`chat-input-${id}`); chatInput?.focus(); }, 0); @@ -313,7 +315,7 @@ filesInputElement.value = ''; }} /> -
+
-
+
{#if typingUsers.length > 0}
- + {typingUsers.map((user) => user.name).join(', ')} {$i18n.t('is typing...')} @@ -373,7 +375,7 @@ recording = false; await tick(); - document.getElementById('chat-input')?.focus(); + document.getElementById(`chat-input-${id}`)?.focus(); }} on:confirm={async (e) => { const { text, filename } = e.detail; @@ -381,7 +383,7 @@ recording = false; await tick(); - document.getElementById('chat-input')?.focus(); + document.getElementById(`chat-input-${id}`)?.focus(); }} /> {:else} @@ -392,7 +394,7 @@ }} >
{#if files.length > 0} @@ -478,61 +480,21 @@
- {#if $settings?.richTextInput ?? true} -
- 0 || - navigator.msMaxTouchPoints > 0 - )} - {placeholder} - largeTextAsFile={$settings?.largeTextAsFile ?? false} - on:keydown={async (e) => { - e = e.detail.event; - const isCtrlPressed = e.ctrlKey || e.metaKey; // metaKey is for Cmd key on Mac - if ( - !$mobile || - !( - 'ontouchstart' in window || - navigator.maxTouchPoints > 0 || - navigator.msMaxTouchPoints > 0 - ) - ) { - // Prevent Enter key from creating a new line - // Uses keyCode '13' for Enter key for chinese/japanese keyboards - if (e.keyCode === 13 && !e.shiftKey) { - e.preventDefault(); - } - - // Submit the content when Enter key is pressed - if (content !== '' && e.keyCode === 13 && !e.shiftKey) { - submitHandler(); - } - } - - if (e.key === 'Escape') { - console.log('Escape'); - } - }} - on:paste={async (e) => { - e = e.detail.event; - console.log(e); - }} - /> -
- {:else} -