From 286cdc41ab3637044a266de53d4ffaef7313514c Mon Sep 17 00:00:00 2001 From: "Junjie.M" <118170653@qq.com> Date: Sat, 8 Feb 2025 16:19:41 +0800 Subject: [PATCH] reasoning model unified think tag is (#13392) Co-authored-by: crazywoola <427733928@qq.com> --- .../__base/large_language_model.py | 21 +++------------ .../model_providers/ollama/llm/llm.py | 1 - .../openai_api_compatible/llm/llm.py | 1 - .../model_providers/xinference/llm/llm.py | 1 - web/app/components/base/markdown.tsx | 27 ++++++++++++++++--- 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index b45b2ca025..f377f12919 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -30,11 +30,6 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) -HTML_THINKING_TAG = ( - '
' - " Thinking... " -) - class LargeLanguageModel(AIModel): """ @@ -408,7 +403,7 @@ if you are not sure about the structure. def _wrap_thinking_by_reasoning_content(self, delta: dict, is_reasoning: bool) -> tuple[str, bool]: """ If the reasoning response is from delta.get("reasoning_content"), we wrap - it with HTML details tag. + it with HTML think tag. :param delta: delta dictionary from LLM streaming response :param is_reasoning: is reasoning @@ -420,25 +415,15 @@ if you are not sure about the structure. if reasoning_content: if not is_reasoning: - content = HTML_THINKING_TAG + reasoning_content + content = "\n" + reasoning_content is_reasoning = True else: content = reasoning_content elif is_reasoning: - content = "
" + content + content = "\n" + content is_reasoning = False return content, is_reasoning - def _wrap_thinking_by_tag(self, content: str) -> str: - """ - if the reasoning response is a ... block from delta.get("content"), - we replace to . - - :param content: delta.get("content") - :return: processed_content - """ - return content.replace("", HTML_THINKING_TAG).replace("", "") - def _invoke_result_generator( self, model: str, diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index b640914b39..3ae728d4b3 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -367,7 +367,6 @@ class OllamaLargeLanguageModel(LargeLanguageModel): # transform assistant message to prompt message text = chunk_json["response"] - text = self._wrap_thinking_by_tag(text) assistant_prompt_message = AssistantPromptMessage(content=text) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 7f79da267f..cab552af25 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -528,7 +528,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): delta_content, is_reasoning_started = self._wrap_thinking_by_reasoning_content( delta, is_reasoning_started ) - delta_content = self._wrap_thinking_by_tag(delta_content) assistant_message_tool_calls = None diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index fcf452d627..7c77f7a0e1 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -654,7 +654,6 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if function_call: assistant_message_tool_calls += [self._extract_response_function_call(function_call)] - delta_content = self._wrap_thinking_by_tag(delta_content) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( content=delta_content or "", tool_calls=assistant_message_tool_calls diff --git a/web/app/components/base/markdown.tsx b/web/app/components/base/markdown.tsx index b26d9df30e..b98df4b7b2 100644 --- a/web/app/components/base/markdown.tsx +++ b/web/app/components/base/markdown.tsx @@ -10,6 +10,7 @@ import SyntaxHighlighter from 'react-syntax-highlighter' import { atelierHeathLight } from 'react-syntax-highlighter/dist/esm/styles/hljs' import { Component, memo, useMemo, useRef, useState } from 'react' import type { CodeComponent } from 'react-markdown/lib/ast-to-react' +import { flow } from 'lodash/fp' import cn from '@/utils/classnames' import CopyBtn from '@/app/components/base/copy-btn' import SVGBtn from '@/app/components/base/svg' @@ -58,9 +59,24 @@ const getCorrectCapitalizationLanguageName = (language: string) => { const preprocessLaTeX = (content: string) => { if (typeof content !== 'string') return content - return content.replace(/\\\[(.*?)\\\]/g, (_, equation) => `$$${equation}$$`) - .replace(/\\\((.*?)\\\)/g, (_, equation) => `$$${equation}$$`) - .replace(/(^|[^\\])\$(.+?)\$/g, (_, prefix, equation) => `${prefix}$${equation}$`) + + return flow([ + (str: string) => str.replace(/\\\[(.*?)\\\]/g, (_, equation) => `$$${equation}$$`), + (str: string) => str.replace(/\\\((.*?)\\\)/g, (_, equation) => `$$${equation}$$`), + (str: string) => str.replace(/(^|[^\\])\$(.+?)\$/g, (_, prefix, equation) => `${prefix}$${equation}$`), + ])(content) +} + +const preprocessThinkTag = (content: string) => { + if (!content.trim().startsWith('\n')) + return content + + return flow([ + (str: string) => str.replace('\n', '
Thinking
\n'), + (str: string) => str.includes('\n') + ? str.replace('\n', '\n
') + : `${str}\n`, + ])(content) } export function PreCode(props: { children: any }) { @@ -225,7 +241,10 @@ const Link = ({ node, ...props }: any) => { } export function Markdown(props: { content: string; className?: string }) { - const latexContent = preprocessLaTeX(props.content) + const latexContent = flow([ + preprocessThinkTag, + preprocessLaTeX, + ])(props.content) return (