mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 07:09:02 +08:00
chore: apply ruff's pyupgrade linter rules to modernize Python code with targeted version (#2419)
This commit is contained in:
parent
589099a005
commit
063191889d
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from werkzeug.exceptions import Unauthorized
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Generator, Union
|
from collections.abc import Generator
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import flask_login
|
import flask_login
|
||||||
from flask import Response, stream_with_context
|
from flask import Response, stream_with_context
|
||||||
@ -169,8 +169,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
|||||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||||
else:
|
else:
|
||||||
def generate() -> Generator:
|
def generate() -> Generator:
|
||||||
for chunk in response:
|
yield from response
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return Response(stream_with_context(generate()), status=200,
|
return Response(stream_with_context(generate()), status=200,
|
||||||
mimetype='text/event-stream')
|
mimetype='text/event-stream')
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Generator, Union
|
from collections.abc import Generator
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from flask import Response, stream_with_context
|
from flask import Response, stream_with_context
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
@ -246,8 +247,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
|||||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||||
else:
|
else:
|
||||||
def generate() -> Generator:
|
def generate() -> Generator:
|
||||||
for chunk in response:
|
yield from response
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return Response(stream_with_context(generate()), status=200,
|
return Response(stream_with_context(generate()), status=200,
|
||||||
mimetype='text/event-stream')
|
mimetype='text/event-stream')
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, marshal_with, reqparse
|
from flask_restful import Resource, marshal_with, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import flask_login
|
import flask_login
|
||||||
from flask import current_app, request
|
from flask import current_app, request
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import flask_restful
|
import flask_restful
|
||||||
from flask import current_app, request
|
from flask import current_app, request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
@ -71,7 +69,7 @@ class DocumentResource(Resource):
|
|||||||
|
|
||||||
return document
|
return document
|
||||||
|
|
||||||
def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]:
|
def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound('Dataset not found.')
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Generator, Union
|
from typing import Union
|
||||||
|
|
||||||
from flask import Response, stream_with_context
|
from flask import Response, stream_with_context
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
@ -164,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
|||||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||||
else:
|
else:
|
||||||
def generate() -> Generator:
|
def generate() -> Generator:
|
||||||
for chunk in response:
|
yield from response
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return Response(stream_with_context(generate()), status=200,
|
return Response(stream_with_context(generate()), status=200,
|
||||||
mimetype='text/event-stream')
|
mimetype='text/event-stream')
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import marshal_with, reqparse
|
from flask_restful import marshal_with, reqparse
|
||||||
from flask_restful.inputs import int_range
|
from flask_restful.inputs import int_range
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from libs.exception import BaseHTTPException
|
from libs.exception import BaseHTTPException
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Generator, Union
|
from collections.abc import Generator
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from flask import Response, stream_with_context
|
from flask import Response, stream_with_context
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
@ -123,8 +123,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
|||||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||||
else:
|
else:
|
||||||
def generate() -> Generator:
|
def generate() -> Generator:
|
||||||
for chunk in response:
|
yield from response
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return Response(stream_with_context(generate()), status=200,
|
return Response(stream_with_context(generate()), status=200,
|
||||||
mimetype='text/event-stream')
|
mimetype='text/event-stream')
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, fields, marshal_with
|
from flask_restful import Resource, fields, marshal_with
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
from flask import current_app, request
|
from flask import current_app, request
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import pytz
|
import pytz
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, abort, fields, marshal_with, reqparse
|
from flask_restful import Resource, abort, fields, marshal_with, reqparse
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import json
|
import json
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Generator, Union
|
from collections.abc import Generator
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from flask import Response, stream_with_context
|
from flask import Response, stream_with_context
|
||||||
from flask_restful import reqparse
|
from flask_restful import reqparse
|
||||||
@ -182,8 +183,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
|||||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||||
else:
|
else:
|
||||||
def generate() -> Generator:
|
def generate() -> Generator:
|
||||||
for chunk in response:
|
yield from response
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return Response(stream_with_context(generate()), status=200,
|
return Response(stream_with_context(generate()), status=200,
|
||||||
mimetype='text/event-stream')
|
mimetype='text/event-stream')
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restful import marshal_with, reqparse
|
from flask_restful import marshal_with, reqparse
|
||||||
from flask_restful.inputs import int_range
|
from flask_restful.inputs import int_range
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from libs.exception import BaseHTTPException
|
from libs.exception import BaseHTTPException
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from flask_restful import fields, marshal_with, reqparse
|
from flask_restful import fields, marshal_with, reqparse
|
||||||
from flask_restful.inputs import int_range
|
from flask_restful.inputs import int_range
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Generator, Union
|
from collections.abc import Generator
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from flask import Response, stream_with_context
|
from flask import Response, stream_with_context
|
||||||
from flask_restful import reqparse
|
from flask_restful import reqparse
|
||||||
@ -154,8 +154,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
|||||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||||
else:
|
else:
|
||||||
def generate() -> Generator:
|
def generate() -> Generator:
|
||||||
for chunk in response:
|
yield from response
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return Response(stream_with_context(generate()), status=200,
|
return Response(stream_with_context(generate()), status=200,
|
||||||
mimetype='text/event-stream')
|
mimetype='text/event-stream')
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from flask_restful import marshal_with, reqparse
|
from flask_restful import marshal_with, reqparse
|
||||||
from flask_restful.inputs import int_range
|
from flask_restful.inputs import int_range
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from libs.exception import BaseHTTPException
|
from libs.exception import BaseHTTPException
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Generator, Union
|
from collections.abc import Generator
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from flask import Response, stream_with_context
|
from flask import Response, stream_with_context
|
||||||
from flask_restful import fields, marshal_with, reqparse
|
from flask_restful import fields, marshal_with, reqparse
|
||||||
@ -160,8 +160,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
|||||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||||
else:
|
else:
|
||||||
def generate() -> Generator:
|
def generate() -> Generator:
|
||||||
for chunk in response:
|
yield from response
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return Response(stream_with_context(generate()), status=200,
|
return Response(stream_with_context(generate()), status=200,
|
||||||
mimetype='text/event-stream')
|
mimetype='text/event-stream')
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from flask_restful import fields, marshal_with
|
from flask_restful import fields, marshal_with
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||||
from core.model_runtime.callbacks.base_callback import Callback
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
@ -17,7 +17,7 @@ class AgentLLMCallback(Callback):
|
|||||||
|
|
||||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None) -> None:
|
stream: bool = True, user: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Before invoke callback
|
Before invoke callback
|
||||||
@ -38,7 +38,7 @@ class AgentLLMCallback(Callback):
|
|||||||
|
|
||||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None):
|
stream: bool = True, user: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
On new chunk callback
|
On new chunk callback
|
||||||
@ -58,7 +58,7 @@ class AgentLLMCallback(Callback):
|
|||||||
|
|
||||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None) -> None:
|
stream: bool = True, user: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
After invoke callback
|
After invoke callback
|
||||||
@ -80,7 +80,7 @@ class AgentLLMCallback(Callback):
|
|||||||
|
|
||||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None) -> None:
|
stream: bool = True, user: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Invoke error callback
|
Invoke error callback
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, cast
|
from typing import cast
|
||||||
|
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage
|
from core.model_runtime.entities.message_entities import PromptMessage
|
||||||
@ -8,7 +8,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
|||||||
|
|
||||||
class CalcTokenMixin:
|
class CalcTokenMixin:
|
||||||
|
|
||||||
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
|
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
|
||||||
"""
|
"""
|
||||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
||||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||||
@ -42,7 +43,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
|||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -85,7 +86,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
|||||||
|
|
||||||
def real_plan(
|
def real_plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -146,7 +147,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
|||||||
|
|
||||||
async def aplan(
|
async def aplan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -158,7 +159,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
|||||||
model_config: ModelConfigEntity,
|
model_config: ModelConfigEntity,
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||||
system_message: Optional[SystemMessage] = SystemMessage(
|
system_message: Optional[SystemMessage] = SystemMessage(
|
||||||
content="You are a helpful AI assistant."
|
content="You are a helpful AI assistant."
|
||||||
),
|
),
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
||||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||||
@ -51,7 +52,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
|||||||
model_config: ModelConfigEntity,
|
model_config: ModelConfigEntity,
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||||
system_message: Optional[SystemMessage] = SystemMessage(
|
system_message: Optional[SystemMessage] = SystemMessage(
|
||||||
content="You are a helpful AI assistant."
|
content="You are a helpful AI assistant."
|
||||||
),
|
),
|
||||||
@ -125,7 +126,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
|||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -207,7 +208,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
|||||||
def return_stopped_response(
|
def return_stopped_response(
|
||||||
self,
|
self,
|
||||||
early_stopping_method: str,
|
early_stopping_method: str,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AgentFinish:
|
) -> AgentFinish:
|
||||||
try:
|
try:
|
||||||
@ -215,7 +216,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||||
|
|
||||||
def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
|
def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
|
||||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||||
rest_tokens = self.get_message_rest_tokens(
|
rest_tokens = self.get_message_rest_tokens(
|
||||||
self.model_config,
|
self.model_config,
|
||||||
@ -264,7 +265,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
|||||||
return new_messages
|
return new_messages
|
||||||
|
|
||||||
def predict_new_summary(
|
def predict_new_summary(
|
||||||
self, messages: List[BaseMessage], existing_summary: str
|
self, messages: list[BaseMessage], existing_summary: str
|
||||||
) -> str:
|
) -> str:
|
||||||
new_lines = get_buffer_string(
|
new_lines = get_buffer_string(
|
||||||
messages,
|
messages,
|
||||||
@ -275,7 +276,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
|||||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
|
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
|
||||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||||
|
|
||||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
from langchain import BasePromptTemplate, PromptTemplate
|
from langchain import BasePromptTemplate, PromptTemplate
|
||||||
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
||||||
@ -68,7 +69,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
|||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -125,8 +126,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
|||||||
suffix: str = SUFFIX,
|
suffix: str = SUFFIX,
|
||||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||||
) -> BasePromptTemplate:
|
) -> BasePromptTemplate:
|
||||||
tool_strings = []
|
tool_strings = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
@ -153,7 +154,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
|||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
prefix: str = PREFIX,
|
prefix: str = PREFIX,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
) -> PromptTemplate:
|
) -> PromptTemplate:
|
||||||
"""Create prompt in the style of the zero shot agent.
|
"""Create prompt in the style of the zero shot agent.
|
||||||
|
|
||||||
@ -180,7 +181,7 @@ Thought: {agent_scratchpad}
|
|||||||
return PromptTemplate(template=template, input_variables=input_variables)
|
return PromptTemplate(template=template, input_variables=input_variables)
|
||||||
|
|
||||||
def _construct_scratchpad(
|
def _construct_scratchpad(
|
||||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||||
) -> str:
|
) -> str:
|
||||||
agent_scratchpad = ""
|
agent_scratchpad = ""
|
||||||
for action, observation in intermediate_steps:
|
for action, observation in intermediate_steps:
|
||||||
@ -213,8 +214,8 @@ Thought: {agent_scratchpad}
|
|||||||
suffix: str = SUFFIX,
|
suffix: str = SUFFIX,
|
||||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools."""
|
"""Construct an agent from an LLM and tools."""
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
from langchain import BasePromptTemplate, PromptTemplate
|
from langchain import BasePromptTemplate, PromptTemplate
|
||||||
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
||||||
@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
|||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -127,7 +128,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
|||||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||||
"I don't know how to respond to that."}, "")
|
"I don't know how to respond to that."}, "")
|
||||||
|
|
||||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
|
||||||
if len(intermediate_steps) >= 2 and self.summary_model_config:
|
if len(intermediate_steps) >= 2 and self.summary_model_config:
|
||||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||||
should_summary_messages = [AIMessage(content=observation)
|
should_summary_messages = [AIMessage(content=observation)
|
||||||
@ -154,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
|||||||
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
||||||
|
|
||||||
def predict_new_summary(
|
def predict_new_summary(
|
||||||
self, messages: List[BaseMessage], existing_summary: str
|
self, messages: list[BaseMessage], existing_summary: str
|
||||||
) -> str:
|
) -> str:
|
||||||
new_lines = get_buffer_string(
|
new_lines = get_buffer_string(
|
||||||
messages,
|
messages,
|
||||||
@ -173,8 +174,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
|||||||
suffix: str = SUFFIX,
|
suffix: str = SUFFIX,
|
||||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||||
) -> BasePromptTemplate:
|
) -> BasePromptTemplate:
|
||||||
tool_strings = []
|
tool_strings = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
@ -200,7 +201,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
|||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
prefix: str = PREFIX,
|
prefix: str = PREFIX,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
) -> PromptTemplate:
|
) -> PromptTemplate:
|
||||||
"""Create prompt in the style of the zero shot agent.
|
"""Create prompt in the style of the zero shot agent.
|
||||||
|
|
||||||
@ -227,7 +228,7 @@ Thought: {agent_scratchpad}
|
|||||||
return PromptTemplate(template=template, input_variables=input_variables)
|
return PromptTemplate(template=template, input_variables=input_variables)
|
||||||
|
|
||||||
def _construct_scratchpad(
|
def _construct_scratchpad(
|
||||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||||
) -> str:
|
) -> str:
|
||||||
agent_scratchpad = ""
|
agent_scratchpad = ""
|
||||||
for action, observation in intermediate_steps:
|
for action, observation in intermediate_steps:
|
||||||
@ -260,8 +261,8 @@ Thought: {agent_scratchpad}
|
|||||||
suffix: str = SUFFIX,
|
suffix: str = SUFFIX,
|
||||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Generator, List, Optional, Tuple, Union, cast
|
from collections.abc import Generator
|
||||||
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||||
from core.entities.application_entities import (
|
from core.entities.application_entities import (
|
||||||
@ -84,7 +85,7 @@ class AppRunner:
|
|||||||
return rest_tokens
|
return rest_tokens
|
||||||
|
|
||||||
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
|
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
|
||||||
prompt_messages: List[PromptMessage]):
|
prompt_messages: list[PromptMessage]):
|
||||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
@ -126,7 +127,7 @@ class AppRunner:
|
|||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
memory: Optional[TokenBufferMemory] = None) \
|
memory: Optional[TokenBufferMemory] = None) \
|
||||||
-> Tuple[List[PromptMessage], Optional[List[str]]]:
|
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||||
"""
|
"""
|
||||||
Organize prompt messages
|
Organize prompt messages
|
||||||
:param context:
|
:param context:
|
||||||
@ -295,7 +296,7 @@ class AppRunner:
|
|||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str) -> Tuple[bool, dict, str]:
|
query: str) -> tuple[bool, dict, str]:
|
||||||
"""
|
"""
|
||||||
Process sensitive_word_avoidance.
|
Process sensitive_word_avoidance.
|
||||||
:param app_id: app id
|
:param app_id: app id
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Generator, Optional, Union, cast
|
from collections.abc import Generator
|
||||||
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -118,7 +119,7 @@ class GenerateTaskPipeline:
|
|||||||
}
|
}
|
||||||
|
|
||||||
self._task_state.llm_result.message.content = annotation.content
|
self._task_state.llm_result.message.content = annotation.content
|
||||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||||
if isinstance(event, QueueMessageEndEvent):
|
if isinstance(event, QueueMessageEndEvent):
|
||||||
self._task_state.llm_result = event.llm_result
|
self._task_state.llm_result = event.llm_result
|
||||||
else:
|
else:
|
||||||
@ -201,7 +202,7 @@ class GenerateTaskPipeline:
|
|||||||
data = self._error_to_stream_response_data(self._handle_error(event))
|
data = self._error_to_stream_response_data(self._handle_error(event))
|
||||||
yield self._yield_response(data)
|
yield self._yield_response(data)
|
||||||
break
|
break
|
||||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||||
if isinstance(event, QueueMessageEndEvent):
|
if isinstance(event, QueueMessageEndEvent):
|
||||||
self._task_state.llm_result = event.llm_result
|
self._task_state.llm_result = event.llm_result
|
||||||
else:
|
else:
|
||||||
@ -353,7 +354,7 @@ class GenerateTaskPipeline:
|
|||||||
|
|
||||||
yield self._yield_response(response)
|
yield self._yield_response(response)
|
||||||
|
|
||||||
elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)):
|
elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent):
|
||||||
chunk = event.chunk
|
chunk = event.chunk
|
||||||
delta_text = chunk.delta.message.content
|
delta_text = chunk.delta.message.content
|
||||||
if delta_text is None:
|
if delta_text is None:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class ModerationRule(BaseModel):
|
class ModerationRule(BaseModel):
|
||||||
type: str
|
type: str
|
||||||
config: Dict[str, Any]
|
config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class OutputModerationHandler(BaseModel):
|
class OutputModerationHandler(BaseModel):
|
||||||
|
@ -2,7 +2,8 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Generator, Optional, Tuple, Union, cast
|
from collections.abc import Generator
|
||||||
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
@ -585,7 +586,7 @@ class ApplicationManager:
|
|||||||
return AppOrchestrationConfigEntity(**properties)
|
return AppOrchestrationConfigEntity(**properties)
|
||||||
|
|
||||||
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
|
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
|
||||||
-> Tuple[Conversation, Message]:
|
-> tuple[Conversation, Message]:
|
||||||
"""
|
"""
|
||||||
Initialize generate records
|
Initialize generate records
|
||||||
:param application_generate_entity: application generate entity
|
:param application_generate_entity: application generate entity
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import queue
|
import queue
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Generator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Generator
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy.orm import DeclarativeMeta
|
from sqlalchemy.orm import DeclarativeMeta
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Optional, Union, cast
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
@ -37,7 +37,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
self._message_agent_thought = None
|
self._message_agent_thought = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def agent_loops(self) -> List[AgentLoop]:
|
def agent_loops(self) -> list[AgentLoop]:
|
||||||
return self._agent_loops
|
return self._agent_loops
|
||||||
|
|
||||||
def clear_agent_loops(self) -> None:
|
def clear_agent_loops(self) -> None:
|
||||||
@ -95,14 +95,14 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_chat_model_start(
|
def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
messages: List[List[BaseMessage]],
|
messages: list[list[BaseMessage]],
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
input_str: str,
|
input_str: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.input import print_text
|
from langchain.input import print_text
|
||||||
@ -21,7 +21,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
|
|||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
tool_inputs: Dict[str, Any],
|
tool_inputs: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
|
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
|
||||||
@ -29,7 +29,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
|
|||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
tool_inputs: Dict[str, Any],
|
tool_inputs: dict[str, Any],
|
||||||
tool_outputs: str,
|
tool_outputs: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""If not the final action, print out observation."""
|
"""If not the final action, print out observation."""
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
|
|
||||||
@ -40,7 +39,7 @@ class DatasetIndexToolCallbackHandler:
|
|||||||
db.session.add(dataset_query)
|
db.session.add(dataset_query)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def on_tool_end(self, documents: List[Document]) -> None:
|
def on_tool_end(self, documents: list[Document]) -> None:
|
||||||
"""Handle tool end."""
|
"""Handle tool end."""
|
||||||
for document in documents:
|
for document in documents:
|
||||||
doc_id = document.metadata['doc_id']
|
doc_id = document.metadata['doc_id']
|
||||||
@ -55,7 +54,7 @@ class DatasetIndexToolCallbackHandler:
|
|||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def return_retriever_resource_info(self, resource: List):
|
def return_retriever_resource_info(self, resource: list):
|
||||||
"""Handle return_retriever_resource_info."""
|
"""Handle return_retriever_resource_info."""
|
||||||
if resource and len(resource) > 0:
|
if resource and len(resource) > 0:
|
||||||
for item in resource:
|
for item in resource:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.input import print_text
|
from langchain.input import print_text
|
||||||
@ -16,8 +16,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_chat_model_start(
|
def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
messages: List[List[BaseMessage]],
|
messages: list[list[BaseMessage]],
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
print_text("\n[on_chat_model_start]\n", color='blue')
|
print_text("\n[on_chat_model_start]\n", color='blue')
|
||||||
@ -26,7 +26,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
print_text(str(sub_message) + "\n", color='blue')
|
print_text(str(sub_message) + "\n", color='blue')
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Print out the prompts."""
|
"""Print out the prompts."""
|
||||||
print_text("\n[on_llm_start]\n", color='blue')
|
print_text("\n[on_llm_start]\n", color='blue')
|
||||||
@ -48,13 +48,13 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue')
|
print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue')
|
||||||
|
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Print out that we are entering a chain."""
|
"""Print out that we are entering a chain."""
|
||||||
chain_type = serialized['id'][-1]
|
chain_type = serialized['id'][-1]
|
||||||
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
|
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||||
"""Print out that we finished a chain."""
|
"""Print out that we finished a chain."""
|
||||||
print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink')
|
print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink')
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
input_str: str,
|
input_str: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain import LLMChain as LCLLMChain
|
from langchain import LLMChain as LCLLMChain
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
@ -16,12 +16,12 @@ class LLMChain(LCLLMChain):
|
|||||||
model_config: ModelConfigEntity
|
model_config: ModelConfigEntity
|
||||||
"""The language model instance to use."""
|
"""The language model instance to use."""
|
||||||
llm: BaseLanguageModel = FakeLLM(response="")
|
llm: BaseLanguageModel = FakeLLM(response="")
|
||||||
parameters: Dict[str, Any] = {}
|
parameters: dict[str, Any] = {}
|
||||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
input_list: List[Dict[str, Any]],
|
input_list: list[dict[str, Any]],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Generate LLM result from inputs."""
|
"""Generate LLM result from inputs."""
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
@ -28,7 +28,7 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM
|
|||||||
|
|
||||||
class FileExtractor:
|
class FileExtractor:
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document], str]:
|
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]:
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
suffix = Path(upload_file.key).suffix
|
suffix = Path(upload_file.key).suffix
|
||||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||||
@ -37,7 +37,7 @@ class FileExtractor:
|
|||||||
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
|
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document], str]:
|
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
|
||||||
response = requests.get(url, headers={
|
response = requests.get(url, headers={
|
||||||
"User-Agent": USER_AGENT
|
"User-Agent": USER_AGENT
|
||||||
})
|
})
|
||||||
@ -53,7 +53,7 @@ class FileExtractor:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def load_from_file(cls, file_path: str, return_text: bool = False,
|
def load_from_file(cls, file_path: str, return_text: bool = False,
|
||||||
upload_file: Optional[UploadFile] = None,
|
upload_file: Optional[UploadFile] = None,
|
||||||
is_automatic: bool = False) -> Union[List[Document], str]:
|
is_automatic: bool = False) -> Union[list[Document], str]:
|
||||||
input_file = Path(file_path)
|
input_file = Path(file_path)
|
||||||
delimiter = '\n'
|
delimiter = '\n'
|
||||||
file_extension = input_file.suffix.lower()
|
file_extension = input_file.suffix.lower()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import csv
|
import csv
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langchain.document_loaders import CSVLoader as LCCSVLoader
|
from langchain.document_loaders import CSVLoader as LCCSVLoader
|
||||||
from langchain.document_loaders.helpers import detect_file_encodings
|
from langchain.document_loaders.helpers import detect_file_encodings
|
||||||
@ -14,7 +14,7 @@ class CSVLoader(LCCSVLoader):
|
|||||||
self,
|
self,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
source_column: Optional[str] = None,
|
source_column: Optional[str] = None,
|
||||||
csv_args: Optional[Dict] = None,
|
csv_args: Optional[dict] = None,
|
||||||
encoding: Optional[str] = None,
|
encoding: Optional[str] = None,
|
||||||
autodetect_encoding: bool = True,
|
autodetect_encoding: bool = True,
|
||||||
):
|
):
|
||||||
@ -24,7 +24,7 @@ class CSVLoader(LCCSVLoader):
|
|||||||
self.csv_args = csv_args or {}
|
self.csv_args = csv_args or {}
|
||||||
self.autodetect_encoding = autodetect_encoding
|
self.autodetect_encoding = autodetect_encoding
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> list[Document]:
|
||||||
"""Load data into document objects."""
|
"""Load data into document objects."""
|
||||||
try:
|
try:
|
||||||
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
|
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain.document_loaders.base import BaseLoader
|
from langchain.document_loaders.base import BaseLoader
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
@ -23,7 +22,7 @@ class ExcelLoader(BaseLoader):
|
|||||||
"""Initialize with file path."""
|
"""Initialize with file path."""
|
||||||
self._file_path = file_path
|
self._file_path = file_path
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> list[Document]:
|
||||||
data = []
|
data = []
|
||||||
keys = []
|
keys = []
|
||||||
wb = load_workbook(filename=self._file_path, read_only=True)
|
wb = load_workbook(filename=self._file_path, read_only=True)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from langchain.document_loaders.base import BaseLoader
|
from langchain.document_loaders.base import BaseLoader
|
||||||
@ -23,7 +22,7 @@ class HTMLLoader(BaseLoader):
|
|||||||
"""Initialize with file path."""
|
"""Initialize with file path."""
|
||||||
self._file_path = file_path
|
self._file_path = file_path
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> list[Document]:
|
||||||
return [Document(page_content=self._load_as_text())]
|
return [Document(page_content=self._load_as_text())]
|
||||||
|
|
||||||
def _load_as_text(self) -> str:
|
def _load_as_text(self) -> str:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional, Tuple, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
from langchain.document_loaders.base import BaseLoader
|
from langchain.document_loaders.base import BaseLoader
|
||||||
from langchain.document_loaders.helpers import detect_file_encodings
|
from langchain.document_loaders.helpers import detect_file_encodings
|
||||||
@ -42,7 +42,7 @@ class MarkdownLoader(BaseLoader):
|
|||||||
self._encoding = encoding
|
self._encoding = encoding
|
||||||
self._autodetect_encoding = autodetect_encoding
|
self._autodetect_encoding = autodetect_encoding
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> list[Document]:
|
||||||
tups = self.parse_tups(self._file_path)
|
tups = self.parse_tups(self._file_path)
|
||||||
documents = []
|
documents = []
|
||||||
for header, value in tups:
|
for header, value in tups:
|
||||||
@ -54,13 +54,13 @@ class MarkdownLoader(BaseLoader):
|
|||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
|
def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str]]:
|
||||||
"""Convert a markdown file to a dictionary.
|
"""Convert a markdown file to a dictionary.
|
||||||
|
|
||||||
The keys are the headers and the values are the text under each header.
|
The keys are the headers and the values are the text under each header.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
markdown_tups: List[Tuple[Optional[str], str]] = []
|
markdown_tups: list[tuple[Optional[str], str]] = []
|
||||||
lines = markdown_text.split("\n")
|
lines = markdown_text.split("\n")
|
||||||
|
|
||||||
current_header = None
|
current_header = None
|
||||||
@ -103,11 +103,11 @@ class MarkdownLoader(BaseLoader):
|
|||||||
content = re.sub(pattern, r"\1", content)
|
content = re.sub(pattern, r"\1", content)
|
||||||
return content
|
return content
|
||||||
|
|
||||||
def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]:
|
def parse_tups(self, filepath: str) -> list[tuple[Optional[str], str]]:
|
||||||
"""Parse file into tuples."""
|
"""Parse file into tuples."""
|
||||||
content = ""
|
content = ""
|
||||||
try:
|
try:
|
||||||
with open(filepath, "r", encoding=self._encoding) as f:
|
with open(filepath, encoding=self._encoding) as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
except UnicodeDecodeError as e:
|
except UnicodeDecodeError as e:
|
||||||
if self._autodetect_encoding:
|
if self._autodetect_encoding:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
@ -67,7 +67,7 @@ class NotionLoader(BaseLoader):
|
|||||||
document_model=document_model
|
document_model=document_model
|
||||||
)
|
)
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> list[Document]:
|
||||||
self.update_last_edited_time(
|
self.update_last_edited_time(
|
||||||
self._document_model
|
self._document_model
|
||||||
)
|
)
|
||||||
@ -78,7 +78,7 @@ class NotionLoader(BaseLoader):
|
|||||||
|
|
||||||
def _load_data_as_documents(
|
def _load_data_as_documents(
|
||||||
self, notion_obj_id: str, notion_page_type: str
|
self, notion_obj_id: str, notion_page_type: str
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
docs = []
|
docs = []
|
||||||
if notion_page_type == 'database':
|
if notion_page_type == 'database':
|
||||||
# get all the pages in the database
|
# get all the pages in the database
|
||||||
@ -94,8 +94,8 @@ class NotionLoader(BaseLoader):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
def _get_notion_database_data(
|
def _get_notion_database_data(
|
||||||
self, database_id: str, query_dict: Dict[str, Any] = {}
|
self, database_id: str, query_dict: dict[str, Any] = {}
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get all the pages from a Notion database."""
|
"""Get all the pages from a Notion database."""
|
||||||
res = requests.post(
|
res = requests.post(
|
||||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||||
@ -149,12 +149,12 @@ class NotionLoader(BaseLoader):
|
|||||||
|
|
||||||
return database_content_list
|
return database_content_list
|
||||||
|
|
||||||
def _get_notion_block_data(self, page_id: str) -> List[str]:
|
def _get_notion_block_data(self, page_id: str) -> list[str]:
|
||||||
result_lines_arr = []
|
result_lines_arr = []
|
||||||
cur_block_id = page_id
|
cur_block_id = page_id
|
||||||
while True:
|
while True:
|
||||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||||
query_dict: Dict[str, Any] = {}
|
query_dict: dict[str, Any] = {}
|
||||||
|
|
||||||
res = requests.request(
|
res = requests.request(
|
||||||
"GET",
|
"GET",
|
||||||
@ -216,7 +216,7 @@ class NotionLoader(BaseLoader):
|
|||||||
cur_block_id = block_id
|
cur_block_id = block_id
|
||||||
while True:
|
while True:
|
||||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||||
query_dict: Dict[str, Any] = {}
|
query_dict: dict[str, Any] = {}
|
||||||
|
|
||||||
res = requests.request(
|
res = requests.request(
|
||||||
"GET",
|
"GET",
|
||||||
@ -280,7 +280,7 @@ class NotionLoader(BaseLoader):
|
|||||||
cur_block_id = block_id
|
cur_block_id = block_id
|
||||||
while not done:
|
while not done:
|
||||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||||
query_dict: Dict[str, Any] = {}
|
query_dict: dict[str, Any] = {}
|
||||||
|
|
||||||
res = requests.request(
|
res = requests.request(
|
||||||
"GET",
|
"GET",
|
||||||
@ -346,7 +346,7 @@ class NotionLoader(BaseLoader):
|
|||||||
else:
|
else:
|
||||||
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
|
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
|
||||||
|
|
||||||
query_dict: Dict[str, Any] = {}
|
query_dict: dict[str, Any] = {}
|
||||||
|
|
||||||
res = requests.request(
|
res = requests.request(
|
||||||
"GET",
|
"GET",
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langchain.document_loaders import PyPDFium2Loader
|
from langchain.document_loaders import PyPDFium2Loader
|
||||||
from langchain.document_loaders.base import BaseLoader
|
from langchain.document_loaders.base import BaseLoader
|
||||||
@ -28,7 +28,7 @@ class PdfLoader(BaseLoader):
|
|||||||
self._file_path = file_path
|
self._file_path = file_path
|
||||||
self._upload_file = upload_file
|
self._upload_file = upload_file
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> list[Document]:
|
||||||
plaintext_file_key = ''
|
plaintext_file_key = ''
|
||||||
plaintext_file_exists = False
|
plaintext_file_exists = False
|
||||||
if self._upload_file:
|
if self._upload_file:
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from langchain.document_loaders.base import BaseLoader
|
from langchain.document_loaders.base import BaseLoader
|
||||||
@ -24,7 +23,7 @@ class UnstructuredEmailLoader(BaseLoader):
|
|||||||
self._file_path = file_path
|
self._file_path = file_path
|
||||||
self._api_url = api_url
|
self._api_url = api_url
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> list[Document]:
|
||||||
from unstructured.partition.email import partition_email
|
from unstructured.partition.email import partition_email
|
||||||
elements = partition_email(filename=self._file_path, api_url=self._api_url)
|
elements = partition_email(filename=self._file_path, api_url=self._api_url)
|
||||||
|
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain.document_loaders.base import BaseLoader
|
from langchain.document_loaders.base import BaseLoader
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
@ -34,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader):
|
|||||||
self._file_path = file_path
|
self._file_path = file_path
|
||||||
self._api_url = api_url
|
self._api_url = api_url
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> list[Document]:
|
||||||
from unstructured.partition.md import partition_md
|
from unstructured.partition.md import partition_md
|
||||||
|
|
||||||
elements = partition_md(filename=self._file_path, api_url=self._api_url)
|
elements = partition_md(filename=self._file_path, api_url=self._api_url)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain.document_loaders.base import BaseLoader
|
from langchain.document_loaders.base import BaseLoader
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
@ -24,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader):
|
|||||||
self._file_path = file_path
|
self._file_path = file_path
|
||||||
self._api_url = api_url
|
self._api_url = api_url
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> list[Document]:
|
||||||
from unstructured.partition.msg import partition_msg
|
from unstructured.partition.msg import partition_msg
|
||||||
|
|
||||||
elements = partition_msg(filename=self._file_path, api_url=self._api_url)
|
elements = partition_msg(filename=self._file_path, api_url=self._api_url)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain.document_loaders.base import BaseLoader
|
from langchain.document_loaders.base import BaseLoader
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
@ -23,7 +22,7 @@ class UnstructuredPPTLoader(BaseLoader):
|
|||||||
self._file_path = file_path
|
self._file_path = file_path
|
||||||
self._api_url = api_url
|
self._api_url = api_url
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> list[Document]:
|
||||||
from unstructured.partition.ppt import partition_ppt
|
from unstructured.partition.ppt import partition_ppt
|
||||||
|
|
||||||
elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
|
elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain.document_loaders.base import BaseLoader
|
from langchain.document_loaders.base import BaseLoader
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
@ -22,7 +21,7 @@ class UnstructuredPPTXLoader(BaseLoader):
|
|||||||
self._file_path = file_path
|
self._file_path = file_path
|
||||||
self._api_url = api_url
|
self._api_url = api_url
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> list[Document]:
|
||||||
from unstructured.partition.pptx import partition_pptx
|
from unstructured.partition.pptx import partition_pptx
|
||||||
|
|
||||||
elements = partition_pptx(filename=self._file_path, api_url=self._api_url)
|
elements = partition_pptx(filename=self._file_path, api_url=self._api_url)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain.document_loaders.base import BaseLoader
|
from langchain.document_loaders.base import BaseLoader
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
@ -24,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader):
|
|||||||
self._file_path = file_path
|
self._file_path = file_path
|
||||||
self._api_url = api_url
|
self._api_url = api_url
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> list[Document]:
|
||||||
from unstructured.partition.text import partition_text
|
from unstructured.partition.text import partition_text
|
||||||
|
|
||||||
elements = partition_text(filename=self._file_path, api_url=self._api_url)
|
elements = partition_text(filename=self._file_path, api_url=self._api_url)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain.document_loaders.base import BaseLoader
|
from langchain.document_loaders.base import BaseLoader
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
@ -24,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader):
|
|||||||
self._file_path = file_path
|
self._file_path = file_path
|
||||||
self._api_url = api_url
|
self._api_url = api_url
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> list[Document]:
|
||||||
from unstructured.partition.xml import partition_xml
|
from unstructured.partition.xml import partition_xml
|
||||||
|
|
||||||
elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
|
elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Any, Dict, Optional, Sequence, cast
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
@ -22,10 +23,10 @@ class DatasetDocumentStore:
|
|||||||
self._document_id = document_id
|
self._document_id = document_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "DatasetDocumentStore":
|
def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
|
||||||
return cls(**config_dict)
|
return cls(**config_dict)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""Serialize to dict."""
|
"""Serialize to dict."""
|
||||||
return {
|
return {
|
||||||
"dataset_id": self._dataset.id,
|
"dataset_id": self._dataset.id,
|
||||||
@ -40,7 +41,7 @@ class DatasetDocumentStore:
|
|||||||
return self._user_id
|
return self._user_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def docs(self) -> Dict[str, Document]:
|
def docs(self) -> dict[str, Document]:
|
||||||
document_segments = db.session.query(DocumentSegment).filter(
|
document_segments = db.session.query(DocumentSegment).filter(
|
||||||
DocumentSegment.dataset_id == self._dataset.id
|
DocumentSegment.dataset_id == self._dataset.id
|
||||||
).all()
|
).all()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
@ -21,7 +21,7 @@ class CacheEmbedding(Embeddings):
|
|||||||
self._model_instance = model_instance
|
self._model_instance = model_instance
|
||||||
self._user = user
|
self._user = user
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
"""Embed search docs in batches of 10."""
|
"""Embed search docs in batches of 10."""
|
||||||
text_embeddings = []
|
text_embeddings = []
|
||||||
try:
|
try:
|
||||||
@ -52,7 +52,7 @@ class CacheEmbedding(Embeddings):
|
|||||||
|
|
||||||
return text_embeddings
|
return text_embeddings
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> list[float]:
|
||||||
"""Embed query text."""
|
"""Embed query text."""
|
||||||
# use doc embedding cache or store if not exists
|
# use doc embedding cache or store if not exists
|
||||||
hash = helper.generate_text_hash(text)
|
hash = helper.generate_text_hash(text)
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Iterator
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Dict, Iterator, List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -135,7 +136,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
if self.provider.provider_credential_schema else []
|
if self.provider.provider_credential_schema else []
|
||||||
)
|
)
|
||||||
|
|
||||||
def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
|
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
|
||||||
"""
|
"""
|
||||||
Validate custom credentials.
|
Validate custom credentials.
|
||||||
:param credentials: provider credentials
|
:param credentials: provider credentials
|
||||||
@ -282,7 +283,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
|
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
|
||||||
-> Tuple[ProviderModel, dict]:
|
-> tuple[ProviderModel, dict]:
|
||||||
"""
|
"""
|
||||||
Validate custom model credentials.
|
Validate custom model credentials.
|
||||||
|
|
||||||
@ -711,7 +712,7 @@ class ProviderConfigurations(BaseModel):
|
|||||||
Model class for provider configuration dict.
|
Model class for provider configuration dict.
|
||||||
"""
|
"""
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
configurations: Dict[str, ProviderConfiguration] = {}
|
configurations: dict[str, ProviderConfiguration] = {}
|
||||||
|
|
||||||
def __init__(self, tenant_id: str):
|
def __init__(self, tenant_id: str):
|
||||||
super().__init__(tenant_id=tenant_id)
|
super().__init__(tenant_id=tenant_id)
|
||||||
@ -759,7 +760,7 @@ class ProviderConfigurations(BaseModel):
|
|||||||
|
|
||||||
return all_models
|
return all_models
|
||||||
|
|
||||||
def to_list(self) -> List[ProviderConfiguration]:
|
def to_list(self) -> list[ProviderConfiguration]:
|
||||||
"""
|
"""
|
||||||
Convert to list.
|
Convert to list.
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ class Extensible:
|
|||||||
|
|
||||||
builtin_file_path = os.path.join(subdir_path, '__builtin__')
|
builtin_file_path = os.path.join(subdir_path, '__builtin__')
|
||||||
if os.path.exists(builtin_file_path):
|
if os.path.exists(builtin_file_path):
|
||||||
with open(builtin_file_path, 'r', encoding='utf-8') as f:
|
with open(builtin_file_path, encoding='utf-8') as f:
|
||||||
position = int(f.read().strip())
|
position = int(f.read().strip())
|
||||||
|
|
||||||
if (extension_name + '.py') not in file_names:
|
if (extension_name + '.py') not in file_names:
|
||||||
@ -93,7 +93,7 @@ class Extensible:
|
|||||||
json_path = os.path.join(subdir_path, 'schema.json')
|
json_path = os.path.join(subdir_path, 'schema.json')
|
||||||
json_data = {}
|
json_data = {}
|
||||||
if os.path.exists(json_path):
|
if os.path.exists(json_path):
|
||||||
with open(json_path, 'r', encoding='utf-8') as f:
|
with open(json_path, encoding='utf-8') as f:
|
||||||
json_data = json.load(f)
|
json_data = json.load(f)
|
||||||
|
|
||||||
extensions[extension_name] = ModuleExtension(
|
extensions[extension_name] = ModuleExtension(
|
||||||
|
@ -2,7 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from mimetypes import guess_extension
|
from mimetypes import guess_extension
|
||||||
from typing import List, Optional, Tuple, Union, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
from core.app_runner.app_runner import AppRunner
|
from core.app_runner.app_runner import AppRunner
|
||||||
from core.application_queue_manager import ApplicationQueueManager
|
from core.application_queue_manager import ApplicationQueueManager
|
||||||
@ -50,7 +50,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|||||||
message: Message,
|
message: Message,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
memory: Optional[TokenBufferMemory] = None,
|
memory: Optional[TokenBufferMemory] = None,
|
||||||
prompt_messages: Optional[List[PromptMessage]] = None,
|
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||||
db_variables: Optional[ToolConversationVariables] = None,
|
db_variables: Optional[ToolConversationVariables] = None,
|
||||||
model_instance: ModelInstance = None
|
model_instance: ModelInstance = None
|
||||||
@ -122,7 +122,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|||||||
|
|
||||||
return app_orchestration_config
|
return app_orchestration_config
|
||||||
|
|
||||||
def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str:
|
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||||
"""
|
"""
|
||||||
Handle tool response
|
Handle tool response
|
||||||
"""
|
"""
|
||||||
@ -140,7 +140,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]:
|
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
|
||||||
"""
|
"""
|
||||||
convert tool to prompt message tool
|
convert tool to prompt message tool
|
||||||
"""
|
"""
|
||||||
@ -325,7 +325,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|||||||
|
|
||||||
return prompt_tool
|
return prompt_tool
|
||||||
|
|
||||||
def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]:
|
def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
|
||||||
"""
|
"""
|
||||||
Extract tool response binary
|
Extract tool response binary
|
||||||
"""
|
"""
|
||||||
@ -356,7 +356,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]:
|
def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]:
|
||||||
"""
|
"""
|
||||||
Create message file
|
Create message file
|
||||||
|
|
||||||
@ -404,7 +404,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def create_agent_thought(self, message_id: str, message: str,
|
def create_agent_thought(self, message_id: str, message: str,
|
||||||
tool_name: str, tool_input: str, messages_ids: List[str]
|
tool_name: str, tool_input: str, messages_ids: list[str]
|
||||||
) -> MessageAgentThought:
|
) -> MessageAgentThought:
|
||||||
"""
|
"""
|
||||||
Create agent thought
|
Create agent thought
|
||||||
@ -449,7 +449,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|||||||
thought: str,
|
thought: str,
|
||||||
observation: str,
|
observation: str,
|
||||||
answer: str,
|
answer: str,
|
||||||
messages_ids: List[str],
|
messages_ids: list[str],
|
||||||
llm_usage: LLMUsage = None) -> MessageAgentThought:
|
llm_usage: LLMUsage = None) -> MessageAgentThought:
|
||||||
"""
|
"""
|
||||||
Save agent thought
|
Save agent thought
|
||||||
@ -505,7 +505,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def get_history_prompt_messages(self) -> List[PromptMessage]:
|
def get_history_prompt_messages(self) -> list[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
Get history prompt messages
|
Get history prompt messages
|
||||||
"""
|
"""
|
||||||
@ -516,7 +516,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|||||||
|
|
||||||
return self.history_prompt_messages
|
return self.history_prompt_messages
|
||||||
|
|
||||||
def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]:
|
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
|
||||||
"""
|
"""
|
||||||
Transform tool message into agent thought
|
Transform tool message into agent thought
|
||||||
"""
|
"""
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Dict, Generator, List, Literal, Union
|
from collections.abc import Generator
|
||||||
|
from typing import Literal, Union
|
||||||
|
|
||||||
from core.application_queue_manager import PublishFrom
|
from core.application_queue_manager import PublishFrom
|
||||||
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
|
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
|
||||||
@ -29,7 +30,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
def run(self, conversation: Conversation,
|
def run(self, conversation: Conversation,
|
||||||
message: Message,
|
message: Message,
|
||||||
query: str,
|
query: str,
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
) -> Union[Generator, LLMResult]:
|
) -> Union[Generator, LLMResult]:
|
||||||
"""
|
"""
|
||||||
Run Cot agent application
|
Run Cot agent application
|
||||||
@ -37,7 +38,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
app_orchestration_config = self.app_orchestration_config
|
app_orchestration_config = self.app_orchestration_config
|
||||||
self._repack_app_orchestration_config(app_orchestration_config)
|
self._repack_app_orchestration_config(app_orchestration_config)
|
||||||
|
|
||||||
agent_scratchpad: List[AgentScratchpadUnit] = []
|
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||||
|
|
||||||
# check model mode
|
# check model mode
|
||||||
if self.app_orchestration_config.model_config.mode == "completion":
|
if self.app_orchestration_config.model_config.mode == "completion":
|
||||||
@ -56,7 +57,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
prompt_messages = self.history_prompt_messages
|
prompt_messages = self.history_prompt_messages
|
||||||
|
|
||||||
# convert tools into ModelRuntime Tool format
|
# convert tools into ModelRuntime Tool format
|
||||||
prompt_messages_tools: List[PromptMessageTool] = []
|
prompt_messages_tools: list[PromptMessageTool] = []
|
||||||
tool_instances = {}
|
tool_instances = {}
|
||||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||||
try:
|
try:
|
||||||
@ -83,7 +84,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
}
|
}
|
||||||
final_answer = ''
|
final_answer = ''
|
||||||
|
|
||||||
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||||
if not final_llm_usage_dict['usage']:
|
if not final_llm_usage_dict['usage']:
|
||||||
final_llm_usage_dict['usage'] = usage
|
final_llm_usage_dict['usage'] = usage
|
||||||
else:
|
else:
|
||||||
@ -493,7 +494,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
if not next_iteration.find("{{observation}}") >= 0:
|
if not next_iteration.find("{{observation}}") >= 0:
|
||||||
raise ValueError("{{observation}} is required in next_iteration")
|
raise ValueError("{{observation}} is required in next_iteration")
|
||||||
|
|
||||||
def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
|
def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||||
"""
|
"""
|
||||||
convert agent scratchpad list to str
|
convert agent scratchpad list to str
|
||||||
"""
|
"""
|
||||||
@ -506,13 +507,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||||
prompt_messages: List[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
tools: List[PromptMessageTool],
|
tools: list[PromptMessageTool],
|
||||||
agent_scratchpad: List[AgentScratchpadUnit],
|
agent_scratchpad: list[AgentScratchpadUnit],
|
||||||
agent_prompt_message: AgentPromptEntity,
|
agent_prompt_message: AgentPromptEntity,
|
||||||
instruction: str,
|
instruction: str,
|
||||||
input: str,
|
input: str,
|
||||||
) -> List[PromptMessage]:
|
) -> list[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
organize chain of thought prompt messages, a standard prompt message is like:
|
organize chain of thought prompt messages, a standard prompt message is like:
|
||||||
Respond to the human as helpfully and accurately as possible.
|
Respond to the human as helpfully and accurately as possible.
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Generator, List, Tuple, Union
|
from collections.abc import Generator
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
from core.application_queue_manager import PublishFrom
|
from core.application_queue_manager import PublishFrom
|
||||||
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
||||||
@ -44,7 +45,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# convert tools into ModelRuntime Tool format
|
# convert tools into ModelRuntime Tool format
|
||||||
prompt_messages_tools: List[PromptMessageTool] = []
|
prompt_messages_tools: list[PromptMessageTool] = []
|
||||||
tool_instances = {}
|
tool_instances = {}
|
||||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||||
try:
|
try:
|
||||||
@ -70,13 +71,13 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
|
|
||||||
# continue to run until there is not any tool call
|
# continue to run until there is not any tool call
|
||||||
function_call_state = True
|
function_call_state = True
|
||||||
agent_thoughts: List[MessageAgentThought] = []
|
agent_thoughts: list[MessageAgentThought] = []
|
||||||
llm_usage = {
|
llm_usage = {
|
||||||
'usage': None
|
'usage': None
|
||||||
}
|
}
|
||||||
final_answer = ''
|
final_answer = ''
|
||||||
|
|
||||||
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||||
if not final_llm_usage_dict['usage']:
|
if not final_llm_usage_dict['usage']:
|
||||||
final_llm_usage_dict['usage'] = usage
|
final_llm_usage_dict['usage'] = usage
|
||||||
else:
|
else:
|
||||||
@ -117,7 +118,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
callbacks=[],
|
callbacks=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_calls: List[Tuple[str, str, Dict[str, Any]]] = []
|
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||||
|
|
||||||
# save full response
|
# save full response
|
||||||
response = ''
|
response = ''
|
||||||
@ -364,7 +365,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||||
"""
|
"""
|
||||||
Extract tool calls from llm result chunk
|
Extract tool calls from llm result chunk
|
||||||
|
|
||||||
@ -381,7 +382,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
|
|
||||||
return tool_calls
|
return tool_calls
|
||||||
|
|
||||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||||
"""
|
"""
|
||||||
Extract blocking tool calls from llm result
|
Extract blocking tool calls from llm result
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ class DatasetRetrievalFeature:
|
|||||||
return_resource: bool,
|
return_resource: bool,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
hit_callback: DatasetIndexToolCallbackHandler) \
|
hit_callback: DatasetIndexToolCallbackHandler) \
|
||||||
-> Optional[List[BaseTool]]:
|
-> Optional[list[BaseTool]]:
|
||||||
"""
|
"""
|
||||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||||
:param tenant_id: tenant id
|
:param tenant_id: tenant id
|
||||||
|
@ -2,7 +2,7 @@ import concurrent
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
|
|
||||||
@ -62,7 +62,7 @@ class ExternalDataFetchFeature:
|
|||||||
app_id: str,
|
app_id: str,
|
||||||
external_data_tool: ExternalDataVariableEntity,
|
external_data_tool: ExternalDataVariableEntity,
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str) -> Tuple[Optional[str], Optional[str]]:
|
query: str) -> tuple[Optional[str], Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Query external data tool.
|
Query external data tool.
|
||||||
:param flask_app: flask app
|
:param flask_app: flask app
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from core.entities.application_entities import AppOrchestrationConfigEntity
|
from core.entities.application_entities import AppOrchestrationConfigEntity
|
||||||
from core.moderation.base import ModerationAction, ModerationException
|
from core.moderation.base import ModerationAction, ModerationException
|
||||||
@ -13,7 +12,7 @@ class ModerationFeature:
|
|||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str) -> Tuple[bool, dict, str]:
|
query: str) -> tuple[bool, dict, str]:
|
||||||
"""
|
"""
|
||||||
Process sensitive_word_avoidance.
|
Process sensitive_word_avoidance.
|
||||||
:param app_id: app id
|
:param app_id: app id
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@ -15,8 +15,8 @@ class MessageFileParser:
|
|||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.app_id = app_id
|
self.app_id = app_id
|
||||||
|
|
||||||
def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig,
|
def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig,
|
||||||
user: Union[Account, EndUser]) -> List[FileObj]:
|
user: Union[Account, EndUser]) -> list[FileObj]:
|
||||||
"""
|
"""
|
||||||
validate and transform files arg
|
validate and transform files arg
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ class MessageFileParser:
|
|||||||
# return all file objs
|
# return all file objs
|
||||||
return new_files
|
return new_files
|
||||||
|
|
||||||
def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]:
|
def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]:
|
||||||
"""
|
"""
|
||||||
transform message files
|
transform message files
|
||||||
|
|
||||||
@ -110,8 +110,8 @@ class MessageFileParser:
|
|||||||
# return all file objs
|
# return all file objs
|
||||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
||||||
|
|
||||||
def _to_file_objs(self, files: List[Union[Dict, MessageFile]],
|
def _to_file_objs(self, files: list[Union[dict, MessageFile]],
|
||||||
file_upload_config: dict) -> Dict[FileType, List[FileObj]]:
|
file_upload_config: dict) -> dict[FileType, list[FileObj]]:
|
||||||
"""
|
"""
|
||||||
transform files to file objs
|
transform files to file objs
|
||||||
|
|
||||||
@ -119,7 +119,7 @@ class MessageFileParser:
|
|||||||
:param file_upload_config:
|
:param file_upload_config:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
type_file_objs: Dict[FileType, List[FileObj]] = {
|
type_file_objs: dict[FileType, list[FileObj]] = {
|
||||||
# Currently only support image
|
# Currently only support image
|
||||||
FileType.IMAGE: []
|
FileType.IMAGE: []
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List
|
from typing import Any
|
||||||
|
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ class BaseIndex(ABC):
|
|||||||
def search(
|
def search(
|
||||||
self, query: str,
|
self, query: str,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Set
|
|
||||||
|
|
||||||
import jieba
|
import jieba
|
||||||
from jieba.analyse import default_tfidf
|
from jieba.analyse import default_tfidf
|
||||||
@ -12,7 +11,7 @@ class JiebaKeywordTableHandler:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
default_tfidf.stop_words = STOPWORDS
|
default_tfidf.stop_words = STOPWORDS
|
||||||
|
|
||||||
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]:
|
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]:
|
||||||
"""Extract keywords with JIEBA tfidf."""
|
"""Extract keywords with JIEBA tfidf."""
|
||||||
keywords = jieba.analyse.extract_tags(
|
keywords = jieba.analyse.extract_tags(
|
||||||
sentence=text,
|
sentence=text,
|
||||||
@ -21,7 +20,7 @@ class JiebaKeywordTableHandler:
|
|||||||
|
|
||||||
return set(self._expand_tokens_with_subtokens(keywords))
|
return set(self._expand_tokens_with_subtokens(keywords))
|
||||||
|
|
||||||
def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]:
|
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
|
||||||
"""Get subtokens from a list of tokens., filtering for stopwords."""
|
"""Get subtokens from a list of tokens., filtering for stopwords."""
|
||||||
results = set()
|
results = set()
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
from pydantic import BaseModel, Extra, Field
|
from pydantic import BaseModel, Extra, Field
|
||||||
@ -116,7 +116,7 @@ class KeywordTableIndex(BaseIndex):
|
|||||||
def search(
|
def search(
|
||||||
self, query: str,
|
self, query: str,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
keyword_table = self._get_dataset_keyword_table()
|
keyword_table = self._get_dataset_keyword_table()
|
||||||
|
|
||||||
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
|
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
|
||||||
@ -221,7 +221,7 @@ class KeywordTableIndex(BaseIndex):
|
|||||||
keywords = keyword_table_handler.extract_keywords(query)
|
keywords = keyword_table_handler.extract_keywords(query)
|
||||||
|
|
||||||
# go through text chunks in order of most matching keywords
|
# go through text chunks in order of most matching keywords
|
||||||
chunk_indices_count: Dict[str, int] = defaultdict(int)
|
chunk_indices_count: dict[str, int] = defaultdict(int)
|
||||||
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
|
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
|
||||||
for keyword in keywords:
|
for keyword in keywords:
|
||||||
for node_id in keyword_table[keyword]:
|
for node_id in keyword_table[keyword]:
|
||||||
@ -235,7 +235,7 @@ class KeywordTableIndex(BaseIndex):
|
|||||||
|
|
||||||
return sorted_chunk_indices[: k]
|
return sorted_chunk_indices[: k]
|
||||||
|
|
||||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: List[str]):
|
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
|
||||||
document_segment = db.session.query(DocumentSegment).filter(
|
document_segment = db.session.query(DocumentSegment).filter(
|
||||||
DocumentSegment.dataset_id == dataset_id,
|
DocumentSegment.dataset_id == dataset_id,
|
||||||
DocumentSegment.index_node_id == node_id
|
DocumentSegment.index_node_id == node_id
|
||||||
@ -244,7 +244,7 @@ class KeywordTableIndex(BaseIndex):
|
|||||||
document_segment.keywords = keywords
|
document_segment.keywords = keywords
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def create_segment_keywords(self, node_id: str, keywords: List[str]):
|
def create_segment_keywords(self, node_id: str, keywords: list[str]):
|
||||||
keyword_table = self._get_dataset_keyword_table()
|
keyword_table = self._get_dataset_keyword_table()
|
||||||
self._update_segment_keywords(self.dataset.id, node_id, keywords)
|
self._update_segment_keywords(self.dataset.id, node_id, keywords)
|
||||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||||
@ -266,7 +266,7 @@ class KeywordTableIndex(BaseIndex):
|
|||||||
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
|
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
|
||||||
self._save_dataset_keyword_table(keyword_table)
|
self._save_dataset_keyword_table(keyword_table)
|
||||||
|
|
||||||
def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
|
def update_segment_keywords_index(self, node_id: str, keywords: list[str]):
|
||||||
keyword_table = self._get_dataset_keyword_table()
|
keyword_table = self._get_dataset_keyword_table()
|
||||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||||
self._save_dataset_keyword_table(keyword_table)
|
self._save_dataset_keyword_table(keyword_table)
|
||||||
@ -282,7 +282,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel):
|
|||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
def get_relevant_documents(self, query: str) -> list[Document]:
|
||||||
"""Get documents relevant for a query.
|
"""Get documents relevant for a query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -293,7 +293,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel):
|
|||||||
"""
|
"""
|
||||||
return self.index.search(query, **self.search_kwargs)
|
return self.index.search(query, **self.search_kwargs)
|
||||||
|
|
||||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
async def aget_relevant_documents(self, query: str) -> list[Document]:
|
||||||
raise NotImplementedError("KeywordTableRetriever does not support async")
|
raise NotImplementedError("KeywordTableRetriever does not support async")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, List, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
@ -43,13 +43,13 @@ class BaseVectorIndex(BaseIndex):
|
|||||||
def search_by_full_text_index(
|
def search_by_full_text_index(
|
||||||
self, query: str,
|
self, query: str,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, query: str,
|
self, query: str,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
vector_store = self._get_vector_store()
|
vector_store = self._get_vector_store()
|
||||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
@ -160,6 +160,6 @@ class MilvusVectorIndex(BaseVectorIndex):
|
|||||||
],
|
],
|
||||||
))
|
))
|
||||||
|
|
||||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
|
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
# milvus/zilliz doesn't support bm25 search
|
# milvus/zilliz doesn't support bm25 search
|
||||||
return []
|
return []
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, List, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
import qdrant_client
|
import qdrant_client
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
@ -210,7 +210,7 @@ class QdrantVectorIndex(BaseVectorIndex):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
|
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
vector_store = self._get_vector_store()
|
vector_store = self._get_vector_store()
|
||||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import weaviate
|
import weaviate
|
||||||
@ -172,7 +172,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
|
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
vector_store = self._get_vector_store()
|
vector_store = self._get_vector_store()
|
||||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
|
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
|
||||||
|
@ -5,7 +5,7 @@ import re
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
@ -40,7 +40,7 @@ class IndexingRunner:
|
|||||||
self.storage = storage
|
self.storage = storage
|
||||||
self.model_manager = ModelManager()
|
self.model_manager = ModelManager()
|
||||||
|
|
||||||
def run(self, dataset_documents: List[DatasetDocument]):
|
def run(self, dataset_documents: list[DatasetDocument]):
|
||||||
"""Run the indexing process."""
|
"""Run the indexing process."""
|
||||||
for dataset_document in dataset_documents:
|
for dataset_document in dataset_documents:
|
||||||
try:
|
try:
|
||||||
@ -238,7 +238,7 @@ class IndexingRunner:
|
|||||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
|
def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict,
|
||||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||||
indexing_technique: str = 'economy') -> dict:
|
indexing_technique: str = 'economy') -> dict:
|
||||||
"""
|
"""
|
||||||
@ -494,7 +494,7 @@ class IndexingRunner:
|
|||||||
"preview": preview_texts
|
"preview": preview_texts
|
||||||
}
|
}
|
||||||
|
|
||||||
def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]:
|
def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]:
|
||||||
# load file
|
# load file
|
||||||
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
|
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
|
||||||
return []
|
return []
|
||||||
@ -526,7 +526,7 @@ class IndexingRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# replace doc id to document model id
|
# replace doc id to document model id
|
||||||
text_docs = cast(List[Document], text_docs)
|
text_docs = cast(list[Document], text_docs)
|
||||||
for text_doc in text_docs:
|
for text_doc in text_docs:
|
||||||
# remove invalid symbol
|
# remove invalid symbol
|
||||||
text_doc.page_content = self.filter_string(text_doc.page_content)
|
text_doc.page_content = self.filter_string(text_doc.page_content)
|
||||||
@ -540,7 +540,7 @@ class IndexingRunner:
|
|||||||
text = re.sub(r'\|>', '>', text)
|
text = re.sub(r'\|>', '>', text)
|
||||||
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
|
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
|
||||||
# Unicode U+FFFE
|
# Unicode U+FFFE
|
||||||
text = re.sub(u'\uFFFE', '', text)
|
text = re.sub('\uFFFE', '', text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def _get_splitter(self, processing_rule: DatasetProcessRule,
|
def _get_splitter(self, processing_rule: DatasetProcessRule,
|
||||||
@ -577,9 +577,9 @@ class IndexingRunner:
|
|||||||
|
|
||||||
return character_splitter
|
return character_splitter
|
||||||
|
|
||||||
def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
|
def _step_split(self, text_docs: list[Document], splitter: TextSplitter,
|
||||||
dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
|
dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
|
||||||
-> List[Document]:
|
-> list[Document]:
|
||||||
"""
|
"""
|
||||||
Split the text documents into documents and save them to the document segment.
|
Split the text documents into documents and save them to the document segment.
|
||||||
"""
|
"""
|
||||||
@ -624,9 +624,9 @@ class IndexingRunner:
|
|||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
|
def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter,
|
||||||
processing_rule: DatasetProcessRule, tenant_id: str,
|
processing_rule: DatasetProcessRule, tenant_id: str,
|
||||||
document_form: str, document_language: str) -> List[Document]:
|
document_form: str, document_language: str) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Split the text documents into nodes.
|
Split the text documents into nodes.
|
||||||
"""
|
"""
|
||||||
@ -699,8 +699,8 @@ class IndexingRunner:
|
|||||||
|
|
||||||
all_qa_documents.extend(format_documents)
|
all_qa_documents.extend(format_documents)
|
||||||
|
|
||||||
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
|
def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter,
|
||||||
processing_rule: DatasetProcessRule) -> List[Document]:
|
processing_rule: DatasetProcessRule) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Split the text documents into nodes.
|
Split the text documents into nodes.
|
||||||
"""
|
"""
|
||||||
@ -770,7 +770,7 @@ class IndexingRunner:
|
|||||||
for q, a in matches if q and a
|
for q, a in matches if q and a
|
||||||
]
|
]
|
||||||
|
|
||||||
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
|
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None:
|
||||||
"""
|
"""
|
||||||
Build the index for the document.
|
Build the index for the document.
|
||||||
"""
|
"""
|
||||||
@ -877,7 +877,7 @@ class IndexingRunner:
|
|||||||
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
|
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset):
|
def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset):
|
||||||
"""
|
"""
|
||||||
Batch add segments index processing
|
Batch add segments index processing
|
||||||
"""
|
"""
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import IO, Generator, List, Optional, Union, cast
|
from collections.abc import Generator
|
||||||
|
from typing import IO, Optional, Union, cast
|
||||||
|
|
||||||
from core.entities.provider_configuration import ProviderModelBundle
|
from core.entities.provider_configuration import ProviderModelBundle
|
||||||
from core.errors.error import ProviderTokenNotInitError
|
from core.errors.error import ProviderTokenNotInitError
|
||||||
@ -47,7 +48,7 @@ class ModelInstance:
|
|||||||
return credentials
|
return credentials
|
||||||
|
|
||||||
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
||||||
-> Union[LLMResult, Generator]:
|
-> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||||
@ -23,7 +23,7 @@ class Callback(ABC):
|
|||||||
|
|
||||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None) -> None:
|
stream: bool = True, user: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Before invoke callback
|
Before invoke callback
|
||||||
@ -42,7 +42,7 @@ class Callback(ABC):
|
|||||||
|
|
||||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None):
|
stream: bool = True, user: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
On new chunk callback
|
On new chunk callback
|
||||||
@ -62,7 +62,7 @@ class Callback(ABC):
|
|||||||
|
|
||||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None) -> None:
|
stream: bool = True, user: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
After invoke callback
|
After invoke callback
|
||||||
@ -82,7 +82,7 @@ class Callback(ABC):
|
|||||||
|
|
||||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None) -> None:
|
stream: bool = True, user: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Invoke error callback
|
Invoke error callback
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.model_runtime.callbacks.base_callback import Callback
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class LoggingCallback(Callback):
|
class LoggingCallback(Callback):
|
||||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None) -> None:
|
stream: bool = True, user: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Before invoke callback
|
Before invoke callback
|
||||||
@ -60,7 +60,7 @@ class LoggingCallback(Callback):
|
|||||||
|
|
||||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None):
|
stream: bool = True, user: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
On new chunk callback
|
On new chunk callback
|
||||||
@ -81,7 +81,7 @@ class LoggingCallback(Callback):
|
|||||||
|
|
||||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None) -> None:
|
stream: bool = True, user: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
After invoke callback
|
After invoke callback
|
||||||
@ -113,7 +113,7 @@ class LoggingCallback(Callback):
|
|||||||
|
|
||||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None) -> None:
|
stream: bool = True, user: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Invoke error callback
|
Invoke error callback
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
from typing import Dict
|
|
||||||
|
|
||||||
from core.model_runtime.entities.model_entities import DefaultParameterName
|
from core.model_runtime.entities.model_entities import DefaultParameterName
|
||||||
|
|
||||||
PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = {
|
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||||
DefaultParameterName.TEMPERATURE: {
|
DefaultParameterName.TEMPERATURE: {
|
||||||
'label': {
|
'label': {
|
||||||
'en_US': 'Temperature',
|
'en_US': 'Temperature',
|
||||||
|
@ -153,7 +153,7 @@ class AIModel(ABC):
|
|||||||
# read _position.yaml file
|
# read _position.yaml file
|
||||||
position_map = {}
|
position_map = {}
|
||||||
if os.path.exists(position_file_path):
|
if os.path.exists(position_file_path):
|
||||||
with open(position_file_path, 'r', encoding='utf-8') as f:
|
with open(position_file_path, encoding='utf-8') as f:
|
||||||
positions = yaml.safe_load(f)
|
positions = yaml.safe_load(f)
|
||||||
# convert list to dict with key as model provider name, value as index
|
# convert list to dict with key as model provider name, value as index
|
||||||
position_map = {position: index for index, position in enumerate(positions)}
|
position_map = {position: index for index, position in enumerate(positions)}
|
||||||
@ -161,7 +161,7 @@ class AIModel(ABC):
|
|||||||
# traverse all model_schema_yaml_paths
|
# traverse all model_schema_yaml_paths
|
||||||
for model_schema_yaml_path in model_schema_yaml_paths:
|
for model_schema_yaml_path in model_schema_yaml_paths:
|
||||||
# read yaml data from yaml file
|
# read yaml data from yaml file
|
||||||
with open(model_schema_yaml_path, 'r', encoding='utf-8') as f:
|
with open(model_schema_yaml_path, encoding='utf-8') as f:
|
||||||
yaml_data = yaml.safe_load(f)
|
yaml_data = yaml.safe_load(f)
|
||||||
|
|
||||||
new_parameter_rules = []
|
new_parameter_rules = []
|
||||||
|
@ -3,7 +3,8 @@ import os
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Generator, List, Optional, Union
|
from collections.abc import Generator
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
from core.model_runtime.callbacks.base_callback import Callback
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||||
@ -29,7 +30,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
|
|
||||||
def invoke(self, model: str, credentials: dict,
|
def invoke(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
||||||
-> Union[LLMResult, Generator]:
|
-> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
@ -122,7 +123,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
|
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[List[str]] = None, stream: bool = True,
|
stop: Optional[list[str]] = None, stream: bool = True,
|
||||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator:
|
user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator:
|
||||||
"""
|
"""
|
||||||
Invoke result generator
|
Invoke result generator
|
||||||
@ -186,7 +187,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None) \
|
stream: bool = True, user: Optional[str] = None) \
|
||||||
-> Union[LLMResult, Generator]:
|
-> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
@ -218,7 +219,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def enforce_stop_tokens(self, text: str, stop: List[str]) -> str:
|
def enforce_stop_tokens(self, text: str, stop: list[str]) -> str:
|
||||||
"""Cut off the text as soon as any stop words occur."""
|
"""Cut off the text as soon as any stop words occur."""
|
||||||
return re.split("|".join(stop), text, maxsplit=1)[0]
|
return re.split("|".join(stop), text, maxsplit=1)[0]
|
||||||
|
|
||||||
@ -329,7 +330,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
def _trigger_before_invoke_callbacks(self, model: str, credentials: dict,
|
def _trigger_before_invoke_callbacks(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[List[str]] = None, stream: bool = True,
|
stop: Optional[list[str]] = None, stream: bool = True,
|
||||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Trigger before invoke callbacks
|
Trigger before invoke callbacks
|
||||||
@ -367,7 +368,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict,
|
def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[List[str]] = None, stream: bool = True,
|
stop: Optional[list[str]] = None, stream: bool = True,
|
||||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Trigger new chunk callbacks
|
Trigger new chunk callbacks
|
||||||
@ -406,7 +407,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict,
|
def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[List[str]] = None, stream: bool = True,
|
stop: Optional[list[str]] = None, stream: bool = True,
|
||||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Trigger after invoke callbacks
|
Trigger after invoke callbacks
|
||||||
@ -446,7 +447,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict,
|
def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[List[str]] = None, stream: bool = True,
|
stop: Optional[list[str]] = None, stream: bool = True,
|
||||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Trigger invoke error callbacks
|
Trigger invoke error callbacks
|
||||||
@ -527,7 +528,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
|
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
|
||||||
elif parameter_rule.type == ParameterType.FLOAT:
|
elif parameter_rule.type == ParameterType.FLOAT:
|
||||||
if not isinstance(parameter_value, (float, int)):
|
if not isinstance(parameter_value, float | int):
|
||||||
raise ValueError(f"Model Parameter {parameter_name} should be float.")
|
raise ValueError(f"Model Parameter {parameter_name} should be float.")
|
||||||
|
|
||||||
# validate parameter value precision
|
# validate parameter value precision
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -12,7 +11,7 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel
|
|||||||
|
|
||||||
class ModelProvider(ABC):
|
class ModelProvider(ABC):
|
||||||
provider_schema: ProviderEntity = None
|
provider_schema: ProviderEntity = None
|
||||||
model_instance_map: Dict[str, AIModel] = {}
|
model_instance_map: dict[str, AIModel] = {}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
@ -47,7 +46,7 @@ class ModelProvider(ABC):
|
|||||||
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
||||||
yaml_data = {}
|
yaml_data = {}
|
||||||
if os.path.exists(yaml_path):
|
if os.path.exists(yaml_path):
|
||||||
with open(yaml_path, 'r', encoding='utf-8') as f:
|
with open(yaml_path, encoding='utf-8') as f:
|
||||||
yaml_data = yaml.safe_load(f)
|
yaml_data = yaml.safe_load(f)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Generator, List, Optional, Union
|
from collections.abc import Generator
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
from anthropic import Anthropic, Stream
|
from anthropic import Anthropic, Stream
|
||||||
@ -29,7 +30,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None) \
|
stream: bool = True, user: Optional[str] = None) \
|
||||||
-> Union[LLMResult, Generator]:
|
-> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
@ -90,7 +91,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
def _generate(self, model: str, credentials: dict,
|
def _generate(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
stop: Optional[List[str]] = None, stream: bool = True,
|
stop: Optional[list[str]] = None, stream: bool = True,
|
||||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
@ -255,7 +256,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
return message_text
|
return message_text
|
||||||
|
|
||||||
def _convert_messages_to_prompt_anthropic(self, messages: List[PromptMessage]) -> str:
|
def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) -> str:
|
||||||
"""
|
"""
|
||||||
Format a list of messages into a full prompt for the Anthropic model
|
Format a list of messages into a full prompt for the Anthropic model
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
from typing import Generator, List, Optional, Union, cast
|
from collections.abc import Generator
|
||||||
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from openai import AzureOpenAI, Stream
|
from openai import AzureOpenAI, Stream
|
||||||
@ -34,7 +35,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None) \
|
stream: bool = True, user: Optional[str] = None) \
|
||||||
-> Union[LLMResult, Generator]:
|
-> Union[LLMResult, Generator]:
|
||||||
|
|
||||||
@ -121,7 +122,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
return ai_model_entity.entity if ai_model_entity else None
|
return ai_model_entity.entity if ai_model_entity else None
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict,
|
def _generate(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
|
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||||
|
|
||||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||||
@ -239,7 +240,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
def _chat_generate(self, model: str, credentials: dict,
|
def _chat_generate(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||||
|
|
||||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||||
@ -537,7 +538,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def _num_tokens_from_messages(self, credentials: dict, messages: List[PromptMessage],
|
def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user