feat(tools/cogview): Updated cogview tool to support cogview-3 and the latest cogview-3-plus (#8382)

This commit is contained in:
Waffle 2024-09-22 10:14:14 +08:00 committed by GitHub
parent 0665268578
commit 740fad06c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
108 changed files with 6513 additions and 405 deletions

View File

@ -1,7 +1,8 @@
from .__version__ import __version__
from ._client import ZhipuAI
from .core._errors import (
from .core import (
APIAuthenticationError,
APIConnectionError,
APIInternalError,
APIReachLimitError,
APIRequestFailedError,

View File

@ -1 +1 @@
__version__ = "v2.0.1"
__version__ = "v2.1.0"

View File

@ -9,15 +9,13 @@ from httpx import Timeout
from typing_extensions import override
from . import api_resource
from .core import _jwt_token
from .core._base_type import NOT_GIVEN, NotGiven
from .core._errors import ZhipuAIError
from .core._http_client import ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient
from .core import NOT_GIVEN, ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient, NotGiven, ZhipuAIError, _jwt_token
class ZhipuAI(HttpClient):
chat: api_resource.chat
chat: api_resource.chat.Chat
api_key: str
_disable_token_cache: bool = True
def __init__(
self,
@ -28,10 +26,15 @@ class ZhipuAI(HttpClient):
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
http_client: httpx.Client | None = None,
custom_headers: Mapping[str, str] | None = None,
disable_token_cache: bool = True,
_strict_response_validation: bool = False,
) -> None:
if api_key is None:
raise ZhipuAIError("No api_key provided, please provide it through parameters or environment variables")
api_key = os.environ.get("ZHIPUAI_API_KEY")
if api_key is None:
raise ZhipuAIError("未提供api_key请通过参数或环境变量提供")
self.api_key = api_key
self._disable_token_cache = disable_token_cache
if base_url is None:
base_url = os.environ.get("ZHIPUAI_BASE_URL")
@ -42,21 +45,31 @@ class ZhipuAI(HttpClient):
super().__init__(
version=__version__,
base_url=base_url,
max_retries=max_retries,
timeout=timeout,
custom_httpx_client=http_client,
custom_headers=custom_headers,
_strict_response_validation=_strict_response_validation,
)
self.chat = api_resource.chat.Chat(self)
self.images = api_resource.images.Images(self)
self.embeddings = api_resource.embeddings.Embeddings(self)
self.files = api_resource.files.Files(self)
self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
self.batches = api_resource.Batches(self)
self.knowledge = api_resource.Knowledge(self)
self.tools = api_resource.Tools(self)
self.videos = api_resource.Videos(self)
self.assistant = api_resource.Assistant(self)
@property
@override
def _auth_headers(self) -> dict[str, str]:
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
return {"Authorization": f"{_jwt_token.generate_token(api_key)}"}
if self._disable_token_cache:
return {"Authorization": f"Bearer {api_key}"}
else:
return {"Authorization": f"Bearer {_jwt_token.generate_token(api_key)}"}
def __del__(self) -> None:
if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close") or not hasattr(self, "_client"):

View File

@ -1,5 +1,34 @@
from .chat import chat
from .assistant import (
Assistant,
)
from .batches import Batches
from .chat import (
AsyncCompletions,
Chat,
Completions,
)
from .embeddings import Embeddings
from .files import Files
from .fine_tuning import fine_tuning
from .files import Files, FilesWithRawResponse
from .fine_tuning import FineTuning
from .images import Images
from .knowledge import Knowledge
from .tools import Tools
from .videos import (
Videos,
)
__all__ = [
"Videos",
"AsyncCompletions",
"Chat",
"Completions",
"Images",
"Embeddings",
"Files",
"FilesWithRawResponse",
"FineTuning",
"Batches",
"Knowledge",
"Tools",
"Assistant",
]

View File

@ -0,0 +1,3 @@
from .assistant import Assistant
__all__ = ["Assistant"]

View File

@ -0,0 +1,122 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
StreamResponse,
deepcopy_minimal,
make_request_options,
maybe_transform,
)
from ...types.assistant import AssistantCompletion
from ...types.assistant.assistant_conversation_resp import ConversationUsageListResp
from ...types.assistant.assistant_support_resp import AssistantSupportResp
if TYPE_CHECKING:
from ..._client import ZhipuAI
from ...types.assistant import assistant_conversation_params, assistant_create_params
__all__ = ["Assistant"]
class Assistant(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def conversation(
self,
assistant_id: str,
model: str,
messages: list[assistant_create_params.ConversationMessage],
*,
stream: bool = True,
conversation_id: Optional[str] = None,
attachments: Optional[list[assistant_create_params.AssistantAttachments]] = None,
metadata: dict | None = None,
request_id: str = None,
user_id: str = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> StreamResponse[AssistantCompletion]:
body = deepcopy_minimal(
{
"assistant_id": assistant_id,
"model": model,
"messages": messages,
"stream": stream,
"conversation_id": conversation_id,
"attachments": attachments,
"metadata": metadata,
"request_id": request_id,
"user_id": user_id,
}
)
return self._post(
"/assistant",
body=maybe_transform(body, assistant_create_params.AssistantParameters),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=AssistantCompletion,
stream=stream or True,
stream_cls=StreamResponse[AssistantCompletion],
)
def query_support(
self,
*,
assistant_id_list: list[str] = None,
request_id: str = None,
user_id: str = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AssistantSupportResp:
body = deepcopy_minimal(
{
"assistant_id_list": assistant_id_list,
"request_id": request_id,
"user_id": user_id,
}
)
return self._post(
"/assistant/list",
body=body,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=AssistantSupportResp,
)
def query_conversation_usage(
self,
assistant_id: str,
page: int = 1,
page_size: int = 10,
*,
request_id: str = None,
user_id: str = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ConversationUsageListResp:
body = deepcopy_minimal(
{
"assistant_id": assistant_id,
"page": page,
"page_size": page_size,
"request_id": request_id,
"user_id": user_id,
}
)
return self._post(
"/assistant/conversation/list",
body=maybe_transform(body, assistant_conversation_params.ConversationParameters),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=ConversationUsageListResp,
)

View File

@ -0,0 +1,146 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Literal, Optional
import httpx
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options, maybe_transform
from ..core.pagination import SyncCursorPage
from ..types import batch_create_params, batch_list_params
from ..types.batch import Batch
if TYPE_CHECKING:
from .._client import ZhipuAI
class Batches(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
completion_window: str | None = None,
endpoint: Literal["/v1/chat/completions", "/v1/embeddings"],
input_file_id: str,
metadata: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
auto_delete_input_file: bool = True,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Batch:
return self._post(
"/batches",
body=maybe_transform(
{
"completion_window": completion_window,
"endpoint": endpoint,
"input_file_id": input_file_id,
"metadata": metadata,
"auto_delete_input_file": auto_delete_input_file,
},
batch_create_params.BatchCreateParams,
),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=Batch,
)
def retrieve(
self,
batch_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Batch:
"""
Retrieves a batch.
Args:
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not batch_id:
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
return self._get(
f"/batches/{batch_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=Batch,
)
def list(
self,
*,
after: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> SyncCursorPage[Batch]:
"""List your organization's batches.
Args:
after: A cursor for use in pagination.
`after` is an object ID that defines your place
in the list. For instance, if you make a list request and receive 100 objects,
ending with obj_foo, your subsequent call can include after=obj_foo in order to
fetch the next page of the list.
limit: A limit on the number of objects to be returned. Limit can range between 1 and
100, and the default is 20.
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
return self._get_api_list(
"/batches",
page=SyncCursorPage[Batch],
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"after": after,
"limit": limit,
},
batch_list_params.BatchListParams,
),
),
model=Batch,
)
def cancel(
self,
batch_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Batch:
"""
Cancels an in-progress batch.
Args:
batch_id: The ID of the batch to cancel.
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not batch_id:
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
return self._post(
f"/batches/{batch_id}/cancel",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=Batch,
)

View File

@ -0,0 +1,5 @@
from .async_completions import AsyncCompletions
from .chat import Chat
from .completions import Completions
__all__ = ["AsyncCompletions", "Chat", "Completions"]

View File

@ -1,13 +1,25 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Literal, Optional, Union
import httpx
from ...core._base_api import BaseAPI
from ...core._base_type import NOT_GIVEN, Headers, NotGiven
from ...core._http_client import make_user_request_input
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
drop_prefix_image_data,
make_request_options,
maybe_transform,
)
from ...types.chat.async_chat_completion import AsyncCompletion, AsyncTaskStatus
from ...types.chat.code_geex import code_geex_params
from ...types.sensitive_word_check import SensitiveWordCheckRequest
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from ..._client import ZhipuAI
@ -22,6 +34,7 @@ class AsyncCompletions(BaseAPI):
*,
model: str,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
user_id: Optional[str] | NotGiven = NOT_GIVEN,
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
@ -29,50 +42,74 @@ class AsyncCompletions(BaseAPI):
seed: int | NotGiven = NOT_GIVEN,
messages: Union[str, list[str], list[int], list[list[int]], None],
stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
tools: Optional[object] | NotGiven = NOT_GIVEN,
tool_choice: str | NotGiven = NOT_GIVEN,
meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
disable_strict_validation: Optional[bool] | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AsyncTaskStatus:
_cast_type = AsyncTaskStatus
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
if temperature is not None and temperature != NOT_GIVEN:
if temperature <= 0:
do_sample = False
temperature = 0.01
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间do_sample重写为:false参数top_p temperture不生效") # noqa: E501
if temperature >= 1:
temperature = 0.99
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
if top_p is not None and top_p != NOT_GIVEN:
if top_p >= 1:
top_p = 0.99
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
if top_p <= 0:
top_p = 0.01
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
if disable_strict_validation:
_cast_type = object
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
if isinstance(messages, list):
for item in messages:
if item.get("content"):
item["content"] = drop_prefix_image_data(item["content"])
body = {
"model": model,
"request_id": request_id,
"user_id": user_id,
"temperature": temperature,
"top_p": top_p,
"do_sample": do_sample,
"max_tokens": max_tokens,
"seed": seed,
"messages": messages,
"stop": stop,
"sensitive_word_check": sensitive_word_check,
"tools": tools,
"tool_choice": tool_choice,
"meta": meta,
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
}
return self._post(
"/async/chat/completions",
body={
"model": model,
"request_id": request_id,
"temperature": temperature,
"top_p": top_p,
"do_sample": do_sample,
"max_tokens": max_tokens,
"seed": seed,
"messages": messages,
"stop": stop,
"sensitive_word_check": sensitive_word_check,
"tools": tools,
"tool_choice": tool_choice,
},
options=make_user_request_input(extra_headers=extra_headers, timeout=timeout),
body=body,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=_cast_type,
enable_stream=False,
stream=False,
)
def retrieve_completion_result(
self,
id: str,
extra_headers: Headers | None = None,
disable_strict_validation: Optional[bool] | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Union[AsyncCompletion, AsyncTaskStatus]:
_cast_type = Union[AsyncCompletion, AsyncTaskStatus]
if disable_strict_validation:
_cast_type = object
return self._get(
path=f"/async-result/{id}",
cast_type=_cast_type,
options=make_user_request_input(extra_headers=extra_headers, timeout=timeout),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
)

View File

@ -1,17 +1,18 @@
from typing import TYPE_CHECKING
from ...core._base_api import BaseAPI
from ...core import BaseAPI, cached_property
from .async_completions import AsyncCompletions
from .completions import Completions
if TYPE_CHECKING:
from ..._client import ZhipuAI
pass
class Chat(BaseAPI):
completions: Completions
@cached_property
def completions(self) -> Completions:
return Completions(self._client)
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
self.completions = Completions(client)
self.asyncCompletions = AsyncCompletions(client)
@cached_property
def asyncCompletions(self) -> AsyncCompletions: # noqa: N802
return AsyncCompletions(self._client)

View File

@ -1,15 +1,28 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Literal, Optional, Union
import httpx
from ...core._base_api import BaseAPI
from ...core._base_type import NOT_GIVEN, Headers, NotGiven
from ...core._http_client import make_user_request_input
from ...core._sse_client import StreamResponse
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
StreamResponse,
deepcopy_minimal,
drop_prefix_image_data,
make_request_options,
maybe_transform,
)
from ...types.chat.chat_completion import Completion
from ...types.chat.chat_completion_chunk import ChatCompletionChunk
from ...types.chat.code_geex import code_geex_params
from ...types.sensitive_word_check import SensitiveWordCheckRequest
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from ..._client import ZhipuAI
@ -24,6 +37,7 @@ class Completions(BaseAPI):
*,
model: str,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
user_id: Optional[str] | NotGiven = NOT_GIVEN,
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
@ -32,23 +46,43 @@ class Completions(BaseAPI):
seed: int | NotGiven = NOT_GIVEN,
messages: Union[str, list[str], list[int], object, None],
stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
tools: Optional[object] | NotGiven = NOT_GIVEN,
tool_choice: str | NotGiven = NOT_GIVEN,
meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
disable_strict_validation: Optional[bool] | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Completion | StreamResponse[ChatCompletionChunk]:
_cast_type = Completion
_stream_cls = StreamResponse[ChatCompletionChunk]
if disable_strict_validation:
_cast_type = object
_stream_cls = StreamResponse[object]
return self._post(
"/chat/completions",
body={
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
if temperature is not None and temperature != NOT_GIVEN:
if temperature <= 0:
do_sample = False
temperature = 0.01
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间do_sample重写为:false参数top_p temperture不生效") # noqa: E501
if temperature >= 1:
temperature = 0.99
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
if top_p is not None and top_p != NOT_GIVEN:
if top_p >= 1:
top_p = 0.99
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
if top_p <= 0:
top_p = 0.01
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
if isinstance(messages, list):
for item in messages:
if item.get("content"):
item["content"] = drop_prefix_image_data(item["content"])
body = deepcopy_minimal(
{
"model": model,
"request_id": request_id,
"user_id": user_id,
"temperature": temperature,
"top_p": top_p,
"do_sample": do_sample,
@ -60,11 +94,15 @@ class Completions(BaseAPI):
"stream": stream,
"tools": tools,
"tool_choice": tool_choice,
},
options=make_user_request_input(
extra_headers=extra_headers,
),
cast_type=_cast_type,
enable_stream=stream or False,
stream_cls=_stream_cls,
"meta": meta,
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
}
)
return self._post(
"/chat/completions",
body=body,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=Completion,
stream=stream or False,
stream_cls=StreamResponse[ChatCompletionChunk],
)

View File

@ -4,9 +4,7 @@ from typing import TYPE_CHECKING, Optional, Union
import httpx
from ..core._base_api import BaseAPI
from ..core._base_type import NOT_GIVEN, Headers, NotGiven
from ..core._http_client import make_user_request_input
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options
from ..types.embeddings import EmbeddingsResponded
if TYPE_CHECKING:
@ -22,10 +20,13 @@ class Embeddings(BaseAPI):
*,
input: Union[str, list[str], list[int], list[list[int]]],
model: Union[str],
dimensions: Union[int] | NotGiven = NOT_GIVEN,
encoding_format: str | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> EmbeddingsResponded:
@ -37,11 +38,13 @@ class Embeddings(BaseAPI):
body={
"input": input,
"model": model,
"dimensions": dimensions,
"encoding_format": encoding_format,
"user": user,
"request_id": request_id,
"sensitive_word_check": sensitive_word_check,
},
options=make_user_request_input(extra_headers=extra_headers, timeout=timeout),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=_cast_type,
enable_stream=False,
stream=False,
)

View File

@ -1,19 +1,30 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from collections.abc import Mapping
from typing import TYPE_CHECKING, Literal, cast
import httpx
from ..core._base_api import BaseAPI
from ..core._base_type import NOT_GIVEN, FileTypes, Headers, NotGiven
from ..core._files import is_file_content
from ..core._http_client import make_user_request_input
from ..types.file_object import FileObject, ListOfFileObject
from ..core import (
NOT_GIVEN,
BaseAPI,
Body,
FileTypes,
Headers,
NotGiven,
_legacy_binary_response,
_legacy_response,
deepcopy_minimal,
extract_files,
make_request_options,
maybe_transform,
)
from ..types.files import FileDeleted, FileObject, ListOfFileObject, UploadDetail, file_create_params
if TYPE_CHECKING:
from .._client import ZhipuAI
__all__ = ["Files"]
__all__ = ["Files", "FilesWithRawResponse"]
class Files(BaseAPI):
@ -23,30 +34,69 @@ class Files(BaseAPI):
def create(
self,
*,
file: FileTypes,
purpose: str,
file: FileTypes = None,
upload_detail: list[UploadDetail] = None,
purpose: Literal["fine-tune", "retrieval", "batch"],
knowledge_id: str = None,
sentence_size: int = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FileObject:
if not is_file_content(file):
prefix = f"Expected file input `{file!r}`"
raise RuntimeError(
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(file)} instead."
) from None
files = [("file", file)]
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
if not file and not upload_detail:
raise ValueError("At least one of `file` and `upload_detail` must be provided.")
body = deepcopy_minimal(
{
"file": file,
"upload_detail": upload_detail,
"purpose": purpose,
"knowledge_id": knowledge_id,
"sentence_size": sentence_size,
}
)
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
if files:
# It should be noted that the actual Content-Type header that will be
# sent to the server will contain a `boundary` parameter, e.g.
# multipart/form-data; boundary=---abc--
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return self._post(
"/files",
body={
"purpose": purpose,
},
body=maybe_transform(body, file_create_params.FileCreateParams),
files=files,
options=make_user_request_input(extra_headers=extra_headers, timeout=timeout),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FileObject,
)
# def retrieve(
# self,
# file_id: str,
# *,
# extra_headers: Headers | None = None,
# extra_body: Body | None = None,
# timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
# ) -> FileObject:
# """
# Returns information about a specific file.
#
# Args:
# file_id: The ID of the file to retrieve information about
# extra_headers: Send extra headers
#
# extra_body: Add additional JSON properties to the request
#
# timeout: Override the client-level default timeout for this request, in seconds
# """
# if not file_id:
# raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
# return self._get(
# f"/files/{file_id}",
# options=make_request_options(
# extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
# ),
# cast_type=FileObject,
# )
def list(
self,
*,
@ -55,13 +105,15 @@ class Files(BaseAPI):
after: str | NotGiven = NOT_GIVEN,
order: str | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ListOfFileObject:
return self._get(
"/files",
cast_type=ListOfFileObject,
options=make_user_request_input(
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query={
"purpose": purpose,
@ -71,3 +123,72 @@ class Files(BaseAPI):
},
),
)
def delete(
self,
file_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FileDeleted:
"""
Delete a file.
Args:
file_id: The ID of the file to delete
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not file_id:
raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
return self._delete(
f"/files/{file_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FileDeleted,
)
def content(
self,
file_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> _legacy_response.HttpxBinaryResponseContent:
"""
Returns the contents of the specified file.
Args:
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not file_id:
raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
extra_headers = {"Accept": "application/binary", **(extra_headers or {})}
return self._get(
f"/files/{file_id}/content",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=_legacy_binary_response.HttpxBinaryResponseContent,
)
class FilesWithRawResponse:
def __init__(self, files: Files) -> None:
self._files = files
self.create = _legacy_response.to_raw_response_wrapper(
files.create,
)
self.list = _legacy_response.to_raw_response_wrapper(
files.list,
)
self.content = _legacy_response.to_raw_response_wrapper(
files.content,
)

View File

@ -0,0 +1,5 @@
from .fine_tuning import FineTuning
from .jobs import Jobs
from .models import FineTunedModels
__all__ = ["Jobs", "FineTunedModels", "FineTuning"]

View File

@ -1,15 +1,18 @@
from typing import TYPE_CHECKING
from ...core._base_api import BaseAPI
from ...core import BaseAPI, cached_property
from .jobs import Jobs
from .models import FineTunedModels
if TYPE_CHECKING:
from ..._client import ZhipuAI
pass
class FineTuning(BaseAPI):
jobs: Jobs
@cached_property
def jobs(self) -> Jobs:
return Jobs(self._client)
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
self.jobs = Jobs(client)
@cached_property
def models(self) -> FineTunedModels:
return FineTunedModels(self._client)

View File

@ -0,0 +1,3 @@
from .jobs import Jobs
__all__ = ["Jobs"]

View File

@ -4,13 +4,23 @@ from typing import TYPE_CHECKING, Optional
import httpx
from ...core._base_api import BaseAPI
from ...core._base_type import NOT_GIVEN, Headers, NotGiven
from ...core._http_client import make_user_request_input
from ...types.fine_tuning import FineTuningJob, FineTuningJobEvent, ListOfFineTuningJob, job_create_params
from ....core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
make_request_options,
)
from ....types.fine_tuning import (
FineTuningJob,
FineTuningJobEvent,
ListOfFineTuningJob,
job_create_params,
)
if TYPE_CHECKING:
from ..._client import ZhipuAI
from ...._client import ZhipuAI
__all__ = ["Jobs"]
@ -29,6 +39,7 @@ class Jobs(BaseAPI):
request_id: Optional[str] | NotGiven = NOT_GIVEN,
validation_file: Optional[str] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
return self._post(
@ -41,7 +52,7 @@ class Jobs(BaseAPI):
"validation_file": validation_file,
"request_id": request_id,
},
options=make_user_request_input(extra_headers=extra_headers, timeout=timeout),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FineTuningJob,
)
@ -50,11 +61,12 @@ class Jobs(BaseAPI):
fine_tuning_job_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
return self._get(
f"/fine_tuning/jobs/{fine_tuning_job_id}",
options=make_user_request_input(extra_headers=extra_headers, timeout=timeout),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FineTuningJob,
)
@ -64,13 +76,15 @@ class Jobs(BaseAPI):
after: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ListOfFineTuningJob:
return self._get(
"/fine_tuning/jobs",
cast_type=ListOfFineTuningJob,
options=make_user_request_input(
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query={
"after": after,
@ -79,6 +93,24 @@ class Jobs(BaseAPI):
),
)
def cancel(
self,
fine_tuning_job_id: str,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # noqa: E501
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
if not fine_tuning_job_id:
raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}")
return self._post(
f"/fine_tuning/jobs/{fine_tuning_job_id}/cancel",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FineTuningJob,
)
def list_events(
self,
fine_tuning_job_id: str,
@ -86,13 +118,15 @@ class Jobs(BaseAPI):
after: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJobEvent:
return self._get(
f"/fine_tuning/jobs/{fine_tuning_job_id}/events",
cast_type=FineTuningJobEvent,
options=make_user_request_input(
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query={
"after": after,
@ -100,3 +134,19 @@ class Jobs(BaseAPI):
},
),
)
def delete(
self,
fine_tuning_job_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
if not fine_tuning_job_id:
raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}")
return self._delete(
f"/fine_tuning/jobs/{fine_tuning_job_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FineTuningJob,
)

View File

@ -0,0 +1,3 @@
from .fine_tuned_models import FineTunedModels
__all__ = ["FineTunedModels"]

View File

@ -0,0 +1,41 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import httpx
from ....core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
make_request_options,
)
from ....types.fine_tuning.models import FineTunedModelsStatus
if TYPE_CHECKING:
from ...._client import ZhipuAI
__all__ = ["FineTunedModels"]
class FineTunedModels(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def delete(
self,
fine_tuned_model: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTunedModelsStatus:
if not fine_tuned_model:
raise ValueError(f"Expected a non-empty value for `fine_tuned_model` but received {fine_tuned_model!r}")
return self._delete(
f"fine_tuning/fine_tuned_models/{fine_tuned_model}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FineTunedModelsStatus,
)

View File

@ -4,10 +4,9 @@ from typing import TYPE_CHECKING, Optional
import httpx
from ..core._base_api import BaseAPI
from ..core._base_type import NOT_GIVEN, Body, Headers, NotGiven
from ..core._http_client import make_user_request_input
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options
from ..types.image import ImagesResponded
from ..types.sensitive_word_check import SensitiveWordCheckRequest
if TYPE_CHECKING:
from .._client import ZhipuAI
@ -27,8 +26,10 @@ class Images(BaseAPI):
response_format: Optional[str] | NotGiven = NOT_GIVEN,
size: Optional[str] | NotGiven = NOT_GIVEN,
style: Optional[str] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
user_id: Optional[str] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
disable_strict_validation: Optional[bool] | None = None,
@ -45,12 +46,14 @@ class Images(BaseAPI):
"n": n,
"quality": quality,
"response_format": response_format,
"sensitive_word_check": sensitive_word_check,
"size": size,
"style": style,
"user": user,
"user_id": user_id,
"request_id": request_id,
},
options=make_user_request_input(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=_cast_type,
enable_stream=False,
stream=False,
)

View File

@ -0,0 +1,3 @@
from .knowledge import Knowledge
__all__ = ["Knowledge"]

View File

@ -0,0 +1,3 @@
from .document import Document
__all__ = ["Document"]

View File

@ -0,0 +1,217 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import TYPE_CHECKING, Literal, Optional, cast
import httpx
from ....core import (
NOT_GIVEN,
BaseAPI,
Body,
FileTypes,
Headers,
NotGiven,
deepcopy_minimal,
extract_files,
make_request_options,
maybe_transform,
)
from ....types.files import UploadDetail, file_create_params
from ....types.knowledge.document import DocumentData, DocumentObject, document_edit_params, document_list_params
from ....types.knowledge.document.document_list_resp import DocumentPage
if TYPE_CHECKING:
from ...._client import ZhipuAI
__all__ = ["Document"]
class Document(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
file: FileTypes = None,
custom_separator: Optional[list[str]] = None,
upload_detail: list[UploadDetail] = None,
purpose: Literal["retrieval"],
knowledge_id: str = None,
sentence_size: int = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> DocumentObject:
if not file and not upload_detail:
raise ValueError("At least one of `file` and `upload_detail` must be provided.")
body = deepcopy_minimal(
{
"file": file,
"upload_detail": upload_detail,
"purpose": purpose,
"custom_separator": custom_separator,
"knowledge_id": knowledge_id,
"sentence_size": sentence_size,
}
)
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
if files:
# It should be noted that the actual Content-Type header that will be
# sent to the server will contain a `boundary` parameter, e.g.
# multipart/form-data; boundary=---abc--
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return self._post(
"/files",
body=maybe_transform(body, file_create_params.FileCreateParams),
files=files,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=DocumentObject,
)
def edit(
self,
document_id: str,
knowledge_type: str,
*,
custom_separator: Optional[list[str]] = None,
sentence_size: Optional[int] = None,
callback_url: Optional[str] = None,
callback_header: Optional[dict[str, str]] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> httpx.Response:
"""
Args:
document_id: 知识id
knowledge_type: 知识类型:
1:文章知识: 支持pdf,url,docx
2.问答知识-文档: 支持pdf,url,docx
3.问答知识-表格: 支持xlsx
4.商品库-表格: 支持xlsx
5.自定义: 支持pdf,url,docx
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
:param knowledge_type:
:param document_id:
:param timeout:
:param extra_body:
:param callback_header:
:param sentence_size:
:param extra_headers:
:param callback_url:
:param custom_separator:
"""
if not document_id:
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
body = deepcopy_minimal(
{
"id": document_id,
"knowledge_type": knowledge_type,
"custom_separator": custom_separator,
"sentence_size": sentence_size,
"callback_url": callback_url,
"callback_header": callback_header,
}
)
return self._put(
f"/document/{document_id}",
body=maybe_transform(body, document_edit_params.DocumentEditParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=httpx.Response,
)
def list(
self,
knowledge_id: str,
*,
purpose: str | NotGiven = NOT_GIVEN,
page: str | NotGiven = NOT_GIVEN,
limit: str | NotGiven = NOT_GIVEN,
order: Literal["desc", "asc"] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> DocumentPage:
return self._get(
"/files",
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"knowledge_id": knowledge_id,
"purpose": purpose,
"page": page,
"limit": limit,
"order": order,
},
document_list_params.DocumentListParams,
),
),
cast_type=DocumentPage,
)
def delete(
self,
document_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> httpx.Response:
"""
Delete a file.
Args:
document_id: 知识id
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not document_id:
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
return self._delete(
f"/document/{document_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=httpx.Response,
)
def retrieve(
self,
document_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> DocumentData:
"""
Args:
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not document_id:
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
return self._get(
f"/document/{document_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=DocumentData,
)

View File

@ -0,0 +1,173 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Literal, Optional
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
cached_property,
deepcopy_minimal,
make_request_options,
maybe_transform,
)
from ...types.knowledge import KnowledgeInfo, KnowledgeUsed, knowledge_create_params, knowledge_list_params
from ...types.knowledge.knowledge_list_resp import KnowledgePage
from .document import Document
if TYPE_CHECKING:
from ..._client import ZhipuAI
__all__ = ["Knowledge"]
class Knowledge(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
@cached_property
def document(self) -> Document:
return Document(self._client)
def create(
self,
embedding_id: int,
name: str,
*,
customer_identifier: Optional[str] = None,
description: Optional[str] = None,
background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None,
icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None,
bucket_id: Optional[str] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> KnowledgeInfo:
body = deepcopy_minimal(
{
"embedding_id": embedding_id,
"name": name,
"customer_identifier": customer_identifier,
"description": description,
"background": background,
"icon": icon,
"bucket_id": bucket_id,
}
)
return self._post(
"/knowledge",
body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=KnowledgeInfo,
)
def modify(
self,
knowledge_id: str,
embedding_id: int,
*,
name: str,
description: Optional[str] = None,
background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None,
icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> httpx.Response:
body = deepcopy_minimal(
{
"id": knowledge_id,
"embedding_id": embedding_id,
"name": name,
"description": description,
"background": background,
"icon": icon,
}
)
return self._put(
f"/knowledge/{knowledge_id}",
body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=httpx.Response,
)
def query(
self,
*,
page: int | NotGiven = 1,
size: int | NotGiven = 10,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> KnowledgePage:
return self._get(
"/knowledge",
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"page": page,
"size": size,
},
knowledge_list_params.KnowledgeListParams,
),
),
cast_type=KnowledgePage,
)
def delete(
self,
knowledge_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> httpx.Response:
"""
Delete a file.
Args:
knowledge_id: 知识库ID
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not knowledge_id:
raise ValueError("Expected a non-empty value for `knowledge_id`")
return self._delete(
f"/knowledge/{knowledge_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=httpx.Response,
)
def used(
self,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> KnowledgeUsed:
"""
Returns the contents of the specified file.
Args:
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
return self._get(
"/knowledge/capacity",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=KnowledgeUsed,
)

View File

@ -0,0 +1,3 @@
from .tools import Tools
__all__ = ["Tools"]

View File

@ -0,0 +1,65 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Literal, Optional, Union
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
StreamResponse,
deepcopy_minimal,
make_request_options,
maybe_transform,
)
from ...types.tools import WebSearch, WebSearchChunk, tools_web_search_params
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from ..._client import ZhipuAI
__all__ = ["Tools"]
class Tools(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def web_search(
self,
*,
model: str,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
messages: Union[str, list[str], list[int], object, None],
scope: Optional[str] | NotGiven = NOT_GIVEN,
location: Optional[str] | NotGiven = NOT_GIVEN,
recent_days: Optional[int] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> WebSearch | StreamResponse[WebSearchChunk]:
body = deepcopy_minimal(
{
"model": model,
"request_id": request_id,
"messages": messages,
"stream": stream,
"scope": scope,
"location": location,
"recent_days": recent_days,
}
)
return self._post(
"/tools",
body=maybe_transform(body, tools_web_search_params.WebSearchParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=WebSearch,
stream=stream or False,
stream_cls=StreamResponse[WebSearchChunk],
)

View File

@ -0,0 +1,7 @@
from .videos import (
Videos,
)
__all__ = [
"Videos",
]

View File

@ -0,0 +1,77 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
deepcopy_minimal,
make_request_options,
maybe_transform,
)
from ...types.sensitive_word_check import SensitiveWordCheckRequest
from ...types.video import VideoObject, video_create_params
if TYPE_CHECKING:
from ..._client import ZhipuAI
__all__ = ["Videos"]
class Videos(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def generations(
self,
model: str,
*,
prompt: str = None,
image_url: str = None,
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
request_id: str = None,
user_id: str = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> VideoObject:
if not model and not model:
raise ValueError("At least one of `model` and `prompt` must be provided.")
body = deepcopy_minimal(
{
"model": model,
"prompt": prompt,
"image_url": image_url,
"sensitive_word_check": sensitive_word_check,
"request_id": request_id,
"user_id": user_id,
}
)
return self._post(
"/videos/generations",
body=maybe_transform(body, video_create_params.VideoCreateParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=VideoObject,
)
def retrieve_videos_result(
self,
id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> VideoObject:
if not id:
raise ValueError("At least one of `id` must be provided.")
return self._get(
f"/async-result/{id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=VideoObject,
)

View File

@ -0,0 +1,108 @@
from ._base_api import BaseAPI
from ._base_compat import (
PYDANTIC_V2,
ConfigDict,
GenericModel,
cached_property,
field_get_default,
get_args,
get_model_config,
get_model_fields,
get_origin,
is_literal_type,
is_union,
parse_obj,
)
from ._base_models import BaseModel, construct_type
from ._base_type import (
NOT_GIVEN,
Body,
FileTypes,
Headers,
IncEx,
ModelT,
NotGiven,
Query,
)
from ._constants import (
ZHIPUAI_DEFAULT_LIMITS,
ZHIPUAI_DEFAULT_MAX_RETRIES,
ZHIPUAI_DEFAULT_TIMEOUT,
)
from ._errors import (
APIAuthenticationError,
APIConnectionError,
APIInternalError,
APIReachLimitError,
APIRequestFailedError,
APIResponseError,
APIResponseValidationError,
APIServerFlowExceedError,
APIStatusError,
APITimeoutError,
ZhipuAIError,
)
from ._files import is_file_content
from ._http_client import HttpClient, make_request_options
from ._sse_client import StreamResponse
from ._utils import (
deepcopy_minimal,
drop_prefix_image_data,
extract_files,
is_given,
is_list,
is_mapping,
maybe_transform,
parse_date,
parse_datetime,
)
__all__ = [
"BaseModel",
"construct_type",
"BaseAPI",
"NOT_GIVEN",
"Headers",
"NotGiven",
"Body",
"IncEx",
"ModelT",
"Query",
"FileTypes",
"PYDANTIC_V2",
"ConfigDict",
"GenericModel",
"get_args",
"is_union",
"parse_obj",
"get_origin",
"is_literal_type",
"get_model_config",
"get_model_fields",
"field_get_default",
"is_file_content",
"ZhipuAIError",
"APIStatusError",
"APIRequestFailedError",
"APIAuthenticationError",
"APIReachLimitError",
"APIInternalError",
"APIServerFlowExceedError",
"APIResponseError",
"APIResponseValidationError",
"APITimeoutError",
"make_request_options",
"HttpClient",
"ZHIPUAI_DEFAULT_TIMEOUT",
"ZHIPUAI_DEFAULT_MAX_RETRIES",
"ZHIPUAI_DEFAULT_LIMITS",
"is_list",
"is_mapping",
"parse_date",
"parse_datetime",
"is_given",
"maybe_transform",
"deepcopy_minimal",
"extract_files",
"StreamResponse",
]

View File

@ -16,3 +16,4 @@ class BaseAPI:
self._post = client.post
self._put = client.put
self._patch = client.patch
self._get_api_list = client.get_api_list

View File

@ -0,0 +1,209 @@
from __future__ import annotations
from collections.abc import Callable
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, overload
import pydantic
from pydantic.fields import FieldInfo
from typing_extensions import Self
from ._base_type import StrBytesIntFloat
_T = TypeVar("_T")
_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
# --------------- Pydantic v2 compatibility ---------------
# Pyright incorrectly reports some of our functions as overriding a method when they don't
# pyright: reportIncompatibleMethodOverride=false
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
# v1 re-exports
if TYPE_CHECKING:
def parse_date(value: date | StrBytesIntFloat) -> date: ...
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: ...
def get_args(t: type[Any]) -> tuple[Any, ...]: ...
def is_union(tp: type[Any] | None) -> bool: ...
def get_origin(t: type[Any]) -> type[Any] | None: ...
def is_literal_type(type_: type[Any]) -> bool: ...
def is_typeddict(type_: type[Any]) -> bool: ...
else:
if PYDANTIC_V2:
from pydantic.v1.typing import ( # noqa: I001
get_args as get_args, # noqa: PLC0414
is_union as is_union, # noqa: PLC0414
get_origin as get_origin, # noqa: PLC0414
is_typeddict as is_typeddict, # noqa: PLC0414
is_literal_type as is_literal_type, # noqa: PLC0414
)
from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414
else:
from pydantic.typing import ( # noqa: I001
get_args as get_args, # noqa: PLC0414
is_union as is_union, # noqa: PLC0414
get_origin as get_origin, # noqa: PLC0414
is_typeddict as is_typeddict, # noqa: PLC0414
is_literal_type as is_literal_type, # noqa: PLC0414
)
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414
# refactored config
if TYPE_CHECKING:
from pydantic import ConfigDict
else:
if PYDANTIC_V2:
from pydantic import ConfigDict
else:
# TODO: provide an error message here?
ConfigDict = None
# renamed methods / properties
def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
if PYDANTIC_V2:
return model.model_validate(value)
else:
# pyright: ignore[reportDeprecated, reportUnnecessaryCast]
return cast(_ModelT, model.parse_obj(value))
def field_is_required(field: FieldInfo) -> bool:
if PYDANTIC_V2:
return field.is_required()
return field.required # type: ignore
def field_get_default(field: FieldInfo) -> Any:
value = field.get_default()
if PYDANTIC_V2:
from pydantic_core import PydanticUndefined
if value == PydanticUndefined:
return None
return value
return value
def field_outer_type(field: FieldInfo) -> Any:
if PYDANTIC_V2:
return field.annotation
return field.outer_type_ # type: ignore
def get_model_config(model: type[pydantic.BaseModel]) -> Any:
if PYDANTIC_V2:
return model.model_config
return model.__config__ # type: ignore
def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
if PYDANTIC_V2:
return model.model_fields
return model.__fields__ # type: ignore
def model_copy(model: _ModelT) -> _ModelT:
if PYDANTIC_V2:
return model.model_copy()
return model.copy() # type: ignore
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
if PYDANTIC_V2:
return model.model_dump_json(indent=indent)
return model.json(indent=indent) # type: ignore
def model_dump(
model: pydantic.BaseModel,
*,
exclude_unset: bool = False,
exclude_defaults: bool = False,
) -> dict[str, Any]:
if PYDANTIC_V2:
return model.model_dump(
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
)
return cast(
"dict[str, Any]",
model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
),
)
def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
if PYDANTIC_V2:
return model.model_validate(data)
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
# generic models
if TYPE_CHECKING:
class GenericModel(pydantic.BaseModel): ...
else:
if PYDANTIC_V2:
# there no longer needs to be a distinction in v2 but
# we still have to create our own subclass to avoid
# inconsistent MRO ordering errors
class GenericModel(pydantic.BaseModel): ...
else:
import pydantic.generics
class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
# cached properties
if TYPE_CHECKING:
cached_property = property
# we define a separate type (copied from typeshed)
# that represents that `cached_property` is `set`able
# at runtime, which differs from `@property`.
#
# this is a separate type as editors likely special case
# `@property` and we don't want to cause issues just to have
# more helpful internal types.
class typed_cached_property(Generic[_T]): # noqa: N801
func: Callable[[Any], _T]
attrname: str | None
def __init__(self, func: Callable[[Any], _T]) -> None: ...
@overload
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ...
@overload
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ...
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
raise NotImplementedError()
def __set_name__(self, owner: type[Any], name: str) -> None: ...
# __set__ is not defined at runtime, but @cached_property is designed to be settable
def __set__(self, instance: object, value: _T) -> None: ...
else:
try:
from functools import cached_property
except ImportError:
from cached_property import cached_property
typed_cached_property = cached_property

View File

@ -0,0 +1,671 @@
from __future__ import annotations
import inspect
import os
from collections.abc import Callable
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeGuard, TypeVar, cast
import pydantic
import pydantic.generics
from pydantic.fields import FieldInfo
from typing_extensions import (
ParamSpec,
Protocol,
override,
runtime_checkable,
)
from ._base_compat import (
PYDANTIC_V2,
ConfigDict,
field_get_default,
get_args,
get_model_config,
get_model_fields,
get_origin,
is_literal_type,
is_union,
parse_obj,
)
from ._base_compat import (
GenericModel as BaseGenericModel,
)
from ._base_type import (
IncEx,
ModelT,
)
from ._utils import (
PropertyInfo,
coerce_boolean,
extract_type_arg,
is_annotated_type,
is_list,
is_mapping,
parse_date,
parse_datetime,
strip_annotated_type,
)
if TYPE_CHECKING:
from pydantic_core.core_schema import LiteralSchema, ModelField, ModelFieldsSchema
__all__ = ["BaseModel", "GenericModel"]
_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")
_T = TypeVar("_T")
P = ParamSpec("P")
@runtime_checkable
class _ConfigProtocol(Protocol):
allow_population_by_field_name: bool
class BaseModel(pydantic.BaseModel):
if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
)
else:
@property
@override
def model_fields_set(self) -> set[str]:
# a forwards-compat shim for pydantic v2
return self.__fields_set__ # type: ignore
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
extra: Any = pydantic.Extra.allow # type: ignore
def to_dict(
self,
*,
mode: Literal["json", "python"] = "python",
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> dict[str, object]:
"""Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
mode:
If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`.
If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
exclude_none: Whether to exclude fields that have a value of `None` from the output.
warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2.
""" # noqa: E501
return self.model_dump(
mode=mode,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)
def to_json(
self,
*,
indent: int | None = 2,
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> str:
"""Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that have the default value.
exclude_none: Whether to exclude fields that have a value of `None`.
warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2.
""" # noqa: E501
return self.model_dump_json(
indent=indent,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)
@override
def __str__(self) -> str:
# mypy complains about an invalid self arg
return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc]
# Override the 'construct' method in a way that supports recursive parsing without validation.
# Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.
@classmethod
@override
def construct(
cls: type[ModelT],
_fields_set: set[str] | None = None,
**values: object,
) -> ModelT:
m = cls.__new__(cls)
fields_values: dict[str, object] = {}
config = get_model_config(cls)
populate_by_name = (
config.allow_population_by_field_name
if isinstance(config, _ConfigProtocol)
else config.get("populate_by_name")
)
if _fields_set is None:
_fields_set = set()
model_fields = get_model_fields(cls)
for name, field in model_fields.items():
key = field.alias
if key is None or (key not in values and populate_by_name):
key = name
if key in values:
fields_values[name] = _construct_field(value=values[key], field=field, key=key)
_fields_set.add(name)
else:
fields_values[name] = field_get_default(field)
_extra = {}
for key, value in values.items():
if key not in model_fields:
if PYDANTIC_V2:
_extra[key] = value
else:
_fields_set.add(key)
fields_values[key] = value
object.__setattr__(m, "__dict__", fields_values) # noqa: PLC2801
if PYDANTIC_V2:
# these properties are copied from Pydantic's `model_construct()` method
object.__setattr__(m, "__pydantic_private__", None) # noqa: PLC2801
object.__setattr__(m, "__pydantic_extra__", _extra) # noqa: PLC2801
object.__setattr__(m, "__pydantic_fields_set__", _fields_set) # noqa: PLC2801
else:
# init_private_attributes() does not exist in v2
m._init_private_attributes() # type: ignore
# copied from Pydantic v1's `construct()` method
object.__setattr__(m, "__fields_set__", _fields_set) # noqa: PLC2801
return m
if not TYPE_CHECKING:
# type checkers incorrectly complain about this assignment
# because the type signatures are technically different
# although not in practice
model_construct = construct
if not PYDANTIC_V2:
# we define aliases for some of the new pydantic v2 methods so
# that we can just document these methods without having to specify
# a specific pydantic version as some users may not know which
# pydantic version they are currently using
@override
def model_dump(
self,
*,
mode: Literal["json", "python"] | str = "python",
include: IncEx = None,
exclude: IncEx = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
context: dict[str, Any] | None = None,
serialize_as_any: bool = False,
) -> dict[str, Any]:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
Args:
mode: The mode in which `to_python` should run.
If mode is 'json', the dictionary will only contain JSON serializable types.
If mode is 'python', the dictionary may contain any Python objects.
include: A list of fields to include in the output.
exclude: A list of fields to exclude from the output.
by_alias: Whether to use the field's alias in the dictionary key if defined.
exclude_unset: Whether to exclude fields that are unset or None from the output.
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
exclude_none: Whether to exclude fields that have a value of `None` from the output.
round_trip: Whether to enable serialization and deserialization round-trip support.
warnings: Whether to log warnings when invalid fields are encountered.
Returns:
A dictionary representation of the model.
"""
if mode != "python":
raise ValueError("mode is only supported in Pydantic v2")
if round_trip != False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings != True:
raise ValueError("warnings is only supported in Pydantic v2")
if context is not None:
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
return super().dict( # pyright: ignore[reportDeprecated]
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
@override
def model_dump_json(
self,
*,
indent: int | None = None,
include: IncEx = None,
exclude: IncEx = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
context: dict[str, Any] | None = None,
serialize_as_any: bool = False,
) -> str:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json
Generates a JSON representation of the model using Pydantic's `to_json` method.
Args:
indent: Indentation to use in the JSON output. If None is passed, the output will be compact.
include: Field(s) to include in the JSON output. Can take either a string or set of strings.
exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.
by_alias: Whether to serialize using field aliases.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that have the default value.
exclude_none: Whether to exclude fields that have a value of `None`.
round_trip: Whether to use serialization/deserialization between JSON and class instance.
warnings: Whether to show any warnings that occurred during serialization.
Returns:
A JSON string representation of the model.
"""
if round_trip != False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings != True:
raise ValueError("warnings is only supported in Pydantic v2")
if context is not None:
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
return super().json( # type: ignore[reportDeprecated]
indent=indent,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
def _construct_field(value: object, field: FieldInfo, key: str) -> object:
if value is None:
return field_get_default(field)
if PYDANTIC_V2:
type_ = field.annotation
else:
type_ = cast(type, field.outer_type_) # type: ignore
if type_ is None:
raise RuntimeError(f"Unexpected field type is None for {key}")
return construct_type(value=value, type_=type_)
def is_basemodel(type_: type) -> bool:
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
if is_union(type_):
return any(is_basemodel(variant) for variant in get_args(type_))
return is_basemodel_type(type_)
def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
origin = get_origin(type_) or type_
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
def build(
base_model_cls: Callable[P, _BaseModelT],
*args: P.args,
**kwargs: P.kwargs,
) -> _BaseModelT:
"""Construct a BaseModel class without validation.
This is useful for cases where you need to instantiate a `BaseModel`
from an API response as this provides type-safe params which isn't supported
by helpers like `construct_type()`.
```py
build(MyModel, my_field_a="foo", my_field_b=123)
```
"""
if args:
raise TypeError(
"Received positional arguments which are not supported; Keyword arguments must be used instead",
)
return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs))
def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T:
"""Loose coercion to the expected type with construction of nested values.
Note: the returned value from this function is not guaranteed to match the
given type.
"""
return cast(_T, construct_type(value=value, type_=type_))
def construct_type(*, value: object, type_: type) -> object:
"""Loose coercion to the expected type with construction of nested values.
If the given value does not match the expected type then it is returned as-is.
"""
# we allow `object` as the input type because otherwise, passing things like
# `Literal['value']` will be reported as a type error by type checkers
type_ = cast("type[object]", type_)
# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(type_):
meta: tuple[Any, ...] = get_args(type_)[1:]
type_ = extract_type_arg(type_, 0)
else:
meta = ()
# we need to use the origin class for any types that are subscripted generics
# e.g. Dict[str, object]
origin = get_origin(type_) or type_
args = get_args(type_)
if is_union(origin):
try:
return validate_type(type_=cast("type[object]", type_), value=value)
except Exception:
pass
# if the type is a discriminated union then we want to construct the right variant
# in the union, even if the data doesn't match exactly, otherwise we'd break code
# that relies on the constructed class types, e.g.
#
# class FooType:
# kind: Literal['foo']
# value: str
#
# class BarType:
# kind: Literal['bar']
# value: int
#
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
# we'd end up constructing `FooType` when it should be `BarType`.
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
if discriminator and is_mapping(value):
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
if variant_value and isinstance(variant_value, str):
variant_type = discriminator.mapping.get(variant_value)
if variant_type:
return construct_type(type_=variant_type, value=value)
# if the data is not valid, use the first variant that doesn't fail while deserializing
for variant in args:
try:
return construct_type(value=value, type_=variant)
except Exception:
continue
raise RuntimeError(f"Could not convert data into a valid instance of {type_}")
if origin == dict:
if not is_mapping(value):
return value
_, items_type = get_args(type_) # Dict[_, items_type]
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):
if is_list(value):
return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
if is_mapping(value):
if issubclass(type_, BaseModel):
return type_.construct(**value) # type: ignore[arg-type]
return cast(Any, type_).construct(**value)
if origin == list:
if not is_list(value):
return value
inner_type = args[0] # List[inner_type]
return [construct_type(value=entry, type_=inner_type) for entry in value]
if origin == float:
if isinstance(value, int):
coerced = float(value)
if coerced != value:
return value
return coerced
return value
if type_ == datetime:
try:
return parse_datetime(value) # type: ignore
except Exception:
return value
if type_ == date:
try:
return parse_date(value) # type: ignore
except Exception:
return value
return value
@runtime_checkable
class CachedDiscriminatorType(Protocol):
__discriminator__: DiscriminatorDetails
class DiscriminatorDetails:
field_name: str
"""The name of the discriminator field in the variant class, e.g.
```py
class Foo(BaseModel):
type: Literal['foo']
```
Will result in field_name='type'
"""
field_alias_from: str | None
"""The name of the discriminator field in the API response, e.g.
```py
class Foo(BaseModel):
type: Literal['foo'] = Field(alias='type_from_api')
```
Will result in field_alias_from='type_from_api'
"""
mapping: dict[str, type]
"""Mapping of discriminator value to variant type, e.g.
{'foo': FooVariant, 'bar': BarVariant}
"""
def __init__(
self,
*,
mapping: dict[str, type],
discriminator_field: str,
discriminator_alias: str | None,
) -> None:
self.mapping = mapping
self.field_name = discriminator_field
self.field_alias_from = discriminator_alias
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
if isinstance(union, CachedDiscriminatorType):
return union.__discriminator__
discriminator_field_name: str | None = None
for annotation in meta_annotations:
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
discriminator_field_name = annotation.discriminator
break
if not discriminator_field_name:
return None
mapping: dict[str, type] = {}
discriminator_alias: str | None = None
for variant in get_args(union):
variant = strip_annotated_type(variant)
if is_basemodel_type(variant):
if PYDANTIC_V2:
field = _extract_field_schema_pv2(variant, discriminator_field_name)
if not field:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field.get("serialization_alias")
field_schema = field["schema"]
if field_schema["type"] == "literal":
for entry in cast("LiteralSchema", field_schema)["expected"]:
if isinstance(entry, str):
mapping[entry] = variant
else:
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
if not field_info:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field_info.alias
if field_info.annotation and is_literal_type(field_info.annotation):
for entry in get_args(field_info.annotation):
if isinstance(entry, str):
mapping[entry] = variant
if not mapping:
return None
details = DiscriminatorDetails(
mapping=mapping,
discriminator_field=discriminator_field_name,
discriminator_alias=discriminator_alias,
)
cast(CachedDiscriminatorType, union).__discriminator__ = details
return details
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
schema = model.__pydantic_core_schema__
if schema["type"] != "model":
return None
fields_schema = schema["schema"]
if fields_schema["type"] != "model-fields":
return None
fields_schema = cast("ModelFieldsSchema", fields_schema)
field = fields_schema["fields"].get(field_name)
if not field:
return None
return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
def validate_type(*, type_: type[_T], value: object) -> _T:
"""Strict validation that the given value matches the expected type"""
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
return cast(_T, parse_obj(type_, value))
return cast(_T, _validate_non_model_type(type_=type_, value=value))
# our use of subclasssing here causes weirdness for type checkers,
# so we just pretend that we don't subclass
if TYPE_CHECKING:
GenericModel = BaseModel
else:
class GenericModel(BaseGenericModel, BaseModel):
pass
if PYDANTIC_V2:
from pydantic import TypeAdapter
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
return TypeAdapter(type_).validate_python(value)
elif not TYPE_CHECKING:
class TypeAdapter(Generic[_T]):
"""Used as a placeholder to easily convert runtime types to a Pydantic format
to provide validation.
For example:
```py
validated = RootModel[int](__root__="5").__root__
# validated: 5
```
"""
def __init__(self, type_: type[_T]):
self.type_ = type_
def validate_python(self, value: Any) -> _T:
if not isinstance(value, self.type_):
raise ValueError(f"Invalid type: {value} is not of type {self.type_}")
return value
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
return TypeAdapter(type_).validate_python(value)

View File

@ -1,11 +1,21 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence
from os import PathLike
from typing import IO, TYPE_CHECKING, Any, Literal, TypeVar, Union
from typing import (
IO,
TYPE_CHECKING,
Any,
Literal,
Optional,
TypeAlias,
TypeVar,
Union,
)
import pydantic
from typing_extensions import override
from httpx import Response
from typing_extensions import Protocol, TypedDict, override, runtime_checkable
Query = Mapping[str, object]
Body = object
@ -22,7 +32,7 @@ else:
# Sentinel class used until PEP 0661 is accepted
class NotGiven(pydantic.BaseModel):
class NotGiven:
"""
A sentinel singleton class used to distinguish omitted keyword arguments
from those passed in with the value None (which may have different behavior).
@ -50,7 +60,7 @@ NotGivenOr = Union[_T, NotGiven]
NOT_GIVEN = NotGiven()
class Omit(pydantic.BaseModel):
class Omit:
"""In certain situations you need to be able to represent a case where a default value has
to be explicitly removed and `None` is not an appropriate substitute, for example:
@ -71,37 +81,90 @@ class Omit(pydantic.BaseModel):
return False
@runtime_checkable
class ModelBuilderProtocol(Protocol):
@classmethod
def build(
cls: type[_T],
*,
response: Response,
data: object,
) -> _T: ...
Headers = Mapping[str, Union[str, Omit]]
class HeadersLikeProtocol(Protocol):
def get(self, __key: str) -> str | None: ...
HeadersLike = Union[Headers, HeadersLikeProtocol]
ResponseT = TypeVar(
"ResponseT",
bound="Union[str, None, BaseModel, list[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol,"
" BinaryResponseContent]",
bound="Union[str, None, BaseModel, list[Any], dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", # noqa: E501
)
StrBytesIntFloat = Union[str, bytes, int, float]
# Note: copied from Pydantic
# https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49
IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None"
PostParser = Callable[[Any], Any]
@runtime_checkable
class InheritsGeneric(Protocol):
"""Represents a type that has inherited from `Generic`
The `__orig_bases__` property can be used to determine the resolved
type variable for a given base class.
"""
__orig_bases__: tuple[_GenericAlias]
class _GenericAlias(Protocol):
__origin__: type[object]
class HttpxSendArgs(TypedDict, total=False):
auth: httpx.Auth
# for user input files
if TYPE_CHECKING:
Base64FileInput = Union[IO[bytes], PathLike[str]]
FileContent = Union[IO[bytes], bytes, PathLike[str]]
else:
Base64FileInput = Union[IO[bytes], PathLike]
FileContent = Union[IO[bytes], bytes, PathLike]
FileTypes = Union[
FileContent, # file content
tuple[str, FileContent], # (filename, file)
tuple[str, FileContent, str], # (filename, file , content_type)
tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers)
# file (or bytes)
FileContent,
# (filename, file (or bytes))
tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
tuple[Optional[str], FileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
]
RequestFiles = Union[Mapping[str, FileTypes], Sequence[tuple[str, FileTypes]]]
# for httpx client supported files
# duplicate of the above but without our custom file support
HttpxFileContent = Union[bytes, IO[bytes]]
HttpxFileTypes = Union[
FileContent, # file content
tuple[str, HttpxFileContent], # (filename, file)
tuple[str, HttpxFileContent, str], # (filename, file , content_type)
tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers)
# file (or bytes)
HttpxFileContent,
# (filename, file (or bytes))
tuple[Optional[str], HttpxFileContent],
# (filename, file (or bytes), content_type)
tuple[Optional[str], HttpxFileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
]
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[tuple[str, HttpxFileTypes]]]

View File

@ -0,0 +1,12 @@
import httpx
RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
# 通过 `Timeout` 控制接口`connect` 和 `read` 超时时间,默认为`timeout=300.0, connect=8.0`
ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
# 通过 `retry` 参数控制重试次数默认为3次
ZHIPUAI_DEFAULT_MAX_RETRIES = 3
# 通过 `Limits` 控制最大连接数和保持连接数,默认为`max_connections=50, max_keepalive_connections=10`
ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10)
INITIAL_RETRY_DELAY = 0.5
MAX_RETRY_DELAY = 8.0

View File

@ -13,6 +13,7 @@ __all__ = [
"APIResponseError",
"APIResponseValidationError",
"APITimeoutError",
"APIConnectionError",
]
@ -24,7 +25,7 @@ class ZhipuAIError(Exception):
super().__init__(message)
class APIStatusError(Exception):
class APIStatusError(ZhipuAIError):
response: httpx.Response
status_code: int
@ -49,7 +50,7 @@ class APIInternalError(APIStatusError): ...
class APIServerFlowExceedError(APIStatusError): ...
class APIResponseError(Exception):
class APIResponseError(ZhipuAIError):
message: str
request: httpx.Request
json_data: object
@ -75,9 +76,11 @@ class APIResponseValidationError(APIResponseError):
self.status_code = response.status_code
class APITimeoutError(Exception):
request: httpx.Request
class APIConnectionError(APIResponseError):
def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
super().__init__(message, request, json_data=None)
def __init__(self, request: httpx.Request):
self.request = request
super().__init__("Request Timeout")
class APITimeoutError(APIConnectionError):
def __init__(self, request: httpx.Request) -> None:
super().__init__(message="Request timed out.", request=request)

View File

@ -2,40 +2,74 @@ from __future__ import annotations
import io
import os
from collections.abc import Mapping, Sequence
from pathlib import Path
import pathlib
from typing import TypeGuard, overload
from ._base_type import FileTypes, HttpxFileTypes, HttpxRequestFiles, RequestFiles
from ._base_type import (
Base64FileInput,
FileContent,
FileTypes,
HttpxFileContent,
HttpxFileTypes,
HttpxRequestFiles,
RequestFiles,
)
from ._utils import is_mapping_t, is_sequence_t, is_tuple_t
def is_file_content(obj: object) -> bool:
def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
return isinstance(obj, io.IOBase | os.PathLike)
def is_file_content(obj: object) -> TypeGuard[FileContent]:
return isinstance(obj, bytes | tuple | io.IOBase | os.PathLike)
def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
if not is_file_content(obj):
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
raise RuntimeError(
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads"
) from None
@overload
def to_httpx_files(files: None) -> None: ...
@overload
def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ...
def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
if files is None:
return None
if is_mapping_t(files):
files = {key: _transform_file(file) for key, file in files.items()}
elif is_sequence_t(files):
files = [(key, _transform_file(file)) for key, file in files]
else:
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
return files
def _transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = Path(file)
return path.name, path.read_bytes()
else:
return file
if isinstance(file, tuple):
if isinstance(file[1], os.PathLike):
return (file[0], Path(file[1]).read_bytes(), *file[2:])
else:
return (file[0], file[1], *file[2:])
else:
raise TypeError(f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type")
path = pathlib.Path(file)
return (path.name, path.read_bytes())
return file
if is_tuple_t(file):
return (file[0], _read_file_content(file[1]), *file[2:])
raise TypeError("Expected file types input to be a FileContent type or to be a tuple")
def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
if files is None:
return None
if isinstance(files, Mapping):
files = {key: _transform_file(file) for key, file in files.items()}
elif isinstance(files, Sequence):
files = [(key, _transform_file(file)) for key, file in files]
else:
raise TypeError(f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence")
return files
def _read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return pathlib.Path(file).read_bytes()
return file

View File

@ -1,23 +1,70 @@
from __future__ import annotations
import inspect
from collections.abc import Mapping
from typing import Any, Union, cast
import logging
import time
import warnings
from collections.abc import Iterator, Mapping
from itertools import starmap
from random import random
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
import httpx
import pydantic
from httpx import URL, Timeout
from tenacity import retry
from tenacity.stop import stop_after_attempt
from . import _errors
from ._base_type import NOT_GIVEN, AnyMapping, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT
from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError
from ._files import make_httpx_files
from ._request_opt import ClientRequestParam, UserRequestInput
from ._response import HttpResponse
from . import _errors, get_origin
from ._base_compat import model_copy
from ._base_models import GenericModel, construct_type, validate_type
from ._base_type import (
NOT_GIVEN,
AnyMapping,
Body,
Data,
Headers,
HttpxSendArgs,
ModelBuilderProtocol,
NotGiven,
Omit,
PostParser,
Query,
RequestFiles,
ResponseT,
)
from ._constants import (
INITIAL_RETRY_DELAY,
MAX_RETRY_DELAY,
RAW_RESPONSE_HEADER,
ZHIPUAI_DEFAULT_LIMITS,
ZHIPUAI_DEFAULT_MAX_RETRIES,
ZHIPUAI_DEFAULT_TIMEOUT,
)
from ._errors import APIConnectionError, APIResponseValidationError, APIStatusError, APITimeoutError
from ._files import to_httpx_files
from ._legacy_response import LegacyAPIResponse
from ._request_opt import FinalRequestOptions, UserRequestInput
from ._response import APIResponse, BaseAPIResponse, extract_response_type
from ._sse_client import StreamResponse
from ._utils import flatten
from ._utils import flatten, is_given, is_mapping
log: logging.Logger = logging.getLogger(__name__)
# TODO: make base page type vars covariant
SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]")
# AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]")
_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
if TYPE_CHECKING:
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
else:
try:
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
except ImportError:
# taken from https://github.com/encode/httpx/blob/3ba5fe0d7ac70222590e759c31442b1cab263791/httpx/_config.py#L366
HTTPX_DEFAULT_TIMEOUT = Timeout(5.0)
headers = {
"Accept": "application/json",
@ -25,50 +72,180 @@ headers = {
}
def _merge_map(map1: Mapping, map2: Mapping) -> Mapping:
merged = {**map1, **map2}
return {key: val for key, val in merged.items() if val is not None}
class PageInfo:
"""Stores the necessary information to build the request to retrieve the next page.
Either `url` or `params` must be set.
"""
url: URL | NotGiven
params: Query | NotGiven
@overload
def __init__(
self,
*,
url: URL,
) -> None: ...
@overload
def __init__(
self,
*,
params: Query,
) -> None: ...
def __init__(
self,
*,
url: URL | NotGiven = NOT_GIVEN,
params: Query | NotGiven = NOT_GIVEN,
) -> None:
self.url = url
self.params = params
from itertools import starmap
class BasePage(GenericModel, Generic[_T]):
"""
Defines the core interface for pagination.
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
Type Args:
ModelT: The pydantic model that represents an item in the response.
ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
ZHIPUAI_DEFAULT_MAX_RETRIES = 3
ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=5, max_keepalive_connections=5)
Methods:
has_next_page(): Check if there is another page available
next_page_info(): Get the necessary information to make a request for the next page
"""
_options: FinalRequestOptions = pydantic.PrivateAttr()
_model: type[_T] = pydantic.PrivateAttr()
def has_next_page(self) -> bool:
items = self._get_page_items()
if not items:
return False
return self.next_page_info() is not None
def next_page_info(self) -> Optional[PageInfo]: ...
def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body]
...
def _params_from_url(self, url: URL) -> httpx.QueryParams:
# TODO: do we have to preprocess params here?
return httpx.QueryParams(cast(Any, self._options.params)).merge(url.params)
def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
options = model_copy(self._options)
options._strip_raw_response_header()
if not isinstance(info.params, NotGiven):
options.params = {**options.params, **info.params}
return options
if not isinstance(info.url, NotGiven):
params = self._params_from_url(info.url)
url = info.url.copy_with(params=params)
options.params = dict(url.params)
options.url = str(url)
return options
raise ValueError("Unexpected PageInfo state")
class BaseSyncPage(BasePage[_T], Generic[_T]):
_client: HttpClient = pydantic.PrivateAttr()
def _set_private_attributes(
self,
client: HttpClient,
model: type[_T],
options: FinalRequestOptions,
) -> None:
self._model = model
self._client = client
self._options = options
# Pydantic uses a custom `__iter__` method to support casting BaseModels
# to dictionaries. e.g. dict(model).
# As we want to support `for item in page`, this is inherently incompatible
# with the default pydantic behaviour. It is not possible to support both
# use cases at once. Fortunately, this is not a big deal as all other pydantic
# methods should continue to work as expected as there is an alternative method
# to cast a model to a dictionary, model.dict(), which is used internally
# by pydantic.
def __iter__(self) -> Iterator[_T]: # type: ignore
for page in self.iter_pages():
yield from page._get_page_items()
def iter_pages(self: SyncPageT) -> Iterator[SyncPageT]:
page = self
while True:
yield page
if page.has_next_page():
page = page.get_next_page()
else:
return
def get_next_page(self: SyncPageT) -> SyncPageT:
info = self.next_page_info()
if not info:
raise RuntimeError(
"No next page expected; please check `.has_next_page()` before calling `.get_next_page()`."
)
options = self._info_to_options(info)
return self._client._request_api_list(self._model, page=self.__class__, options=options)
class HttpClient:
_client: httpx.Client
_version: str
_base_url: URL
max_retries: int
timeout: Union[float, Timeout, None]
_limits: httpx.Limits
_has_custom_http_client: bool
_default_stream_cls: type[StreamResponse[Any]] | None = None
_strict_response_validation: bool
def __init__(
self,
*,
version: str,
base_url: URL,
_strict_response_validation: bool,
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
timeout: Union[float, Timeout, None],
limits: httpx.Limits | None = None,
custom_httpx_client: httpx.Client | None = None,
custom_headers: Mapping[str, str] | None = None,
) -> None:
if timeout is None or isinstance(timeout, NotGiven):
if limits is not None:
warnings.warn(
"The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead", # noqa: E501
category=DeprecationWarning,
stacklevel=3,
)
if custom_httpx_client is not None:
raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`")
else:
limits = ZHIPUAI_DEFAULT_LIMITS
if not is_given(timeout):
if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT:
timeout = custom_httpx_client.timeout
else:
timeout = ZHIPUAI_DEFAULT_TIMEOUT
self.timeout = cast(Timeout, timeout)
self.max_retries = max_retries
self.timeout = timeout
self._limits = limits
self._has_custom_http_client = bool(custom_httpx_client)
self._client = custom_httpx_client or httpx.Client(
base_url=base_url,
timeout=self.timeout,
limits=ZHIPUAI_DEFAULT_LIMITS,
limits=limits,
)
self._version = version
url = URL(url=base_url)
@ -76,6 +253,7 @@ class HttpClient:
url = url.copy_with(raw_path=url.raw_path + b"/")
self._base_url = url
self._custom_headers = custom_headers or {}
self._strict_response_validation = _strict_response_validation
def _prepare_url(self, url: str) -> URL:
sub_url = URL(url)
@ -93,55 +271,101 @@ class HttpClient:
"ZhipuAI-SDK-Ver": self._version,
"source_type": "zhipu-sdk-python",
"x-request-sdk": "zhipu-sdk-python",
**self._auth_headers,
**self.auth_headers,
**self._custom_headers,
}
@property
def _auth_headers(self):
def custom_auth(self) -> httpx.Auth | None:
return None
@property
def auth_headers(self):
return {}
def _prepare_headers(self, request_param: ClientRequestParam) -> httpx.Headers:
custom_headers = request_param.headers or {}
headers_dict = _merge_map(self._default_headers, custom_headers)
def _prepare_headers(self, options: FinalRequestOptions) -> httpx.Headers:
custom_headers = options.headers or {}
headers_dict = _merge_mappings(self._default_headers, custom_headers)
httpx_headers = httpx.Headers(headers_dict)
return httpx_headers
def _prepare_request(self, request_param: ClientRequestParam) -> httpx.Request:
def _remaining_retries(
self,
remaining_retries: Optional[int],
options: FinalRequestOptions,
) -> int:
return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries)
def _calculate_retry_timeout(
self,
remaining_retries: int,
options: FinalRequestOptions,
response_headers: Optional[httpx.Headers] = None,
) -> float:
max_retries = options.get_max_retries(self.max_retries)
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
# retry_after = self._parse_retry_after_header(response_headers)
# if retry_after is not None and 0 < retry_after <= 60:
# return retry_after
nb_retries = max_retries - remaining_retries
# Apply exponential backoff, but not more than the max.
sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY)
# Apply some jitter, plus-or-minus half a second.
jitter = 1 - 0.25 * random()
timeout = sleep_seconds * jitter
return max(timeout, 0)
def _build_request(self, options: FinalRequestOptions) -> httpx.Request:
kwargs: dict[str, Any] = {}
json_data = request_param.json_data
headers = self._prepare_headers(request_param)
url = self._prepare_url(request_param.url)
json_data = request_param.json_data
headers = self._prepare_headers(options)
url = self._prepare_url(options.url)
json_data = options.json_data
if options.extra_json is not None:
if json_data is None:
json_data = cast(Body, options.extra_json)
elif is_mapping(json_data):
json_data = _merge_mappings(json_data, options.extra_json)
else:
raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")
content_type = headers.get("Content-Type")
# multipart/form-data; boundary=---abc--
if headers.get("Content-Type") == "multipart/form-data":
headers.pop("Content-Type")
if "boundary" not in content_type:
# only remove the header if the boundary hasn't been explicitly set
# as the caller doesn't want httpx to come up with their own boundary
headers.pop("Content-Type")
if json_data:
kwargs["data"] = self._make_multipartform(json_data)
return self._client.build_request(
headers=headers,
timeout=self.timeout if isinstance(request_param.timeout, NotGiven) else request_param.timeout,
method=request_param.method,
timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout,
method=options.method,
url=url,
json=json_data,
files=request_param.files,
params=request_param.params,
files=options.files,
params=options.params,
**kwargs,
)
def _object_to_formdata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]:
def _object_to_formfata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]:
items = []
if isinstance(value, Mapping):
for k, v in value.items():
items.extend(self._object_to_formdata(f"{key}[{k}]", v))
items.extend(self._object_to_formfata(f"{key}[{k}]", v))
return items
if isinstance(value, list | tuple):
for v in value:
items.extend(self._object_to_formdata(key + "[]", v))
items.extend(self._object_to_formfata(key + "[]", v))
return items
def _primitive_value_to_str(val) -> str:
@ -161,7 +385,7 @@ class HttpClient:
return [(key, str_data)]
def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
items = flatten(list(starmap(self._object_to_formdata, data.items())))
items = flatten(list(starmap(self._object_to_formfata, data.items())))
serialized: dict[str, object] = {}
for key, value in items:
@ -170,20 +394,6 @@ class HttpClient:
serialized[key] = value
return serialized
def _parse_response(
self,
*,
cast_type: type[ResponseT],
response: httpx.Response,
enable_stream: bool,
request_param: ClientRequestParam,
stream_cls: type[StreamResponse[Any]] | None = None,
) -> HttpResponse:
http_response = HttpResponse(
raw_response=response, cast_type=cast_type, client=self, enable_stream=enable_stream, stream_cls=stream_cls
)
return http_response.parse()
def _process_response_data(
self,
*,
@ -194,14 +404,58 @@ class HttpClient:
if data is None:
return cast(ResponseT, None)
try:
if inspect.isclass(cast_type) and issubclass(cast_type, pydantic.BaseModel):
return cast(ResponseT, cast_type.validate(data))
if cast_type is object:
return cast(ResponseT, data)
return cast(ResponseT, pydantic.TypeAdapter(cast_type).validate_python(data))
try:
if inspect.isclass(cast_type) and issubclass(cast_type, ModelBuilderProtocol):
return cast(ResponseT, cast_type.build(response=response, data=data))
if self._strict_response_validation:
return cast(ResponseT, validate_type(type_=cast_type, value=data))
return cast(ResponseT, construct_type(type_=cast_type, value=data))
except pydantic.ValidationError as err:
raise APIResponseValidationError(response=response, json_data=data) from err
def _should_stream_response_body(self, request: httpx.Request) -> bool:
return request.headers.get(RAW_RESPONSE_HEADER) == "stream" # type: ignore[no-any-return]
def _should_retry(self, response: httpx.Response) -> bool:
# Note: this is not a standard header
should_retry_header = response.headers.get("x-should-retry")
# If the server explicitly says whether or not to retry, obey.
if should_retry_header == "true":
log.debug("Retrying as header `x-should-retry` is set to `true`")
return True
if should_retry_header == "false":
log.debug("Not retrying as header `x-should-retry` is set to `false`")
return False
# Retry on request timeouts.
if response.status_code == 408:
log.debug("Retrying due to status code %i", response.status_code)
return True
# Retry on lock timeouts.
if response.status_code == 409:
log.debug("Retrying due to status code %i", response.status_code)
return True
# Retry on rate limits.
if response.status_code == 429:
log.debug("Retrying due to status code %i", response.status_code)
return True
# Retry internal errors.
if response.status_code >= 500:
log.debug("Retrying due to status code %i", response.status_code)
return True
log.debug("Not retrying")
return False
def is_closed(self) -> bool:
return self._client.is_closed
@ -214,117 +468,385 @@ class HttpClient:
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
@retry(stop=stop_after_attempt(ZHIPUAI_DEFAULT_MAX_RETRIES))
def request(
self,
cast_type: type[ResponseT],
options: FinalRequestOptions,
remaining_retries: Optional[int] = None,
*,
stream: bool = False,
stream_cls: type[StreamResponse] | None = None,
) -> ResponseT | StreamResponse:
return self._request(
cast_type=cast_type,
options=options,
stream=stream,
stream_cls=stream_cls,
remaining_retries=remaining_retries,
)
def _request(
self,
*,
cast_type: type[ResponseT],
params: ClientRequestParam,
enable_stream: bool = False,
stream_cls: type[StreamResponse[Any]] | None = None,
options: FinalRequestOptions,
remaining_retries: int | None,
stream: bool,
stream_cls: type[StreamResponse] | None,
) -> ResponseT | StreamResponse:
request = self._prepare_request(params)
retries = self._remaining_retries(remaining_retries, options)
request = self._build_request(options)
kwargs: HttpxSendArgs = {}
if self.custom_auth is not None:
kwargs["auth"] = self.custom_auth
try:
response = self._client.send(
request,
stream=enable_stream,
stream=stream or self._should_stream_response_body(request=request),
**kwargs,
)
response.raise_for_status()
except httpx.TimeoutException as err:
log.debug("Encountered httpx.TimeoutException", exc_info=True)
if retries > 0:
return self._retry_request(
options,
cast_type,
retries,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
)
log.debug("Raising timeout error")
raise APITimeoutError(request=request) from err
except httpx.HTTPStatusError as err:
err.response.read()
# raise err
except Exception as err:
log.debug("Encountered Exception", exc_info=True)
if retries > 0:
return self._retry_request(
options,
cast_type,
retries,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
)
log.debug("Raising connection error")
raise APIConnectionError(request=request) from err
log.debug(
'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
)
try:
response.raise_for_status()
except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
if retries > 0 and self._should_retry(err.response):
err.response.close()
return self._retry_request(
options,
cast_type,
retries,
err.response.headers,
stream=stream,
stream_cls=stream_cls,
)
# If the response is streamed then we need to explicitly read the response
# to completion before attempting to access the response text.
if not err.response.is_closed:
err.response.read()
log.debug("Re-raising status error")
raise self._make_status_error(err.response) from None
except Exception as err:
raise err
return self._parse_response(
# return self._parse_response(
# cast_type=cast_type,
# options=options,
# response=response,
# stream=stream,
# stream_cls=stream_cls,
# )
return self._process_response(
cast_type=cast_type,
request_param=params,
options=options,
response=response,
enable_stream=enable_stream,
stream=stream,
stream_cls=stream_cls,
)
def _retry_request(
self,
options: FinalRequestOptions,
cast_type: type[ResponseT],
remaining_retries: int,
response_headers: httpx.Headers | None,
*,
stream: bool,
stream_cls: type[StreamResponse] | None,
) -> ResponseT | StreamResponse:
remaining = remaining_retries - 1
if remaining == 1:
log.debug("1 retry left")
else:
log.debug("%i retries left", remaining)
timeout = self._calculate_retry_timeout(remaining, options, response_headers)
log.info("Retrying request to %s in %f seconds", options.url, timeout)
# In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a
# different thread if necessary.
time.sleep(timeout)
return self._request(
options=options,
cast_type=cast_type,
remaining_retries=remaining,
stream=stream,
stream_cls=stream_cls,
)
def _process_response(
self,
*,
cast_type: type[ResponseT],
options: FinalRequestOptions,
response: httpx.Response,
stream: bool,
stream_cls: type[StreamResponse] | None,
) -> ResponseT:
# _legacy_response with raw_response_header to paser method
if response.request.headers.get(RAW_RESPONSE_HEADER) == "true":
return cast(
ResponseT,
LegacyAPIResponse(
raw=response,
client=self,
cast_type=cast_type,
stream=stream,
stream_cls=stream_cls,
options=options,
),
)
origin = get_origin(cast_type) or cast_type
if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse):
if not issubclass(origin, APIResponse):
raise TypeError(f"API Response types must subclass {APIResponse}; Received {origin}")
response_cls = cast("type[BaseAPIResponse[Any]]", cast_type)
return cast(
ResponseT,
response_cls(
raw=response,
client=self,
cast_type=extract_response_type(response_cls),
stream=stream,
stream_cls=stream_cls,
options=options,
),
)
if cast_type == httpx.Response:
return cast(ResponseT, response)
api_response = APIResponse(
raw=response,
client=self,
cast_type=cast("type[ResponseT]", cast_type), # pyright: ignore[reportUnnecessaryCast]
stream=stream,
stream_cls=stream_cls,
options=options,
)
if bool(response.request.headers.get(RAW_RESPONSE_HEADER)):
return cast(ResponseT, api_response)
return api_response.parse()
def _request_api_list(
self,
model: type[object],
page: type[SyncPageT],
options: FinalRequestOptions,
) -> SyncPageT:
def _parser(resp: SyncPageT) -> SyncPageT:
resp._set_private_attributes(
client=self,
model=model,
options=options,
)
return resp
options.post_parser = _parser
return self.request(page, options, stream=False)
@overload
def get(
self,
path: str,
*,
cast_type: type[ResponseT],
options: UserRequestInput = {},
stream: Literal[False] = False,
) -> ResponseT: ...
@overload
def get(
self,
path: str,
*,
cast_type: type[ResponseT],
options: UserRequestInput = {},
stream: Literal[True],
stream_cls: type[StreamResponse],
) -> StreamResponse: ...
@overload
def get(
self,
path: str,
*,
cast_type: type[ResponseT],
options: UserRequestInput = {},
stream: bool,
stream_cls: type[StreamResponse] | None = None,
) -> ResponseT | StreamResponse: ...
def get(
self,
path: str,
*,
cast_type: type[ResponseT],
options: UserRequestInput = {},
enable_stream: bool = False,
) -> ResponseT | StreamResponse:
opts = ClientRequestParam.construct(method="get", url=path, **options)
return self.request(cast_type=cast_type, params=opts, enable_stream=enable_stream)
stream: bool = False,
stream_cls: type[StreamResponse] | None = None,
) -> ResponseT:
opts = FinalRequestOptions.construct(method="get", url=path, **options)
return cast(ResponseT, self.request(cast_type, opts, stream=stream, stream_cls=stream_cls))
@overload
def post(
self,
path: str,
*,
cast_type: type[ResponseT],
body: Body | None = None,
options: UserRequestInput = {},
files: RequestFiles | None = None,
stream: Literal[False] = False,
) -> ResponseT: ...
@overload
def post(
self,
path: str,
*,
cast_type: type[ResponseT],
body: Body | None = None,
options: UserRequestInput = {},
files: RequestFiles | None = None,
stream: Literal[True],
stream_cls: type[StreamResponse],
) -> StreamResponse: ...
@overload
def post(
self,
path: str,
*,
cast_type: type[ResponseT],
body: Body | None = None,
options: UserRequestInput = {},
files: RequestFiles | None = None,
stream: bool,
stream_cls: type[StreamResponse] | None = None,
) -> ResponseT | StreamResponse: ...
def post(
self,
path: str,
*,
body: Body | None = None,
cast_type: type[ResponseT],
body: Body | None = None,
options: UserRequestInput = {},
files: RequestFiles | None = None,
enable_stream: bool = False,
stream: bool = False,
stream_cls: type[StreamResponse[Any]] | None = None,
) -> ResponseT | StreamResponse:
opts = ClientRequestParam.construct(
method="post", json_data=body, files=make_httpx_files(files), url=path, **options
opts = FinalRequestOptions.construct(
method="post", url=path, json_data=body, files=to_httpx_files(files), **options
)
return self.request(cast_type=cast_type, params=opts, enable_stream=enable_stream, stream_cls=stream_cls)
return cast(ResponseT, self.request(cast_type, opts, stream=stream, stream_cls=stream_cls))
def patch(
self,
path: str,
*,
body: Body | None = None,
cast_type: type[ResponseT],
body: Body | None = None,
options: UserRequestInput = {},
) -> ResponseT:
opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options)
opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options)
return self.request(
cast_type=cast_type,
params=opts,
options=opts,
)
def put(
self,
path: str,
*,
body: Body | None = None,
cast_type: type[ResponseT],
body: Body | None = None,
options: UserRequestInput = {},
files: RequestFiles | None = None,
) -> ResponseT | StreamResponse:
opts = ClientRequestParam.construct(
method="put", url=path, json_data=body, files=make_httpx_files(files), **options
opts = FinalRequestOptions.construct(
method="put", url=path, json_data=body, files=to_httpx_files(files), **options
)
return self.request(
cast_type=cast_type,
params=opts,
options=opts,
)
def delete(
self,
path: str,
*,
body: Body | None = None,
cast_type: type[ResponseT],
body: Body | None = None,
options: UserRequestInput = {},
) -> ResponseT | StreamResponse:
opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options)
opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, **options)
return self.request(
cast_type=cast_type,
params=opts,
options=opts,
)
def get_api_list(
self,
path: str,
*,
model: type[object],
page: type[SyncPageT],
body: Body | None = None,
options: UserRequestInput = {},
method: str = "get",
) -> SyncPageT:
opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options)
return self._request_api_list(model, page, opts)
def _make_status_error(self, response) -> APIStatusError:
response_text = response.text.strip()
status_code = response.status_code
@ -343,24 +865,46 @@ class HttpClient:
return APIStatusError(message=error_msg, response=response)
def make_user_request_input(
max_retries: int | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
extra_headers: Headers = None,
extra_body: Body | None = None,
def make_request_options(
*,
query: Query | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
post_parser: PostParser | NotGiven = NOT_GIVEN,
) -> UserRequestInput:
"""Create a dict of type RequestOptions without keys of NotGiven values."""
options: UserRequestInput = {}
if extra_headers is not None:
options["headers"] = extra_headers
if max_retries is not None:
options["max_retries"] = max_retries
if not isinstance(timeout, NotGiven):
options["timeout"] = timeout
if query is not None:
options["params"] = query
if extra_body is not None:
options["extra_json"] = cast(AnyMapping, extra_body)
if query is not None:
options["params"] = query
if extra_query is not None:
options["params"] = {**options.get("params", {}), **extra_query}
if not isinstance(timeout, NotGiven):
options["timeout"] = timeout
if is_given(post_parser):
# internal
options["post_parser"] = post_parser # type: ignore
return options
def _merge_mappings(
obj1: Mapping[_T_co, Union[_T, Omit]],
obj2: Mapping[_T_co, Union[_T, Omit]],
) -> dict[_T_co, _T]:
"""Merge two mappings of the same type, removing any values that are instances of `Omit`.
In cases with duplicate keys the second mapping takes precedence.
"""
merged = {**obj1, **obj2}
return {key: value for key, value in merged.items() if not isinstance(value, Omit)}

View File

@ -3,9 +3,11 @@ import time
import cachetools.func
import jwt
API_TOKEN_TTL_SECONDS = 3 * 60
# 缓存时间 3分钟
CACHE_TTL_SECONDS = 3 * 60
CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30
# token 有效期比缓存时间 多30秒
API_TOKEN_TTL_SECONDS = CACHE_TTL_SECONDS + 30
@cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS)

View File

@ -0,0 +1,207 @@
from __future__ import annotations
import os
from collections.abc import AsyncIterator, Iterator
from typing import Any
import httpx
class HttpxResponseContent:
@property
def content(self) -> bytes:
raise NotImplementedError("This method is not implemented for this class.")
@property
def text(self) -> str:
raise NotImplementedError("This method is not implemented for this class.")
@property
def encoding(self) -> str | None:
raise NotImplementedError("This method is not implemented for this class.")
@property
def charset_encoding(self) -> str | None:
raise NotImplementedError("This method is not implemented for this class.")
def json(self, **kwargs: Any) -> Any:
raise NotImplementedError("This method is not implemented for this class.")
def read(self) -> bytes:
raise NotImplementedError("This method is not implemented for this class.")
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
raise NotImplementedError("This method is not implemented for this class.")
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
raise NotImplementedError("This method is not implemented for this class.")
def iter_lines(self) -> Iterator[str]:
raise NotImplementedError("This method is not implemented for this class.")
def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]:
raise NotImplementedError("This method is not implemented for this class.")
def write_to_file(
self,
file: str | os.PathLike[str],
) -> None:
raise NotImplementedError("This method is not implemented for this class.")
def stream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
raise NotImplementedError("This method is not implemented for this class.")
def close(self) -> None:
raise NotImplementedError("This method is not implemented for this class.")
async def aread(self) -> bytes:
raise NotImplementedError("This method is not implemented for this class.")
async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
raise NotImplementedError("This method is not implemented for this class.")
async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
raise NotImplementedError("This method is not implemented for this class.")
async def aiter_lines(self) -> AsyncIterator[str]:
raise NotImplementedError("This method is not implemented for this class.")
async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
raise NotImplementedError("This method is not implemented for this class.")
async def astream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
raise NotImplementedError("This method is not implemented for this class.")
async def aclose(self) -> None:
raise NotImplementedError("This method is not implemented for this class.")
class HttpxBinaryResponseContent(HttpxResponseContent):
response: httpx.Response
def __init__(self, response: httpx.Response) -> None:
self.response = response
@property
def content(self) -> bytes:
return self.response.content
@property
def encoding(self) -> str | None:
return self.response.encoding
@property
def charset_encoding(self) -> str | None:
return self.response.charset_encoding
def read(self) -> bytes:
return self.response.read()
def text(self) -> str:
raise NotImplementedError("Not implemented for binary response content")
def json(self, **kwargs: Any) -> Any:
raise NotImplementedError("Not implemented for binary response content")
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
raise NotImplementedError("Not implemented for binary response content")
def iter_lines(self) -> Iterator[str]:
raise NotImplementedError("Not implemented for binary response content")
async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
raise NotImplementedError("Not implemented for binary response content")
async def aiter_lines(self) -> AsyncIterator[str]:
raise NotImplementedError("Not implemented for binary response content")
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
return self.response.iter_bytes(chunk_size)
def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]:
return self.response.iter_raw(chunk_size)
def write_to_file(
self,
file: str | os.PathLike[str],
) -> None:
"""Write the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
Note: if you want to stream the data to the file instead of writing
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `client.with_streaming_response.foo().stream_to_file('my_filename.txt')`
"""
with open(file, mode="wb") as f:
for data in self.response.iter_bytes():
f.write(data)
def stream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
with open(file, mode="wb") as f:
for data in self.response.iter_bytes(chunk_size):
f.write(data)
def close(self) -> None:
return self.response.close()
async def aread(self) -> bytes:
return await self.response.aread()
async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
return self.response.aiter_bytes(chunk_size)
async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
return self.response.aiter_raw(chunk_size)
async def astream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
path = anyio.Path(file)
async with await path.open(mode="wb") as f:
async for data in self.response.aiter_bytes(chunk_size):
await f.write(data)
async def aclose(self) -> None:
return await self.response.aclose()
class HttpxTextBinaryResponseContent(HttpxBinaryResponseContent):
response: httpx.Response
@property
def text(self) -> str:
return self.response.text
def json(self, **kwargs: Any) -> Any:
return self.response.json(**kwargs)
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
return self.response.iter_text(chunk_size)
def iter_lines(self) -> Iterator[str]:
return self.response.iter_lines()
async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
return self.response.aiter_text(chunk_size)
async def aiter_lines(self) -> AsyncIterator[str]:
return self.response.aiter_lines()

View File

@ -0,0 +1,341 @@
from __future__ import annotations
import datetime
import functools
import inspect
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, get_origin, overload
import httpx
import pydantic
from typing_extensions import ParamSpec, override
from ._base_models import BaseModel, is_basemodel
from ._base_type import NoneType
from ._constants import RAW_RESPONSE_HEADER
from ._errors import APIResponseValidationError
from ._legacy_binary_response import HttpxResponseContent, HttpxTextBinaryResponseContent
from ._sse_client import StreamResponse, extract_stream_chunk_type, is_stream_class_type
from ._utils import extract_type_arg, is_annotated_type, is_given
if TYPE_CHECKING:
from ._http_client import HttpClient
from ._request_opt import FinalRequestOptions
P = ParamSpec("P")
R = TypeVar("R")
_T = TypeVar("_T")
log: logging.Logger = logging.getLogger(__name__)
class LegacyAPIResponse(Generic[R]):
"""This is a legacy class as it will be replaced by `APIResponse`
and `AsyncAPIResponse` in the `_response.py` file in the next major
release.
For the sync client this will mostly be the same with the exception
of `content` & `text` will be methods instead of properties. In the
async client, all methods will be async.
A migration script will be provided & the migration in general should
be smooth.
"""
_cast_type: type[R]
_client: HttpClient
_parsed_by_type: dict[type[Any], Any]
_stream: bool
_stream_cls: type[StreamResponse[Any]] | None
_options: FinalRequestOptions
http_response: httpx.Response
def __init__(
self,
*,
raw: httpx.Response,
cast_type: type[R],
client: HttpClient,
stream: bool,
stream_cls: type[StreamResponse[Any]] | None,
options: FinalRequestOptions,
) -> None:
self._cast_type = cast_type
self._client = client
self._parsed_by_type = {}
self._stream = stream
self._stream_cls = stream_cls
self._options = options
self.http_response = raw
@property
def request_id(self) -> str | None:
return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return]
@overload
def parse(self, *, to: type[_T]) -> _T: ...
@overload
def parse(self) -> R: ...
def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
NOTE: For the async client: this will become a coroutine in the next major version.
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
You can customise the type that the response is parsed into through
the `to` argument, e.g.
```py
from zhipuai import BaseModel
class MyModel(BaseModel):
foo: str
obj = response.parse(to=MyModel)
print(obj.foo)
```
We support parsing:
- `BaseModel`
- `dict`
- `list`
- `Union`
- `str`
- `int`
- `float`
- `httpx.Response`
"""
cache_key = to if to is not None else self._cast_type
cached = self._parsed_by_type.get(cache_key)
if cached is not None:
return cached # type: ignore[no-any-return]
parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)
self._parsed_by_type[cache_key] = parsed
return parsed
@property
def headers(self) -> httpx.Headers:
return self.http_response.headers
@property
def http_request(self) -> httpx.Request:
return self.http_response.request
@property
def status_code(self) -> int:
return self.http_response.status_code
@property
def url(self) -> httpx.URL:
return self.http_response.url
@property
def method(self) -> str:
return self.http_request.method
@property
def content(self) -> bytes:
"""Return the binary response content.
NOTE: this will be removed in favour of `.read()` in the
next major version.
"""
return self.http_response.content
@property
def text(self) -> str:
"""Return the decoded response content.
NOTE: this will be turned into a method in the next major version.
"""
return self.http_response.text
@property
def http_version(self) -> str:
return self.http_response.http_version
@property
def is_closed(self) -> bool:
return self.http_response.is_closed
@property
def elapsed(self) -> datetime.timedelta:
"""The time taken for the complete request/response cycle to complete."""
return self.http_response.elapsed
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
# unwrap `Annotated[T, ...]` -> `T`
if to and is_annotated_type(to):
to = extract_type_arg(to, 0)
if self._stream:
if to:
if not is_stream_class_type(to):
raise TypeError(f"Expected custom parse type to be a subclass of {StreamResponse}")
return cast(
_T,
to(
cast_type=extract_stream_chunk_type(
to,
failure_message="Expected custom stream type to be passed with a type argument, e.g. StreamResponse[ChunkType]", # noqa: E501
),
response=self.http_response,
client=cast(Any, self._client),
),
)
if self._stream_cls:
return cast(
R,
self._stream_cls(
cast_type=extract_stream_chunk_type(self._stream_cls),
response=self.http_response,
client=cast(Any, self._client),
),
)
stream_cls = cast("type[StreamResponse[Any]] | None", self._client._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return cast(
R,
stream_cls(
cast_type=self._cast_type,
response=self.http_response,
client=cast(Any, self._client),
),
)
cast_type = to if to is not None else self._cast_type
# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(cast_type):
cast_type = extract_type_arg(cast_type, 0)
if cast_type is NoneType:
return cast(R, None)
response = self.http_response
if cast_type == str:
return cast(R, response.text)
if cast_type == int:
return cast(R, int(response.text))
if cast_type == float:
return cast(R, float(response.text))
origin = get_origin(cast_type) or cast_type
if inspect.isclass(origin) and issubclass(origin, HttpxResponseContent):
# in the response, e.g. mime file
*_, filename = response.headers.get("content-disposition", "").split("filename=")
# 判断文件类型是jsonl类型的使用HttpxTextBinaryResponseContent
if filename and filename.endswith(".jsonl") or filename and filename.endswith(".xlsx"):
return cast(R, HttpxTextBinaryResponseContent(response))
else:
return cast(R, cast_type(response)) # type: ignore
if origin == LegacyAPIResponse:
raise RuntimeError("Unexpected state - cast_type is `APIResponse`")
if inspect.isclass(origin) and issubclass(origin, httpx.Response):
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
# and pass that class to our request functions. We cannot change the variance to be either
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
# the response class ourselves but that is something that should be supported directly in httpx
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
if cast_type != httpx.Response:
raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_type`")
return cast(R, response)
if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):
raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
if (
cast_type is not object
and origin is not list
and origin is not dict
and origin is not Union
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
f"Unsupported type, expected {cast_type} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." # noqa: E501
)
# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type", "*").split(";")
if content_type != "application/json":
if is_basemodel(cast_type):
try:
data = response.json()
except Exception as exc:
log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
else:
return self._client._process_response_data(
data=data,
cast_type=cast_type, # type: ignore
response=response,
)
if self._client._strict_response_validation:
raise APIResponseValidationError(
response=response,
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", # noqa: E501
json_data=response.text,
)
# If the API responds with content that isn't JSON then we just return
# the (decoded) text without performing any parsing so that you can still
# handle the response however you need to.
return response.text # type: ignore
data = response.json()
return self._client._process_response_data(
data=data,
cast_type=cast_type, # type: ignore
response=response,
)
@override
def __repr__(self) -> str:
return f"<APIResponse [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_type}>"
class MissingStreamClassError(TypeError):
def __init__(self) -> None:
super().__init__(
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference", # noqa: E501
)
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "true"
kwargs["extra_headers"] = extra_headers
return cast(LegacyAPIResponse[R], func(*args, **kwargs))
return wrapped

View File

@ -1,48 +1,97 @@
from __future__ import annotations
from typing import Any, ClassVar, Union
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, ClassVar, Union, cast
import pydantic.generics
from httpx import Timeout
from pydantic import ConfigDict
from typing_extensions import TypedDict, Unpack
from typing_extensions import Required, TypedDict, Unpack, final
from ._base_type import Body, Headers, HttpxRequestFiles, NotGiven, Query
from ._utils import remove_notgiven_indict
from ._base_compat import PYDANTIC_V2, ConfigDict
from ._base_type import AnyMapping, Body, Headers, HttpxRequestFiles, NotGiven, Query
from ._constants import RAW_RESPONSE_HEADER
from ._utils import is_given, strip_not_given
class UserRequestInput(TypedDict, total=False):
headers: Headers
max_retries: int
timeout: float | Timeout | None
params: Query
extra_json: AnyMapping
class FinalRequestOptionsInput(TypedDict, total=False):
method: Required[str]
url: Required[str]
params: Query
headers: Headers
params: Query | None
max_retries: int
timeout: float | Timeout | None
files: HttpxRequestFiles | None
json_data: Body
extra_json: AnyMapping
class ClientRequestParam:
@final
class FinalRequestOptions(pydantic.BaseModel):
method: str
url: str
max_retries: Union[int, NotGiven] = NotGiven()
timeout: Union[float, NotGiven] = NotGiven()
headers: Union[Headers, NotGiven] = NotGiven()
json_data: Union[Body, None] = None
files: Union[HttpxRequestFiles, None] = None
params: Query = {}
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
headers: Union[Headers, NotGiven] = NotGiven()
max_retries: Union[int, NotGiven] = NotGiven()
timeout: Union[float, Timeout, None, NotGiven] = NotGiven()
files: Union[HttpxRequestFiles, None] = None
idempotency_key: Union[str, None] = None
post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
def get_max_retries(self, max_retries) -> int:
# It should be noted that we cannot use `json` here as that would override
# a BaseModel method in an incompatible fashion.
json_data: Union[Body, None] = None
extra_json: Union[AnyMapping, None] = None
if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
else:
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
arbitrary_types_allowed: bool = True
def get_max_retries(self, max_retries: int) -> int:
if isinstance(self.max_retries, NotGiven):
return max_retries
return self.max_retries
def _strip_raw_response_header(self) -> None:
if not is_given(self.headers):
return
if self.headers.get(RAW_RESPONSE_HEADER):
self.headers = {**self.headers}
self.headers.pop(RAW_RESPONSE_HEADER)
# override the `construct` method so that we can run custom transformations.
# this is necessary as we don't want to do any actual runtime type checking
# (which means we can't use validators) but we do want to ensure that `NotGiven`
# values are not present
#
# type ignore required because we're adding explicit types to `**values`
@classmethod
def construct( # type: ignore
cls,
_fields_set: set[str] | None = None,
**values: Unpack[UserRequestInput],
) -> ClientRequestParam:
kwargs: dict[str, Any] = {key: remove_notgiven_indict(value) for key, value in values.items()}
client = cls()
client.__dict__.update(kwargs)
) -> FinalRequestOptions:
kwargs: dict[str, Any] = {
# we unconditionally call `strip_not_given` on any value
# as it will just ignore any non-mapping types
key: strip_not_given(value)
for key, value in values.items()
}
if PYDANTIC_V2:
return super().model_construct(_fields_set, **kwargs)
return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
return client
model_construct = construct
if not TYPE_CHECKING:
# type checkers incorrectly complain about this assignment
model_construct = construct

View File

@ -1,87 +1,193 @@
from __future__ import annotations
import datetime
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, get_args, get_origin
import inspect
import logging
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, get_origin, overload
import httpx
import pydantic
from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, override
from ._base_models import BaseModel, is_basemodel
from ._base_type import NoneType
from ._sse_client import StreamResponse
from ._errors import APIResponseValidationError, ZhipuAIError
from ._sse_client import StreamResponse, extract_stream_chunk_type, is_stream_class_type
from ._utils import extract_type_arg, extract_type_var_from_base, is_annotated_type, is_given
if TYPE_CHECKING:
from ._http_client import HttpClient
from ._request_opt import FinalRequestOptions
P = ParamSpec("P")
R = TypeVar("R")
_T = TypeVar("_T")
_APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]")
log: logging.Logger = logging.getLogger(__name__)
class HttpResponse(Generic[R]):
class BaseAPIResponse(Generic[R]):
_cast_type: type[R]
_client: HttpClient
_parsed: R | None
_enable_stream: bool
_parsed_by_type: dict[type[Any], Any]
_is_sse_stream: bool
_stream_cls: type[StreamResponse[Any]]
_options: FinalRequestOptions
http_response: httpx.Response
def __init__(
self,
*,
raw_response: httpx.Response,
raw: httpx.Response,
cast_type: type[R],
client: HttpClient,
enable_stream: bool = False,
stream: bool,
stream_cls: type[StreamResponse[Any]] | None = None,
options: FinalRequestOptions,
) -> None:
self._cast_type = cast_type
self._client = client
self._parsed = None
self._parsed_by_type = {}
self._is_sse_stream = stream
self._stream_cls = stream_cls
self._enable_stream = enable_stream
self.http_response = raw_response
self._options = options
self.http_response = raw
def parse(self) -> R:
self._parsed = self._parse()
return self._parsed
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
# unwrap `Annotated[T, ...]` -> `T`
if to and is_annotated_type(to):
to = extract_type_arg(to, 0)
def _parse(self) -> R:
if self._enable_stream:
self._parsed = cast(
R,
self._stream_cls(
cast_type=cast(type, get_args(self._stream_cls)[0]),
response=self.http_response,
client=self._client,
),
)
return self._parsed
cast_type = self._cast_type
if cast_type is NoneType:
return cast(R, None)
http_response = self.http_response
if cast_type == str:
return cast(R, http_response.text)
if self._is_sse_stream:
if to:
if not is_stream_class_type(to):
raise TypeError(f"Expected custom parse type to be a subclass of {StreamResponse}")
content_type, *_ = http_response.headers.get("content-type", "application/json").split(";")
origin = get_origin(cast_type) or cast_type
if content_type != "application/json":
if issubclass(origin, pydantic.BaseModel):
data = http_response.json()
return self._client._process_response_data(
data=data,
cast_type=cast_type, # type: ignore
response=http_response,
return cast(
_T,
to(
cast_type=extract_stream_chunk_type(
to,
failure_message="Expected custom stream type to be passed with a type argument, e.g. StreamResponse[ChunkType]", # noqa: E501
),
response=self.http_response,
client=cast(Any, self._client),
),
)
return http_response.text
if self._stream_cls:
return cast(
R,
self._stream_cls(
cast_type=extract_stream_chunk_type(self._stream_cls),
response=self.http_response,
client=cast(Any, self._client),
),
)
data = http_response.json()
stream_cls = cast("type[Stream[Any]] | None", self._client._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return cast(
R,
stream_cls(
cast_type=self._cast_type,
response=self.http_response,
client=cast(Any, self._client),
),
)
cast_type = to if to is not None else self._cast_type
# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(cast_type):
cast_type = extract_type_arg(cast_type, 0)
if cast_type is NoneType:
return cast(R, None)
response = self.http_response
if cast_type == str:
return cast(R, response.text)
if cast_type == bytes:
return cast(R, response.content)
if cast_type == int:
return cast(R, int(response.text))
if cast_type == float:
return cast(R, float(response.text))
origin = get_origin(cast_type) or cast_type
# handle the legacy binary response case
if inspect.isclass(cast_type) and cast_type.__name__ == "HttpxBinaryResponseContent":
return cast(R, cast_type(response)) # type: ignore
if origin == APIResponse:
raise RuntimeError("Unexpected state - cast_type is `APIResponse`")
if inspect.isclass(origin) and issubclass(origin, httpx.Response):
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
# and pass that class to our request functions. We cannot change the variance to be either
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
# the response class ourselves but that is something that should be supported directly in httpx
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
if cast_type != httpx.Response:
raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_type`")
return cast(R, response)
if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):
raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
if (
cast_type is not object
and origin is not list
and origin is not dict
and origin is not Union
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
f"Unsupported type, expected {cast_type} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." # noqa: E501
)
# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type", "*").split(";")
if content_type != "application/json":
if is_basemodel(cast_type):
try:
data = response.json()
except Exception as exc:
log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
else:
return self._client._process_response_data(
data=data,
cast_type=cast_type, # type: ignore
response=response,
)
if self._client._strict_response_validation:
raise APIResponseValidationError(
response=response,
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", # noqa: E501
json_data=response.text,
)
# If the API responds with content that isn't JSON then we just return
# the (decoded) text without performing any parsing so that you can still
# handle the response however you need to.
return response.text # type: ignore
data = response.json()
return self._client._process_response_data(
data=data,
cast_type=cast_type, # type: ignore
response=http_response,
response=response,
)
@property
@ -90,6 +196,7 @@ class HttpResponse(Generic[R]):
@property
def http_request(self) -> httpx.Request:
"""Returns the httpx Request instance associated with the current response."""
return self.http_response.request
@property
@ -98,24 +205,194 @@ class HttpResponse(Generic[R]):
@property
def url(self) -> httpx.URL:
"""Returns the URL for which the request was made."""
return self.http_response.url
@property
def method(self) -> str:
return self.http_request.method
@property
def content(self) -> bytes:
return self.http_response.content
@property
def text(self) -> str:
return self.http_response.text
@property
def http_version(self) -> str:
return self.http_response.http_version
@property
def elapsed(self) -> datetime.timedelta:
"""The time taken for the complete request/response cycle to complete."""
return self.http_response.elapsed
@property
def is_closed(self) -> bool:
"""Whether or not the response body has been closed.
If this is False then there is response data that has not been read yet.
You must either fully consume the response body or call `.close()`
before discarding the response to prevent resource leaks.
"""
return self.http_response.is_closed
@override
def __repr__(self) -> str:
return f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_type}>" # noqa: E501
class APIResponse(BaseAPIResponse[R]):
@property
def request_id(self) -> str | None:
return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return]
@overload
def parse(self, *, to: type[_T]) -> _T: ...
@overload
def parse(self) -> R: ...
def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
You can customise the type that the response is parsed into through
the `to` argument, e.g.
```py
from openai import BaseModel
class MyModel(BaseModel):
foo: str
obj = response.parse(to=MyModel)
print(obj.foo)
```
We support parsing:
- `BaseModel`
- `dict`
- `list`
- `Union`
- `str`
- `int`
- `float`
- `httpx.Response`
"""
cache_key = to if to is not None else self._cast_type
cached = self._parsed_by_type.get(cache_key)
if cached is not None:
return cached # type: ignore[no-any-return]
if not self._is_sse_stream:
self.read()
parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)
self._parsed_by_type[cache_key] = parsed
return parsed
def read(self) -> bytes:
"""Read and return the binary response content."""
try:
return self.http_response.read()
except httpx.StreamConsumed as exc:
# The default error raised by httpx isn't very
# helpful in our case so we re-raise it with
# a different error message.
raise StreamAlreadyConsumed() from exc
def text(self) -> str:
"""Read and decode the response content into a string."""
self.read()
return self.http_response.text
def json(self) -> object:
"""Read and decode the JSON response content."""
self.read()
return self.http_response.json()
def close(self) -> None:
"""Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self.http_response.close()
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
"""
A byte-iterator over the decoded response content.
This automatically handles gzip, deflate and brotli encoded responses.
"""
yield from self.http_response.iter_bytes(chunk_size)
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
"""A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
yield from self.http_response.iter_text(chunk_size)
def iter_lines(self) -> Iterator[str]:
"""Like `iter_text()` but will only yield chunks for each line"""
yield from self.http_response.iter_lines()
class MissingStreamClassError(TypeError):
def __init__(self) -> None:
super().__init__(
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference", # noqa: E501
)
class StreamAlreadyConsumed(ZhipuAIError): # noqa: N818
"""
Attempted to read or stream content, but the content has already
been streamed.
This can happen if you use a method like `.iter_lines()` and then attempt
to read th entire response body afterwards, e.g.
```py
response = await client.post(...)
async for line in response.iter_lines():
... # do something with `line`
content = await response.read()
# ^ error
```
If you want this behaviour you'll need to either manually accumulate the response
content or call `await response.read()` before iterating over the stream.
"""
def __init__(self) -> None:
message = (
"Attempted to read or stream some content, but the content has "
"already been streamed. "
"This could be due to attempting to stream the response "
"content more than once."
"\n\n"
"You can fix this by manually accumulating the response content while streaming "
"or by calling `.read()` before starting to stream."
)
super().__init__(message)
def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type:
"""Given a type like `APIResponse[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyResponse(APIResponse[bytes]):
...
extract_response_type(MyResponse) -> bytes
```
"""
return extract_type_var_from_base(
typ,
generic_bases=cast("tuple[type, ...]", (BaseAPIResponse, APIResponse)),
index=0,
)

View File

@ -1,13 +1,16 @@
from __future__ import annotations
import inspect
import json
from collections.abc import Iterator, Mapping
from typing import TYPE_CHECKING, Generic
from typing import TYPE_CHECKING, Generic, TypeGuard, cast
import httpx
from . import get_origin
from ._base_type import ResponseT
from ._errors import APIResponseError
from ._utils import extract_type_var_from_base, is_mapping
_FIELD_SEPARATOR = ":"
@ -53,8 +56,41 @@ class StreamResponse(Generic[ResponseT]):
request=self.response.request,
json_data=data["error"],
)
if sse.event is None:
data = sse.json_data()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"
raise APIResponseError(
message=message,
request=self.response.request,
json_data=data["error"],
)
yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response)
else:
data = sse.json_data()
if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"
raise APIResponseError(
message=message,
request=self.response.request,
json_data=data["error"],
)
yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response)
for sse in iterator:
pass
@ -138,3 +174,33 @@ class SSELineParser:
except (TypeError, ValueError):
pass
return
def is_stream_class_type(typ: type) -> TypeGuard[type[StreamResponse[object]]]:
"""TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
origin = get_origin(typ) or typ
return inspect.isclass(origin) and issubclass(origin, StreamResponse)
def extract_stream_chunk_type(
stream_cls: type,
*,
failure_message: str | None = None,
) -> type:
"""Given a type like `StreamResponse[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyStream(StreamResponse[bytes]):
...
extract_stream_chunk_type(MyStream) -> bytes
```
"""
return extract_type_var_from_base(
stream_cls,
index=0,
generic_bases=cast("tuple[type, ...]", (StreamResponse,)),
failure_message=failure_message,
)

View File

@ -1,19 +0,0 @@
from __future__ import annotations
from collections.abc import Iterable, Mapping
from typing import TypeVar
from ._base_type import NotGiven
def remove_notgiven_indict(obj):
if obj is None or (not isinstance(obj, Mapping)):
return obj
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
_T = TypeVar("_T")
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
return [item for sublist in t for item in sublist]

View File

@ -0,0 +1,52 @@
from ._utils import ( # noqa: I001
remove_notgiven_indict as remove_notgiven_indict, # noqa: PLC0414
flatten as flatten, # noqa: PLC0414
is_dict as is_dict, # noqa: PLC0414
is_list as is_list, # noqa: PLC0414
is_given as is_given, # noqa: PLC0414
is_tuple as is_tuple, # noqa: PLC0414
is_mapping as is_mapping, # noqa: PLC0414
is_tuple_t as is_tuple_t, # noqa: PLC0414
parse_date as parse_date, # noqa: PLC0414
is_iterable as is_iterable, # noqa: PLC0414
is_sequence as is_sequence, # noqa: PLC0414
coerce_float as coerce_float, # noqa: PLC0414
is_mapping_t as is_mapping_t, # noqa: PLC0414
removeprefix as removeprefix, # noqa: PLC0414
removesuffix as removesuffix, # noqa: PLC0414
extract_files as extract_files, # noqa: PLC0414
is_sequence_t as is_sequence_t, # noqa: PLC0414
required_args as required_args, # noqa: PLC0414
coerce_boolean as coerce_boolean, # noqa: PLC0414
coerce_integer as coerce_integer, # noqa: PLC0414
file_from_path as file_from_path, # noqa: PLC0414
parse_datetime as parse_datetime, # noqa: PLC0414
strip_not_given as strip_not_given, # noqa: PLC0414
deepcopy_minimal as deepcopy_minimal, # noqa: PLC0414
get_async_library as get_async_library, # noqa: PLC0414
maybe_coerce_float as maybe_coerce_float, # noqa: PLC0414
get_required_header as get_required_header, # noqa: PLC0414
maybe_coerce_boolean as maybe_coerce_boolean, # noqa: PLC0414
maybe_coerce_integer as maybe_coerce_integer, # noqa: PLC0414
drop_prefix_image_data as drop_prefix_image_data, # noqa: PLC0414
)
from ._typing import (
is_list_type as is_list_type, # noqa: PLC0414
is_union_type as is_union_type, # noqa: PLC0414
extract_type_arg as extract_type_arg, # noqa: PLC0414
is_iterable_type as is_iterable_type, # noqa: PLC0414
is_required_type as is_required_type, # noqa: PLC0414
is_annotated_type as is_annotated_type, # noqa: PLC0414
strip_annotated_type as strip_annotated_type, # noqa: PLC0414
extract_type_var_from_base as extract_type_var_from_base, # noqa: PLC0414
)
from ._transform import (
PropertyInfo as PropertyInfo, # noqa: PLC0414
transform as transform, # noqa: PLC0414
async_transform as async_transform, # noqa: PLC0414
maybe_transform as maybe_transform, # noqa: PLC0414
async_maybe_transform as async_maybe_transform, # noqa: PLC0414
)

View File

@ -0,0 +1,383 @@
from __future__ import annotations
import base64
import io
import pathlib
from collections.abc import Mapping
from datetime import date, datetime
from typing import Any, Literal, TypeVar, cast, get_args, get_type_hints
import anyio
import pydantic
from typing_extensions import override
from .._base_compat import is_typeddict, model_dump
from .._files import is_base64_file_input
from ._typing import (
extract_type_arg,
is_annotated_type,
is_iterable_type,
is_list_type,
is_required_type,
is_union_type,
strip_annotated_type,
)
from ._utils import (
is_iterable,
is_list,
is_mapping,
)
_T = TypeVar("_T")
# TODO: support for drilling globals() and locals()
# TODO: ensure works correctly with forward references in all cases
PropertyFormat = Literal["iso8601", "base64", "custom"]
class PropertyInfo:
"""Metadata class to be used in Annotated types to provide information about a given type.
For example:
class MyParams(TypedDict):
account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
""" # noqa: E501
alias: str | None
format: PropertyFormat | None
format_template: str | None
discriminator: str | None
def __init__(
self,
*,
alias: str | None = None,
format: PropertyFormat | None = None,
format_template: str | None = None,
discriminator: str | None = None,
) -> None:
self.alias = alias
self.format = format
self.format_template = format_template
self.discriminator = discriminator
@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')" # noqa: E501
def maybe_transform(
data: object,
expected_type: object,
) -> Any | None:
"""Wrapper over `transform()` that allows `None` to be passed.
See `transform()` for more details.
"""
if data is None:
return None
return transform(data, expected_type)
# Wrapper over _transform_recursive providing fake types
def transform(
data: _T,
expected_type: object,
) -> _T:
"""Transform dictionaries based off of type information from the given type, for example:
```py
class Params(TypedDict, total=False):
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
transformed = transform({"card_id": "<my card ID>"}, Params)
# {'cardID': '<my card ID>'}
```
Any keys / data that does not have type information given will be included as is.
It should be noted that the transformations that this function does are not represented in the type system.
"""
transformed = _transform_recursive(data, annotation=cast(type, expected_type))
return cast(_T, transformed)
def _get_annotated_type(type_: type) -> type | None:
"""If the given type is an `Annotated` type then it is returned, if not `None` is returned.
This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
"""
if is_required_type(type_):
# Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
type_ = get_args(type_)[0]
if is_annotated_type(type_):
return type_
return None
def _maybe_transform_key(key: str, type_: type) -> str:
"""Transform the given `data` based on the annotations provided in `type_`.
Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata.
"""
annotated_type = _get_annotated_type(type_)
if annotated_type is None:
# no `Annotated` definition for this type, no transformation needed
return key
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
return annotation.alias
return key
def _transform_recursive(
data: object,
*,
annotation: type,
inner_type: type | None = None,
) -> object:
"""Transform the given data against the expected type.
Args:
annotation: The direct type annotation given to the particular piece of data.
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
the list can be transformed using the metadata from the container type.
Defaults to the same value as the `annotation` argument.
"""
if inner_type is None:
inner_type = annotation
stripped_type = strip_annotated_type(inner_type)
if is_typeddict(stripped_type) and is_mapping(data):
return _transform_typeddict(data, stripped_type)
if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
):
inner_type = extract_type_arg(stripped_type, 0)
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
return data
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True)
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return _format_data(data, annotation.format, annotation.format_template)
return data
def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
if isinstance(data, date | datetime):
if format_ == "iso8601":
return data.isoformat()
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)
if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None
if isinstance(data, pathlib.Path):
binary = data.read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()
if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()
if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
return base64.b64encode(binary).decode("ascii")
return data
def _transform_typeddict(
data: Mapping[str, object],
expected_type: type,
) -> Mapping[str, object]:
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
return result
async def async_maybe_transform(
data: object,
expected_type: object,
) -> Any | None:
"""Wrapper over `async_transform()` that allows `None` to be passed.
See `async_transform()` for more details.
"""
if data is None:
return None
return await async_transform(data, expected_type)
async def async_transform(
data: _T,
expected_type: object,
) -> _T:
"""Transform dictionaries based off of type information from the given type, for example:
```py
class Params(TypedDict, total=False):
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
transformed = transform({"card_id": "<my card ID>"}, Params)
# {'cardID': '<my card ID>'}
```
Any keys / data that does not have type information given will be included as is.
It should be noted that the transformations that this function does are not represented in the type system.
"""
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
return cast(_T, transformed)
async def _async_transform_recursive(
data: object,
*,
annotation: type,
inner_type: type | None = None,
) -> object:
"""Transform the given data against the expected type.
Args:
annotation: The direct type annotation given to the particular piece of data.
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
the list can be transformed using the metadata from the container type.
Defaults to the same value as the `annotation` argument.
"""
if inner_type is None:
inner_type = annotation
stripped_type = strip_annotated_type(inner_type)
if is_typeddict(stripped_type) and is_mapping(data):
return await _async_transform_typeddict(data, stripped_type)
if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
):
inner_type = extract_type_arg(stripped_type, 0)
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
return data
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True)
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return await _async_format_data(data, annotation.format, annotation.format_template)
return data
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
if isinstance(data, date | datetime):
if format_ == "iso8601":
return data.isoformat()
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)
if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None
if isinstance(data, pathlib.Path):
binary = await anyio.Path(data).read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()
if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()
if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
return base64.b64encode(binary).decode("ascii")
return data
async def _async_transform_typeddict(
data: Mapping[str, object],
expected_type: type,
) -> Mapping[str, object]:
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
return result

View File

@ -0,0 +1,122 @@
from __future__ import annotations
from collections import abc as _c_abc
from collections.abc import Iterable
from typing import Annotated, Any, TypeVar, cast, get_args, get_origin
from typing_extensions import Required
from .._base_compat import is_union as _is_union
from .._base_type import InheritsGeneric
def is_annotated_type(typ: type) -> bool:
return get_origin(typ) == Annotated
def is_list_type(typ: type) -> bool:
return (get_origin(typ) or typ) == list
def is_iterable_type(typ: type) -> bool:
"""If the given type is `typing.Iterable[T]`"""
origin = get_origin(typ) or typ
return origin in {Iterable, _c_abc.Iterable}
def is_union_type(typ: type) -> bool:
return _is_union(get_origin(typ))
def is_required_type(typ: type) -> bool:
return get_origin(typ) == Required
def is_typevar(typ: type) -> bool:
# type ignore is required because type checkers
# think this expression will always return False
return type(typ) == TypeVar # type: ignore
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
return strip_annotated_type(cast(type, get_args(typ)[0]))
return typ
def extract_type_arg(typ: type, index: int) -> type:
args = get_args(typ)
try:
return cast(type, args[index])
except IndexError as err:
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
def extract_type_var_from_base(
typ: type,
*,
generic_bases: tuple[type, ...],
index: int,
failure_message: str | None = None,
) -> type:
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyResponse(Foo[bytes]):
...
extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
```
And where a generic subclass is given:
```py
_T = TypeVar('_T')
class MyResponse(Foo[_T]):
...
extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes
```
"""
cls = cast(object, get_origin(typ) or typ)
if cls in generic_bases:
# we're given the class directly
return extract_type_arg(typ, index)
# if a subclass is given
# ---
# this is needed as __orig_bases__ is not present in the typeshed stubs
# because it is intended to be for internal use only, however there does
# not seem to be a way to resolve generic TypeVars for inherited subclasses
# without using it.
if isinstance(cls, InheritsGeneric):
target_base_class: Any | None = None
for base in cls.__orig_bases__:
if base.__origin__ in generic_bases:
target_base_class = base
break
if target_base_class is None:
raise RuntimeError(
"Could not find the generic base class;\n"
"This should never happen;\n"
f"Does {cls} inherit from one of {generic_bases} ?"
)
extracted = extract_type_arg(target_base_class, index)
if is_typevar(extracted):
# If the extracted type argument is itself a type variable
# then that means the subclass itself is generic, so we have
# to resolve the type argument from the class itself, not
# the base class.
#
# Note: if there is more than 1 type argument, the subclass could
# change the ordering of the type arguments, this is not currently
# supported.
return extract_type_arg(typ, index)
return extracted
raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")

View File

@ -0,0 +1,409 @@
from __future__ import annotations
import functools
import inspect
import os
import re
from collections.abc import Callable, Iterable, Mapping, Sequence
from pathlib import Path
from typing import (
Any,
TypeGuard,
TypeVar,
Union,
cast,
overload,
)
import sniffio
from .._base_compat import parse_date as parse_date # noqa: PLC0414
from .._base_compat import parse_datetime as parse_datetime # noqa: PLC0414
from .._base_type import FileTypes, Headers, HeadersLike, NotGiven, NotGivenOr
def remove_notgiven_indict(obj):
if obj is None or (not isinstance(obj, Mapping)):
return obj
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
_T = TypeVar("_T")
_TupleT = TypeVar("_TupleT", bound=tuple[object, ...])
_MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
return [item for sublist in t for item in sublist]
def extract_files(
# TODO: this needs to take Dict but variance issues.....
# create protocol type ?
query: Mapping[str, object],
*,
paths: Sequence[Sequence[str]],
) -> list[tuple[str, FileTypes]]:
"""Recursively extract files from the given dictionary based on specified paths.
A path may look like this ['foo', 'files', '<array>', 'data'].
Note: this mutates the given dictionary.
"""
files: list[tuple[str, FileTypes]] = []
for path in paths:
files.extend(_extract_items(query, path, index=0, flattened_key=None))
return files
def _extract_items(
obj: object,
path: Sequence[str],
*,
index: int,
flattened_key: str | None,
) -> list[tuple[str, FileTypes]]:
try:
key = path[index]
except IndexError:
if isinstance(obj, NotGiven):
# no value was provided - we can safely ignore
return []
# cyclical import
from .._files import assert_is_file_content
# We have exhausted the path, return the entry we found.
assert_is_file_content(obj, key=flattened_key)
assert flattened_key is not None
return [(flattened_key, cast(FileTypes, obj))]
index += 1
if is_dict(obj):
try:
# We are at the last entry in the path so we must remove the field
if (len(path)) == index:
item = obj.pop(key)
else:
item = obj[key]
except KeyError:
# Key was not present in the dictionary, this is not indicative of an error
# as the given path may not point to a required field. We also do not want
# to enforce required fields as the API may differ from the spec in some cases.
return []
if flattened_key is None:
flattened_key = key
else:
flattened_key += f"[{key}]"
return _extract_items(
item,
path,
index=index,
flattened_key=flattened_key,
)
elif is_list(obj):
if key != "<array>":
return []
return flatten(
[
_extract_items(
item,
path,
index=index,
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
)
for item in obj
]
)
# Something unexpected was passed, just ignore it.
return []
def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]:
return not isinstance(obj, NotGiven)
# Type safe methods for narrowing types with TypeVars.
# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
# however this cause Pyright to rightfully report errors. As we know we don't
# care about the contained types we can safely use `object` in it's place.
#
# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
# `is_*` is for when you're dealing with an unknown input
# `is_*_t` is for when you're narrowing a known union type to a specific subset
def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
return isinstance(obj, tuple)
def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
return isinstance(obj, tuple)
def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
return isinstance(obj, Sequence)
def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
return isinstance(obj, Sequence)
def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
return isinstance(obj, Mapping)
def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
return isinstance(obj, Mapping)
def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
return isinstance(obj, dict)
def is_list(obj: object) -> TypeGuard[list[object]]:
return isinstance(obj, list)
def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
return isinstance(obj, Iterable)
def deepcopy_minimal(item: _T) -> _T:
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
- mappings, e.g. `dict`
- list
This is done for performance reasons.
"""
if is_mapping(item):
return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
if is_list(item):
return cast(_T, [deepcopy_minimal(entry) for entry in item])
return item
# copied from https://github.com/Rapptz/RoboDanny
def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
size = len(seq)
if size == 0:
return ""
if size == 1:
return seq[0]
if size == 2:
return f"{seq[0]} {final} {seq[1]}"
return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
def quote(string: str) -> str:
"""Add single quotation marks around the given string. Does *not* do any escaping."""
return f"'{string}'"
def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
"""Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
Useful for enforcing runtime validation of overloaded functions.
Example usage:
```py
@overload
def foo(*, a: str) -> str:
...
@overload
def foo(*, b: bool) -> str:
...
# This enforces the same constraints that a static type checker would
# i.e. that either a or b must be passed to the function
@required_args(["a"], ["b"])
def foo(*, a: str | None = None, b: bool | None = None) -> str:
...
```
"""
def inner(func: CallableT) -> CallableT:
params = inspect.signature(func).parameters
positional = [
name
for name, param in params.items()
if param.kind
in {
param.POSITIONAL_ONLY,
param.POSITIONAL_OR_KEYWORD,
}
]
@functools.wraps(func)
def wrapper(*args: object, **kwargs: object) -> object:
given_params: set[str] = set()
for i, _ in enumerate(args):
try:
given_params.add(positional[i])
except IndexError:
raise TypeError(
f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
) from None
given_params.update(kwargs.keys())
for variant in variants:
matches = all(param in given_params for param in variant)
if matches:
break
else: # no break
if len(variants) > 1:
variations = human_join(
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
)
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
else:
# TODO: this error message is not deterministic
missing = list(set(variants[0]) - given_params)
if len(missing) > 1:
msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
else:
msg = f"Missing required argument: {quote(missing[0])}"
raise TypeError(msg)
return func(*args, **kwargs)
return wrapper # type: ignore
return inner
_K = TypeVar("_K")
_V = TypeVar("_V")
@overload
def strip_not_given(obj: None) -> None: ...
@overload
def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ...
@overload
def strip_not_given(obj: object) -> object: ...
def strip_not_given(obj: object | None) -> object:
"""Remove all top-level keys where their values are instances of `NotGiven`"""
if obj is None:
return None
if not is_mapping(obj):
return obj
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
def coerce_integer(val: str) -> int:
return int(val, base=10)
def coerce_float(val: str) -> float:
return float(val)
def coerce_boolean(val: str) -> bool:
return val in {"true", "1", "on"}
def maybe_coerce_integer(val: str | None) -> int | None:
if val is None:
return None
return coerce_integer(val)
def maybe_coerce_float(val: str | None) -> float | None:
if val is None:
return None
return coerce_float(val)
def maybe_coerce_boolean(val: str | None) -> bool | None:
if val is None:
return None
return coerce_boolean(val)
def removeprefix(string: str, prefix: str) -> str:
"""Remove a prefix from a string.
Backport of `str.removeprefix` for Python < 3.9
"""
if string.startswith(prefix):
return string[len(prefix) :]
return string
def removesuffix(string: str, suffix: str) -> str:
"""Remove a suffix from a string.
Backport of `str.removesuffix` for Python < 3.9
"""
if string.endswith(suffix):
return string[: -len(suffix)]
return string
def file_from_path(path: str) -> FileTypes:
contents = Path(path).read_bytes()
file_name = os.path.basename(path)
return (file_name, contents)
def get_required_header(headers: HeadersLike, header: str) -> str:
lower_header = header.lower()
if isinstance(headers, Mapping):
headers = cast(Headers, headers)
for k, v in headers.items():
if k.lower() == lower_header and isinstance(v, str):
return v
""" to deal with the case where the header looks like Stainless-Event-Id """
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
value = headers.get(normalized_header)
if value:
return value
raise ValueError(f"Could not find {header} header")
def get_async_library() -> str:
try:
return sniffio.current_async_library()
except Exception:
return "false"
def drop_prefix_image_data(content: Union[str, list[dict]]) -> Union[str, list[dict]]:
"""
删除 ;base64, 前缀
:param image_data:
:return:
"""
if isinstance(content, list):
for data in content:
if data.get("type") == "image_url":
image_data = data.get("image_url").get("url")
if image_data.startswith("data:image/"):
image_data = image_data.split("base64,")[-1]
data["image_url"]["url"] = image_data
return content

View File

@ -0,0 +1,78 @@
import logging
import os
import time
logger = logging.getLogger(__name__)
class LoggerNameFilter(logging.Filter):
def filter(self, record):
# return record.name.startswith("loom_core") or record.name in "ERROR" or (
# record.name.startswith("uvicorn.error")
# and record.getMessage().startswith("Uvicorn running on")
# )
return True
def get_log_file(log_path: str, sub_dir: str):
"""
sub_dir should contain a timestamp.
"""
log_dir = os.path.join(log_path, sub_dir)
# Here should be creating a new directory each time, so `exist_ok=False`
os.makedirs(log_dir, exist_ok=False)
return os.path.join(log_dir, "zhipuai.log")
def get_config_dict(log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int) -> dict:
# for windows, the path should be a raw string.
log_file_path = log_file_path.encode("unicode-escape").decode() if os.name == "nt" else log_file_path
log_level = log_level.upper()
config_dict = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"formatter": {"format": ("%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s")},
},
"filters": {
"logger_name_filter": {
"()": __name__ + ".LoggerNameFilter",
},
},
"handlers": {
"stream_handler": {
"class": "logging.StreamHandler",
"formatter": "formatter",
"level": log_level,
# "stream": "ext://sys.stdout",
# "filters": ["logger_name_filter"],
},
"file_handler": {
"class": "logging.handlers.RotatingFileHandler",
"formatter": "formatter",
"level": log_level,
"filename": log_file_path,
"mode": "a",
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf8",
},
},
"loggers": {
"loom_core": {
"handlers": ["stream_handler", "file_handler"],
"level": log_level,
"propagate": False,
}
},
"root": {
"level": log_level,
"handlers": ["stream_handler", "file_handler"],
},
}
return config_dict
def get_timestamp_ms():
t = time.time()
return int(round(t * 1000))

View File

@ -0,0 +1,62 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Any, Generic, Optional, TypeVar, cast
from typing_extensions import Protocol, override, runtime_checkable
from ._http_client import BasePage, BaseSyncPage, PageInfo
__all__ = ["SyncPage", "SyncCursorPage"]
_T = TypeVar("_T")
@runtime_checkable
class CursorPageItem(Protocol):
id: Optional[str]
class SyncPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
"""Note: no pagination actually occurs yet, this is for forwards-compatibility."""
data: list[_T]
object: str
@override
def _get_page_items(self) -> list[_T]:
data = self.data
if not data:
return []
return data
@override
def next_page_info(self) -> None:
"""
This page represents a response that isn't actually paginated at the API level
so there will never be a next page.
"""
return None
class SyncCursorPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
data: list[_T]
@override
def _get_page_items(self) -> list[_T]:
data = self.data
if not data:
return []
return data
@override
def next_page_info(self) -> Optional[PageInfo]:
data = self.data
if not data:
return None
item = cast(Any, data[-1])
if not isinstance(item, CursorPageItem) or item.id is None:
# TODO emit warning log
return None
return PageInfo(params={"after": item.id})

View File

@ -0,0 +1,5 @@
from .assistant_completion import AssistantCompletion
__all__ = [
"AssistantCompletion",
]

View File

@ -0,0 +1,40 @@
from typing import Any, Optional
from ...core import BaseModel
from .message import MessageContent
__all__ = ["AssistantCompletion", "CompletionUsage"]
class ErrorInfo(BaseModel):
code: str # 错误码
message: str # 错误信息
class AssistantChoice(BaseModel):
index: int # 结果下标
delta: MessageContent # 当前会话输出消息体
finish_reason: str
"""
# 推理结束原因 stop代表推理自然结束或触发停止词。 sensitive 代表模型推理内容被安全审核接口拦截。请注意,针对此类内容,请用户自行判断并决定是否撤回已公开的内容。
# network_error 代表模型推理服务异常。
""" # noqa: E501
metadata: dict # 元信息,拓展字段
class CompletionUsage(BaseModel):
prompt_tokens: int # 输入的 tokens 数量
completion_tokens: int # 输出的 tokens 数量
total_tokens: int # 总 tokens 数量
class AssistantCompletion(BaseModel):
id: str # 请求 ID
conversation_id: str # 会话 ID
assistant_id: str # 智能体 ID
created: int # 请求创建时间Unix 时间戳
status: str # 返回状态,包括:`completed` 表示生成结束`in_progress`表示生成中 `failed` 表示生成异常
last_error: Optional[ErrorInfo] # 异常信息
choices: list[AssistantChoice] # 增量返回的信息
metadata: Optional[dict[str, Any]] # 元信息,拓展字段
usage: Optional[CompletionUsage] # tokens 数量统计

View File

@ -0,0 +1,7 @@
from typing import TypedDict
class ConversationParameters(TypedDict, total=False):
assistant_id: str # 智能体 ID
page: int # 当前分页
page_size: int # 分页数量

View File

@ -0,0 +1,29 @@
from ...core import BaseModel
__all__ = ["ConversationUsageListResp"]
class Usage(BaseModel):
prompt_tokens: int # 用户输入的 tokens 数量
completion_tokens: int # 模型输入的 tokens 数量
total_tokens: int # 总 tokens 数量
class ConversationUsage(BaseModel):
id: str # 会话 id
assistant_id: str # 智能体Assistant id
create_time: int # 创建时间
update_time: int # 更新时间
usage: Usage # 会话中 tokens 数量统计
class ConversationUsageList(BaseModel):
assistant_id: str # 智能体id
has_more: bool # 是否还有更多页
conversation_list: list[ConversationUsage] # 返回的
class ConversationUsageListResp(BaseModel):
code: int
msg: str
data: ConversationUsageList

View File

@ -0,0 +1,32 @@
from typing import Optional, TypedDict, Union
class AssistantAttachments:
file_id: str
class MessageTextContent:
type: str # 目前支持 type = text
text: str
MessageContent = Union[MessageTextContent]
class ConversationMessage(TypedDict):
"""会话消息体"""
role: str # 用户的输入角色,例如 'user'
content: list[MessageContent] # 会话消息体的内容
class AssistantParameters(TypedDict, total=False):
"""智能体参数类"""
assistant_id: str # 智能体 ID
conversation_id: Optional[str] # 会话 ID不传则创建新会话
model: str # 模型名称,默认为 'GLM-4-Assistant'
stream: bool # 是否支持流式 SSE需要传入 True
messages: list[ConversationMessage] # 会话消息体
attachments: Optional[list[AssistantAttachments]] # 会话指定的文件,非必填
metadata: Optional[dict] # 元信息,拓展字段,非必填

View File

@ -0,0 +1,21 @@
from ...core import BaseModel
__all__ = ["AssistantSupportResp"]
class AssistantSupport(BaseModel):
assistant_id: str # 智能体的 Assistant id用于智能体会话
created_at: int # 创建时间
updated_at: int # 更新时间
name: str # 智能体名称
avatar: str # 智能体头像
description: str # 智能体描述
status: str # 智能体状态,目前只有 publish
tools: list[str] # 智能体支持的工具名
starter_prompts: list[str] # 智能体启动推荐的 prompt
class AssistantSupportResp(BaseModel):
code: int
msg: str
data: list[AssistantSupport] # 智能体列表

View File

@ -0,0 +1,3 @@
from .message_content import MessageContent
__all__ = ["MessageContent"]

View File

@ -0,0 +1,13 @@
from typing import Annotated, TypeAlias, Union
from ....core._utils import PropertyInfo
from .text_content_block import TextContentBlock
from .tools_delta_block import ToolsDeltaBlock
__all__ = ["MessageContent"]
MessageContent: TypeAlias = Annotated[
Union[ToolsDeltaBlock, TextContentBlock],
PropertyInfo(discriminator="type"),
]

View File

@ -0,0 +1,14 @@
from typing import Literal
from ....core import BaseModel
__all__ = ["TextContentBlock"]
class TextContentBlock(BaseModel):
content: str
role: str = "assistant"
type: Literal["content"] = "content"
"""Always `content`."""

View File

@ -0,0 +1,27 @@
from typing import Literal
__all__ = ["CodeInterpreterToolBlock"]
from .....core import BaseModel
class CodeInterpreterToolOutput(BaseModel):
"""代码工具输出结果"""
type: str # 代码执行日志,目前只有 logs
logs: str # 代码执行的日志结果
error_msg: str # 错误信息
class CodeInterpreter(BaseModel):
"""代码解释器"""
input: str # 生成的代码片段,输入给代码沙盒
outputs: list[CodeInterpreterToolOutput] # 代码执行后的输出结果
class CodeInterpreterToolBlock(BaseModel):
"""代码工具块"""
code_interpreter: CodeInterpreter # 代码解释器对象
type: Literal["code_interpreter"] # 调用工具的类型,始终为 `code_interpreter`

View File

@ -0,0 +1,21 @@
from typing import Literal
from .....core import BaseModel
__all__ = ["DrawingToolBlock"]
class DrawingToolOutput(BaseModel):
image: str
class DrawingTool(BaseModel):
input: str
outputs: list[DrawingToolOutput]
class DrawingToolBlock(BaseModel):
drawing_tool: DrawingTool
type: Literal["drawing_tool"]
"""Always `drawing_tool`."""

View File

@ -0,0 +1,22 @@
from typing import Literal, Union
__all__ = ["FunctionToolBlock"]
from .....core import BaseModel
class FunctionToolOutput(BaseModel):
content: str
class FunctionTool(BaseModel):
name: str
arguments: Union[str, dict]
outputs: list[FunctionToolOutput]
class FunctionToolBlock(BaseModel):
function: FunctionTool
type: Literal["function"]
"""Always `drawing_tool`."""

View File

@ -0,0 +1,41 @@
from typing import Literal
from .....core import BaseModel
class RetrievalToolOutput(BaseModel):
"""
This class represents the output of a retrieval tool.
Attributes:
- text (str): The text snippet retrieved from the knowledge base.
- document (str): The name of the document from which the text snippet was retrieved, returned only in intelligent configuration.
""" # noqa: E501
text: str
document: str
class RetrievalTool(BaseModel):
"""
This class represents the outputs of a retrieval tool.
Attributes:
- outputs (List[RetrievalToolOutput]): A list of text snippets and their respective document names retrieved from the knowledge base.
""" # noqa: E501
outputs: list[RetrievalToolOutput]
class RetrievalToolBlock(BaseModel):
"""
This class represents a block for invoking the retrieval tool.
Attributes:
- retrieval (RetrievalTool): An instance of the RetrievalTool class containing the retrieval outputs.
- type (Literal["retrieval"]): The type of tool being used, always set to "retrieval".
"""
retrieval: RetrievalTool
type: Literal["retrieval"]
"""Always `retrieval`."""

View File

@ -0,0 +1,16 @@
from typing import Annotated, TypeAlias, Union
from .....core._utils import PropertyInfo
from .code_interpreter_delta_block import CodeInterpreterToolBlock
from .drawing_tool_delta_block import DrawingToolBlock
from .function_delta_block import FunctionToolBlock
from .retrieval_delta_black import RetrievalToolBlock
from .web_browser_delta_block import WebBrowserToolBlock
__all__ = ["ToolsType"]
ToolsType: TypeAlias = Annotated[
Union[DrawingToolBlock, CodeInterpreterToolBlock, WebBrowserToolBlock, RetrievalToolBlock, FunctionToolBlock],
PropertyInfo(discriminator="type"),
]

View File

@ -0,0 +1,48 @@
from typing import Literal
from .....core import BaseModel
__all__ = ["WebBrowserToolBlock"]
class WebBrowserOutput(BaseModel):
"""
This class represents the output of a web browser search result.
Attributes:
- title (str): The title of the search result.
- link (str): The URL link to the search result's webpage.
- content (str): The textual content extracted from the search result.
- error_msg (str): Any error message encountered during the search or retrieval process.
"""
title: str
link: str
content: str
error_msg: str
class WebBrowser(BaseModel):
"""
This class represents the input and outputs of a web browser search.
Attributes:
- input (str): The input query for the web browser search.
- outputs (List[WebBrowserOutput]): A list of search results returned by the web browser.
"""
input: str
outputs: list[WebBrowserOutput]
class WebBrowserToolBlock(BaseModel):
"""
This class represents a block for invoking the web browser tool.
Attributes:
- web_browser (WebBrowser): An instance of the WebBrowser class containing the search input and outputs.
- type (Literal["web_browser"]): The type of tool being used, always set to "web_browser".
"""
web_browser: WebBrowser
type: Literal["web_browser"]

View File

@ -0,0 +1,16 @@
from typing import Literal
from ....core import BaseModel
from .tools.tools_type import ToolsType
__all__ = ["ToolsDeltaBlock"]
class ToolsDeltaBlock(BaseModel):
tool_calls: list[ToolsType]
"""The index of the content part in the message."""
role: str = "tool"
type: Literal["tool_calls"] = "tool_calls"
"""Always `tool_calls`."""

View File

@ -0,0 +1,82 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
import builtins
from typing import Literal, Optional
from ..core import BaseModel
from .batch_error import BatchError
from .batch_request_counts import BatchRequestCounts
__all__ = ["Batch", "Errors"]
class Errors(BaseModel):
data: Optional[list[BatchError]] = None
object: Optional[str] = None
"""这个类型,一直是`list`。"""
class Batch(BaseModel):
id: str
completion_window: str
"""用于执行请求的地址信息。"""
created_at: int
"""这是 Unix timestamp (in seconds) 表示的创建时间。"""
endpoint: str
"""这是ZhipuAI endpoint的地址。"""
input_file_id: str
"""标记为batch的输入文件的ID。"""
object: Literal["batch"]
"""这个类型,一直是`batch`."""
status: Literal[
"validating", "failed", "in_progress", "finalizing", "completed", "expired", "cancelling", "cancelled"
]
"""batch 的状态。"""
cancelled_at: Optional[int] = None
"""Unix timestamp (in seconds) 表示的取消时间。"""
cancelling_at: Optional[int] = None
"""Unix timestamp (in seconds) 表示发起取消的请求时间 """
completed_at: Optional[int] = None
"""Unix timestamp (in seconds) 表示的完成时间。"""
error_file_id: Optional[str] = None
"""这个文件id包含了执行请求失败的请求的输出。"""
errors: Optional[Errors] = None
expired_at: Optional[int] = None
"""Unix timestamp (in seconds) 表示的将在过期时间。"""
expires_at: Optional[int] = None
"""Unix timestamp (in seconds) 触发过期"""
failed_at: Optional[int] = None
"""Unix timestamp (in seconds) 表示的失败时间。"""
finalizing_at: Optional[int] = None
"""Unix timestamp (in seconds) 表示的最终时间。"""
in_progress_at: Optional[int] = None
"""Unix timestamp (in seconds) 表示的开始处理时间。"""
metadata: Optional[builtins.object] = None
"""
key:value形式的元数据以便将信息存储
结构化格式键的长度是64个字符值最长512个字符
"""
output_file_id: Optional[str] = None
"""完成请求的输出文件的ID。"""
request_counts: Optional[BatchRequestCounts] = None
"""批次中不同状态的请求计数"""

View File

@ -0,0 +1,37 @@
from __future__ import annotations
from typing import Literal, Optional
from typing_extensions import Required, TypedDict
__all__ = ["BatchCreateParams"]
class BatchCreateParams(TypedDict, total=False):
completion_window: Required[str]
"""The time frame within which the batch should be processed.
Currently only `24h` is supported.
"""
endpoint: Required[Literal["/v1/chat/completions", "/v1/embeddings"]]
"""The endpoint to be used for all requests in the batch.
Currently `/v1/chat/completions` and `/v1/embeddings` are supported.
"""
input_file_id: Required[str]
"""The ID of an uploaded file that contains requests for the new batch.
See [upload file](https://platform.openai.com/docs/api-reference/files/create)
for how to upload a file.
Your input file must be formatted as a
[JSONL file](https://platform.openai.com/docs/api-reference/batch/requestInput),
and must be uploaded with the purpose `batch`.
"""
metadata: Optional[dict[str, str]]
"""Optional custom metadata for the batch."""
auto_delete_input_file: Optional[bool]

View File

@ -0,0 +1,21 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from ..core import BaseModel
__all__ = ["BatchError"]
class BatchError(BaseModel):
code: Optional[str] = None
"""定义的业务错误码"""
line: Optional[int] = None
"""文件中的行号"""
message: Optional[str] = None
"""关于对话文件中的错误的描述"""
param: Optional[str] = None
"""参数名称,如果有的话"""

View File

@ -0,0 +1,20 @@
from __future__ import annotations
from typing_extensions import TypedDict
__all__ = ["BatchListParams"]
class BatchListParams(TypedDict, total=False):
after: str
"""分页的游标,用于获取下一页的数据。
`after` 是一个指向当前页面的游标用于获取下一页的数据如果没有提供 `after`则返回第一页的数据
list.
"""
limit: int
"""这个参数用于限制返回的结果数量。
Limit 用于限制返回的结果数量默认值为 10
"""

View File

@ -0,0 +1,14 @@
from ..core import BaseModel
__all__ = ["BatchRequestCounts"]
class BatchRequestCounts(BaseModel):
completed: int
"""这个数字表示已经完成的请求。"""
failed: int
"""这个数字表示失败的请求。"""
total: int
"""这个数字表示总的请求。"""

View File

@ -1,10 +1,9 @@
from typing import Optional
from pydantic import BaseModel
from ...core import BaseModel
from .chat_completion import CompletionChoice, CompletionUsage
__all__ = ["AsyncTaskStatus"]
__all__ = ["AsyncTaskStatus", "AsyncCompletion"]
class AsyncTaskStatus(BaseModel):

View File

@ -1,6 +1,6 @@
from typing import Optional
from pydantic import BaseModel
from ...core import BaseModel
__all__ = ["Completion", "CompletionUsage"]

View File

@ -1,8 +1,9 @@
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel
from ...core import BaseModel
__all__ = [
"CompletionUsage",
"ChatCompletionChunk",
"Choice",
"ChoiceDelta",
@ -53,3 +54,4 @@ class ChatCompletionChunk(BaseModel):
created: Optional[int] = None
model: Optional[str] = None
usage: Optional[CompletionUsage] = None
extra_json: dict[str, Any]

View File

@ -0,0 +1,146 @@
from typing import Literal, Optional
from typing_extensions import Required, TypedDict
__all__ = [
"CodeGeexTarget",
"CodeGeexContext",
"CodeGeexExtra",
]
class CodeGeexTarget(TypedDict, total=False):
"""补全的内容参数"""
path: Optional[str]
"""文件路径"""
language: Required[
Literal[
"c",
"c++",
"cpp",
"c#",
"csharp",
"c-sharp",
"css",
"cuda",
"dart",
"lua",
"objectivec",
"objective-c",
"objective-c++",
"python",
"perl",
"prolog",
"swift",
"lisp",
"java",
"scala",
"tex",
"jsx",
"tsx",
"vue",
"markdown",
"html",
"php",
"js",
"javascript",
"typescript",
"go",
"shell",
"rust",
"sql",
"kotlin",
"vb",
"ruby",
"pascal",
"r",
"fortran",
"lean",
"matlab",
"delphi",
"scheme",
"basic",
"assembly",
"groovy",
"abap",
"gdscript",
"haskell",
"julia",
"elixir",
"excel",
"clojure",
"actionscript",
"solidity",
"powershell",
"erlang",
"cobol",
"alloy",
"awk",
"thrift",
"sparql",
"augeas",
"cmake",
"f-sharp",
"stan",
"isabelle",
"dockerfile",
"rmarkdown",
"literate-agda",
"tcl",
"glsl",
"antlr",
"verilog",
"racket",
"standard-ml",
"elm",
"yaml",
"smalltalk",
"ocaml",
"idris",
"visual-basic",
"protocol-buffer",
"bluespec",
"applescript",
"makefile",
"tcsh",
"maple",
"systemverilog",
"literate-coffeescript",
"vhdl",
"restructuredtext",
"sas",
"literate-haskell",
"java-server-pages",
"coffeescript",
"emacs-lisp",
"mathematica",
"xslt",
"zig",
"common-lisp",
"stata",
"agda",
"ada",
]
]
"""代码语言类型如python"""
code_prefix: Required[str]
"""补全位置的前文"""
code_suffix: Required[str]
"""补全位置的后文"""
class CodeGeexContext(TypedDict, total=False):
"""附加代码"""
path: Required[str]
"""附加代码文件的路径"""
code: Required[str]
"""附加的代码内容"""
class CodeGeexExtra(TypedDict, total=False):
target: Required[CodeGeexTarget]
"""补全的内容参数"""
contexts: Optional[list[CodeGeexContext]]
"""附加代码"""

View File

@ -2,8 +2,7 @@ from __future__ import annotations
from typing import Optional
from pydantic import BaseModel
from ..core import BaseModel
from .chat.chat_completion import CompletionUsage
__all__ = ["Embedding", "EmbeddingsResponded"]

View File

@ -0,0 +1,5 @@
from .file_deleted import FileDeleted
from .file_object import FileObject, ListOfFileObject
from .upload_detail import UploadDetail
__all__ = ["FileObject", "ListOfFileObject", "UploadDetail", "FileDeleted"]

View File

@ -0,0 +1,38 @@
from __future__ import annotations
from typing import Literal, Optional
from typing_extensions import Required, TypedDict
__all__ = ["FileCreateParams"]
from ...core import FileTypes
from . import UploadDetail
class FileCreateParams(TypedDict, total=False):
file: FileTypes
"""file和 upload_detail二选一必填"""
upload_detail: list[UploadDetail]
"""file和 upload_detail二选一必填"""
purpose: Required[Literal["fine-tune", "retrieval", "batch"]]
"""
上传文件的用途支持 "fine-tune和 "retrieval"
retrieval支持上传DocDocxPDFXlsxURL类型文件且单个文件的大小不超过 5MB
fine-tune支持上传.jsonl文件且当前单个文件的大小最大可为 100 MB 文件中语料格式需满足微调指南中所描述的格式
"""
custom_separator: Optional[list[str]]
"""
purpose retrieval 且文件类型为 pdf, url, docx 时上传切片规则默认为 `\n`
"""
knowledge_id: str
"""
当文件上传目的为 retrieval 需要指定知识库ID进行上传
"""
sentence_size: int
"""
当文件上传目的为 retrieval 需要指定知识库ID进行上传
"""

View File

@ -0,0 +1,13 @@
from typing import Literal
from ...core import BaseModel
__all__ = ["FileDeleted"]
class FileDeleted(BaseModel):
id: str
deleted: bool
object: Literal["file"]

View File

@ -1,8 +1,8 @@
from typing import Optional
from pydantic import BaseModel
from ...core import BaseModel
__all__ = ["FileObject"]
__all__ = ["FileObject", "ListOfFileObject"]
class FileObject(BaseModel):

View File

@ -0,0 +1,13 @@
from typing import Optional
from ...core import BaseModel
class UploadDetail(BaseModel):
url: str
knowledge_type: int
file_name: Optional[str] = None
sentence_size: Optional[int] = None
custom_separator: Optional[list[str]] = None
callback_url: Optional[str] = None
callback_header: Optional[dict[str, str]] = None

View File

@ -1,6 +1,6 @@
from typing import Optional, Union
from pydantic import BaseModel
from ...core import BaseModel
__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob"]

View File

@ -1,6 +1,6 @@
from typing import Optional, Union
from pydantic import BaseModel
from ...core import BaseModel
__all__ = ["FineTuningJobEvent", "Metric", "JobEvent"]

View File

@ -0,0 +1 @@
from .fine_tuned_models import FineTunedModelsStatus

View File

@ -0,0 +1,13 @@
from typing import ClassVar
from ....core import PYDANTIC_V2, BaseModel, ConfigDict
__all__ = ["FineTunedModelsStatus"]
class FineTunedModelsStatus(BaseModel):
if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow", protected_namespaces=())
request_id: str # 请求id
model_name: str # 模型名称
delete_status: str # 删除状态 deleting删除中, deleted (已删除)

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Optional
from pydantic import BaseModel
from ..core import BaseModel
__all__ = ["GeneratedImage", "ImagesResponded"]

View File

@ -0,0 +1,8 @@
from .knowledge import KnowledgeInfo
from .knowledge_used import KnowledgeStatistics, KnowledgeUsed
__all__ = [
"KnowledgeInfo",
"KnowledgeStatistics",
"KnowledgeUsed",
]

View File

@ -0,0 +1,8 @@
from .document import DocumentData, DocumentFailedInfo, DocumentObject, DocumentSuccessinfo
__all__ = [
"DocumentData",
"DocumentObject",
"DocumentSuccessinfo",
"DocumentFailedInfo",
]

View File

@ -0,0 +1,51 @@
from typing import Optional
from ....core import BaseModel
__all__ = ["DocumentData", "DocumentObject", "DocumentSuccessinfo", "DocumentFailedInfo"]
class DocumentSuccessinfo(BaseModel):
documentId: Optional[str] = None
"""文件id"""
filename: Optional[str] = None
"""文件名称"""
class DocumentFailedInfo(BaseModel):
failReason: Optional[str] = None
"""上传失败的原因,包括:文件格式不支持、文件大小超出限制、知识库容量已满、容量上限为 50 万字。"""
filename: Optional[str] = None
"""文件名称"""
documentId: Optional[str] = None
"""知识库id"""
class DocumentObject(BaseModel):
"""文档信息"""
successInfos: Optional[list[DocumentSuccessinfo]] = None
"""上传成功的文件信息"""
failedInfos: Optional[list[DocumentFailedInfo]] = None
"""上传失败的文件信息"""
class DocumentDataFailInfo(BaseModel):
"""失败原因"""
embedding_code: Optional[int] = (
None # 失败码 10001知识不可用知识库空间已达上限 10002知识不可用知识库空间已达上限(字数超出限制)
)
embedding_msg: Optional[str] = None # 失败原因
class DocumentData(BaseModel):
id: str = None # 知识唯一id
custom_separator: list[str] = None # 切片规则
sentence_size: str = None # 切片大小
length: int = None # 文件大小(字节)
word_num: int = None # 文件字数
name: str = None # 文件名
url: str = None # 文件下载链接
embedding_stat: int = None # 0:向量化中 1:向量化完成 2:向量化失败
failInfo: Optional[DocumentDataFailInfo] = None # 失败原因 向量化失败embedding_stat=2的时候 会有此值

View File

@ -0,0 +1,29 @@
from typing import Optional, TypedDict
__all__ = ["DocumentEditParams"]
class DocumentEditParams(TypedDict):
"""
知识参数类型定义
Attributes:
id (str): 知识ID
knowledge_type (int): 知识类型:
1:文章知识: 支持pdf,url,docx
2.问答知识-文档: 支持pdf,url,docx
3.问答知识-表格: 支持xlsx
4.商品库-表格: 支持xlsx
5.自定义: 支持pdf,url,docx
custom_separator (Optional[List[str]]): 当前知识类型为自定义(knowledge_type=5)时的切片规则默认\n
sentence_size (Optional[int]): 当前知识类型为自定义(knowledge_type=5)时的切片字数取值范围: 20-2000默认300
callback_url (Optional[str]): 回调地址
callback_header (Optional[dict]): 回调时携带的header
"""
id: str
knowledge_type: int
custom_separator: Optional[list[str]]
sentence_size: Optional[int]
callback_url: Optional[str]
callback_header: Optional[dict[str, str]]

View File

@ -0,0 +1,26 @@
from __future__ import annotations
from typing import Optional
from typing_extensions import TypedDict
class DocumentListParams(TypedDict, total=False):
"""
文件查询参数类型定义
Attributes:
purpose (Optional[str]): 文件用途
knowledge_id (Optional[str]): 当文件用途为 retrieval 需要提供查询的知识库ID
page (Optional[int]): 默认1
limit (Optional[int]): 查询文件列表数默认10
after (Optional[str]): 查询指定fileID之后的文件列表当文件用途为 fine-tune 时需要
order (Optional[str]): 排序规则可选值['desc', 'asc']默认desc当文件用途为 fine-tune 时需要
"""
purpose: Optional[str]
knowledge_id: Optional[str]
page: Optional[int]
limit: Optional[int]
after: Optional[str]
order: Optional[str]

View File

@ -0,0 +1,11 @@
from __future__ import annotations
from ....core import BaseModel
from . import DocumentData
__all__ = ["DocumentPage"]
class DocumentPage(BaseModel):
list: list[DocumentData]
object: str

View File

@ -0,0 +1,21 @@
from typing import Optional
from ...core import BaseModel
__all__ = ["KnowledgeInfo"]
class KnowledgeInfo(BaseModel):
id: Optional[str] = None
"""知识库唯一 id"""
embedding_id: Optional[str] = (
None # 知识库绑定的向量化模型 见模型列表 [内部服务开放接口文档](https://lslfd0slxc.feishu.cn/docx/YauWdbBiMopV0FxB7KncPWCEn8f#H15NduiQZo3ugmxnWQFcfAHpnQ4)
)
name: Optional[str] = None # 知识库名称 100字限制
customer_identifier: Optional[str] = None # 用户标识 长度32位以内
description: Optional[str] = None # 知识库描述 500字限制
background: Optional[str] = None # 背景颜色(给枚举)'blue', 'red', 'orange', 'purple', 'sky'
icon: Optional[str] = (
None # 知识库图标(给枚举) question: 问号、book: 书籍、seal: 印章、wrench: 扳手、tag: 标签、horn: 喇叭、house: 房子 # noqa: E501
)
bucket_id: Optional[str] = None # 桶id 限制32位

View File

@ -0,0 +1,30 @@
from __future__ import annotations
from typing import Literal, Optional
from typing_extensions import TypedDict
__all__ = ["KnowledgeBaseParams"]
class KnowledgeBaseParams(TypedDict):
"""
知识库参数类型定义
Attributes:
embedding_id (int): 知识库绑定的向量化模型ID
name (str): 知识库名称限制100字
customer_identifier (Optional[str]): 用户标识长度32位以内
description (Optional[str]): 知识库描述限制500字
background (Optional[Literal['blue', 'red', 'orange', 'purple', 'sky']]): 背景颜色
icon (Optional[Literal['question', 'book', 'seal', 'wrench', 'tag', 'horn', 'house']]): 知识库图标
bucket_id (Optional[str]): 桶ID限制32位
"""
embedding_id: int
name: str
customer_identifier: Optional[str]
description: Optional[str]
background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None
icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None
bucket_id: Optional[str]

View File

@ -0,0 +1,15 @@
from __future__ import annotations
from typing_extensions import TypedDict
__all__ = ["KnowledgeListParams"]
class KnowledgeListParams(TypedDict, total=False):
page: int = 1
""" 页码,默认 1第一页
"""
size: int = 10
"""每页数量 默认10
"""

View File

@ -0,0 +1,11 @@
from __future__ import annotations
from ...core import BaseModel
from . import KnowledgeInfo
__all__ = ["KnowledgePage"]
class KnowledgePage(BaseModel):
list: list[KnowledgeInfo]
object: str

View File

@ -0,0 +1,21 @@
from typing import Optional
from ...core import BaseModel
__all__ = ["KnowledgeStatistics", "KnowledgeUsed"]
class KnowledgeStatistics(BaseModel):
"""
使用量统计
"""
word_num: Optional[int] = None
length: Optional[int] = None
class KnowledgeUsed(BaseModel):
used: Optional[KnowledgeStatistics] = None
"""已使用量"""
total: Optional[KnowledgeStatistics] = None
"""知识库总量"""

View File

@ -0,0 +1,3 @@
from .sensitive_word_check import SensitiveWordCheckRequest
__all__ = ["SensitiveWordCheckRequest"]

View File

@ -0,0 +1,14 @@
from typing import Optional
from typing_extensions import TypedDict
class SensitiveWordCheckRequest(TypedDict, total=False):
type: Optional[str]
"""敏感词类型当前仅支持ALL"""
status: Optional[str]
"""敏感词启用禁用状态
启用ENABLE
禁用DISABLE
备注默认开启敏感词校验如果要关闭敏感词校验需联系商务获取对应权限否则敏感词禁用不生效
"""

View File

@ -0,0 +1,9 @@
from .web_search import (
SearchIntent,
SearchRecommend,
SearchResult,
WebSearch,
)
from .web_search_chunk import WebSearchChunk
__all__ = ["WebSearch", "SearchIntent", "SearchResult", "SearchRecommend", "WebSearchChunk"]

Some files were not shown because too many files have changed in this diff Show More