fix: better memory usage from 800+ to 500+ (#11796)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>
This commit is contained in:
yihong 2024-12-20 14:51:43 +08:00 committed by GitHub
parent 52201d95b1
commit 7b03a0316d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 56 additions and 26 deletions

View File

@ -4,11 +4,10 @@ import json
import logging import logging
import time import time
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import TYPE_CHECKING, Optional, Union, cast
import google.auth.transport.requests import google.auth.transport.requests
import requests import requests
import vertexai.generative_models as glm
from anthropic import AnthropicVertex, Stream from anthropic import AnthropicVertex, Stream
from anthropic.types import ( from anthropic.types import (
ContentBlockDeltaEvent, ContentBlockDeltaEvent,
@ -19,8 +18,6 @@ from anthropic.types import (
MessageStreamEvent, MessageStreamEvent,
) )
from google.api_core import exceptions from google.api_core import exceptions
from google.cloud import aiplatform
from google.oauth2 import service_account
from PIL import Image from PIL import Image
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@ -47,6 +44,9 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
if TYPE_CHECKING:
import vertexai.generative_models as glm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -102,6 +102,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
:param stream: is stream response :param stream: is stream response
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
from google.oauth2 import service_account
# use Anthropic official SDK references # use Anthropic official SDK references
# - https://github.com/anthropics/anthropic-sdk-python # - https://github.com/anthropics/anthropic-sdk-python
service_account_key = credentials.get("vertex_service_account_key", "") service_account_key = credentials.get("vertex_service_account_key", "")
@ -406,13 +408,15 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return text.rstrip() return text.rstrip()
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> "glm.Tool":
""" """
Convert tool messages to glm tools Convert tool messages to glm tools
:param tools: tool messages :param tools: tool messages
:return: glm tools :return: glm tools
""" """
import vertexai.generative_models as glm
return glm.Tool( return glm.Tool(
function_declarations=[ function_declarations=[
glm.FunctionDeclaration( glm.FunctionDeclaration(
@ -473,6 +477,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
:param user: unique user id :param user: unique user id
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
import vertexai.generative_models as glm
from google.cloud import aiplatform
from google.oauth2 import service_account
config_kwargs = model_parameters.copy() config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
@ -522,7 +530,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response( def _handle_generate_response(
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
) -> LLMResult: ) -> LLMResult:
""" """
Handle llm response Handle llm response
@ -554,7 +562,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return result return result
def _handle_generate_stream_response( def _handle_generate_stream_response(
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
) -> Generator: ) -> Generator:
""" """
Handle llm stream response Handle llm stream response
@ -638,13 +646,15 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return message_text return message_text
def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content: def _format_message_to_glm_content(self, message: PromptMessage) -> "glm.Content":
""" """
Format a single message into glm.Content for Google API Format a single message into glm.Content for Google API
:param message: one PromptMessage :param message: one PromptMessage
:return: glm Content representation of message :return: glm Content representation of message
""" """
import vertexai.generative_models as glm
if isinstance(message, UserPromptMessage): if isinstance(message, UserPromptMessage):
glm_content = glm.Content(role="user", parts=[]) glm_content = glm.Content(role="user", parts=[])

View File

@ -2,12 +2,9 @@ import base64
import json import json
import time import time
from decimal import Decimal from decimal import Decimal
from typing import Optional from typing import TYPE_CHECKING, Optional
import tiktoken import tiktoken
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
from core.entities.embedding_type import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
@ -24,6 +21,11 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi
if TYPE_CHECKING:
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
else:
VertexTextEmbeddingModel = None
class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
""" """
@ -48,6 +50,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:param input_type: input type :param input_type: input type
:return: embeddings result :return: embeddings result
""" """
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
service_account_key = credentials.get("vertex_service_account_key", "") service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"] project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"] location = credentials["vertex_location"]
@ -100,6 +106,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:param credentials: model credentials :param credentials: model credentials
:return: :return:
""" """
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
try: try:
service_account_key = credentials.get("vertex_service_account_key", "") service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"] project_id = credentials["vertex_project_id"]

View File

@ -1,18 +1,19 @@
import re import re
from typing import Optional from typing import Optional
import jieba
from jieba.analyse import default_tfidf
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
class JiebaKeywordTableHandler: class JiebaKeywordTableHandler:
def __init__(self): def __init__(self):
default_tfidf.stop_words = STOPWORDS import jieba.analyse
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
jieba.analyse.default_tfidf.stop_words = STOPWORDS
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf.""" """Extract keywords with JIEBA tfidf."""
import jieba
keywords = jieba.analyse.extract_tags( keywords = jieba.analyse.extract_tags(
sentence=text, sentence=text,
topK=max_keywords_per_chunk, topK=max_keywords_per_chunk,
@ -22,6 +23,8 @@ class JiebaKeywordTableHandler:
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords.""" """Get subtokens from a list of tokens., filtering for stopwords."""
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
results = set() results = set()
for token in tokens: for token in tokens:
results.add(token) results.add(token)

View File

@ -6,10 +6,8 @@ from contextlib import contextmanager
from typing import Any from typing import Any
import jieba.posseg as pseg import jieba.posseg as pseg
import nltk
import numpy import numpy
import oracledb import oracledb
from nltk.corpus import stopwords
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from configs import dify_config from configs import dify_config
@ -202,6 +200,10 @@ class OracleVector(BaseVector):
return docs return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# lazy import
import nltk
from nltk.corpus import stopwords
top_k = kwargs.get("top_k", 5) top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later # just not implement fetch by score_threshold now, may be later
score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)

View File

@ -8,12 +8,6 @@ import docx
import pandas as pd import pandas as pd
import pypdfium2 # type: ignore import pypdfium2 # type: ignore
import yaml # type: ignore import yaml # type: ignore
from unstructured.partition.api import partition_via_api
from unstructured.partition.email import partition_email
from unstructured.partition.epub import partition_epub
from unstructured.partition.msg import partition_msg
from unstructured.partition.ppt import partition_ppt
from unstructured.partition.pptx import partition_pptx
from configs import dify_config from configs import dify_config
from core.file import File, FileTransferMethod, file_manager from core.file import File, FileTransferMethod, file_manager
@ -256,6 +250,8 @@ def _extract_text_from_excel(file_content: bytes) -> str:
def _extract_text_from_ppt(file_content: bytes) -> str: def _extract_text_from_ppt(file_content: bytes) -> str:
from unstructured.partition.ppt import partition_ppt
try: try:
with io.BytesIO(file_content) as file: with io.BytesIO(file_content) as file:
elements = partition_ppt(file=file) elements = partition_ppt(file=file)
@ -265,6 +261,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
def _extract_text_from_pptx(file_content: bytes) -> str: def _extract_text_from_pptx(file_content: bytes) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.pptx import partition_pptx
try: try:
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY: if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file: with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
@ -287,6 +286,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
def _extract_text_from_epub(file_content: bytes) -> str: def _extract_text_from_epub(file_content: bytes) -> str:
from unstructured.partition.epub import partition_epub
try: try:
with io.BytesIO(file_content) as file: with io.BytesIO(file_content) as file:
elements = partition_epub(file=file) elements = partition_epub(file=file)
@ -296,6 +297,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:
def _extract_text_from_eml(file_content: bytes) -> str: def _extract_text_from_eml(file_content: bytes) -> str:
from unstructured.partition.email import partition_email
try: try:
with io.BytesIO(file_content) as file: with io.BytesIO(file_content) as file:
elements = partition_email(file=file) elements = partition_email(file=file)
@ -305,6 +308,8 @@ def _extract_text_from_eml(file_content: bytes) -> str:
def _extract_text_from_msg(file_content: bytes) -> str: def _extract_text_from_msg(file_content: bytes) -> str:
from unstructured.partition.msg import partition_msg
try: try:
with io.BytesIO(file_content) as file: with io.BytesIO(file_content) as file:
elements = partition_msg(file=file) elements = partition_msg(file=file)