mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-17 00:35:56 +08:00
Merge branch 'main' into feat/attachments
This commit is contained in:
commit
8fe5028f74
@ -20,6 +20,9 @@ FILES_URL=http://127.0.0.1:5001
|
|||||||
# The time in seconds after the signature is rejected
|
# The time in seconds after the signature is rejected
|
||||||
FILES_ACCESS_TIMEOUT=300
|
FILES_ACCESS_TIMEOUT=300
|
||||||
|
|
||||||
|
# Access token expiration time in minutes
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||||
|
|
||||||
# celery configuration
|
# celery configuration
|
||||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||||
|
|
||||||
@ -39,7 +42,7 @@ DB_DATABASE=dify
|
|||||||
|
|
||||||
# Storage configuration
|
# Storage configuration
|
||||||
# use for store upload files, private keys...
|
# use for store upload files, private keys...
|
||||||
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos, baidu-obs
|
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos, baidu-obs, supabase
|
||||||
STORAGE_TYPE=local
|
STORAGE_TYPE=local
|
||||||
STORAGE_LOCAL_PATH=storage
|
STORAGE_LOCAL_PATH=storage
|
||||||
S3_USE_AWS_MANAGED_IAM=false
|
S3_USE_AWS_MANAGED_IAM=false
|
||||||
@ -99,11 +102,16 @@ VOLCENGINE_TOS_ACCESS_KEY=your-access-key
|
|||||||
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
|
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
|
||||||
VOLCENGINE_TOS_REGION=your-region
|
VOLCENGINE_TOS_REGION=your-region
|
||||||
|
|
||||||
|
# Supabase Storage Configuration
|
||||||
|
SUPABASE_BUCKET_NAME=your-bucket-name
|
||||||
|
SUPABASE_API_KEY=your-access-key
|
||||||
|
SUPABASE_URL=your-server-url
|
||||||
|
|
||||||
# CORS configuration
|
# CORS configuration
|
||||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||||
|
|
||||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
|
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb
|
||||||
VECTOR_STORE=weaviate
|
VECTOR_STORE=weaviate
|
||||||
|
|
||||||
# Weaviate configuration
|
# Weaviate configuration
|
||||||
@ -203,6 +211,24 @@ OPENSEARCH_USER=admin
|
|||||||
OPENSEARCH_PASSWORD=admin
|
OPENSEARCH_PASSWORD=admin
|
||||||
OPENSEARCH_SECURE=true
|
OPENSEARCH_SECURE=true
|
||||||
|
|
||||||
|
# Baidu configuration
|
||||||
|
BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
|
||||||
|
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000
|
||||||
|
BAIDU_VECTOR_DB_ACCOUNT=root
|
||||||
|
BAIDU_VECTOR_DB_API_KEY=dify
|
||||||
|
BAIDU_VECTOR_DB_DATABASE=dify
|
||||||
|
BAIDU_VECTOR_DB_SHARD=1
|
||||||
|
BAIDU_VECTOR_DB_REPLICAS=3
|
||||||
|
|
||||||
|
# ViKingDB configuration
|
||||||
|
VIKINGDB_ACCESS_KEY=your-ak
|
||||||
|
VIKINGDB_SECRET_KEY=your-sk
|
||||||
|
VIKINGDB_REGION=cn-shanghai
|
||||||
|
VIKINGDB_HOST=api-vikingdb.xxx.volces.com
|
||||||
|
VIKINGDB_SCHEMA=http
|
||||||
|
VIKINGDB_CONNECTION_TIMEOUT=30
|
||||||
|
VIKINGDB_SOCKET_TIMEOUT=30
|
||||||
|
|
||||||
# Upload configuration
|
# Upload configuration
|
||||||
UPLOAD_FILE_SIZE_LIMIT=15
|
UPLOAD_FILE_SIZE_LIMIT=15
|
||||||
UPLOAD_FILE_BATCH_LIMIT=5
|
UPLOAD_FILE_BATCH_LIMIT=5
|
||||||
|
@ -183,7 +183,7 @@ def load_user_from_request(request_from_flask_login):
|
|||||||
decoded = PassportService().verify(auth_token)
|
decoded = PassportService().verify(auth_token)
|
||||||
user_id = decoded.get("user_id")
|
user_id = decoded.get("user_id")
|
||||||
|
|
||||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
|
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||||
if logged_in_account:
|
if logged_in_account:
|
||||||
contexts.tenant_id.set(logged_in_account.current_tenant_id)
|
contexts.tenant_id.set(logged_in_account.current_tenant_id)
|
||||||
return logged_in_account
|
return logged_in_account
|
||||||
|
@ -347,6 +347,14 @@ def migrate_knowledge_vector_database():
|
|||||||
index_name = Dataset.gen_collection_name_by_id(dataset_id)
|
index_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
|
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
|
elif vector_type == VectorType.BAIDU:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
index_struct_dict = {
|
||||||
|
"type": VectorType.BAIDU,
|
||||||
|
"vector_store": {"class_prefix": collection_name},
|
||||||
|
}
|
||||||
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||||
|
|
||||||
|
@ -360,9 +360,9 @@ class WorkflowConfig(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OAuthConfig(BaseSettings):
|
class AuthConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Configuration for OAuth authentication
|
Configuration for authentication and OAuth
|
||||||
"""
|
"""
|
||||||
|
|
||||||
OAUTH_REDIRECT_PATH: str = Field(
|
OAUTH_REDIRECT_PATH: str = Field(
|
||||||
@ -371,7 +371,7 @@ class OAuthConfig(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
GITHUB_CLIENT_ID: Optional[str] = Field(
|
GITHUB_CLIENT_ID: Optional[str] = Field(
|
||||||
description="GitHub OAuth client secret",
|
description="GitHub OAuth client ID",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -390,6 +390,11 @@ class OAuthConfig(BaseSettings):
|
|||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field(
|
||||||
|
description="Expiration time for access tokens in minutes",
|
||||||
|
default=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModerationConfig(BaseSettings):
|
class ModerationConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
@ -607,6 +612,7 @@ class PositionConfig(BaseSettings):
|
|||||||
class FeatureConfig(
|
class FeatureConfig(
|
||||||
# place the configs in alphabet order
|
# place the configs in alphabet order
|
||||||
AppExecutionConfig,
|
AppExecutionConfig,
|
||||||
|
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||||
BillingConfig,
|
BillingConfig,
|
||||||
CodeExecutionSandboxConfig,
|
CodeExecutionSandboxConfig,
|
||||||
DataSetConfig,
|
DataSetConfig,
|
||||||
@ -621,14 +627,13 @@ class FeatureConfig(
|
|||||||
MailConfig,
|
MailConfig,
|
||||||
ModelLoadBalanceConfig,
|
ModelLoadBalanceConfig,
|
||||||
ModerationConfig,
|
ModerationConfig,
|
||||||
OAuthConfig,
|
PositionConfig,
|
||||||
RagEtlConfig,
|
RagEtlConfig,
|
||||||
SecurityConfig,
|
SecurityConfig,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
UpdateConfig,
|
UpdateConfig,
|
||||||
WorkflowConfig,
|
WorkflowConfig,
|
||||||
WorkspaceConfig,
|
WorkspaceConfig,
|
||||||
PositionConfig,
|
|
||||||
# hosted services config
|
# hosted services config
|
||||||
HostedServiceConfig,
|
HostedServiceConfig,
|
||||||
CeleryBeatConfig,
|
CeleryBeatConfig,
|
||||||
|
@ -12,6 +12,7 @@ from configs.middleware.storage.baidu_obs_storage_config import BaiduOBSStorageC
|
|||||||
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||||
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||||
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||||
|
from configs.middleware.storage.supabase_storage_config import SupabaseStorageConfig
|
||||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||||
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||||
@ -27,6 +28,7 @@ from configs.middleware.vdb.qdrant_config import QdrantConfig
|
|||||||
from configs.middleware.vdb.relyt_config import RelytConfig
|
from configs.middleware.vdb.relyt_config import RelytConfig
|
||||||
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
|
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
|
||||||
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
||||||
|
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
|
||||||
from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
||||||
|
|
||||||
|
|
||||||
@ -191,6 +193,22 @@ class CeleryConfig(DatabaseConfig):
|
|||||||
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False
|
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False
|
||||||
|
|
||||||
|
|
||||||
|
class InternalTestConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration settings for Internal Test
|
||||||
|
"""
|
||||||
|
|
||||||
|
AWS_SECRET_ACCESS_KEY: Optional[str] = Field(
|
||||||
|
description="Internal test AWS secret access key",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
AWS_ACCESS_KEY_ID: Optional[str] = Field(
|
||||||
|
description="Internal test AWS access key ID",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MiddlewareConfig(
|
class MiddlewareConfig(
|
||||||
# place the configs in alphabet order
|
# place the configs in alphabet order
|
||||||
CeleryConfig,
|
CeleryConfig,
|
||||||
@ -206,6 +224,7 @@ class MiddlewareConfig(
|
|||||||
HuaweiCloudOBSStorageConfig,
|
HuaweiCloudOBSStorageConfig,
|
||||||
OCIStorageConfig,
|
OCIStorageConfig,
|
||||||
S3StorageConfig,
|
S3StorageConfig,
|
||||||
|
SupabaseStorageConfig,
|
||||||
TencentCloudCOSStorageConfig,
|
TencentCloudCOSStorageConfig,
|
||||||
VolcengineTOSStorageConfig,
|
VolcengineTOSStorageConfig,
|
||||||
# configs of vdb and vdb providers
|
# configs of vdb and vdb providers
|
||||||
@ -224,5 +243,7 @@ class MiddlewareConfig(
|
|||||||
TiDBVectorConfig,
|
TiDBVectorConfig,
|
||||||
WeaviateConfig,
|
WeaviateConfig,
|
||||||
ElasticsearchConfig,
|
ElasticsearchConfig,
|
||||||
|
InternalTestConfig,
|
||||||
|
VikingDBConfig,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
24
api/configs/middleware/storage/supabase_storage_config.py
Normal file
24
api/configs/middleware/storage/supabase_storage_config.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SupabaseStorageConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration settings for Supabase Object Storage Service
|
||||||
|
"""
|
||||||
|
|
||||||
|
SUPABASE_BUCKET_NAME: Optional[str] = Field(
|
||||||
|
description="Name of the Supabase bucket to store and retrieve objects (e.g., 'dify-bucket')",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
SUPABASE_API_KEY: Optional[str] = Field(
|
||||||
|
description="API KEY for authenticating with Supabase",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
SUPABASE_URL: Optional[str] = Field(
|
||||||
|
description="URL of the Supabase",
|
||||||
|
default=None,
|
||||||
|
)
|
45
api/configs/middleware/vdb/baidu_vector_config.py
Normal file
45
api/configs/middleware/vdb/baidu_vector_config.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class BaiduVectorDBConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration settings for Baidu Vector Database
|
||||||
|
"""
|
||||||
|
|
||||||
|
BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field(
|
||||||
|
description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field(
|
||||||
|
description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)",
|
||||||
|
default=30000,
|
||||||
|
)
|
||||||
|
|
||||||
|
BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field(
|
||||||
|
description="Account for authenticating with the Baidu Vector Database",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field(
|
||||||
|
description="API key for authenticating with the Baidu Vector Database service",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field(
|
||||||
|
description="Name of the specific Baidu Vector Database to connect to",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
BAIDU_VECTOR_DB_SHARD: PositiveInt = Field(
|
||||||
|
description="Number of shards for the Baidu Vector Database (default is 1)",
|
||||||
|
default=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
|
||||||
|
description="Number of replicas for the Baidu Vector Database (default is 3)",
|
||||||
|
default=3,
|
||||||
|
)
|
37
api/configs/middleware/vdb/vikingdb_config.py
Normal file
37
api/configs/middleware/vdb/vikingdb_config.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class VikingDBConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration for connecting to Volcengine VikingDB.
|
||||||
|
Refer to the following documentation for details on obtaining credentials:
|
||||||
|
https://www.volcengine.com/docs/6291/65568
|
||||||
|
"""
|
||||||
|
|
||||||
|
VIKINGDB_ACCESS_KEY: Optional[str] = Field(
|
||||||
|
default=None, description="The Access Key provided by Volcengine VikingDB for API authentication."
|
||||||
|
)
|
||||||
|
VIKINGDB_SECRET_KEY: Optional[str] = Field(
|
||||||
|
default=None, description="The Secret Key provided by Volcengine VikingDB for API authentication."
|
||||||
|
)
|
||||||
|
VIKINGDB_REGION: Optional[str] = Field(
|
||||||
|
default="cn-shanghai",
|
||||||
|
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
|
||||||
|
)
|
||||||
|
VIKINGDB_HOST: Optional[str] = Field(
|
||||||
|
default="api-vikingdb.mlp.cn-shanghai.volces.com",
|
||||||
|
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
|
||||||
|
'api-vikingdb.mlp.cn-shanghai.volces.com')",
|
||||||
|
)
|
||||||
|
VIKINGDB_SCHEME: Optional[str] = Field(
|
||||||
|
default="http",
|
||||||
|
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
|
||||||
|
)
|
||||||
|
VIKINGDB_CONNECTION_TIMEOUT: Optional[int] = Field(
|
||||||
|
default=30, description="The connection timeout of the Volcengine VikingDB service."
|
||||||
|
)
|
||||||
|
VIKINGDB_SOCKET_TIMEOUT: Optional[int] = Field(
|
||||||
|
default=30, description="The socket timeout of the Volcengine VikingDB service."
|
||||||
|
)
|
@ -7,7 +7,7 @@ from flask_restful import Resource, reqparse
|
|||||||
import services
|
import services
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from libs.helper import email, get_remote_ip
|
from libs.helper import email, extract_remote_ip
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from services.account_service import AccountService, TenantService
|
from services.account_service import AccountService, TenantService
|
||||||
@ -40,17 +40,16 @@ class LoginApi(Resource):
|
|||||||
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
|
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
|
||||||
}
|
}
|
||||||
|
|
||||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||||
|
|
||||||
return {"result": "success", "data": token}
|
return {"result": "success", "data": token_pair.model_dump()}
|
||||||
|
|
||||||
|
|
||||||
class LogoutApi(Resource):
|
class LogoutApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
def get(self):
|
def get(self):
|
||||||
account = cast(Account, flask_login.current_user)
|
account = cast(Account, flask_login.current_user)
|
||||||
token = request.headers.get("Authorization", "").split(" ")[1]
|
AccountService.logout(account=account)
|
||||||
AccountService.logout(account=account, token=token)
|
|
||||||
flask_login.logout_user()
|
flask_login.logout_user()
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@ -106,5 +105,19 @@ class ResetPasswordApi(Resource):
|
|||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenApi(Resource):
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("refresh_token", type=str, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
new_token_pair = AccountService.refresh_token(args["refresh_token"])
|
||||||
|
return {"result": "success", "data": new_token_pair.model_dump()}
|
||||||
|
except Exception as e:
|
||||||
|
return {"result": "fail", "data": str(e)}, 401
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(LoginApi, "/login")
|
api.add_resource(LoginApi, "/login")
|
||||||
api.add_resource(LogoutApi, "/logout")
|
api.add_resource(LogoutApi, "/logout")
|
||||||
|
api.add_resource(RefreshTokenApi, "/refresh-token")
|
||||||
|
@ -9,7 +9,7 @@ from flask_restful import Resource
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import get_remote_ip
|
from libs.helper import extract_remote_ip
|
||||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||||
from models.account import Account, AccountStatus
|
from models.account import Account, AccountStatus
|
||||||
from services.account_service import AccountService, RegisterService, TenantService
|
from services.account_service import AccountService, RegisterService, TenantService
|
||||||
@ -81,9 +81,14 @@ class OAuthCallback(Resource):
|
|||||||
|
|
||||||
TenantService.create_owner_tenant_if_not_exist(account)
|
TenantService.create_owner_tenant_if_not_exist(account)
|
||||||
|
|
||||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
token_pair = AccountService.login(
|
||||||
|
account=account,
|
||||||
|
ip_address=extract_remote_ip(request),
|
||||||
|
)
|
||||||
|
|
||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
|
return redirect(
|
||||||
|
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||||
|
@ -617,6 +617,8 @@ class DatasetRetrievalSettingApi(Resource):
|
|||||||
| VectorType.CHROMA
|
| VectorType.CHROMA
|
||||||
| VectorType.TENCENT
|
| VectorType.TENCENT
|
||||||
| VectorType.PGVECTO_RS
|
| VectorType.PGVECTO_RS
|
||||||
|
| VectorType.BAIDU
|
||||||
|
| VectorType.VIKINGDB
|
||||||
):
|
):
|
||||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||||
case (
|
case (
|
||||||
@ -653,6 +655,8 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||||||
| VectorType.CHROMA
|
| VectorType.CHROMA
|
||||||
| VectorType.TENCENT
|
| VectorType.TENCENT
|
||||||
| VectorType.PGVECTO_RS
|
| VectorType.PGVECTO_RS
|
||||||
|
| VectorType.BAIDU
|
||||||
|
| VectorType.VIKINGDB
|
||||||
):
|
):
|
||||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||||
case (
|
case (
|
||||||
|
@ -13,6 +13,7 @@ from libs.login import login_required
|
|||||||
from services.dataset_service import DatasetService
|
from services.dataset_service import DatasetService
|
||||||
from services.external_knowledge_service import ExternalDatasetService
|
from services.external_knowledge_service import ExternalDatasetService
|
||||||
from services.hit_testing_service import HitTestingService
|
from services.hit_testing_service import HitTestingService
|
||||||
|
from services.knowledge_service import ExternalDatasetTestService
|
||||||
|
|
||||||
|
|
||||||
def _validate_name(name):
|
def _validate_name(name):
|
||||||
@ -232,8 +233,31 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
|||||||
raise InternalServerError(str(e))
|
raise InternalServerError(str(e))
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockRetrievalApi(Resource):
|
||||||
|
# this api is only for internal testing
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
|
||||||
|
parser.add_argument(
|
||||||
|
"query",
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Call the knowledge retrieval service
|
||||||
|
result = ExternalDatasetTestService.knowledge_retrieval(
|
||||||
|
args["retrieval_setting"], args["query"], args["knowledge_id"]
|
||||||
|
)
|
||||||
|
return result, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets/<uuid:dataset_id>/external-hit-testing")
|
api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets/<uuid:dataset_id>/external-hit-testing")
|
||||||
api.add_resource(ExternalDatasetCreateApi, "/datasets/external")
|
api.add_resource(ExternalDatasetCreateApi, "/datasets/external")
|
||||||
api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api")
|
api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api")
|
||||||
api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
|
api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
|
||||||
api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
|
api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
|
||||||
|
# this api is only for internal test
|
||||||
|
api.add_resource(BedrockRetrievalApi, "/test/retrieval")
|
||||||
|
@ -4,7 +4,7 @@ from flask import request
|
|||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from libs.helper import StrLen, email, get_remote_ip
|
from libs.helper import StrLen, email, extract_remote_ip
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models.model import DifySetup
|
from models.model import DifySetup
|
||||||
from services.account_service import RegisterService, TenantService
|
from services.account_service import RegisterService, TenantService
|
||||||
@ -46,7 +46,7 @@ class SetupApi(Resource):
|
|||||||
|
|
||||||
# setup
|
# setup
|
||||||
RegisterService.setup(
|
RegisterService.setup(
|
||||||
email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request)
|
email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"result": "success"}, 201
|
return {"result": "success"}, 201
|
||||||
|
@ -126,13 +126,12 @@ class ModelProviderIconApi(Resource):
|
|||||||
Get model provider icon
|
Get model provider icon
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
def get(self, provider: str, icon_type: str, lang: str):
|
def get(self, provider: str, icon_type: str, lang: str):
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
icon, mimetype = model_provider_service.get_model_provider_icon(
|
icon, mimetype = model_provider_service.get_model_provider_icon(
|
||||||
provider=provider, icon_type=icon_type, lang=lang
|
provider=provider,
|
||||||
|
icon_type=icon_type,
|
||||||
|
lang=lang,
|
||||||
)
|
)
|
||||||
|
|
||||||
return send_file(io.BytesIO(icon), mimetype=mimetype)
|
return send_file(io.BytesIO(icon), mimetype=mimetype)
|
||||||
|
@ -56,6 +56,7 @@ from models.account import Account
|
|||||||
from models.model import Conversation, EndUser, Message
|
from models.model import Conversation, EndUser, Message
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
Workflow,
|
Workflow,
|
||||||
|
WorkflowNodeExecution,
|
||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -72,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
_workflow: Workflow
|
_workflow: Workflow
|
||||||
_user: Union[Account, EndUser]
|
_user: Union[Account, EndUser]
|
||||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||||
|
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -115,6 +117,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
}
|
}
|
||||||
|
|
||||||
self._task_state = WorkflowTaskState()
|
self._task_state = WorkflowTaskState()
|
||||||
|
self._wip_workflow_node_executions = {}
|
||||||
|
|
||||||
self._conversation_name_generate_thread = None
|
self._conversation_name_generate_thread = None
|
||||||
|
|
||||||
|
@ -52,6 +52,7 @@ from models.workflow import (
|
|||||||
Workflow,
|
Workflow,
|
||||||
WorkflowAppLog,
|
WorkflowAppLog,
|
||||||
WorkflowAppLogCreatedFrom,
|
WorkflowAppLogCreatedFrom,
|
||||||
|
WorkflowNodeExecution,
|
||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
)
|
)
|
||||||
@ -69,6 +70,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
_task_state: WorkflowTaskState
|
_task_state: WorkflowTaskState
|
||||||
_application_generate_entity: WorkflowAppGenerateEntity
|
_application_generate_entity: WorkflowAppGenerateEntity
|
||||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||||
|
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -103,6 +105,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
}
|
}
|
||||||
|
|
||||||
self._task_state = WorkflowTaskState()
|
self._task_state = WorkflowTaskState()
|
||||||
|
self._wip_workflow_node_executions = {}
|
||||||
|
|
||||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||||
"""
|
"""
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
|
import logging
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.app.entities.app_invoke_entities import (
|
from core.app.entities.app_invoke_entities import (
|
||||||
AdvancedChatAppGenerateEntity,
|
AdvancedChatAppGenerateEntity,
|
||||||
AgentChatAppGenerateEntity,
|
AgentChatAppGenerateEntity,
|
||||||
@ -83,7 +85,9 @@ class MessageCycleManage:
|
|||||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
|
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
|
||||||
conversation.name = name
|
conversation.name = name
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"generate conversation name failed: {e}")
|
if dify_config.DEBUG:
|
||||||
|
logging.exception(f"generate conversation name failed: {e}")
|
||||||
|
pass
|
||||||
|
|
||||||
db.session.merge(conversation)
|
db.session.merge(conversation)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
@ -57,6 +57,7 @@ class WorkflowCycleManage:
|
|||||||
_user: Union[Account, EndUser]
|
_user: Union[Account, EndUser]
|
||||||
_task_state: WorkflowTaskState
|
_task_state: WorkflowTaskState
|
||||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||||
|
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||||
|
|
||||||
def _handle_workflow_run_start(self) -> WorkflowRun:
|
def _handle_workflow_run_start(self) -> WorkflowRun:
|
||||||
max_sequence = (
|
max_sequence = (
|
||||||
@ -251,6 +252,8 @@ class WorkflowCycleManage:
|
|||||||
db.session.refresh(workflow_node_execution)
|
db.session.refresh(workflow_node_execution)
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
|
self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution
|
||||||
|
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||||
@ -263,20 +266,36 @@ class WorkflowCycleManage:
|
|||||||
|
|
||||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||||
|
execution_metadata = (
|
||||||
|
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||||
|
)
|
||||||
|
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||||
|
|
||||||
|
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||||
|
{
|
||||||
|
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||||
|
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||||
|
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
|
||||||
|
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
|
||||||
|
WorkflowNodeExecution.execution_metadata: execution_metadata,
|
||||||
|
WorkflowNodeExecution.finished_at: finished_at,
|
||||||
|
WorkflowNodeExecution.elapsed_time: elapsed_time,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||||
workflow_node_execution.execution_metadata = (
|
workflow_node_execution.execution_metadata = execution_metadata
|
||||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
workflow_node_execution.finished_at = finished_at
|
||||||
)
|
workflow_node_execution.elapsed_time = elapsed_time
|
||||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
||||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
|
||||||
|
|
||||||
db.session.commit()
|
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
|
||||||
db.session.refresh(workflow_node_execution)
|
|
||||||
db.session.close()
|
|
||||||
|
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
@ -290,18 +309,33 @@ class WorkflowCycleManage:
|
|||||||
|
|
||||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||||
|
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||||
|
|
||||||
|
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||||
|
{
|
||||||
|
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
|
||||||
|
WorkflowNodeExecution.error: event.error,
|
||||||
|
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||||
|
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
|
||||||
|
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
|
||||||
|
WorkflowNodeExecution.finished_at: finished_at,
|
||||||
|
WorkflowNodeExecution.elapsed_time: elapsed_time,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||||
workflow_node_execution.error = event.error
|
workflow_node_execution.error = event.error
|
||||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
||||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
workflow_node_execution.finished_at = finished_at
|
||||||
|
workflow_node_execution.elapsed_time = elapsed_time
|
||||||
|
|
||||||
db.session.commit()
|
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
|
||||||
db.session.refresh(workflow_node_execution)
|
|
||||||
db.session.close()
|
|
||||||
|
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
@ -678,17 +712,7 @@ class WorkflowCycleManage:
|
|||||||
:param node_execution_id: workflow node execution id
|
:param node_execution_id: workflow node execution id
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
workflow_node_execution = (
|
workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id)
|
||||||
db.session.query(WorkflowNodeExecution)
|
|
||||||
.filter(
|
|
||||||
WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id,
|
|
||||||
WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id,
|
|
||||||
WorkflowNodeExecution.workflow_id == self._workflow.id,
|
|
||||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
|
||||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not workflow_node_execution:
|
if not workflow_node_execution:
|
||||||
raise Exception(f"Workflow node execution not found: {node_execution_id}")
|
raise Exception(f"Workflow node execution not found: {node_execution_id}")
|
||||||
|
@ -5,6 +5,7 @@ from typing import Optional, cast
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.embedding.embedding_constant import EmbeddingInputType
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||||
@ -110,6 +111,8 @@ class CacheEmbedding(Embeddings):
|
|||||||
embedding_results = embedding_result.embeddings[0]
|
embedding_results = embedding_result.embeddings[0]
|
||||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
|
if dify_config.DEBUG:
|
||||||
|
logging.exception(f"Failed to embed query text: {ex}")
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -122,6 +125,8 @@ class CacheEmbedding(Embeddings):
|
|||||||
encoded_str = encoded_vector.decode("utf-8")
|
encoded_str = encoded_vector.decode("utf-8")
|
||||||
redis_client.setex(embedding_cache_key, 600, encoded_str)
|
redis_client.setex(embedding_cache_key, 600, encoded_str)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.exception("Failed to add embedding to redis %s", ex)
|
if dify_config.DEBUG:
|
||||||
|
logging.exception("Failed to add embedding to redis %s", ex)
|
||||||
|
raise ex
|
||||||
|
|
||||||
return embedding_results
|
return embedding_results
|
||||||
|
@ -60,8 +60,8 @@ class TokenBufferMemory:
|
|||||||
thread_messages = extract_thread_messages(messages)
|
thread_messages = extract_thread_messages(messages)
|
||||||
|
|
||||||
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||||
if thread_messages and not thread_messages[-1].answer:
|
if thread_messages and not thread_messages[0].answer:
|
||||||
thread_messages.pop()
|
thread_messages.pop(0)
|
||||||
|
|
||||||
messages = list(reversed(thread_messages))
|
messages = list(reversed(thread_messages))
|
||||||
|
|
||||||
|
@ -1,8 +1,18 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||||
|
from core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
FetchFrom,
|
||||||
|
ModelFeature,
|
||||||
|
ModelPropertyKey,
|
||||||
|
ModelType,
|
||||||
|
ParameterRule,
|
||||||
|
ParameterType,
|
||||||
|
)
|
||||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||||
|
|
||||||
|
|
||||||
@ -29,3 +39,53 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||||||
def _add_custom_parameters(cls, credentials: dict) -> None:
|
def _add_custom_parameters(cls, credentials: dict) -> None:
|
||||||
credentials["mode"] = "chat"
|
credentials["mode"] = "chat"
|
||||||
credentials["endpoint_url"] = "https://api.siliconflow.cn/v1"
|
credentials["endpoint_url"] = "https://api.siliconflow.cn/v1"
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
return AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model, zh_Hans=model),
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
||||||
|
if credentials.get("function_calling_type") == "tool_call"
|
||||||
|
else [],
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)),
|
||||||
|
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||||
|
},
|
||||||
|
parameter_rules=[
|
||||||
|
ParameterRule(
|
||||||
|
name="temperature",
|
||||||
|
use_template="temperature",
|
||||||
|
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="max_tokens",
|
||||||
|
use_template="max_tokens",
|
||||||
|
default=512,
|
||||||
|
min=1,
|
||||||
|
max=int(credentials.get("max_tokens", 1024)),
|
||||||
|
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
|
||||||
|
type=ParameterType.INT,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="top_p",
|
||||||
|
use_template="top_p",
|
||||||
|
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="top_k",
|
||||||
|
use_template="top_k",
|
||||||
|
label=I18nObject(en_US="Top K", zh_Hans="Top K"),
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="frequency_penalty",
|
||||||
|
use_template="frequency_penalty",
|
||||||
|
label=I18nObject(en_US="Frequency Penalty", zh_Hans="重复惩罚"),
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -20,6 +20,7 @@ supported_model_types:
|
|||||||
- speech2text
|
- speech2text
|
||||||
configurate_methods:
|
configurate_methods:
|
||||||
- predefined-model
|
- predefined-model
|
||||||
|
- customizable-model
|
||||||
provider_credential_schema:
|
provider_credential_schema:
|
||||||
credential_form_schemas:
|
credential_form_schemas:
|
||||||
- variable: api_key
|
- variable: api_key
|
||||||
@ -30,3 +31,57 @@ provider_credential_schema:
|
|||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入您的 API Key
|
zh_Hans: 在此输入您的 API Key
|
||||||
en_US: Enter your API Key
|
en_US: Enter your API Key
|
||||||
|
model_credential_schema:
|
||||||
|
model:
|
||||||
|
label:
|
||||||
|
en_US: Model Name
|
||||||
|
zh_Hans: 模型名称
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your model name
|
||||||
|
zh_Hans: 输入模型名称
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: api_key
|
||||||
|
label:
|
||||||
|
en_US: API Key
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 API Key
|
||||||
|
en_US: Enter your API Key
|
||||||
|
- variable: context_size
|
||||||
|
label:
|
||||||
|
zh_Hans: 模型上下文长度
|
||||||
|
en_US: Model context size
|
||||||
|
required: true
|
||||||
|
type: text-input
|
||||||
|
default: '4096'
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的模型上下文长度
|
||||||
|
en_US: Enter your Model context size
|
||||||
|
- variable: max_tokens
|
||||||
|
label:
|
||||||
|
zh_Hans: 最大 token 上限
|
||||||
|
en_US: Upper bound for max tokens
|
||||||
|
default: '4096'
|
||||||
|
type: text-input
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
- variable: function_calling_type
|
||||||
|
label:
|
||||||
|
en_US: Function calling
|
||||||
|
type: select
|
||||||
|
required: false
|
||||||
|
default: no_call
|
||||||
|
options:
|
||||||
|
- value: no_call
|
||||||
|
label:
|
||||||
|
en_US: Not Support
|
||||||
|
zh_Hans: 不支持
|
||||||
|
- value: function_call
|
||||||
|
label:
|
||||||
|
en_US: Support
|
||||||
|
zh_Hans: 支持
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
- gte-rerank
|
@ -0,0 +1,4 @@
|
|||||||
|
model: gte-rerank
|
||||||
|
model_type: rerank
|
||||||
|
model_properties:
|
||||||
|
context_size: 4000
|
136
api/core/model_runtime/model_providers/tongyi/rerank/rerank.py
Normal file
136
api/core/model_runtime/model_providers/tongyi/rerank/rerank.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import dashscope
|
||||||
|
from dashscope.common.error import (
|
||||||
|
AuthenticationError,
|
||||||
|
InvalidParameter,
|
||||||
|
RequestFailure,
|
||||||
|
ServiceUnavailableError,
|
||||||
|
UnsupportedHTTPMethod,
|
||||||
|
UnsupportedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||||
|
|
||||||
|
|
||||||
|
class GTERerankModel(RerankModel):
|
||||||
|
"""
|
||||||
|
Model class for GTE rerank model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
query: str,
|
||||||
|
docs: list[str],
|
||||||
|
score_threshold: Optional[float] = None,
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> RerankResult:
|
||||||
|
"""
|
||||||
|
Invoke rerank model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param query: search query
|
||||||
|
:param docs: docs for reranking
|
||||||
|
:param score_threshold: score threshold
|
||||||
|
:param top_n: top n
|
||||||
|
:param user: unique user id
|
||||||
|
:return: rerank result
|
||||||
|
"""
|
||||||
|
if len(docs) == 0:
|
||||||
|
return RerankResult(model=model, docs=docs)
|
||||||
|
|
||||||
|
# initialize client
|
||||||
|
dashscope.api_key = credentials["dashscope_api_key"]
|
||||||
|
|
||||||
|
response = dashscope.TextReRank.call(
|
||||||
|
query=query,
|
||||||
|
documents=docs,
|
||||||
|
model=model,
|
||||||
|
top_n=top_n,
|
||||||
|
return_documents=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
rerank_documents = []
|
||||||
|
for _, result in enumerate(response.output.results):
|
||||||
|
# format document
|
||||||
|
rerank_document = RerankDocument(
|
||||||
|
index=result.index,
|
||||||
|
score=result.relevance_score,
|
||||||
|
text=result["document"]["text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# score threshold check
|
||||||
|
if score_threshold is not None:
|
||||||
|
if result.relevance_score >= score_threshold:
|
||||||
|
rerank_documents.append(rerank_document)
|
||||||
|
else:
|
||||||
|
rerank_documents.append(rerank_document)
|
||||||
|
|
||||||
|
return RerankResult(model=model, docs=rerank_documents)
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
query="What is the capital of the United States?",
|
||||||
|
docs=[
|
||||||
|
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||||
|
"Census, Carson City had a population of 55,274.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||||
|
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||||
|
],
|
||||||
|
score_threshold=0.8,
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
print(ex)
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [
|
||||||
|
RequestFailure,
|
||||||
|
],
|
||||||
|
InvokeServerUnavailableError: [
|
||||||
|
ServiceUnavailableError,
|
||||||
|
],
|
||||||
|
InvokeRateLimitError: [],
|
||||||
|
InvokeAuthorizationError: [
|
||||||
|
AuthenticationError,
|
||||||
|
],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
InvalidParameter,
|
||||||
|
UnsupportedModel,
|
||||||
|
UnsupportedHTTPMethod,
|
||||||
|
],
|
||||||
|
}
|
@ -18,6 +18,7 @@ supported_model_types:
|
|||||||
- llm
|
- llm
|
||||||
- tts
|
- tts
|
||||||
- text-embedding
|
- text-embedding
|
||||||
|
- rerank
|
||||||
configurate_methods:
|
configurate_methods:
|
||||||
- predefined-model
|
- predefined-model
|
||||||
- customizable-model
|
- customizable-model
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from zhipuai import ZhipuAI
|
||||||
|
from zhipuai.types.chat.chat_completion import Completion
|
||||||
|
from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
@ -16,9 +20,6 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
|
||||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion
|
|
||||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
||||||
from core.model_runtime.utils import helper
|
from core.model_runtime.utils import helper
|
||||||
|
|
||||||
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
|
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from zhipuai import ZhipuAI
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.embedding.embedding_constant import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
|
||||||
|
|
||||||
|
|
||||||
class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||||
|
@ -1,15 +0,0 @@
|
|||||||
from .__version__ import __version__
|
|
||||||
from ._client import ZhipuAI
|
|
||||||
from .core import (
|
|
||||||
APIAuthenticationError,
|
|
||||||
APIConnectionError,
|
|
||||||
APIInternalError,
|
|
||||||
APIReachLimitError,
|
|
||||||
APIRequestFailedError,
|
|
||||||
APIResponseError,
|
|
||||||
APIResponseValidationError,
|
|
||||||
APIServerFlowExceedError,
|
|
||||||
APIStatusError,
|
|
||||||
APITimeoutError,
|
|
||||||
ZhipuAIError,
|
|
||||||
)
|
|
@ -1 +0,0 @@
|
|||||||
__version__ = "v2.1.0"
|
|
@ -1,82 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from httpx import Timeout
|
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from . import api_resource
|
|
||||||
from .core import NOT_GIVEN, ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient, NotGiven, ZhipuAIError, _jwt_token
|
|
||||||
|
|
||||||
|
|
||||||
class ZhipuAI(HttpClient):
|
|
||||||
chat: api_resource.chat.Chat
|
|
||||||
api_key: str
|
|
||||||
_disable_token_cache: bool = True
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
api_key: str | None = None,
|
|
||||||
base_url: str | httpx.URL | None = None,
|
|
||||||
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
|
|
||||||
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
|
|
||||||
http_client: httpx.Client | None = None,
|
|
||||||
custom_headers: Mapping[str, str] | None = None,
|
|
||||||
disable_token_cache: bool = True,
|
|
||||||
_strict_response_validation: bool = False,
|
|
||||||
) -> None:
|
|
||||||
if api_key is None:
|
|
||||||
api_key = os.environ.get("ZHIPUAI_API_KEY")
|
|
||||||
if api_key is None:
|
|
||||||
raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供")
|
|
||||||
self.api_key = api_key
|
|
||||||
self._disable_token_cache = disable_token_cache
|
|
||||||
|
|
||||||
if base_url is None:
|
|
||||||
base_url = os.environ.get("ZHIPUAI_BASE_URL")
|
|
||||||
if base_url is None:
|
|
||||||
base_url = "https://open.bigmodel.cn/api/paas/v4"
|
|
||||||
from .__version__ import __version__
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
version=__version__,
|
|
||||||
base_url=base_url,
|
|
||||||
max_retries=max_retries,
|
|
||||||
timeout=timeout,
|
|
||||||
custom_httpx_client=http_client,
|
|
||||||
custom_headers=custom_headers,
|
|
||||||
_strict_response_validation=_strict_response_validation,
|
|
||||||
)
|
|
||||||
self.chat = api_resource.chat.Chat(self)
|
|
||||||
self.images = api_resource.images.Images(self)
|
|
||||||
self.embeddings = api_resource.embeddings.Embeddings(self)
|
|
||||||
self.files = api_resource.files.Files(self)
|
|
||||||
self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
|
|
||||||
self.batches = api_resource.Batches(self)
|
|
||||||
self.knowledge = api_resource.Knowledge(self)
|
|
||||||
self.tools = api_resource.Tools(self)
|
|
||||||
self.videos = api_resource.Videos(self)
|
|
||||||
self.assistant = api_resource.Assistant(self)
|
|
||||||
|
|
||||||
@property
|
|
||||||
@override
|
|
||||||
def auth_headers(self) -> dict[str, str]:
|
|
||||||
api_key = self.api_key
|
|
||||||
if self._disable_token_cache:
|
|
||||||
return {"Authorization": f"Bearer {api_key}"}
|
|
||||||
else:
|
|
||||||
return {"Authorization": f"Bearer {_jwt_token.generate_token(api_key)}"}
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
|
||||||
if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close") or not hasattr(self, "_client"):
|
|
||||||
# if the '__init__' method raised an error, self would not have client attr
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._has_custom_http_client:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.close()
|
|
@ -1,34 +0,0 @@
|
|||||||
from .assistant import (
|
|
||||||
Assistant,
|
|
||||||
)
|
|
||||||
from .batches import Batches
|
|
||||||
from .chat import (
|
|
||||||
AsyncCompletions,
|
|
||||||
Chat,
|
|
||||||
Completions,
|
|
||||||
)
|
|
||||||
from .embeddings import Embeddings
|
|
||||||
from .files import Files, FilesWithRawResponse
|
|
||||||
from .fine_tuning import FineTuning
|
|
||||||
from .images import Images
|
|
||||||
from .knowledge import Knowledge
|
|
||||||
from .tools import Tools
|
|
||||||
from .videos import (
|
|
||||||
Videos,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Videos",
|
|
||||||
"AsyncCompletions",
|
|
||||||
"Chat",
|
|
||||||
"Completions",
|
|
||||||
"Images",
|
|
||||||
"Embeddings",
|
|
||||||
"Files",
|
|
||||||
"FilesWithRawResponse",
|
|
||||||
"FineTuning",
|
|
||||||
"Batches",
|
|
||||||
"Knowledge",
|
|
||||||
"Tools",
|
|
||||||
"Assistant",
|
|
||||||
]
|
|
@ -1,3 +0,0 @@
|
|||||||
from .assistant import Assistant
|
|
||||||
|
|
||||||
__all__ = ["Assistant"]
|
|
@ -1,122 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from ...core import (
|
|
||||||
NOT_GIVEN,
|
|
||||||
BaseAPI,
|
|
||||||
Body,
|
|
||||||
Headers,
|
|
||||||
NotGiven,
|
|
||||||
StreamResponse,
|
|
||||||
deepcopy_minimal,
|
|
||||||
make_request_options,
|
|
||||||
maybe_transform,
|
|
||||||
)
|
|
||||||
from ...types.assistant import AssistantCompletion
|
|
||||||
from ...types.assistant.assistant_conversation_resp import ConversationUsageListResp
|
|
||||||
from ...types.assistant.assistant_support_resp import AssistantSupportResp
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ..._client import ZhipuAI
|
|
||||||
|
|
||||||
from ...types.assistant import assistant_conversation_params, assistant_create_params
|
|
||||||
|
|
||||||
__all__ = ["Assistant"]
|
|
||||||
|
|
||||||
|
|
||||||
class Assistant(BaseAPI):
|
|
||||||
def __init__(self, client: ZhipuAI) -> None:
|
|
||||||
super().__init__(client)
|
|
||||||
|
|
||||||
def conversation(
|
|
||||||
self,
|
|
||||||
assistant_id: str,
|
|
||||||
model: str,
|
|
||||||
messages: list[assistant_create_params.ConversationMessage],
|
|
||||||
*,
|
|
||||||
stream: bool = True,
|
|
||||||
conversation_id: Optional[str] = None,
|
|
||||||
attachments: Optional[list[assistant_create_params.AssistantAttachments]] = None,
|
|
||||||
metadata: dict | None = None,
|
|
||||||
request_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> StreamResponse[AssistantCompletion]:
|
|
||||||
body = deepcopy_minimal(
|
|
||||||
{
|
|
||||||
"assistant_id": assistant_id,
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"stream": stream,
|
|
||||||
"conversation_id": conversation_id,
|
|
||||||
"attachments": attachments,
|
|
||||||
"metadata": metadata,
|
|
||||||
"request_id": request_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return self._post(
|
|
||||||
"/assistant",
|
|
||||||
body=maybe_transform(body, assistant_create_params.AssistantParameters),
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=AssistantCompletion,
|
|
||||||
stream=stream or True,
|
|
||||||
stream_cls=StreamResponse[AssistantCompletion],
|
|
||||||
)
|
|
||||||
|
|
||||||
def query_support(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
assistant_id_list: Optional[list[str]] = None,
|
|
||||||
request_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> AssistantSupportResp:
|
|
||||||
body = deepcopy_minimal(
|
|
||||||
{
|
|
||||||
"assistant_id_list": assistant_id_list,
|
|
||||||
"request_id": request_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return self._post(
|
|
||||||
"/assistant/list",
|
|
||||||
body=body,
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=AssistantSupportResp,
|
|
||||||
)
|
|
||||||
|
|
||||||
def query_conversation_usage(
|
|
||||||
self,
|
|
||||||
assistant_id: str,
|
|
||||||
page: int = 1,
|
|
||||||
page_size: int = 10,
|
|
||||||
*,
|
|
||||||
request_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> ConversationUsageListResp:
|
|
||||||
body = deepcopy_minimal(
|
|
||||||
{
|
|
||||||
"assistant_id": assistant_id,
|
|
||||||
"page": page,
|
|
||||||
"page_size": page_size,
|
|
||||||
"request_id": request_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return self._post(
|
|
||||||
"/assistant/conversation/list",
|
|
||||||
body=maybe_transform(body, assistant_conversation_params.ConversationParameters),
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=ConversationUsageListResp,
|
|
||||||
)
|
|
@ -1,146 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Literal, Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options, maybe_transform
|
|
||||||
from ..core.pagination import SyncCursorPage
|
|
||||||
from ..types import batch_create_params, batch_list_params
|
|
||||||
from ..types.batch import Batch
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .._client import ZhipuAI
|
|
||||||
|
|
||||||
|
|
||||||
class Batches(BaseAPI):
|
|
||||||
def __init__(self, client: ZhipuAI) -> None:
|
|
||||||
super().__init__(client)
|
|
||||||
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
completion_window: str | None = None,
|
|
||||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings"],
|
|
||||||
input_file_id: str,
|
|
||||||
metadata: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
|
|
||||||
auto_delete_input_file: bool = True,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> Batch:
|
|
||||||
return self._post(
|
|
||||||
"/batches",
|
|
||||||
body=maybe_transform(
|
|
||||||
{
|
|
||||||
"completion_window": completion_window,
|
|
||||||
"endpoint": endpoint,
|
|
||||||
"input_file_id": input_file_id,
|
|
||||||
"metadata": metadata,
|
|
||||||
"auto_delete_input_file": auto_delete_input_file,
|
|
||||||
},
|
|
||||||
batch_create_params.BatchCreateParams,
|
|
||||||
),
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=Batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
def retrieve(
|
|
||||||
self,
|
|
||||||
batch_id: str,
|
|
||||||
*,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> Batch:
|
|
||||||
"""
|
|
||||||
Retrieves a batch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
extra_headers: Send extra headers
|
|
||||||
|
|
||||||
extra_body: Add additional JSON properties to the request
|
|
||||||
|
|
||||||
timeout: Override the client-level default timeout for this request, in seconds
|
|
||||||
"""
|
|
||||||
if not batch_id:
|
|
||||||
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
|
|
||||||
return self._get(
|
|
||||||
f"/batches/{batch_id}",
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=Batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
def list(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
after: str | NotGiven = NOT_GIVEN,
|
|
||||||
limit: int | NotGiven = NOT_GIVEN,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> SyncCursorPage[Batch]:
|
|
||||||
"""List your organization's batches.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
after: A cursor for use in pagination.
|
|
||||||
|
|
||||||
`after` is an object ID that defines your place
|
|
||||||
in the list. For instance, if you make a list request and receive 100 objects,
|
|
||||||
ending with obj_foo, your subsequent call can include after=obj_foo in order to
|
|
||||||
fetch the next page of the list.
|
|
||||||
|
|
||||||
limit: A limit on the number of objects to be returned. Limit can range between 1 and
|
|
||||||
100, and the default is 20.
|
|
||||||
|
|
||||||
extra_headers: Send extra headers
|
|
||||||
|
|
||||||
extra_body: Add additional JSON properties to the request
|
|
||||||
|
|
||||||
timeout: Override the client-level default timeout for this request, in seconds
|
|
||||||
"""
|
|
||||||
return self._get_api_list(
|
|
||||||
"/batches",
|
|
||||||
page=SyncCursorPage[Batch],
|
|
||||||
options=make_request_options(
|
|
||||||
extra_headers=extra_headers,
|
|
||||||
extra_body=extra_body,
|
|
||||||
timeout=timeout,
|
|
||||||
query=maybe_transform(
|
|
||||||
{
|
|
||||||
"after": after,
|
|
||||||
"limit": limit,
|
|
||||||
},
|
|
||||||
batch_list_params.BatchListParams,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
model=Batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
def cancel(
|
|
||||||
self,
|
|
||||||
batch_id: str,
|
|
||||||
*,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> Batch:
|
|
||||||
"""
|
|
||||||
Cancels an in-progress batch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch_id: The ID of the batch to cancel.
|
|
||||||
extra_headers: Send extra headers
|
|
||||||
|
|
||||||
extra_body: Add additional JSON properties to the request
|
|
||||||
|
|
||||||
timeout: Override the client-level default timeout for this request, in seconds
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not batch_id:
|
|
||||||
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
|
|
||||||
return self._post(
|
|
||||||
f"/batches/{batch_id}/cancel",
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=Batch,
|
|
||||||
)
|
|
@ -1,5 +0,0 @@
|
|||||||
from .async_completions import AsyncCompletions
|
|
||||||
from .chat import Chat
|
|
||||||
from .completions import Completions
|
|
||||||
|
|
||||||
__all__ = ["AsyncCompletions", "Chat", "Completions"]
|
|
@ -1,115 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from ...core import (
|
|
||||||
NOT_GIVEN,
|
|
||||||
BaseAPI,
|
|
||||||
Body,
|
|
||||||
Headers,
|
|
||||||
NotGiven,
|
|
||||||
drop_prefix_image_data,
|
|
||||||
make_request_options,
|
|
||||||
maybe_transform,
|
|
||||||
)
|
|
||||||
from ...types.chat.async_chat_completion import AsyncCompletion, AsyncTaskStatus
|
|
||||||
from ...types.chat.code_geex import code_geex_params
|
|
||||||
from ...types.sensitive_word_check import SensitiveWordCheckRequest
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ..._client import ZhipuAI
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncCompletions(BaseAPI):
|
|
||||||
def __init__(self, client: ZhipuAI) -> None:
|
|
||||||
super().__init__(client)
|
|
||||||
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
model: str,
|
|
||||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
user_id: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
|
||||||
temperature: Optional[float] | NotGiven = NOT_GIVEN,
|
|
||||||
top_p: Optional[float] | NotGiven = NOT_GIVEN,
|
|
||||||
max_tokens: int | NotGiven = NOT_GIVEN,
|
|
||||||
seed: int | NotGiven = NOT_GIVEN,
|
|
||||||
messages: Union[str, list[str], list[int], list[list[int]], None],
|
|
||||||
stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
|
|
||||||
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
|
||||||
tools: Optional[object] | NotGiven = NOT_GIVEN,
|
|
||||||
tool_choice: str | NotGiven = NOT_GIVEN,
|
|
||||||
meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
|
|
||||||
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> AsyncTaskStatus:
|
|
||||||
_cast_type = AsyncTaskStatus
|
|
||||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
|
||||||
if temperature is not None and temperature != NOT_GIVEN:
|
|
||||||
if temperature <= 0:
|
|
||||||
do_sample = False
|
|
||||||
temperature = 0.01
|
|
||||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperature不生效)") # noqa: E501
|
|
||||||
if temperature >= 1:
|
|
||||||
temperature = 0.99
|
|
||||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
|
|
||||||
if top_p is not None and top_p != NOT_GIVEN:
|
|
||||||
if top_p >= 1:
|
|
||||||
top_p = 0.99
|
|
||||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
|
||||||
if top_p <= 0:
|
|
||||||
top_p = 0.01
|
|
||||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
|
||||||
|
|
||||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
|
||||||
if isinstance(messages, list):
|
|
||||||
for item in messages:
|
|
||||||
if item.get("content"):
|
|
||||||
item["content"] = drop_prefix_image_data(item["content"])
|
|
||||||
|
|
||||||
body = {
|
|
||||||
"model": model,
|
|
||||||
"request_id": request_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"temperature": temperature,
|
|
||||||
"top_p": top_p,
|
|
||||||
"do_sample": do_sample,
|
|
||||||
"max_tokens": max_tokens,
|
|
||||||
"seed": seed,
|
|
||||||
"messages": messages,
|
|
||||||
"stop": stop,
|
|
||||||
"sensitive_word_check": sensitive_word_check,
|
|
||||||
"tools": tools,
|
|
||||||
"tool_choice": tool_choice,
|
|
||||||
"meta": meta,
|
|
||||||
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
|
|
||||||
}
|
|
||||||
return self._post(
|
|
||||||
"/async/chat/completions",
|
|
||||||
body=body,
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=_cast_type,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def retrieve_completion_result(
|
|
||||||
self,
|
|
||||||
id: str,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> Union[AsyncCompletion, AsyncTaskStatus]:
|
|
||||||
_cast_type = Union[AsyncCompletion, AsyncTaskStatus]
|
|
||||||
return self._get(
|
|
||||||
path=f"/async-result/{id}",
|
|
||||||
cast_type=_cast_type,
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
)
|
|
@ -1,18 +0,0 @@
|
|||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from ...core import BaseAPI, cached_property
|
|
||||||
from .async_completions import AsyncCompletions
|
|
||||||
from .completions import Completions
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Chat(BaseAPI):
|
|
||||||
@cached_property
|
|
||||||
def completions(self) -> Completions:
|
|
||||||
return Completions(self._client)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def asyncCompletions(self) -> AsyncCompletions: # noqa: N802
|
|
||||||
return AsyncCompletions(self._client)
|
|
@ -1,108 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from ...core import (
|
|
||||||
NOT_GIVEN,
|
|
||||||
BaseAPI,
|
|
||||||
Body,
|
|
||||||
Headers,
|
|
||||||
NotGiven,
|
|
||||||
StreamResponse,
|
|
||||||
deepcopy_minimal,
|
|
||||||
drop_prefix_image_data,
|
|
||||||
make_request_options,
|
|
||||||
maybe_transform,
|
|
||||||
)
|
|
||||||
from ...types.chat.chat_completion import Completion
|
|
||||||
from ...types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
||||||
from ...types.chat.code_geex import code_geex_params
|
|
||||||
from ...types.sensitive_word_check import SensitiveWordCheckRequest
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ..._client import ZhipuAI
|
|
||||||
|
|
||||||
|
|
||||||
class Completions(BaseAPI):
|
|
||||||
def __init__(self, client: ZhipuAI) -> None:
|
|
||||||
super().__init__(client)
|
|
||||||
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
model: str,
|
|
||||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
user_id: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
|
||||||
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
|
||||||
temperature: Optional[float] | NotGiven = NOT_GIVEN,
|
|
||||||
top_p: Optional[float] | NotGiven = NOT_GIVEN,
|
|
||||||
max_tokens: int | NotGiven = NOT_GIVEN,
|
|
||||||
seed: int | NotGiven = NOT_GIVEN,
|
|
||||||
messages: Union[str, list[str], list[int], object, None],
|
|
||||||
stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
|
|
||||||
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
|
||||||
tools: Optional[object] | NotGiven = NOT_GIVEN,
|
|
||||||
tool_choice: str | NotGiven = NOT_GIVEN,
|
|
||||||
meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
|
|
||||||
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> Completion | StreamResponse[ChatCompletionChunk]:
|
|
||||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
|
||||||
if temperature is not None and temperature != NOT_GIVEN:
|
|
||||||
if temperature <= 0:
|
|
||||||
do_sample = False
|
|
||||||
temperature = 0.01
|
|
||||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperature不生效)") # noqa: E501
|
|
||||||
if temperature >= 1:
|
|
||||||
temperature = 0.99
|
|
||||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
|
|
||||||
if top_p is not None and top_p != NOT_GIVEN:
|
|
||||||
if top_p >= 1:
|
|
||||||
top_p = 0.99
|
|
||||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
|
||||||
if top_p <= 0:
|
|
||||||
top_p = 0.01
|
|
||||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
|
||||||
|
|
||||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
|
||||||
if isinstance(messages, list):
|
|
||||||
for item in messages:
|
|
||||||
if item.get("content"):
|
|
||||||
item["content"] = drop_prefix_image_data(item["content"])
|
|
||||||
|
|
||||||
body = deepcopy_minimal(
|
|
||||||
{
|
|
||||||
"model": model,
|
|
||||||
"request_id": request_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"temperature": temperature,
|
|
||||||
"top_p": top_p,
|
|
||||||
"do_sample": do_sample,
|
|
||||||
"max_tokens": max_tokens,
|
|
||||||
"seed": seed,
|
|
||||||
"messages": messages,
|
|
||||||
"stop": stop,
|
|
||||||
"sensitive_word_check": sensitive_word_check,
|
|
||||||
"stream": stream,
|
|
||||||
"tools": tools,
|
|
||||||
"tool_choice": tool_choice,
|
|
||||||
"meta": meta,
|
|
||||||
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return self._post(
|
|
||||||
"/chat/completions",
|
|
||||||
body=body,
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=Completion,
|
|
||||||
stream=stream or False,
|
|
||||||
stream_cls=StreamResponse[ChatCompletionChunk],
|
|
||||||
)
|
|
@ -1,50 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options
|
|
||||||
from ..types.embeddings import EmbeddingsResponded
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .._client import ZhipuAI
|
|
||||||
|
|
||||||
|
|
||||||
class Embeddings(BaseAPI):
|
|
||||||
def __init__(self, client: ZhipuAI) -> None:
|
|
||||||
super().__init__(client)
|
|
||||||
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
input: Union[str, list[str], list[int], list[list[int]]],
|
|
||||||
model: Union[str],
|
|
||||||
dimensions: Union[int] | NotGiven = NOT_GIVEN,
|
|
||||||
encoding_format: str | NotGiven = NOT_GIVEN,
|
|
||||||
user: str | NotGiven = NOT_GIVEN,
|
|
||||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
disable_strict_validation: Optional[bool] | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> EmbeddingsResponded:
|
|
||||||
_cast_type = EmbeddingsResponded
|
|
||||||
if disable_strict_validation:
|
|
||||||
_cast_type = object
|
|
||||||
return self._post(
|
|
||||||
"/embeddings",
|
|
||||||
body={
|
|
||||||
"input": input,
|
|
||||||
"model": model,
|
|
||||||
"dimensions": dimensions,
|
|
||||||
"encoding_format": encoding_format,
|
|
||||||
"user": user,
|
|
||||||
"request_id": request_id,
|
|
||||||
"sensitive_word_check": sensitive_word_check,
|
|
||||||
},
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=_cast_type,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
@ -1,194 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, cast
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from ..core import (
|
|
||||||
NOT_GIVEN,
|
|
||||||
BaseAPI,
|
|
||||||
Body,
|
|
||||||
FileTypes,
|
|
||||||
Headers,
|
|
||||||
NotGiven,
|
|
||||||
_legacy_binary_response,
|
|
||||||
_legacy_response,
|
|
||||||
deepcopy_minimal,
|
|
||||||
extract_files,
|
|
||||||
make_request_options,
|
|
||||||
maybe_transform,
|
|
||||||
)
|
|
||||||
from ..types.files import FileDeleted, FileObject, ListOfFileObject, UploadDetail, file_create_params
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .._client import ZhipuAI
|
|
||||||
|
|
||||||
__all__ = ["Files", "FilesWithRawResponse"]
|
|
||||||
|
|
||||||
|
|
||||||
class Files(BaseAPI):
|
|
||||||
def __init__(self, client: ZhipuAI) -> None:
|
|
||||||
super().__init__(client)
|
|
||||||
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
file: Optional[FileTypes] = None,
|
|
||||||
upload_detail: Optional[list[UploadDetail]] = None,
|
|
||||||
purpose: Literal["fine-tune", "retrieval", "batch"],
|
|
||||||
knowledge_id: Optional[str] = None,
|
|
||||||
sentence_size: Optional[int] = None,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> FileObject:
|
|
||||||
if not file and not upload_detail:
|
|
||||||
raise ValueError("At least one of `file` and `upload_detail` must be provided.")
|
|
||||||
body = deepcopy_minimal(
|
|
||||||
{
|
|
||||||
"file": file,
|
|
||||||
"upload_detail": upload_detail,
|
|
||||||
"purpose": purpose,
|
|
||||||
"knowledge_id": knowledge_id,
|
|
||||||
"sentence_size": sentence_size,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
|
|
||||||
if files:
|
|
||||||
# It should be noted that the actual Content-Type header that will be
|
|
||||||
# sent to the server will contain a `boundary` parameter, e.g.
|
|
||||||
# multipart/form-data; boundary=---abc--
|
|
||||||
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
|
|
||||||
return self._post(
|
|
||||||
"/files",
|
|
||||||
body=maybe_transform(body, file_create_params.FileCreateParams),
|
|
||||||
files=files,
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=FileObject,
|
|
||||||
)
|
|
||||||
|
|
||||||
# def retrieve(
|
|
||||||
# self,
|
|
||||||
# file_id: str,
|
|
||||||
# *,
|
|
||||||
# extra_headers: Headers | None = None,
|
|
||||||
# extra_body: Body | None = None,
|
|
||||||
# timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
# ) -> FileObject:
|
|
||||||
# """
|
|
||||||
# Returns information about a specific file.
|
|
||||||
#
|
|
||||||
# Args:
|
|
||||||
# file_id: The ID of the file to retrieve information about
|
|
||||||
# extra_headers: Send extra headers
|
|
||||||
#
|
|
||||||
# extra_body: Add additional JSON properties to the request
|
|
||||||
#
|
|
||||||
# timeout: Override the client-level default timeout for this request, in seconds
|
|
||||||
# """
|
|
||||||
# if not file_id:
|
|
||||||
# raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
|
|
||||||
# return self._get(
|
|
||||||
# f"/files/{file_id}",
|
|
||||||
# options=make_request_options(
|
|
||||||
# extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
|
|
||||||
# ),
|
|
||||||
# cast_type=FileObject,
|
|
||||||
# )
|
|
||||||
|
|
||||||
def list(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
purpose: str | NotGiven = NOT_GIVEN,
|
|
||||||
limit: int | NotGiven = NOT_GIVEN,
|
|
||||||
after: str | NotGiven = NOT_GIVEN,
|
|
||||||
order: str | NotGiven = NOT_GIVEN,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> ListOfFileObject:
|
|
||||||
return self._get(
|
|
||||||
"/files",
|
|
||||||
cast_type=ListOfFileObject,
|
|
||||||
options=make_request_options(
|
|
||||||
extra_headers=extra_headers,
|
|
||||||
extra_body=extra_body,
|
|
||||||
timeout=timeout,
|
|
||||||
query={
|
|
||||||
"purpose": purpose,
|
|
||||||
"limit": limit,
|
|
||||||
"after": after,
|
|
||||||
"order": order,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(
|
|
||||||
self,
|
|
||||||
file_id: str,
|
|
||||||
*,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> FileDeleted:
|
|
||||||
"""
|
|
||||||
Delete a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_id: The ID of the file to delete
|
|
||||||
extra_headers: Send extra headers
|
|
||||||
|
|
||||||
extra_body: Add additional JSON properties to the request
|
|
||||||
|
|
||||||
timeout: Override the client-level default timeout for this request, in seconds
|
|
||||||
"""
|
|
||||||
if not file_id:
|
|
||||||
raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
|
|
||||||
return self._delete(
|
|
||||||
f"/files/{file_id}",
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=FileDeleted,
|
|
||||||
)
|
|
||||||
|
|
||||||
def content(
|
|
||||||
self,
|
|
||||||
file_id: str,
|
|
||||||
*,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> _legacy_response.HttpxBinaryResponseContent:
|
|
||||||
"""
|
|
||||||
Returns the contents of the specified file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
extra_headers: Send extra headers
|
|
||||||
|
|
||||||
extra_body: Add additional JSON properties to the request
|
|
||||||
|
|
||||||
timeout: Override the client-level default timeout for this request, in seconds
|
|
||||||
"""
|
|
||||||
if not file_id:
|
|
||||||
raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
|
|
||||||
extra_headers = {"Accept": "application/binary", **(extra_headers or {})}
|
|
||||||
return self._get(
|
|
||||||
f"/files/{file_id}/content",
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=_legacy_binary_response.HttpxBinaryResponseContent,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FilesWithRawResponse:
|
|
||||||
def __init__(self, files: Files) -> None:
|
|
||||||
self._files = files
|
|
||||||
|
|
||||||
self.create = _legacy_response.to_raw_response_wrapper(
|
|
||||||
files.create,
|
|
||||||
)
|
|
||||||
self.list = _legacy_response.to_raw_response_wrapper(
|
|
||||||
files.list,
|
|
||||||
)
|
|
||||||
self.content = _legacy_response.to_raw_response_wrapper(
|
|
||||||
files.content,
|
|
||||||
)
|
|
@ -1,5 +0,0 @@
|
|||||||
from .fine_tuning import FineTuning
|
|
||||||
from .jobs import Jobs
|
|
||||||
from .models import FineTunedModels
|
|
||||||
|
|
||||||
__all__ = ["Jobs", "FineTunedModels", "FineTuning"]
|
|
@ -1,18 +0,0 @@
|
|||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from ...core import BaseAPI, cached_property
|
|
||||||
from .jobs import Jobs
|
|
||||||
from .models import FineTunedModels
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FineTuning(BaseAPI):
|
|
||||||
@cached_property
|
|
||||||
def jobs(self) -> Jobs:
|
|
||||||
return Jobs(self._client)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def models(self) -> FineTunedModels:
|
|
||||||
return FineTunedModels(self._client)
|
|
@ -1,3 +0,0 @@
|
|||||||
from .jobs import Jobs
|
|
||||||
|
|
||||||
__all__ = ["Jobs"]
|
|
@ -1,152 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from ....core import (
|
|
||||||
NOT_GIVEN,
|
|
||||||
BaseAPI,
|
|
||||||
Body,
|
|
||||||
Headers,
|
|
||||||
NotGiven,
|
|
||||||
make_request_options,
|
|
||||||
)
|
|
||||||
from ....types.fine_tuning import (
|
|
||||||
FineTuningJob,
|
|
||||||
FineTuningJobEvent,
|
|
||||||
ListOfFineTuningJob,
|
|
||||||
job_create_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ...._client import ZhipuAI
|
|
||||||
|
|
||||||
__all__ = ["Jobs"]
|
|
||||||
|
|
||||||
|
|
||||||
class Jobs(BaseAPI):
|
|
||||||
def __init__(self, client: ZhipuAI) -> None:
|
|
||||||
super().__init__(client)
|
|
||||||
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
model: str,
|
|
||||||
training_file: str,
|
|
||||||
hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
|
|
||||||
suffix: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
validation_file: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> FineTuningJob:
|
|
||||||
return self._post(
|
|
||||||
"/fine_tuning/jobs",
|
|
||||||
body={
|
|
||||||
"model": model,
|
|
||||||
"training_file": training_file,
|
|
||||||
"hyperparameters": hyperparameters,
|
|
||||||
"suffix": suffix,
|
|
||||||
"validation_file": validation_file,
|
|
||||||
"request_id": request_id,
|
|
||||||
},
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=FineTuningJob,
|
|
||||||
)
|
|
||||||
|
|
||||||
def retrieve(
|
|
||||||
self,
|
|
||||||
fine_tuning_job_id: str,
|
|
||||||
*,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> FineTuningJob:
|
|
||||||
return self._get(
|
|
||||||
f"/fine_tuning/jobs/{fine_tuning_job_id}",
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=FineTuningJob,
|
|
||||||
)
|
|
||||||
|
|
||||||
def list(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
after: str | NotGiven = NOT_GIVEN,
|
|
||||||
limit: int | NotGiven = NOT_GIVEN,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> ListOfFineTuningJob:
|
|
||||||
return self._get(
|
|
||||||
"/fine_tuning/jobs",
|
|
||||||
cast_type=ListOfFineTuningJob,
|
|
||||||
options=make_request_options(
|
|
||||||
extra_headers=extra_headers,
|
|
||||||
extra_body=extra_body,
|
|
||||||
timeout=timeout,
|
|
||||||
query={
|
|
||||||
"after": after,
|
|
||||||
"limit": limit,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def cancel(
|
|
||||||
self,
|
|
||||||
fine_tuning_job_id: str,
|
|
||||||
*,
|
|
||||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # noqa: E501
|
|
||||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> FineTuningJob:
|
|
||||||
if not fine_tuning_job_id:
|
|
||||||
raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}")
|
|
||||||
return self._post(
|
|
||||||
f"/fine_tuning/jobs/{fine_tuning_job_id}/cancel",
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=FineTuningJob,
|
|
||||||
)
|
|
||||||
|
|
||||||
def list_events(
|
|
||||||
self,
|
|
||||||
fine_tuning_job_id: str,
|
|
||||||
*,
|
|
||||||
after: str | NotGiven = NOT_GIVEN,
|
|
||||||
limit: int | NotGiven = NOT_GIVEN,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> FineTuningJobEvent:
|
|
||||||
return self._get(
|
|
||||||
f"/fine_tuning/jobs/{fine_tuning_job_id}/events",
|
|
||||||
cast_type=FineTuningJobEvent,
|
|
||||||
options=make_request_options(
|
|
||||||
extra_headers=extra_headers,
|
|
||||||
extra_body=extra_body,
|
|
||||||
timeout=timeout,
|
|
||||||
query={
|
|
||||||
"after": after,
|
|
||||||
"limit": limit,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(
|
|
||||||
self,
|
|
||||||
fine_tuning_job_id: str,
|
|
||||||
*,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> FineTuningJob:
|
|
||||||
if not fine_tuning_job_id:
|
|
||||||
raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}")
|
|
||||||
return self._delete(
|
|
||||||
f"/fine_tuning/jobs/{fine_tuning_job_id}",
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=FineTuningJob,
|
|
||||||
)
|
|
@ -1,3 +0,0 @@
|
|||||||
from .fine_tuned_models import FineTunedModels
|
|
||||||
|
|
||||||
__all__ = ["FineTunedModels"]
|
|
@ -1,41 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from ....core import (
|
|
||||||
NOT_GIVEN,
|
|
||||||
BaseAPI,
|
|
||||||
Body,
|
|
||||||
Headers,
|
|
||||||
NotGiven,
|
|
||||||
make_request_options,
|
|
||||||
)
|
|
||||||
from ....types.fine_tuning.models import FineTunedModelsStatus
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ...._client import ZhipuAI
|
|
||||||
|
|
||||||
__all__ = ["FineTunedModels"]
|
|
||||||
|
|
||||||
|
|
||||||
class FineTunedModels(BaseAPI):
|
|
||||||
def __init__(self, client: ZhipuAI) -> None:
|
|
||||||
super().__init__(client)
|
|
||||||
|
|
||||||
def delete(
|
|
||||||
self,
|
|
||||||
fine_tuned_model: str,
|
|
||||||
*,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> FineTunedModelsStatus:
|
|
||||||
if not fine_tuned_model:
|
|
||||||
raise ValueError(f"Expected a non-empty value for `fine_tuned_model` but received {fine_tuned_model!r}")
|
|
||||||
return self._delete(
|
|
||||||
f"fine_tuning/fine_tuned_models/{fine_tuned_model}",
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=FineTunedModelsStatus,
|
|
||||||
)
|
|
@ -1,59 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options
|
|
||||||
from ..types.image import ImagesResponded
|
|
||||||
from ..types.sensitive_word_check import SensitiveWordCheckRequest
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .._client import ZhipuAI
|
|
||||||
|
|
||||||
|
|
||||||
class Images(BaseAPI):
|
|
||||||
def __init__(self, client: ZhipuAI) -> None:
|
|
||||||
super().__init__(client)
|
|
||||||
|
|
||||||
def generations(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
prompt: str,
|
|
||||||
model: str | NotGiven = NOT_GIVEN,
|
|
||||||
n: Optional[int] | NotGiven = NOT_GIVEN,
|
|
||||||
quality: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
response_format: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
size: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
style: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
|
||||||
user: str | NotGiven = NOT_GIVEN,
|
|
||||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
user_id: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
disable_strict_validation: Optional[bool] | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> ImagesResponded:
|
|
||||||
_cast_type = ImagesResponded
|
|
||||||
if disable_strict_validation:
|
|
||||||
_cast_type = object
|
|
||||||
return self._post(
|
|
||||||
"/images/generations",
|
|
||||||
body={
|
|
||||||
"prompt": prompt,
|
|
||||||
"model": model,
|
|
||||||
"n": n,
|
|
||||||
"quality": quality,
|
|
||||||
"response_format": response_format,
|
|
||||||
"sensitive_word_check": sensitive_word_check,
|
|
||||||
"size": size,
|
|
||||||
"style": style,
|
|
||||||
"user": user,
|
|
||||||
"user_id": user_id,
|
|
||||||
"request_id": request_id,
|
|
||||||
},
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=_cast_type,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
@ -1,3 +0,0 @@
|
|||||||
from .knowledge import Knowledge
|
|
||||||
|
|
||||||
__all__ = ["Knowledge"]
|
|
@ -1,3 +0,0 @@
|
|||||||
from .document import Document
|
|
||||||
|
|
||||||
__all__ = ["Document"]
|
|
@ -1,217 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, cast
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from ....core import (
|
|
||||||
NOT_GIVEN,
|
|
||||||
BaseAPI,
|
|
||||||
Body,
|
|
||||||
FileTypes,
|
|
||||||
Headers,
|
|
||||||
NotGiven,
|
|
||||||
deepcopy_minimal,
|
|
||||||
extract_files,
|
|
||||||
make_request_options,
|
|
||||||
maybe_transform,
|
|
||||||
)
|
|
||||||
from ....types.files import UploadDetail, file_create_params
|
|
||||||
from ....types.knowledge.document import DocumentData, DocumentObject, document_edit_params, document_list_params
|
|
||||||
from ....types.knowledge.document.document_list_resp import DocumentPage
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ...._client import ZhipuAI
|
|
||||||
|
|
||||||
__all__ = ["Document"]
|
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseAPI):
|
|
||||||
def __init__(self, client: ZhipuAI) -> None:
|
|
||||||
super().__init__(client)
|
|
||||||
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
file: Optional[FileTypes] = None,
|
|
||||||
custom_separator: Optional[list[str]] = None,
|
|
||||||
upload_detail: Optional[list[UploadDetail]] = None,
|
|
||||||
purpose: Literal["retrieval"],
|
|
||||||
knowledge_id: Optional[str] = None,
|
|
||||||
sentence_size: Optional[int] = None,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> DocumentObject:
|
|
||||||
if not file and not upload_detail:
|
|
||||||
raise ValueError("At least one of `file` and `upload_detail` must be provided.")
|
|
||||||
body = deepcopy_minimal(
|
|
||||||
{
|
|
||||||
"file": file,
|
|
||||||
"upload_detail": upload_detail,
|
|
||||||
"purpose": purpose,
|
|
||||||
"custom_separator": custom_separator,
|
|
||||||
"knowledge_id": knowledge_id,
|
|
||||||
"sentence_size": sentence_size,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
|
|
||||||
if files:
|
|
||||||
# It should be noted that the actual Content-Type header that will be
|
|
||||||
# sent to the server will contain a `boundary` parameter, e.g.
|
|
||||||
# multipart/form-data; boundary=---abc--
|
|
||||||
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
|
|
||||||
return self._post(
|
|
||||||
"/files",
|
|
||||||
body=maybe_transform(body, file_create_params.FileCreateParams),
|
|
||||||
files=files,
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=DocumentObject,
|
|
||||||
)
|
|
||||||
|
|
||||||
def edit(
|
|
||||||
self,
|
|
||||||
document_id: str,
|
|
||||||
knowledge_type: str,
|
|
||||||
*,
|
|
||||||
custom_separator: Optional[list[str]] = None,
|
|
||||||
sentence_size: Optional[int] = None,
|
|
||||||
callback_url: Optional[str] = None,
|
|
||||||
callback_header: Optional[dict[str, str]] = None,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> httpx.Response:
|
|
||||||
"""
|
|
||||||
|
|
||||||
Args:
|
|
||||||
document_id: 知识id
|
|
||||||
knowledge_type: 知识类型:
|
|
||||||
1:文章知识: 支持pdf,url,docx
|
|
||||||
2.问答知识-文档: 支持pdf,url,docx
|
|
||||||
3.问答知识-表格: 支持xlsx
|
|
||||||
4.商品库-表格: 支持xlsx
|
|
||||||
5.自定义: 支持pdf,url,docx
|
|
||||||
extra_headers: Send extra headers
|
|
||||||
|
|
||||||
extra_body: Add additional JSON properties to the request
|
|
||||||
|
|
||||||
timeout: Override the client-level default timeout for this request, in seconds
|
|
||||||
:param knowledge_type:
|
|
||||||
:param document_id:
|
|
||||||
:param timeout:
|
|
||||||
:param extra_body:
|
|
||||||
:param callback_header:
|
|
||||||
:param sentence_size:
|
|
||||||
:param extra_headers:
|
|
||||||
:param callback_url:
|
|
||||||
:param custom_separator:
|
|
||||||
"""
|
|
||||||
if not document_id:
|
|
||||||
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
|
|
||||||
|
|
||||||
body = deepcopy_minimal(
|
|
||||||
{
|
|
||||||
"id": document_id,
|
|
||||||
"knowledge_type": knowledge_type,
|
|
||||||
"custom_separator": custom_separator,
|
|
||||||
"sentence_size": sentence_size,
|
|
||||||
"callback_url": callback_url,
|
|
||||||
"callback_header": callback_header,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._put(
|
|
||||||
f"/document/{document_id}",
|
|
||||||
body=maybe_transform(body, document_edit_params.DocumentEditParams),
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=httpx.Response,
|
|
||||||
)
|
|
||||||
|
|
||||||
def list(
|
|
||||||
self,
|
|
||||||
knowledge_id: str,
|
|
||||||
*,
|
|
||||||
purpose: str | NotGiven = NOT_GIVEN,
|
|
||||||
page: str | NotGiven = NOT_GIVEN,
|
|
||||||
limit: str | NotGiven = NOT_GIVEN,
|
|
||||||
order: Literal["desc", "asc"] | NotGiven = NOT_GIVEN,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> DocumentPage:
|
|
||||||
return self._get(
|
|
||||||
"/files",
|
|
||||||
options=make_request_options(
|
|
||||||
extra_headers=extra_headers,
|
|
||||||
extra_body=extra_body,
|
|
||||||
timeout=timeout,
|
|
||||||
query=maybe_transform(
|
|
||||||
{
|
|
||||||
"knowledge_id": knowledge_id,
|
|
||||||
"purpose": purpose,
|
|
||||||
"page": page,
|
|
||||||
"limit": limit,
|
|
||||||
"order": order,
|
|
||||||
},
|
|
||||||
document_list_params.DocumentListParams,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
cast_type=DocumentPage,
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(
|
|
||||||
self,
|
|
||||||
document_id: str,
|
|
||||||
*,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> httpx.Response:
|
|
||||||
"""
|
|
||||||
Delete a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
|
|
||||||
document_id: 知识id
|
|
||||||
extra_headers: Send extra headers
|
|
||||||
|
|
||||||
extra_body: Add additional JSON properties to the request
|
|
||||||
|
|
||||||
timeout: Override the client-level default timeout for this request, in seconds
|
|
||||||
"""
|
|
||||||
if not document_id:
|
|
||||||
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
|
|
||||||
|
|
||||||
return self._delete(
|
|
||||||
f"/document/{document_id}",
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=httpx.Response,
|
|
||||||
)
|
|
||||||
|
|
||||||
def retrieve(
|
|
||||||
self,
|
|
||||||
document_id: str,
|
|
||||||
*,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> DocumentData:
|
|
||||||
"""
|
|
||||||
|
|
||||||
Args:
|
|
||||||
extra_headers: Send extra headers
|
|
||||||
|
|
||||||
extra_body: Add additional JSON properties to the request
|
|
||||||
|
|
||||||
timeout: Override the client-level default timeout for this request, in seconds
|
|
||||||
"""
|
|
||||||
if not document_id:
|
|
||||||
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
|
|
||||||
|
|
||||||
return self._get(
|
|
||||||
f"/document/{document_id}",
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=DocumentData,
|
|
||||||
)
|
|
@ -1,173 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Literal, Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from ...core import (
|
|
||||||
NOT_GIVEN,
|
|
||||||
BaseAPI,
|
|
||||||
Body,
|
|
||||||
Headers,
|
|
||||||
NotGiven,
|
|
||||||
cached_property,
|
|
||||||
deepcopy_minimal,
|
|
||||||
make_request_options,
|
|
||||||
maybe_transform,
|
|
||||||
)
|
|
||||||
from ...types.knowledge import KnowledgeInfo, KnowledgeUsed, knowledge_create_params, knowledge_list_params
|
|
||||||
from ...types.knowledge.knowledge_list_resp import KnowledgePage
|
|
||||||
from .document import Document
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ..._client import ZhipuAI
|
|
||||||
|
|
||||||
__all__ = ["Knowledge"]
|
|
||||||
|
|
||||||
|
|
||||||
class Knowledge(BaseAPI):
|
|
||||||
def __init__(self, client: ZhipuAI) -> None:
|
|
||||||
super().__init__(client)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def document(self) -> Document:
|
|
||||||
return Document(self._client)
|
|
||||||
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
embedding_id: int,
|
|
||||||
name: str,
|
|
||||||
*,
|
|
||||||
customer_identifier: Optional[str] = None,
|
|
||||||
description: Optional[str] = None,
|
|
||||||
background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None,
|
|
||||||
icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None,
|
|
||||||
bucket_id: Optional[str] = None,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> KnowledgeInfo:
|
|
||||||
body = deepcopy_minimal(
|
|
||||||
{
|
|
||||||
"embedding_id": embedding_id,
|
|
||||||
"name": name,
|
|
||||||
"customer_identifier": customer_identifier,
|
|
||||||
"description": description,
|
|
||||||
"background": background,
|
|
||||||
"icon": icon,
|
|
||||||
"bucket_id": bucket_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return self._post(
|
|
||||||
"/knowledge",
|
|
||||||
body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams),
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=KnowledgeInfo,
|
|
||||||
)
|
|
||||||
|
|
||||||
def modify(
|
|
||||||
self,
|
|
||||||
knowledge_id: str,
|
|
||||||
embedding_id: int,
|
|
||||||
*,
|
|
||||||
name: str,
|
|
||||||
description: Optional[str] = None,
|
|
||||||
background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None,
|
|
||||||
icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> httpx.Response:
|
|
||||||
body = deepcopy_minimal(
|
|
||||||
{
|
|
||||||
"id": knowledge_id,
|
|
||||||
"embedding_id": embedding_id,
|
|
||||||
"name": name,
|
|
||||||
"description": description,
|
|
||||||
"background": background,
|
|
||||||
"icon": icon,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return self._put(
|
|
||||||
f"/knowledge/{knowledge_id}",
|
|
||||||
body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams),
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=httpx.Response,
|
|
||||||
)
|
|
||||||
|
|
||||||
def query(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
page: int | NotGiven = 1,
|
|
||||||
size: int | NotGiven = 10,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> KnowledgePage:
|
|
||||||
return self._get(
|
|
||||||
"/knowledge",
|
|
||||||
options=make_request_options(
|
|
||||||
extra_headers=extra_headers,
|
|
||||||
extra_body=extra_body,
|
|
||||||
timeout=timeout,
|
|
||||||
query=maybe_transform(
|
|
||||||
{
|
|
||||||
"page": page,
|
|
||||||
"size": size,
|
|
||||||
},
|
|
||||||
knowledge_list_params.KnowledgeListParams,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
cast_type=KnowledgePage,
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(
|
|
||||||
self,
|
|
||||||
knowledge_id: str,
|
|
||||||
*,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> httpx.Response:
|
|
||||||
"""
|
|
||||||
Delete a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
knowledge_id: 知识库ID
|
|
||||||
extra_headers: Send extra headers
|
|
||||||
|
|
||||||
extra_body: Add additional JSON properties to the request
|
|
||||||
|
|
||||||
timeout: Override the client-level default timeout for this request, in seconds
|
|
||||||
"""
|
|
||||||
if not knowledge_id:
|
|
||||||
raise ValueError("Expected a non-empty value for `knowledge_id`")
|
|
||||||
|
|
||||||
return self._delete(
|
|
||||||
f"/knowledge/{knowledge_id}",
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=httpx.Response,
|
|
||||||
)
|
|
||||||
|
|
||||||
def used(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> KnowledgeUsed:
|
|
||||||
"""
|
|
||||||
Returns the contents of the specified file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
extra_headers: Send extra headers
|
|
||||||
|
|
||||||
extra_body: Add additional JSON properties to the request
|
|
||||||
|
|
||||||
timeout: Override the client-level default timeout for this request, in seconds
|
|
||||||
"""
|
|
||||||
return self._get(
|
|
||||||
"/knowledge/capacity",
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=KnowledgeUsed,
|
|
||||||
)
|
|
@ -1,3 +0,0 @@
|
|||||||
from .tools import Tools
|
|
||||||
|
|
||||||
__all__ = ["Tools"]
|
|
@ -1,65 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from ...core import (
|
|
||||||
NOT_GIVEN,
|
|
||||||
BaseAPI,
|
|
||||||
Body,
|
|
||||||
Headers,
|
|
||||||
NotGiven,
|
|
||||||
StreamResponse,
|
|
||||||
deepcopy_minimal,
|
|
||||||
make_request_options,
|
|
||||||
maybe_transform,
|
|
||||||
)
|
|
||||||
from ...types.tools import WebSearch, WebSearchChunk, tools_web_search_params
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ..._client import ZhipuAI
|
|
||||||
|
|
||||||
__all__ = ["Tools"]
|
|
||||||
|
|
||||||
|
|
||||||
class Tools(BaseAPI):
|
|
||||||
def __init__(self, client: ZhipuAI) -> None:
|
|
||||||
super().__init__(client)
|
|
||||||
|
|
||||||
def web_search(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
model: str,
|
|
||||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
|
||||||
messages: Union[str, list[str], list[int], object, None],
|
|
||||||
scope: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
location: Optional[str] | NotGiven = NOT_GIVEN,
|
|
||||||
recent_days: Optional[int] | NotGiven = NOT_GIVEN,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> WebSearch | StreamResponse[WebSearchChunk]:
|
|
||||||
body = deepcopy_minimal(
|
|
||||||
{
|
|
||||||
"model": model,
|
|
||||||
"request_id": request_id,
|
|
||||||
"messages": messages,
|
|
||||||
"stream": stream,
|
|
||||||
"scope": scope,
|
|
||||||
"location": location,
|
|
||||||
"recent_days": recent_days,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return self._post(
|
|
||||||
"/tools",
|
|
||||||
body=maybe_transform(body, tools_web_search_params.WebSearchParams),
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=WebSearch,
|
|
||||||
stream=stream or False,
|
|
||||||
stream_cls=StreamResponse[WebSearchChunk],
|
|
||||||
)
|
|
@ -1,7 +0,0 @@
|
|||||||
from .videos import (
|
|
||||||
Videos,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Videos",
|
|
||||||
]
|
|
@ -1,77 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from ...core import (
|
|
||||||
NOT_GIVEN,
|
|
||||||
BaseAPI,
|
|
||||||
Body,
|
|
||||||
Headers,
|
|
||||||
NotGiven,
|
|
||||||
deepcopy_minimal,
|
|
||||||
make_request_options,
|
|
||||||
maybe_transform,
|
|
||||||
)
|
|
||||||
from ...types.sensitive_word_check import SensitiveWordCheckRequest
|
|
||||||
from ...types.video import VideoObject, video_create_params
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ..._client import ZhipuAI
|
|
||||||
|
|
||||||
__all__ = ["Videos"]
|
|
||||||
|
|
||||||
|
|
||||||
class Videos(BaseAPI):
|
|
||||||
def __init__(self, client: ZhipuAI) -> None:
|
|
||||||
super().__init__(client)
|
|
||||||
|
|
||||||
def generations(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
*,
|
|
||||||
prompt: Optional[str] = None,
|
|
||||||
image_url: Optional[str] = None,
|
|
||||||
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
|
||||||
request_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> VideoObject:
|
|
||||||
if not model and not model:
|
|
||||||
raise ValueError("At least one of `model` and `prompt` must be provided.")
|
|
||||||
body = deepcopy_minimal(
|
|
||||||
{
|
|
||||||
"model": model,
|
|
||||||
"prompt": prompt,
|
|
||||||
"image_url": image_url,
|
|
||||||
"sensitive_word_check": sensitive_word_check,
|
|
||||||
"request_id": request_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return self._post(
|
|
||||||
"/videos/generations",
|
|
||||||
body=maybe_transform(body, video_create_params.VideoCreateParams),
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=VideoObject,
|
|
||||||
)
|
|
||||||
|
|
||||||
def retrieve_videos_result(
|
|
||||||
self,
|
|
||||||
id: str,
|
|
||||||
*,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
) -> VideoObject:
|
|
||||||
if not id:
|
|
||||||
raise ValueError("At least one of `id` must be provided.")
|
|
||||||
|
|
||||||
return self._get(
|
|
||||||
f"/async-result/{id}",
|
|
||||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
|
||||||
cast_type=VideoObject,
|
|
||||||
)
|
|
@ -1,108 +0,0 @@
|
|||||||
from ._base_api import BaseAPI
|
|
||||||
from ._base_compat import (
|
|
||||||
PYDANTIC_V2,
|
|
||||||
ConfigDict,
|
|
||||||
GenericModel,
|
|
||||||
cached_property,
|
|
||||||
field_get_default,
|
|
||||||
get_args,
|
|
||||||
get_model_config,
|
|
||||||
get_model_fields,
|
|
||||||
get_origin,
|
|
||||||
is_literal_type,
|
|
||||||
is_union,
|
|
||||||
parse_obj,
|
|
||||||
)
|
|
||||||
from ._base_models import BaseModel, construct_type
|
|
||||||
from ._base_type import (
|
|
||||||
NOT_GIVEN,
|
|
||||||
Body,
|
|
||||||
FileTypes,
|
|
||||||
Headers,
|
|
||||||
IncEx,
|
|
||||||
ModelT,
|
|
||||||
NotGiven,
|
|
||||||
Query,
|
|
||||||
)
|
|
||||||
from ._constants import (
|
|
||||||
ZHIPUAI_DEFAULT_LIMITS,
|
|
||||||
ZHIPUAI_DEFAULT_MAX_RETRIES,
|
|
||||||
ZHIPUAI_DEFAULT_TIMEOUT,
|
|
||||||
)
|
|
||||||
from ._errors import (
|
|
||||||
APIAuthenticationError,
|
|
||||||
APIConnectionError,
|
|
||||||
APIInternalError,
|
|
||||||
APIReachLimitError,
|
|
||||||
APIRequestFailedError,
|
|
||||||
APIResponseError,
|
|
||||||
APIResponseValidationError,
|
|
||||||
APIServerFlowExceedError,
|
|
||||||
APIStatusError,
|
|
||||||
APITimeoutError,
|
|
||||||
ZhipuAIError,
|
|
||||||
)
|
|
||||||
from ._files import is_file_content
|
|
||||||
from ._http_client import HttpClient, make_request_options
|
|
||||||
from ._sse_client import StreamResponse
|
|
||||||
from ._utils import (
|
|
||||||
deepcopy_minimal,
|
|
||||||
drop_prefix_image_data,
|
|
||||||
extract_files,
|
|
||||||
is_given,
|
|
||||||
is_list,
|
|
||||||
is_mapping,
|
|
||||||
maybe_transform,
|
|
||||||
parse_date,
|
|
||||||
parse_datetime,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseModel",
|
|
||||||
"construct_type",
|
|
||||||
"BaseAPI",
|
|
||||||
"NOT_GIVEN",
|
|
||||||
"Headers",
|
|
||||||
"NotGiven",
|
|
||||||
"Body",
|
|
||||||
"IncEx",
|
|
||||||
"ModelT",
|
|
||||||
"Query",
|
|
||||||
"FileTypes",
|
|
||||||
"PYDANTIC_V2",
|
|
||||||
"ConfigDict",
|
|
||||||
"GenericModel",
|
|
||||||
"get_args",
|
|
||||||
"is_union",
|
|
||||||
"parse_obj",
|
|
||||||
"get_origin",
|
|
||||||
"is_literal_type",
|
|
||||||
"get_model_config",
|
|
||||||
"get_model_fields",
|
|
||||||
"field_get_default",
|
|
||||||
"is_file_content",
|
|
||||||
"ZhipuAIError",
|
|
||||||
"APIStatusError",
|
|
||||||
"APIRequestFailedError",
|
|
||||||
"APIAuthenticationError",
|
|
||||||
"APIReachLimitError",
|
|
||||||
"APIInternalError",
|
|
||||||
"APIServerFlowExceedError",
|
|
||||||
"APIResponseError",
|
|
||||||
"APIResponseValidationError",
|
|
||||||
"APITimeoutError",
|
|
||||||
"make_request_options",
|
|
||||||
"HttpClient",
|
|
||||||
"ZHIPUAI_DEFAULT_TIMEOUT",
|
|
||||||
"ZHIPUAI_DEFAULT_MAX_RETRIES",
|
|
||||||
"ZHIPUAI_DEFAULT_LIMITS",
|
|
||||||
"is_list",
|
|
||||||
"is_mapping",
|
|
||||||
"parse_date",
|
|
||||||
"parse_datetime",
|
|
||||||
"is_given",
|
|
||||||
"maybe_transform",
|
|
||||||
"deepcopy_minimal",
|
|
||||||
"extract_files",
|
|
||||||
"StreamResponse",
|
|
||||||
]
|
|
@ -1,19 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .._client import ZhipuAI
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAPI:
|
|
||||||
_client: ZhipuAI
|
|
||||||
|
|
||||||
def __init__(self, client: ZhipuAI) -> None:
|
|
||||||
self._client = client
|
|
||||||
self._delete = client.delete
|
|
||||||
self._get = client.get
|
|
||||||
self._post = client.post
|
|
||||||
self._put = client.put
|
|
||||||
self._patch = client.patch
|
|
||||||
self._get_api_list = client.get_api_list
|
|
@ -1,209 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Callable
|
|
||||||
from datetime import date, datetime
|
|
||||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, overload
|
|
||||||
|
|
||||||
import pydantic
|
|
||||||
from pydantic.fields import FieldInfo
|
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from ._base_type import StrBytesIntFloat
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
|
|
||||||
|
|
||||||
# --------------- Pydantic v2 compatibility ---------------
|
|
||||||
|
|
||||||
# Pyright incorrectly reports some of our functions as overriding a method when they don't
|
|
||||||
# pyright: reportIncompatibleMethodOverride=false
|
|
||||||
|
|
||||||
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
|
|
||||||
|
|
||||||
# v1 re-exports
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
|
|
||||||
def parse_date(value: date | StrBytesIntFloat) -> date: ...
|
|
||||||
|
|
||||||
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: ...
|
|
||||||
|
|
||||||
def get_args(t: type[Any]) -> tuple[Any, ...]: ...
|
|
||||||
|
|
||||||
def is_union(tp: type[Any] | None) -> bool: ...
|
|
||||||
|
|
||||||
def get_origin(t: type[Any]) -> type[Any] | None: ...
|
|
||||||
|
|
||||||
def is_literal_type(type_: type[Any]) -> bool: ...
|
|
||||||
|
|
||||||
def is_typeddict(type_: type[Any]) -> bool: ...
|
|
||||||
|
|
||||||
else:
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
from pydantic.v1.typing import ( # noqa: I001
|
|
||||||
get_args as get_args, # noqa: PLC0414
|
|
||||||
is_union as is_union, # noqa: PLC0414
|
|
||||||
get_origin as get_origin, # noqa: PLC0414
|
|
||||||
is_typeddict as is_typeddict, # noqa: PLC0414
|
|
||||||
is_literal_type as is_literal_type, # noqa: PLC0414
|
|
||||||
)
|
|
||||||
from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414
|
|
||||||
else:
|
|
||||||
from pydantic.typing import ( # noqa: I001
|
|
||||||
get_args as get_args, # noqa: PLC0414
|
|
||||||
is_union as is_union, # noqa: PLC0414
|
|
||||||
get_origin as get_origin, # noqa: PLC0414
|
|
||||||
is_typeddict as is_typeddict, # noqa: PLC0414
|
|
||||||
is_literal_type as is_literal_type, # noqa: PLC0414
|
|
||||||
)
|
|
||||||
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414
|
|
||||||
|
|
||||||
|
|
||||||
# refactored config
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from pydantic import ConfigDict
|
|
||||||
else:
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
from pydantic import ConfigDict
|
|
||||||
else:
|
|
||||||
# TODO: provide an error message here?
|
|
||||||
ConfigDict = None
|
|
||||||
|
|
||||||
|
|
||||||
# renamed methods / properties
|
|
||||||
def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
return model.model_validate(value)
|
|
||||||
else:
|
|
||||||
# pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
|
||||||
return cast(_ModelT, model.parse_obj(value))
|
|
||||||
|
|
||||||
|
|
||||||
def field_is_required(field: FieldInfo) -> bool:
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
return field.is_required()
|
|
||||||
return field.required # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def field_get_default(field: FieldInfo) -> Any:
|
|
||||||
value = field.get_default()
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
from pydantic_core import PydanticUndefined
|
|
||||||
|
|
||||||
if value == PydanticUndefined:
|
|
||||||
return None
|
|
||||||
return value
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def field_outer_type(field: FieldInfo) -> Any:
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
return field.annotation
|
|
||||||
return field.outer_type_ # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_config(model: type[pydantic.BaseModel]) -> Any:
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
return model.model_config
|
|
||||||
return model.__config__ # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
return model.model_fields
|
|
||||||
return model.__fields__ # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def model_copy(model: _ModelT) -> _ModelT:
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
return model.model_copy()
|
|
||||||
return model.copy() # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
return model.model_dump_json(indent=indent)
|
|
||||||
return model.json(indent=indent) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def model_dump(
|
|
||||||
model: pydantic.BaseModel,
|
|
||||||
*,
|
|
||||||
exclude_unset: bool = False,
|
|
||||||
exclude_defaults: bool = False,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
return model.model_dump(
|
|
||||||
exclude_unset=exclude_unset,
|
|
||||||
exclude_defaults=exclude_defaults,
|
|
||||||
)
|
|
||||||
return cast(
|
|
||||||
"dict[str, Any]",
|
|
||||||
model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
|
||||||
exclude_unset=exclude_unset,
|
|
||||||
exclude_defaults=exclude_defaults,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
return model.model_validate(data)
|
|
||||||
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
|
|
||||||
|
|
||||||
|
|
||||||
# generic models
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
|
|
||||||
class GenericModel(pydantic.BaseModel): ...
|
|
||||||
|
|
||||||
else:
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
# there no longer needs to be a distinction in v2 but
|
|
||||||
# we still have to create our own subclass to avoid
|
|
||||||
# inconsistent MRO ordering errors
|
|
||||||
class GenericModel(pydantic.BaseModel): ...
|
|
||||||
|
|
||||||
else:
|
|
||||||
import pydantic.generics
|
|
||||||
|
|
||||||
class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
|
|
||||||
|
|
||||||
|
|
||||||
# cached properties
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
cached_property = property
|
|
||||||
|
|
||||||
# we define a separate type (copied from typeshed)
|
|
||||||
# that represents that `cached_property` is `set`able
|
|
||||||
# at runtime, which differs from `@property`.
|
|
||||||
#
|
|
||||||
# this is a separate type as editors likely special case
|
|
||||||
# `@property` and we don't want to cause issues just to have
|
|
||||||
# more helpful internal types.
|
|
||||||
|
|
||||||
class typed_cached_property(Generic[_T]): # noqa: N801
|
|
||||||
func: Callable[[Any], _T]
|
|
||||||
attrname: str | None
|
|
||||||
|
|
||||||
def __init__(self, func: Callable[[Any], _T]) -> None: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ...
|
|
||||||
|
|
||||||
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def __set_name__(self, owner: type[Any], name: str) -> None: ...
|
|
||||||
|
|
||||||
# __set__ is not defined at runtime, but @cached_property is designed to be settable
|
|
||||||
def __set__(self, instance: object, value: _T) -> None: ...
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from functools import cached_property
|
|
||||||
except ImportError:
|
|
||||||
from cached_property import cached_property
|
|
||||||
|
|
||||||
typed_cached_property = cached_property
|
|
@ -1,670 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
import os
|
|
||||||
from collections.abc import Callable
|
|
||||||
from datetime import date, datetime
|
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeGuard, TypeVar, cast
|
|
||||||
|
|
||||||
import pydantic
|
|
||||||
import pydantic.generics
|
|
||||||
from pydantic.fields import FieldInfo
|
|
||||||
from typing_extensions import (
|
|
||||||
ParamSpec,
|
|
||||||
Protocol,
|
|
||||||
override,
|
|
||||||
runtime_checkable,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ._base_compat import (
|
|
||||||
PYDANTIC_V2,
|
|
||||||
ConfigDict,
|
|
||||||
field_get_default,
|
|
||||||
get_args,
|
|
||||||
get_model_config,
|
|
||||||
get_model_fields,
|
|
||||||
get_origin,
|
|
||||||
is_literal_type,
|
|
||||||
is_union,
|
|
||||||
parse_obj,
|
|
||||||
)
|
|
||||||
from ._base_compat import (
|
|
||||||
GenericModel as BaseGenericModel,
|
|
||||||
)
|
|
||||||
from ._base_type import (
|
|
||||||
IncEx,
|
|
||||||
ModelT,
|
|
||||||
)
|
|
||||||
from ._utils import (
|
|
||||||
PropertyInfo,
|
|
||||||
coerce_boolean,
|
|
||||||
extract_type_arg,
|
|
||||||
is_annotated_type,
|
|
||||||
is_list,
|
|
||||||
is_mapping,
|
|
||||||
parse_date,
|
|
||||||
parse_datetime,
|
|
||||||
strip_annotated_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from pydantic_core.core_schema import ModelField
|
|
||||||
|
|
||||||
__all__ = ["BaseModel", "GenericModel"]
|
|
||||||
_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
P = ParamSpec("P")
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class _ConfigProtocol(Protocol):
|
|
||||||
allow_population_by_field_name: bool
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(pydantic.BaseModel):
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
|
||||||
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
|
|
||||||
@property
|
|
||||||
@override
|
|
||||||
def model_fields_set(self) -> set[str]:
|
|
||||||
# a forwards-compat shim for pydantic v2
|
|
||||||
return self.__fields_set__ # type: ignore
|
|
||||||
|
|
||||||
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
|
|
||||||
extra: Any = pydantic.Extra.allow # type: ignore
|
|
||||||
|
|
||||||
def to_dict(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
mode: Literal["json", "python"] = "python",
|
|
||||||
use_api_names: bool = True,
|
|
||||||
exclude_unset: bool = True,
|
|
||||||
exclude_defaults: bool = False,
|
|
||||||
exclude_none: bool = False,
|
|
||||||
warnings: bool = True,
|
|
||||||
) -> dict[str, object]:
|
|
||||||
"""Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
|
|
||||||
|
|
||||||
By default, fields that were not set by the API will not be included,
|
|
||||||
and keys will match the API response, *not* the property names from the model.
|
|
||||||
|
|
||||||
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
|
|
||||||
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mode:
|
|
||||||
If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`.
|
|
||||||
If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)`
|
|
||||||
|
|
||||||
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
|
|
||||||
exclude_unset: Whether to exclude fields that have not been explicitly set.
|
|
||||||
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
|
|
||||||
exclude_none: Whether to exclude fields that have a value of `None` from the output.
|
|
||||||
warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2.
|
|
||||||
""" # noqa: E501
|
|
||||||
return self.model_dump(
|
|
||||||
mode=mode,
|
|
||||||
by_alias=use_api_names,
|
|
||||||
exclude_unset=exclude_unset,
|
|
||||||
exclude_defaults=exclude_defaults,
|
|
||||||
exclude_none=exclude_none,
|
|
||||||
warnings=warnings,
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_json(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
indent: int | None = 2,
|
|
||||||
use_api_names: bool = True,
|
|
||||||
exclude_unset: bool = True,
|
|
||||||
exclude_defaults: bool = False,
|
|
||||||
exclude_none: bool = False,
|
|
||||||
warnings: bool = True,
|
|
||||||
) -> str:
|
|
||||||
"""Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
|
|
||||||
|
|
||||||
By default, fields that were not set by the API will not be included,
|
|
||||||
and keys will match the API response, *not* the property names from the model.
|
|
||||||
|
|
||||||
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
|
|
||||||
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2`
|
|
||||||
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
|
|
||||||
exclude_unset: Whether to exclude fields that have not been explicitly set.
|
|
||||||
exclude_defaults: Whether to exclude fields that have the default value.
|
|
||||||
exclude_none: Whether to exclude fields that have a value of `None`.
|
|
||||||
warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2.
|
|
||||||
""" # noqa: E501
|
|
||||||
return self.model_dump_json(
|
|
||||||
indent=indent,
|
|
||||||
by_alias=use_api_names,
|
|
||||||
exclude_unset=exclude_unset,
|
|
||||||
exclude_defaults=exclude_defaults,
|
|
||||||
exclude_none=exclude_none,
|
|
||||||
warnings=warnings,
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def __str__(self) -> str:
|
|
||||||
# mypy complains about an invalid self arg
|
|
||||||
return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc]
|
|
||||||
|
|
||||||
# Override the 'construct' method in a way that supports recursive parsing without validation.
|
|
||||||
# Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.
|
|
||||||
@classmethod
|
|
||||||
@override
|
|
||||||
def construct(
|
|
||||||
cls: type[ModelT],
|
|
||||||
_fields_set: set[str] | None = None,
|
|
||||||
**values: object,
|
|
||||||
) -> ModelT:
|
|
||||||
m = cls.__new__(cls)
|
|
||||||
fields_values: dict[str, object] = {}
|
|
||||||
|
|
||||||
config = get_model_config(cls)
|
|
||||||
populate_by_name = (
|
|
||||||
config.allow_population_by_field_name
|
|
||||||
if isinstance(config, _ConfigProtocol)
|
|
||||||
else config.get("populate_by_name")
|
|
||||||
)
|
|
||||||
|
|
||||||
if _fields_set is None:
|
|
||||||
_fields_set = set()
|
|
||||||
|
|
||||||
model_fields = get_model_fields(cls)
|
|
||||||
for name, field in model_fields.items():
|
|
||||||
key = field.alias
|
|
||||||
if key is None or (key not in values and populate_by_name):
|
|
||||||
key = name
|
|
||||||
|
|
||||||
if key in values:
|
|
||||||
fields_values[name] = _construct_field(value=values[key], field=field, key=key)
|
|
||||||
_fields_set.add(name)
|
|
||||||
else:
|
|
||||||
fields_values[name] = field_get_default(field)
|
|
||||||
|
|
||||||
_extra = {}
|
|
||||||
for key, value in values.items():
|
|
||||||
if key not in model_fields:
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
_extra[key] = value
|
|
||||||
else:
|
|
||||||
_fields_set.add(key)
|
|
||||||
fields_values[key] = value
|
|
||||||
|
|
||||||
object.__setattr__(m, "__dict__", fields_values) # noqa: PLC2801
|
|
||||||
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
# these properties are copied from Pydantic's `model_construct()` method
|
|
||||||
object.__setattr__(m, "__pydantic_private__", None) # noqa: PLC2801
|
|
||||||
object.__setattr__(m, "__pydantic_extra__", _extra) # noqa: PLC2801
|
|
||||||
object.__setattr__(m, "__pydantic_fields_set__", _fields_set) # noqa: PLC2801
|
|
||||||
else:
|
|
||||||
# init_private_attributes() does not exist in v2
|
|
||||||
m._init_private_attributes() # type: ignore
|
|
||||||
|
|
||||||
# copied from Pydantic v1's `construct()` method
|
|
||||||
object.__setattr__(m, "__fields_set__", _fields_set) # noqa: PLC2801
|
|
||||||
|
|
||||||
return m
|
|
||||||
|
|
||||||
if not TYPE_CHECKING:
|
|
||||||
# type checkers incorrectly complain about this assignment
|
|
||||||
# because the type signatures are technically different
|
|
||||||
# although not in practice
|
|
||||||
model_construct = construct
|
|
||||||
|
|
||||||
if not PYDANTIC_V2:
|
|
||||||
# we define aliases for some of the new pydantic v2 methods so
|
|
||||||
# that we can just document these methods without having to specify
|
|
||||||
# a specific pydantic version as some users may not know which
|
|
||||||
# pydantic version they are currently using
|
|
||||||
|
|
||||||
@override
|
|
||||||
def model_dump(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
mode: Literal["json", "python"] | str = "python",
|
|
||||||
include: IncEx = None,
|
|
||||||
exclude: IncEx = None,
|
|
||||||
by_alias: bool = False,
|
|
||||||
exclude_unset: bool = False,
|
|
||||||
exclude_defaults: bool = False,
|
|
||||||
exclude_none: bool = False,
|
|
||||||
round_trip: bool = False,
|
|
||||||
warnings: bool | Literal["none", "warn", "error"] = True,
|
|
||||||
context: dict[str, Any] | None = None,
|
|
||||||
serialize_as_any: bool = False,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
|
|
||||||
|
|
||||||
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mode: The mode in which `to_python` should run.
|
|
||||||
If mode is 'json', the dictionary will only contain JSON serializable types.
|
|
||||||
If mode is 'python', the dictionary may contain any Python objects.
|
|
||||||
include: A list of fields to include in the output.
|
|
||||||
exclude: A list of fields to exclude from the output.
|
|
||||||
by_alias: Whether to use the field's alias in the dictionary key if defined.
|
|
||||||
exclude_unset: Whether to exclude fields that are unset or None from the output.
|
|
||||||
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
|
|
||||||
exclude_none: Whether to exclude fields that have a value of `None` from the output.
|
|
||||||
round_trip: Whether to enable serialization and deserialization round-trip support.
|
|
||||||
warnings: Whether to log warnings when invalid fields are encountered.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary representation of the model.
|
|
||||||
"""
|
|
||||||
if mode != "python":
|
|
||||||
raise ValueError("mode is only supported in Pydantic v2")
|
|
||||||
if round_trip != False:
|
|
||||||
raise ValueError("round_trip is only supported in Pydantic v2")
|
|
||||||
if warnings != True:
|
|
||||||
raise ValueError("warnings is only supported in Pydantic v2")
|
|
||||||
if context is not None:
|
|
||||||
raise ValueError("context is only supported in Pydantic v2")
|
|
||||||
if serialize_as_any != False:
|
|
||||||
raise ValueError("serialize_as_any is only supported in Pydantic v2")
|
|
||||||
return super().dict( # pyright: ignore[reportDeprecated]
|
|
||||||
include=include,
|
|
||||||
exclude=exclude,
|
|
||||||
by_alias=by_alias,
|
|
||||||
exclude_unset=exclude_unset,
|
|
||||||
exclude_defaults=exclude_defaults,
|
|
||||||
exclude_none=exclude_none,
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def model_dump_json(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
indent: int | None = None,
|
|
||||||
include: IncEx = None,
|
|
||||||
exclude: IncEx = None,
|
|
||||||
by_alias: bool = False,
|
|
||||||
exclude_unset: bool = False,
|
|
||||||
exclude_defaults: bool = False,
|
|
||||||
exclude_none: bool = False,
|
|
||||||
round_trip: bool = False,
|
|
||||||
warnings: bool | Literal["none", "warn", "error"] = True,
|
|
||||||
context: dict[str, Any] | None = None,
|
|
||||||
serialize_as_any: bool = False,
|
|
||||||
) -> str:
|
|
||||||
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json
|
|
||||||
|
|
||||||
Generates a JSON representation of the model using Pydantic's `to_json` method.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
indent: Indentation to use in the JSON output. If None is passed, the output will be compact.
|
|
||||||
include: Field(s) to include in the JSON output. Can take either a string or set of strings.
|
|
||||||
exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.
|
|
||||||
by_alias: Whether to serialize using field aliases.
|
|
||||||
exclude_unset: Whether to exclude fields that have not been explicitly set.
|
|
||||||
exclude_defaults: Whether to exclude fields that have the default value.
|
|
||||||
exclude_none: Whether to exclude fields that have a value of `None`.
|
|
||||||
round_trip: Whether to use serialization/deserialization between JSON and class instance.
|
|
||||||
warnings: Whether to show any warnings that occurred during serialization.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A JSON string representation of the model.
|
|
||||||
"""
|
|
||||||
if round_trip != False:
|
|
||||||
raise ValueError("round_trip is only supported in Pydantic v2")
|
|
||||||
if warnings != True:
|
|
||||||
raise ValueError("warnings is only supported in Pydantic v2")
|
|
||||||
if context is not None:
|
|
||||||
raise ValueError("context is only supported in Pydantic v2")
|
|
||||||
if serialize_as_any != False:
|
|
||||||
raise ValueError("serialize_as_any is only supported in Pydantic v2")
|
|
||||||
return super().json( # type: ignore[reportDeprecated]
|
|
||||||
indent=indent,
|
|
||||||
include=include,
|
|
||||||
exclude=exclude,
|
|
||||||
by_alias=by_alias,
|
|
||||||
exclude_unset=exclude_unset,
|
|
||||||
exclude_defaults=exclude_defaults,
|
|
||||||
exclude_none=exclude_none,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _construct_field(value: object, field: FieldInfo, key: str) -> object:
|
|
||||||
if value is None:
|
|
||||||
return field_get_default(field)
|
|
||||||
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
type_ = field.annotation
|
|
||||||
else:
|
|
||||||
type_ = cast(type, field.outer_type_) # type: ignore
|
|
||||||
|
|
||||||
if type_ is None:
|
|
||||||
raise RuntimeError(f"Unexpected field type is None for {key}")
|
|
||||||
|
|
||||||
return construct_type(value=value, type_=type_)
|
|
||||||
|
|
||||||
|
|
||||||
def is_basemodel(type_: type) -> bool:
|
|
||||||
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
|
|
||||||
if is_union(type_):
|
|
||||||
return any(is_basemodel(variant) for variant in get_args(type_))
|
|
||||||
|
|
||||||
return is_basemodel_type(type_)
|
|
||||||
|
|
||||||
|
|
||||||
def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
|
|
||||||
origin = get_origin(type_) or type_
|
|
||||||
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
|
|
||||||
|
|
||||||
|
|
||||||
def build(
|
|
||||||
base_model_cls: Callable[P, _BaseModelT],
|
|
||||||
*args: P.args,
|
|
||||||
**kwargs: P.kwargs,
|
|
||||||
) -> _BaseModelT:
|
|
||||||
"""Construct a BaseModel class without validation.
|
|
||||||
|
|
||||||
This is useful for cases where you need to instantiate a `BaseModel`
|
|
||||||
from an API response as this provides type-safe params which isn't supported
|
|
||||||
by helpers like `construct_type()`.
|
|
||||||
|
|
||||||
```py
|
|
||||||
build(MyModel, my_field_a="foo", my_field_b=123)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if args:
|
|
||||||
raise TypeError(
|
|
||||||
"Received positional arguments which are not supported; Keyword arguments must be used instead",
|
|
||||||
)
|
|
||||||
|
|
||||||
return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs))
|
|
||||||
|
|
||||||
|
|
||||||
def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T:
|
|
||||||
"""Loose coercion to the expected type with construction of nested values.
|
|
||||||
|
|
||||||
Note: the returned value from this function is not guaranteed to match the
|
|
||||||
given type.
|
|
||||||
"""
|
|
||||||
return cast(_T, construct_type(value=value, type_=type_))
|
|
||||||
|
|
||||||
|
|
||||||
def construct_type(*, value: object, type_: type) -> object:
|
|
||||||
"""Loose coercion to the expected type with construction of nested values.
|
|
||||||
|
|
||||||
If the given value does not match the expected type then it is returned as-is.
|
|
||||||
"""
|
|
||||||
# we allow `object` as the input type because otherwise, passing things like
|
|
||||||
# `Literal['value']` will be reported as a type error by type checkers
|
|
||||||
type_ = cast("type[object]", type_)
|
|
||||||
|
|
||||||
# unwrap `Annotated[T, ...]` -> `T`
|
|
||||||
if is_annotated_type(type_):
|
|
||||||
meta: tuple[Any, ...] = get_args(type_)[1:]
|
|
||||||
type_ = extract_type_arg(type_, 0)
|
|
||||||
else:
|
|
||||||
meta = ()
|
|
||||||
# we need to use the origin class for any types that are subscripted generics
|
|
||||||
# e.g. Dict[str, object]
|
|
||||||
origin = get_origin(type_) or type_
|
|
||||||
args = get_args(type_)
|
|
||||||
|
|
||||||
if is_union(origin):
|
|
||||||
try:
|
|
||||||
return validate_type(type_=cast("type[object]", type_), value=value)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# if the type is a discriminated union then we want to construct the right variant
|
|
||||||
# in the union, even if the data doesn't match exactly, otherwise we'd break code
|
|
||||||
# that relies on the constructed class types, e.g.
|
|
||||||
#
|
|
||||||
# class FooType:
|
|
||||||
# kind: Literal['foo']
|
|
||||||
# value: str
|
|
||||||
#
|
|
||||||
# class BarType:
|
|
||||||
# kind: Literal['bar']
|
|
||||||
# value: int
|
|
||||||
#
|
|
||||||
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
|
|
||||||
# we'd end up constructing `FooType` when it should be `BarType`.
|
|
||||||
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
|
|
||||||
if discriminator and is_mapping(value):
|
|
||||||
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
|
|
||||||
if variant_value and isinstance(variant_value, str):
|
|
||||||
variant_type = discriminator.mapping.get(variant_value)
|
|
||||||
if variant_type:
|
|
||||||
return construct_type(type_=variant_type, value=value)
|
|
||||||
|
|
||||||
# if the data is not valid, use the first variant that doesn't fail while deserializing
|
|
||||||
for variant in args:
|
|
||||||
try:
|
|
||||||
return construct_type(value=value, type_=variant)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
raise RuntimeError(f"Could not convert data into a valid instance of {type_}")
|
|
||||||
if origin == dict:
|
|
||||||
if not is_mapping(value):
|
|
||||||
return value
|
|
||||||
|
|
||||||
_, items_type = get_args(type_) # Dict[_, items_type]
|
|
||||||
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
|
|
||||||
|
|
||||||
if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):
|
|
||||||
if is_list(value):
|
|
||||||
return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
|
|
||||||
|
|
||||||
if is_mapping(value):
|
|
||||||
if issubclass(type_, BaseModel):
|
|
||||||
return type_.construct(**value) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
return cast(Any, type_).construct(**value)
|
|
||||||
|
|
||||||
if origin == list:
|
|
||||||
if not is_list(value):
|
|
||||||
return value
|
|
||||||
|
|
||||||
inner_type = args[0] # List[inner_type]
|
|
||||||
return [construct_type(value=entry, type_=inner_type) for entry in value]
|
|
||||||
|
|
||||||
if origin == float:
|
|
||||||
if isinstance(value, int):
|
|
||||||
coerced = float(value)
|
|
||||||
if coerced != value:
|
|
||||||
return value
|
|
||||||
return coerced
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
if type_ == datetime:
|
|
||||||
try:
|
|
||||||
return parse_datetime(value) # type: ignore
|
|
||||||
except Exception:
|
|
||||||
return value
|
|
||||||
|
|
||||||
if type_ == date:
|
|
||||||
try:
|
|
||||||
return parse_date(value) # type: ignore
|
|
||||||
except Exception:
|
|
||||||
return value
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class CachedDiscriminatorType(Protocol):
|
|
||||||
__discriminator__: DiscriminatorDetails
|
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorDetails:
|
|
||||||
field_name: str
|
|
||||||
"""The name of the discriminator field in the variant class, e.g.
|
|
||||||
|
|
||||||
```py
|
|
||||||
class Foo(BaseModel):
|
|
||||||
type: Literal['foo']
|
|
||||||
```
|
|
||||||
|
|
||||||
Will result in field_name='type'
|
|
||||||
"""
|
|
||||||
|
|
||||||
field_alias_from: str | None
|
|
||||||
"""The name of the discriminator field in the API response, e.g.
|
|
||||||
|
|
||||||
```py
|
|
||||||
class Foo(BaseModel):
|
|
||||||
type: Literal['foo'] = Field(alias='type_from_api')
|
|
||||||
```
|
|
||||||
|
|
||||||
Will result in field_alias_from='type_from_api'
|
|
||||||
"""
|
|
||||||
|
|
||||||
mapping: dict[str, type]
|
|
||||||
"""Mapping of discriminator value to variant type, e.g.
|
|
||||||
|
|
||||||
{'foo': FooVariant, 'bar': BarVariant}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
mapping: dict[str, type],
|
|
||||||
discriminator_field: str,
|
|
||||||
discriminator_alias: str | None,
|
|
||||||
) -> None:
|
|
||||||
self.mapping = mapping
|
|
||||||
self.field_name = discriminator_field
|
|
||||||
self.field_alias_from = discriminator_alias
|
|
||||||
|
|
||||||
|
|
||||||
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
|
|
||||||
if isinstance(union, CachedDiscriminatorType):
|
|
||||||
return union.__discriminator__
|
|
||||||
|
|
||||||
discriminator_field_name: str | None = None
|
|
||||||
|
|
||||||
for annotation in meta_annotations:
|
|
||||||
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
|
|
||||||
discriminator_field_name = annotation.discriminator
|
|
||||||
break
|
|
||||||
|
|
||||||
if not discriminator_field_name:
|
|
||||||
return None
|
|
||||||
|
|
||||||
mapping: dict[str, type] = {}
|
|
||||||
discriminator_alias: str | None = None
|
|
||||||
|
|
||||||
for variant in get_args(union):
|
|
||||||
variant = strip_annotated_type(variant)
|
|
||||||
if is_basemodel_type(variant):
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
field = _extract_field_schema_pv2(variant, discriminator_field_name)
|
|
||||||
if not field:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Note: if one variant defines an alias then they all should
|
|
||||||
discriminator_alias = field.get("serialization_alias")
|
|
||||||
|
|
||||||
field_schema = field["schema"]
|
|
||||||
|
|
||||||
if field_schema["type"] == "literal":
|
|
||||||
for entry in cast("LiteralSchema", field_schema)["expected"]:
|
|
||||||
if isinstance(entry, str):
|
|
||||||
mapping[entry] = variant
|
|
||||||
else:
|
|
||||||
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
|
||||||
if not field_info:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Note: if one variant defines an alias then they all should
|
|
||||||
discriminator_alias = field_info.alias
|
|
||||||
|
|
||||||
if field_info.annotation and is_literal_type(field_info.annotation):
|
|
||||||
for entry in get_args(field_info.annotation):
|
|
||||||
if isinstance(entry, str):
|
|
||||||
mapping[entry] = variant
|
|
||||||
|
|
||||||
if not mapping:
|
|
||||||
return None
|
|
||||||
|
|
||||||
details = DiscriminatorDetails(
|
|
||||||
mapping=mapping,
|
|
||||||
discriminator_field=discriminator_field_name,
|
|
||||||
discriminator_alias=discriminator_alias,
|
|
||||||
)
|
|
||||||
cast(CachedDiscriminatorType, union).__discriminator__ = details
|
|
||||||
return details
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
|
|
||||||
schema = model.__pydantic_core_schema__
|
|
||||||
if schema["type"] != "model":
|
|
||||||
return None
|
|
||||||
|
|
||||||
fields_schema = schema["schema"]
|
|
||||||
if fields_schema["type"] != "model-fields":
|
|
||||||
return None
|
|
||||||
|
|
||||||
fields_schema = cast("ModelFieldsSchema", fields_schema)
|
|
||||||
|
|
||||||
field = fields_schema["fields"].get(field_name)
|
|
||||||
if not field:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
|
|
||||||
|
|
||||||
|
|
||||||
def validate_type(*, type_: type[_T], value: object) -> _T:
|
|
||||||
"""Strict validation that the given value matches the expected type"""
|
|
||||||
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
|
|
||||||
return cast(_T, parse_obj(type_, value))
|
|
||||||
|
|
||||||
return cast(_T, _validate_non_model_type(type_=type_, value=value))
|
|
||||||
|
|
||||||
|
|
||||||
# Subclassing here confuses type checkers, so we treat this class as non-inheriting.
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
GenericModel = BaseModel
|
|
||||||
else:
|
|
||||||
|
|
||||||
class GenericModel(BaseGenericModel, BaseModel):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
from pydantic import TypeAdapter
|
|
||||||
|
|
||||||
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
|
|
||||||
return TypeAdapter(type_).validate_python(value)
|
|
||||||
|
|
||||||
elif not TYPE_CHECKING:
|
|
||||||
|
|
||||||
class TypeAdapter(Generic[_T]):
|
|
||||||
"""Used as a placeholder to easily convert runtime types to a Pydantic format
|
|
||||||
to provide validation.
|
|
||||||
|
|
||||||
For example:
|
|
||||||
```py
|
|
||||||
validated = RootModel[int](__root__="5").__root__
|
|
||||||
# validated: 5
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, type_: type[_T]):
|
|
||||||
self.type_ = type_
|
|
||||||
|
|
||||||
def validate_python(self, value: Any) -> _T:
|
|
||||||
if not isinstance(value, self.type_):
|
|
||||||
raise ValueError(f"Invalid type: {value} is not of type {self.type_}")
|
|
||||||
return value
|
|
||||||
|
|
||||||
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
|
|
||||||
return TypeAdapter(type_).validate_python(value)
|
|
@ -1,170 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Callable, Mapping, Sequence
|
|
||||||
from os import PathLike
|
|
||||||
from typing import (
|
|
||||||
IO,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
TypeAlias,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import pydantic
|
|
||||||
from httpx import Response
|
|
||||||
from typing_extensions import Protocol, TypedDict, override, runtime_checkable
|
|
||||||
|
|
||||||
Query = Mapping[str, object]
|
|
||||||
Body = object
|
|
||||||
AnyMapping = Mapping[str, object]
|
|
||||||
PrimitiveData = Union[str, int, float, bool, None]
|
|
||||||
Data = Union[PrimitiveData, list[Any], tuple[Any], "Mapping[str, Any]"]
|
|
||||||
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
NoneType: type[None]
|
|
||||||
else:
|
|
||||||
NoneType = type(None)
|
|
||||||
|
|
||||||
|
|
||||||
# Sentinel class used until PEP 0661 is accepted
|
|
||||||
class NotGiven:
|
|
||||||
"""
|
|
||||||
A sentinel singleton class used to distinguish omitted keyword arguments
|
|
||||||
from those passed in with the value None (which may have different behavior).
|
|
||||||
|
|
||||||
For example:
|
|
||||||
|
|
||||||
```py
|
|
||||||
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
|
|
||||||
|
|
||||||
get(timeout=1) # 1s timeout
|
|
||||||
get(timeout=None) # No timeout
|
|
||||||
get() # Default timeout behavior, which may not be statically known at the method definition.
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __bool__(self) -> Literal[False]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@override
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return "NOT_GIVEN"
|
|
||||||
|
|
||||||
|
|
||||||
NotGivenOr = Union[_T, NotGiven]
|
|
||||||
NOT_GIVEN = NotGiven()
|
|
||||||
|
|
||||||
|
|
||||||
class Omit:
|
|
||||||
"""In certain situations you need to be able to represent a case where a default value has
|
|
||||||
to be explicitly removed and `None` is not an appropriate substitute, for example:
|
|
||||||
|
|
||||||
```py
|
|
||||||
# as the default `Content-Type` header is `application/json` that will be sent
|
|
||||||
client.post('/upload/files', files={'file': b'my raw file content'})
|
|
||||||
|
|
||||||
# you can't explicitly override the header as it has to be dynamically generated
|
|
||||||
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
|
|
||||||
client.post(..., headers={'Content-Type': 'multipart/form-data'})
|
|
||||||
|
|
||||||
# instead you can remove the default `application/json` header by passing Omit
|
|
||||||
client.post(..., headers={'Content-Type': Omit()})
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __bool__(self) -> Literal[False]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class ModelBuilderProtocol(Protocol):
|
|
||||||
@classmethod
|
|
||||||
def build(
|
|
||||||
cls: type[_T],
|
|
||||||
*,
|
|
||||||
response: Response,
|
|
||||||
data: object,
|
|
||||||
) -> _T: ...
|
|
||||||
|
|
||||||
|
|
||||||
Headers = Mapping[str, Union[str, Omit]]
|
|
||||||
|
|
||||||
|
|
||||||
class HeadersLikeProtocol(Protocol):
|
|
||||||
def get(self, __key: str) -> str | None: ...
|
|
||||||
|
|
||||||
|
|
||||||
HeadersLike = Union[Headers, HeadersLikeProtocol]
|
|
||||||
|
|
||||||
ResponseT = TypeVar(
|
|
||||||
"ResponseT",
|
|
||||||
bound="Union[str, None, BaseModel, list[Any], dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", # noqa: E501
|
|
||||||
)
|
|
||||||
|
|
||||||
StrBytesIntFloat = Union[str, bytes, int, float]
|
|
||||||
|
|
||||||
# Note: copied from Pydantic
|
|
||||||
# https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49
|
|
||||||
IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None"
|
|
||||||
|
|
||||||
PostParser = Callable[[Any], Any]
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class InheritsGeneric(Protocol):
|
|
||||||
"""Represents a type that has inherited from `Generic`
|
|
||||||
|
|
||||||
The `__orig_bases__` property can be used to determine the resolved
|
|
||||||
type variable for a given base class.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__orig_bases__: tuple[_GenericAlias]
|
|
||||||
|
|
||||||
|
|
||||||
class _GenericAlias(Protocol):
|
|
||||||
__origin__: type[object]
|
|
||||||
|
|
||||||
|
|
||||||
class HttpxSendArgs(TypedDict, total=False):
|
|
||||||
auth: httpx.Auth
|
|
||||||
|
|
||||||
|
|
||||||
# for user input files
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
Base64FileInput = Union[IO[bytes], PathLike[str]]
|
|
||||||
FileContent = Union[IO[bytes], bytes, PathLike[str]]
|
|
||||||
else:
|
|
||||||
Base64FileInput = Union[IO[bytes], PathLike]
|
|
||||||
FileContent = Union[IO[bytes], bytes, PathLike]
|
|
||||||
|
|
||||||
FileTypes = Union[
|
|
||||||
# file (or bytes)
|
|
||||||
FileContent,
|
|
||||||
# (filename, file (or bytes))
|
|
||||||
tuple[Optional[str], FileContent],
|
|
||||||
# (filename, file (or bytes), content_type)
|
|
||||||
tuple[Optional[str], FileContent, Optional[str]],
|
|
||||||
# (filename, file (or bytes), content_type, headers)
|
|
||||||
tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
|
|
||||||
]
|
|
||||||
RequestFiles = Union[Mapping[str, FileTypes], Sequence[tuple[str, FileTypes]]]
|
|
||||||
|
|
||||||
# duplicate of the above but without our custom file support
|
|
||||||
HttpxFileContent = Union[bytes, IO[bytes]]
|
|
||||||
HttpxFileTypes = Union[
|
|
||||||
# file (or bytes)
|
|
||||||
HttpxFileContent,
|
|
||||||
# (filename, file (or bytes))
|
|
||||||
tuple[Optional[str], HttpxFileContent],
|
|
||||||
# (filename, file (or bytes), content_type)
|
|
||||||
tuple[Optional[str], HttpxFileContent, Optional[str]],
|
|
||||||
# (filename, file (or bytes), content_type, headers)
|
|
||||||
tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
|
|
||||||
]
|
|
||||||
|
|
||||||
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[tuple[str, HttpxFileTypes]]]
|
|
@ -1,12 +0,0 @@
|
|||||||
import httpx
|
|
||||||
|
|
||||||
RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
|
|
||||||
# 通过 `Timeout` 控制接口`connect` 和 `read` 超时时间,默认为`timeout=300.0, connect=8.0`
|
|
||||||
ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
|
|
||||||
# 通过 `retry` 参数控制重试次数,默认为3次
|
|
||||||
ZHIPUAI_DEFAULT_MAX_RETRIES = 3
|
|
||||||
# 通过 `Limits` 控制最大连接数和保持连接数,默认为`max_connections=50, max_keepalive_connections=10`
|
|
||||||
ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10)
|
|
||||||
|
|
||||||
INITIAL_RETRY_DELAY = 0.5
|
|
||||||
MAX_RETRY_DELAY = 8.0
|
|
@ -1,86 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ZhipuAIError",
|
|
||||||
"APIStatusError",
|
|
||||||
"APIRequestFailedError",
|
|
||||||
"APIAuthenticationError",
|
|
||||||
"APIReachLimitError",
|
|
||||||
"APIInternalError",
|
|
||||||
"APIServerFlowExceedError",
|
|
||||||
"APIResponseError",
|
|
||||||
"APIResponseValidationError",
|
|
||||||
"APITimeoutError",
|
|
||||||
"APIConnectionError",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ZhipuAIError(Exception):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
message: str,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class APIStatusError(ZhipuAIError):
|
|
||||||
response: httpx.Response
|
|
||||||
status_code: int
|
|
||||||
|
|
||||||
def __init__(self, message: str, *, response: httpx.Response) -> None:
|
|
||||||
super().__init__(message)
|
|
||||||
self.response = response
|
|
||||||
self.status_code = response.status_code
|
|
||||||
|
|
||||||
|
|
||||||
class APIRequestFailedError(APIStatusError): ...
|
|
||||||
|
|
||||||
|
|
||||||
class APIAuthenticationError(APIStatusError): ...
|
|
||||||
|
|
||||||
|
|
||||||
class APIReachLimitError(APIStatusError): ...
|
|
||||||
|
|
||||||
|
|
||||||
class APIInternalError(APIStatusError): ...
|
|
||||||
|
|
||||||
|
|
||||||
class APIServerFlowExceedError(APIStatusError): ...
|
|
||||||
|
|
||||||
|
|
||||||
class APIResponseError(ZhipuAIError):
|
|
||||||
message: str
|
|
||||||
request: httpx.Request
|
|
||||||
json_data: object
|
|
||||||
|
|
||||||
def __init__(self, message: str, request: httpx.Request, json_data: object):
|
|
||||||
self.message = message
|
|
||||||
self.request = request
|
|
||||||
self.json_data = json_data
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class APIResponseValidationError(APIResponseError):
|
|
||||||
status_code: int
|
|
||||||
response: httpx.Response
|
|
||||||
|
|
||||||
def __init__(self, response: httpx.Response, json_data: object | None, *, message: str | None = None) -> None:
|
|
||||||
super().__init__(
|
|
||||||
message=message or "Data returned by API invalid for expected schema.",
|
|
||||||
request=response.request,
|
|
||||||
json_data=json_data,
|
|
||||||
)
|
|
||||||
self.response = response
|
|
||||||
self.status_code = response.status_code
|
|
||||||
|
|
||||||
|
|
||||||
class APIConnectionError(APIResponseError):
|
|
||||||
def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
|
|
||||||
super().__init__(message, request, json_data=None)
|
|
||||||
|
|
||||||
|
|
||||||
class APITimeoutError(APIConnectionError):
|
|
||||||
def __init__(self, request: httpx.Request) -> None:
|
|
||||||
super().__init__(message="Request timed out.", request=request)
|
|
@ -1,75 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import pathlib
|
|
||||||
from typing import TypeGuard, overload
|
|
||||||
|
|
||||||
from ._base_type import (
|
|
||||||
Base64FileInput,
|
|
||||||
FileContent,
|
|
||||||
FileTypes,
|
|
||||||
HttpxFileContent,
|
|
||||||
HttpxFileTypes,
|
|
||||||
HttpxRequestFiles,
|
|
||||||
RequestFiles,
|
|
||||||
)
|
|
||||||
from ._utils import is_mapping_t, is_sequence_t, is_tuple_t
|
|
||||||
|
|
||||||
|
|
||||||
def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
|
|
||||||
return isinstance(obj, io.IOBase | os.PathLike)
|
|
||||||
|
|
||||||
|
|
||||||
def is_file_content(obj: object) -> TypeGuard[FileContent]:
|
|
||||||
return isinstance(obj, bytes | tuple | io.IOBase | os.PathLike)
|
|
||||||
|
|
||||||
|
|
||||||
def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
|
|
||||||
if not is_file_content(obj):
|
|
||||||
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
|
|
||||||
raise RuntimeError(
|
|
||||||
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads"
|
|
||||||
) from None
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def to_httpx_files(files: None) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ...
|
|
||||||
|
|
||||||
|
|
||||||
def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
|
|
||||||
if files is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if is_mapping_t(files):
|
|
||||||
files = {key: _transform_file(file) for key, file in files.items()}
|
|
||||||
elif is_sequence_t(files):
|
|
||||||
files = [(key, _transform_file(file)) for key, file in files]
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
|
|
||||||
|
|
||||||
return files
|
|
||||||
|
|
||||||
|
|
||||||
def _transform_file(file: FileTypes) -> HttpxFileTypes:
|
|
||||||
if is_file_content(file):
|
|
||||||
if isinstance(file, os.PathLike):
|
|
||||||
path = pathlib.Path(file)
|
|
||||||
return (path.name, path.read_bytes())
|
|
||||||
|
|
||||||
return file
|
|
||||||
|
|
||||||
if is_tuple_t(file):
|
|
||||||
return (file[0], _read_file_content(file[1]), *file[2:])
|
|
||||||
|
|
||||||
raise TypeError("Expected file types input to be a FileContent type or to be a tuple")
|
|
||||||
|
|
||||||
|
|
||||||
def _read_file_content(file: FileContent) -> HttpxFileContent:
|
|
||||||
if isinstance(file, os.PathLike):
|
|
||||||
return pathlib.Path(file).read_bytes()
|
|
||||||
return file
|
|
@ -1,910 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import warnings
|
|
||||||
from collections.abc import Iterator, Mapping
|
|
||||||
from itertools import starmap
|
|
||||||
from random import random
|
|
||||||
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import pydantic
|
|
||||||
from httpx import URL, Timeout
|
|
||||||
|
|
||||||
from . import _errors, get_origin
|
|
||||||
from ._base_compat import model_copy
|
|
||||||
from ._base_models import GenericModel, construct_type, validate_type
|
|
||||||
from ._base_type import (
|
|
||||||
NOT_GIVEN,
|
|
||||||
AnyMapping,
|
|
||||||
Body,
|
|
||||||
Data,
|
|
||||||
Headers,
|
|
||||||
HttpxSendArgs,
|
|
||||||
ModelBuilderProtocol,
|
|
||||||
NotGiven,
|
|
||||||
Omit,
|
|
||||||
PostParser,
|
|
||||||
Query,
|
|
||||||
RequestFiles,
|
|
||||||
ResponseT,
|
|
||||||
)
|
|
||||||
from ._constants import (
|
|
||||||
INITIAL_RETRY_DELAY,
|
|
||||||
MAX_RETRY_DELAY,
|
|
||||||
RAW_RESPONSE_HEADER,
|
|
||||||
ZHIPUAI_DEFAULT_LIMITS,
|
|
||||||
ZHIPUAI_DEFAULT_MAX_RETRIES,
|
|
||||||
ZHIPUAI_DEFAULT_TIMEOUT,
|
|
||||||
)
|
|
||||||
from ._errors import APIConnectionError, APIResponseValidationError, APIStatusError, APITimeoutError
|
|
||||||
from ._files import to_httpx_files
|
|
||||||
from ._legacy_response import LegacyAPIResponse
|
|
||||||
from ._request_opt import FinalRequestOptions, UserRequestInput
|
|
||||||
from ._response import APIResponse, BaseAPIResponse, extract_response_type
|
|
||||||
from ._sse_client import StreamResponse
|
|
||||||
from ._utils import flatten, is_given, is_mapping
|
|
||||||
|
|
||||||
log: logging.Logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# TODO: make base page type vars covariant
|
|
||||||
SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]")
|
|
||||||
# AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]")
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
_T_co = TypeVar("_T_co", covariant=True)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
|
|
||||||
except ImportError:
|
|
||||||
# taken from https://github.com/encode/httpx/blob/3ba5fe0d7ac70222590e759c31442b1cab263791/httpx/_config.py#L366
|
|
||||||
HTTPX_DEFAULT_TIMEOUT = Timeout(5.0)
|
|
||||||
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Accept": "application/json",
|
|
||||||
"Content-Type": "application/json; charset=UTF-8",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class PageInfo:
|
|
||||||
"""Stores the necessary information to build the request to retrieve the next page.
|
|
||||||
|
|
||||||
Either `url` or `params` must be set.
|
|
||||||
"""
|
|
||||||
|
|
||||||
url: URL | NotGiven
|
|
||||||
params: Query | NotGiven
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
url: URL,
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
params: Query,
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
url: URL | NotGiven = NOT_GIVEN,
|
|
||||||
params: Query | NotGiven = NOT_GIVEN,
|
|
||||||
) -> None:
|
|
||||||
self.url = url
|
|
||||||
self.params = params
|
|
||||||
|
|
||||||
|
|
||||||
class BasePage(GenericModel, Generic[_T]):
|
|
||||||
"""
|
|
||||||
Defines the core interface for pagination.
|
|
||||||
|
|
||||||
Type Args:
|
|
||||||
ModelT: The pydantic model that represents an item in the response.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
has_next_page(): Check if there is another page available
|
|
||||||
next_page_info(): Get the necessary information to make a request for the next page
|
|
||||||
"""
|
|
||||||
|
|
||||||
_options: FinalRequestOptions = pydantic.PrivateAttr()
|
|
||||||
_model: type[_T] = pydantic.PrivateAttr()
|
|
||||||
|
|
||||||
def has_next_page(self) -> bool:
|
|
||||||
items = self._get_page_items()
|
|
||||||
if not items:
|
|
||||||
return False
|
|
||||||
return self.next_page_info() is not None
|
|
||||||
|
|
||||||
def next_page_info(self) -> Optional[PageInfo]: ...
|
|
||||||
|
|
||||||
def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body]
|
|
||||||
...
|
|
||||||
|
|
||||||
def _params_from_url(self, url: URL) -> httpx.QueryParams:
|
|
||||||
# TODO: do we have to preprocess params here?
|
|
||||||
return httpx.QueryParams(cast(Any, self._options.params)).merge(url.params)
|
|
||||||
|
|
||||||
def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
|
|
||||||
options = model_copy(self._options)
|
|
||||||
options._strip_raw_response_header()
|
|
||||||
|
|
||||||
if not isinstance(info.params, NotGiven):
|
|
||||||
options.params = {**options.params, **info.params}
|
|
||||||
return options
|
|
||||||
|
|
||||||
if not isinstance(info.url, NotGiven):
|
|
||||||
params = self._params_from_url(info.url)
|
|
||||||
url = info.url.copy_with(params=params)
|
|
||||||
options.params = dict(url.params)
|
|
||||||
options.url = str(url)
|
|
||||||
return options
|
|
||||||
|
|
||||||
raise ValueError("Unexpected PageInfo state")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSyncPage(BasePage[_T], Generic[_T]):
|
|
||||||
_client: HttpClient = pydantic.PrivateAttr()
|
|
||||||
|
|
||||||
def _set_private_attributes(
|
|
||||||
self,
|
|
||||||
client: HttpClient,
|
|
||||||
model: type[_T],
|
|
||||||
options: FinalRequestOptions,
|
|
||||||
) -> None:
|
|
||||||
self._model = model
|
|
||||||
self._client = client
|
|
||||||
self._options = options
|
|
||||||
|
|
||||||
# Pydantic uses a custom `__iter__` method to support casting BaseModels
|
|
||||||
# to dictionaries. e.g. dict(model).
|
|
||||||
# As we want to support `for item in page`, this is inherently incompatible
|
|
||||||
# with the default pydantic behavior. It is not possible to support both
|
|
||||||
# use cases at once. Fortunately, this is not a big deal as all other pydantic
|
|
||||||
# methods should continue to work as expected as there is an alternative method
|
|
||||||
# to cast a model to a dictionary, model.dict(), which is used internally
|
|
||||||
# by pydantic.
|
|
||||||
def __iter__(self) -> Iterator[_T]: # type: ignore
|
|
||||||
for page in self.iter_pages():
|
|
||||||
yield from page._get_page_items()
|
|
||||||
|
|
||||||
def iter_pages(self: SyncPageT) -> Iterator[SyncPageT]:
|
|
||||||
page = self
|
|
||||||
while True:
|
|
||||||
yield page
|
|
||||||
if page.has_next_page():
|
|
||||||
page = page.get_next_page()
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
|
|
||||||
def get_next_page(self: SyncPageT) -> SyncPageT:
|
|
||||||
info = self.next_page_info()
|
|
||||||
if not info:
|
|
||||||
raise RuntimeError(
|
|
||||||
"No next page expected; please check `.has_next_page()` before calling `.get_next_page()`."
|
|
||||||
)
|
|
||||||
|
|
||||||
options = self._info_to_options(info)
|
|
||||||
return self._client._request_api_list(self._model, page=self.__class__, options=options)
|
|
||||||
|
|
||||||
|
|
||||||
class HttpClient:
|
|
||||||
_client: httpx.Client
|
|
||||||
_version: str
|
|
||||||
_base_url: URL
|
|
||||||
max_retries: int
|
|
||||||
timeout: Union[float, Timeout, None]
|
|
||||||
_limits: httpx.Limits
|
|
||||||
_has_custom_http_client: bool
|
|
||||||
_default_stream_cls: type[StreamResponse[Any]] | None = None
|
|
||||||
|
|
||||||
_strict_response_validation: bool
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
version: str,
|
|
||||||
base_url: URL,
|
|
||||||
_strict_response_validation: bool,
|
|
||||||
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
|
|
||||||
timeout: Union[float, Timeout, None],
|
|
||||||
limits: httpx.Limits | None = None,
|
|
||||||
custom_httpx_client: httpx.Client | None = None,
|
|
||||||
custom_headers: Mapping[str, str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
if limits is not None:
|
|
||||||
warnings.warn(
|
|
||||||
"The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead", # noqa: E501
|
|
||||||
category=DeprecationWarning,
|
|
||||||
stacklevel=3,
|
|
||||||
)
|
|
||||||
if custom_httpx_client is not None:
|
|
||||||
raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`")
|
|
||||||
else:
|
|
||||||
limits = ZHIPUAI_DEFAULT_LIMITS
|
|
||||||
|
|
||||||
if not is_given(timeout):
|
|
||||||
if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT:
|
|
||||||
timeout = custom_httpx_client.timeout
|
|
||||||
else:
|
|
||||||
timeout = ZHIPUAI_DEFAULT_TIMEOUT
|
|
||||||
self.max_retries = max_retries
|
|
||||||
self.timeout = timeout
|
|
||||||
self._limits = limits
|
|
||||||
self._has_custom_http_client = bool(custom_httpx_client)
|
|
||||||
self._client = custom_httpx_client or httpx.Client(
|
|
||||||
base_url=base_url,
|
|
||||||
timeout=self.timeout,
|
|
||||||
limits=limits,
|
|
||||||
)
|
|
||||||
self._version = version
|
|
||||||
url = URL(url=base_url)
|
|
||||||
if not url.raw_path.endswith(b"/"):
|
|
||||||
url = url.copy_with(raw_path=url.raw_path + b"/")
|
|
||||||
self._base_url = url
|
|
||||||
self._custom_headers = custom_headers or {}
|
|
||||||
self._strict_response_validation = _strict_response_validation
|
|
||||||
|
|
||||||
def _prepare_url(self, url: str) -> URL:
|
|
||||||
sub_url = URL(url)
|
|
||||||
if sub_url.is_relative_url:
|
|
||||||
request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/")
|
|
||||||
return self._base_url.copy_with(raw_path=request_raw_url)
|
|
||||||
|
|
||||||
return sub_url
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _default_headers(self):
|
|
||||||
return {
|
|
||||||
"Accept": "application/json",
|
|
||||||
"Content-Type": "application/json; charset=UTF-8",
|
|
||||||
"ZhipuAI-SDK-Ver": self._version,
|
|
||||||
"source_type": "zhipu-sdk-python",
|
|
||||||
"x-request-sdk": "zhipu-sdk-python",
|
|
||||||
**self.auth_headers,
|
|
||||||
**self._custom_headers,
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def custom_auth(self) -> httpx.Auth | None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def auth_headers(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def _prepare_headers(self, options: FinalRequestOptions) -> httpx.Headers:
|
|
||||||
custom_headers = options.headers or {}
|
|
||||||
headers_dict = _merge_mappings(self._default_headers, custom_headers)
|
|
||||||
|
|
||||||
httpx_headers = httpx.Headers(headers_dict)
|
|
||||||
|
|
||||||
return httpx_headers
|
|
||||||
|
|
||||||
def _remaining_retries(
|
|
||||||
self,
|
|
||||||
remaining_retries: Optional[int],
|
|
||||||
options: FinalRequestOptions,
|
|
||||||
) -> int:
|
|
||||||
return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries)
|
|
||||||
|
|
||||||
def _calculate_retry_timeout(
|
|
||||||
self,
|
|
||||||
remaining_retries: int,
|
|
||||||
options: FinalRequestOptions,
|
|
||||||
response_headers: Optional[httpx.Headers] = None,
|
|
||||||
) -> float:
|
|
||||||
max_retries = options.get_max_retries(self.max_retries)
|
|
||||||
|
|
||||||
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
|
|
||||||
# retry_after = self._parse_retry_after_header(response_headers)
|
|
||||||
# if retry_after is not None and 0 < retry_after <= 60:
|
|
||||||
# return retry_after
|
|
||||||
|
|
||||||
nb_retries = max_retries - remaining_retries
|
|
||||||
|
|
||||||
# Apply exponential backoff, but not more than the max.
|
|
||||||
sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY)
|
|
||||||
|
|
||||||
# Apply some jitter, plus-or-minus half a second.
|
|
||||||
jitter = 1 - 0.25 * random()
|
|
||||||
timeout = sleep_seconds * jitter
|
|
||||||
return max(timeout, 0)
|
|
||||||
|
|
||||||
def _build_request(self, options: FinalRequestOptions) -> httpx.Request:
|
|
||||||
kwargs: dict[str, Any] = {}
|
|
||||||
headers = self._prepare_headers(options)
|
|
||||||
url = self._prepare_url(options.url)
|
|
||||||
json_data = options.json_data
|
|
||||||
if options.extra_json is not None:
|
|
||||||
if json_data is None:
|
|
||||||
json_data = cast(Body, options.extra_json)
|
|
||||||
elif is_mapping(json_data):
|
|
||||||
json_data = _merge_mappings(json_data, options.extra_json)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")
|
|
||||||
|
|
||||||
content_type = headers.get("Content-Type")
|
|
||||||
# multipart/form-data; boundary=---abc--
|
|
||||||
if headers.get("Content-Type") == "multipart/form-data":
|
|
||||||
if "boundary" not in content_type:
|
|
||||||
# only remove the header if the boundary hasn't been explicitly set
|
|
||||||
# as the caller doesn't want httpx to come up with their own boundary
|
|
||||||
headers.pop("Content-Type")
|
|
||||||
|
|
||||||
if json_data:
|
|
||||||
kwargs["data"] = self._make_multipartform(json_data)
|
|
||||||
|
|
||||||
return self._client.build_request(
|
|
||||||
headers=headers,
|
|
||||||
timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout,
|
|
||||||
method=options.method,
|
|
||||||
url=url,
|
|
||||||
json=json_data,
|
|
||||||
files=options.files,
|
|
||||||
params=options.params,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _object_to_formdata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]:
|
|
||||||
items = []
|
|
||||||
|
|
||||||
if isinstance(value, Mapping):
|
|
||||||
for k, v in value.items():
|
|
||||||
items.extend(self._object_to_formdata(f"{key}[{k}]", v))
|
|
||||||
return items
|
|
||||||
if isinstance(value, list | tuple):
|
|
||||||
for v in value:
|
|
||||||
items.extend(self._object_to_formdata(key + "[]", v))
|
|
||||||
return items
|
|
||||||
|
|
||||||
def _primitive_value_to_str(val) -> str:
|
|
||||||
# copied from httpx
|
|
||||||
if val is True:
|
|
||||||
return "true"
|
|
||||||
elif val is False:
|
|
||||||
return "false"
|
|
||||||
elif val is None:
|
|
||||||
return ""
|
|
||||||
return str(val)
|
|
||||||
|
|
||||||
str_data = _primitive_value_to_str(value)
|
|
||||||
|
|
||||||
if not str_data:
|
|
||||||
return []
|
|
||||||
return [(key, str_data)]
|
|
||||||
|
|
||||||
def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
|
|
||||||
items = flatten(list(starmap(self._object_to_formdata, data.items())))
|
|
||||||
|
|
||||||
serialized: dict[str, object] = {}
|
|
||||||
for key, value in items:
|
|
||||||
if key in serialized:
|
|
||||||
raise ValueError(f"存在重复的键: {key};")
|
|
||||||
serialized[key] = value
|
|
||||||
return serialized
|
|
||||||
|
|
||||||
def _process_response_data(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
data: object,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
response: httpx.Response,
|
|
||||||
) -> ResponseT:
|
|
||||||
if data is None:
|
|
||||||
return cast(ResponseT, None)
|
|
||||||
|
|
||||||
if cast_type is object:
|
|
||||||
return cast(ResponseT, data)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if inspect.isclass(cast_type) and issubclass(cast_type, ModelBuilderProtocol):
|
|
||||||
return cast(ResponseT, cast_type.build(response=response, data=data))
|
|
||||||
|
|
||||||
if self._strict_response_validation:
|
|
||||||
return cast(ResponseT, validate_type(type_=cast_type, value=data))
|
|
||||||
|
|
||||||
return cast(ResponseT, construct_type(type_=cast_type, value=data))
|
|
||||||
except pydantic.ValidationError as err:
|
|
||||||
raise APIResponseValidationError(response=response, json_data=data) from err
|
|
||||||
|
|
||||||
def _should_stream_response_body(self, request: httpx.Request) -> bool:
|
|
||||||
return request.headers.get(RAW_RESPONSE_HEADER) == "stream" # type: ignore[no-any-return]
|
|
||||||
|
|
||||||
def _should_retry(self, response: httpx.Response) -> bool:
|
|
||||||
# Note: this is not a standard header
|
|
||||||
should_retry_header = response.headers.get("x-should-retry")
|
|
||||||
|
|
||||||
# If the server explicitly says whether or not to retry, obey.
|
|
||||||
if should_retry_header == "true":
|
|
||||||
log.debug("Retrying as header `x-should-retry` is set to `true`")
|
|
||||||
return True
|
|
||||||
if should_retry_header == "false":
|
|
||||||
log.debug("Not retrying as header `x-should-retry` is set to `false`")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Retry on request timeouts.
|
|
||||||
if response.status_code == 408:
|
|
||||||
log.debug("Retrying due to status code %i", response.status_code)
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Retry on lock timeouts.
|
|
||||||
if response.status_code == 409:
|
|
||||||
log.debug("Retrying due to status code %i", response.status_code)
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Retry on rate limits.
|
|
||||||
if response.status_code == 429:
|
|
||||||
log.debug("Retrying due to status code %i", response.status_code)
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Retry internal errors.
|
|
||||||
if response.status_code >= 500:
|
|
||||||
log.debug("Retrying due to status code %i", response.status_code)
|
|
||||||
return True
|
|
||||||
|
|
||||||
log.debug("Not retrying")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_closed(self) -> bool:
|
|
||||||
return self._client.is_closed
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self._client.close()
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def request(
|
|
||||||
self,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
options: FinalRequestOptions,
|
|
||||||
remaining_retries: Optional[int] = None,
|
|
||||||
*,
|
|
||||||
stream: bool = False,
|
|
||||||
stream_cls: type[StreamResponse] | None = None,
|
|
||||||
) -> ResponseT | StreamResponse:
|
|
||||||
return self._request(
|
|
||||||
cast_type=cast_type,
|
|
||||||
options=options,
|
|
||||||
stream=stream,
|
|
||||||
stream_cls=stream_cls,
|
|
||||||
remaining_retries=remaining_retries,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _request(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
options: FinalRequestOptions,
|
|
||||||
remaining_retries: int | None,
|
|
||||||
stream: bool,
|
|
||||||
stream_cls: type[StreamResponse] | None,
|
|
||||||
) -> ResponseT | StreamResponse:
|
|
||||||
retries = self._remaining_retries(remaining_retries, options)
|
|
||||||
request = self._build_request(options)
|
|
||||||
|
|
||||||
kwargs: HttpxSendArgs = {}
|
|
||||||
if self.custom_auth is not None:
|
|
||||||
kwargs["auth"] = self.custom_auth
|
|
||||||
try:
|
|
||||||
response = self._client.send(
|
|
||||||
request,
|
|
||||||
stream=stream or self._should_stream_response_body(request=request),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
except httpx.TimeoutException as err:
|
|
||||||
log.debug("Encountered httpx.TimeoutException", exc_info=True)
|
|
||||||
|
|
||||||
if retries > 0:
|
|
||||||
return self._retry_request(
|
|
||||||
options,
|
|
||||||
cast_type,
|
|
||||||
retries,
|
|
||||||
stream=stream,
|
|
||||||
stream_cls=stream_cls,
|
|
||||||
response_headers=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
log.debug("Raising timeout error")
|
|
||||||
raise APITimeoutError(request=request) from err
|
|
||||||
except Exception as err:
|
|
||||||
log.debug("Encountered Exception", exc_info=True)
|
|
||||||
|
|
||||||
if retries > 0:
|
|
||||||
return self._retry_request(
|
|
||||||
options,
|
|
||||||
cast_type,
|
|
||||||
retries,
|
|
||||||
stream=stream,
|
|
||||||
stream_cls=stream_cls,
|
|
||||||
response_headers=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
log.debug("Raising connection error")
|
|
||||||
raise APIConnectionError(request=request) from err
|
|
||||||
|
|
||||||
log.debug(
|
|
||||||
'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response.raise_for_status()
|
|
||||||
except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
|
|
||||||
log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
|
|
||||||
|
|
||||||
if retries > 0 and self._should_retry(err.response):
|
|
||||||
err.response.close()
|
|
||||||
return self._retry_request(
|
|
||||||
options,
|
|
||||||
cast_type,
|
|
||||||
retries,
|
|
||||||
err.response.headers,
|
|
||||||
stream=stream,
|
|
||||||
stream_cls=stream_cls,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If the response is streamed then we need to explicitly read the response
|
|
||||||
# to completion before attempting to access the response text.
|
|
||||||
if not err.response.is_closed:
|
|
||||||
err.response.read()
|
|
||||||
|
|
||||||
log.debug("Re-raising status error")
|
|
||||||
raise self._make_status_error(err.response) from None
|
|
||||||
|
|
||||||
# return self._parse_response(
|
|
||||||
# cast_type=cast_type,
|
|
||||||
# options=options,
|
|
||||||
# response=response,
|
|
||||||
# stream=stream,
|
|
||||||
# stream_cls=stream_cls,
|
|
||||||
# )
|
|
||||||
return self._process_response(
|
|
||||||
cast_type=cast_type,
|
|
||||||
options=options,
|
|
||||||
response=response,
|
|
||||||
stream=stream,
|
|
||||||
stream_cls=stream_cls,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _retry_request(
|
|
||||||
self,
|
|
||||||
options: FinalRequestOptions,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
remaining_retries: int,
|
|
||||||
response_headers: httpx.Headers | None,
|
|
||||||
*,
|
|
||||||
stream: bool,
|
|
||||||
stream_cls: type[StreamResponse] | None,
|
|
||||||
) -> ResponseT | StreamResponse:
|
|
||||||
remaining = remaining_retries - 1
|
|
||||||
if remaining == 1:
|
|
||||||
log.debug("1 retry left")
|
|
||||||
else:
|
|
||||||
log.debug("%i retries left", remaining)
|
|
||||||
|
|
||||||
timeout = self._calculate_retry_timeout(remaining, options, response_headers)
|
|
||||||
log.info("Retrying request to %s in %f seconds", options.url, timeout)
|
|
||||||
|
|
||||||
# In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a
|
|
||||||
# different thread if necessary.
|
|
||||||
time.sleep(timeout)
|
|
||||||
|
|
||||||
return self._request(
|
|
||||||
options=options,
|
|
||||||
cast_type=cast_type,
|
|
||||||
remaining_retries=remaining,
|
|
||||||
stream=stream,
|
|
||||||
stream_cls=stream_cls,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _process_response(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
options: FinalRequestOptions,
|
|
||||||
response: httpx.Response,
|
|
||||||
stream: bool,
|
|
||||||
stream_cls: type[StreamResponse] | None,
|
|
||||||
) -> ResponseT:
|
|
||||||
# _legacy_response with raw_response_header to parser method
|
|
||||||
if response.request.headers.get(RAW_RESPONSE_HEADER) == "true":
|
|
||||||
return cast(
|
|
||||||
ResponseT,
|
|
||||||
LegacyAPIResponse(
|
|
||||||
raw=response,
|
|
||||||
client=self,
|
|
||||||
cast_type=cast_type,
|
|
||||||
stream=stream,
|
|
||||||
stream_cls=stream_cls,
|
|
||||||
options=options,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
origin = get_origin(cast_type) or cast_type
|
|
||||||
|
|
||||||
if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse):
|
|
||||||
if not issubclass(origin, APIResponse):
|
|
||||||
raise TypeError(f"API Response types must subclass {APIResponse}; Received {origin}")
|
|
||||||
|
|
||||||
response_cls = cast("type[BaseAPIResponse[Any]]", cast_type)
|
|
||||||
return cast(
|
|
||||||
ResponseT,
|
|
||||||
response_cls(
|
|
||||||
raw=response,
|
|
||||||
client=self,
|
|
||||||
cast_type=extract_response_type(response_cls),
|
|
||||||
stream=stream,
|
|
||||||
stream_cls=stream_cls,
|
|
||||||
options=options,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if cast_type == httpx.Response:
|
|
||||||
return cast(ResponseT, response)
|
|
||||||
|
|
||||||
api_response = APIResponse(
|
|
||||||
raw=response,
|
|
||||||
client=self,
|
|
||||||
cast_type=cast("type[ResponseT]", cast_type), # pyright: ignore[reportUnnecessaryCast]
|
|
||||||
stream=stream,
|
|
||||||
stream_cls=stream_cls,
|
|
||||||
options=options,
|
|
||||||
)
|
|
||||||
if bool(response.request.headers.get(RAW_RESPONSE_HEADER)):
|
|
||||||
return cast(ResponseT, api_response)
|
|
||||||
|
|
||||||
return api_response.parse()
|
|
||||||
|
|
||||||
def _request_api_list(
|
|
||||||
self,
|
|
||||||
model: type[object],
|
|
||||||
page: type[SyncPageT],
|
|
||||||
options: FinalRequestOptions,
|
|
||||||
) -> SyncPageT:
|
|
||||||
def _parser(resp: SyncPageT) -> SyncPageT:
|
|
||||||
resp._set_private_attributes(
|
|
||||||
client=self,
|
|
||||||
model=model,
|
|
||||||
options=options,
|
|
||||||
)
|
|
||||||
return resp
|
|
||||||
|
|
||||||
options.post_parser = _parser
|
|
||||||
|
|
||||||
return self.request(page, options, stream=False)
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
*,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
options: UserRequestInput = {},
|
|
||||||
stream: Literal[False] = False,
|
|
||||||
) -> ResponseT: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
*,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
options: UserRequestInput = {},
|
|
||||||
stream: Literal[True],
|
|
||||||
stream_cls: type[StreamResponse],
|
|
||||||
) -> StreamResponse: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
*,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
options: UserRequestInput = {},
|
|
||||||
stream: bool,
|
|
||||||
stream_cls: type[StreamResponse] | None = None,
|
|
||||||
) -> ResponseT | StreamResponse: ...
|
|
||||||
|
|
||||||
def get(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
*,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
options: UserRequestInput = {},
|
|
||||||
stream: bool = False,
|
|
||||||
stream_cls: type[StreamResponse] | None = None,
|
|
||||||
) -> ResponseT:
|
|
||||||
opts = FinalRequestOptions.construct(method="get", url=path, **options)
|
|
||||||
return cast(ResponseT, self.request(cast_type, opts, stream=stream, stream_cls=stream_cls))
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def post(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
*,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
body: Body | None = None,
|
|
||||||
options: UserRequestInput = {},
|
|
||||||
files: RequestFiles | None = None,
|
|
||||||
stream: Literal[False] = False,
|
|
||||||
) -> ResponseT: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def post(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
*,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
body: Body | None = None,
|
|
||||||
options: UserRequestInput = {},
|
|
||||||
files: RequestFiles | None = None,
|
|
||||||
stream: Literal[True],
|
|
||||||
stream_cls: type[StreamResponse],
|
|
||||||
) -> StreamResponse: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def post(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
*,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
body: Body | None = None,
|
|
||||||
options: UserRequestInput = {},
|
|
||||||
files: RequestFiles | None = None,
|
|
||||||
stream: bool,
|
|
||||||
stream_cls: type[StreamResponse] | None = None,
|
|
||||||
) -> ResponseT | StreamResponse: ...
|
|
||||||
|
|
||||||
def post(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
*,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
body: Body | None = None,
|
|
||||||
options: UserRequestInput = {},
|
|
||||||
files: RequestFiles | None = None,
|
|
||||||
stream: bool = False,
|
|
||||||
stream_cls: type[StreamResponse[Any]] | None = None,
|
|
||||||
) -> ResponseT | StreamResponse:
|
|
||||||
opts = FinalRequestOptions.construct(
|
|
||||||
method="post", url=path, json_data=body, files=to_httpx_files(files), **options
|
|
||||||
)
|
|
||||||
|
|
||||||
return cast(ResponseT, self.request(cast_type, opts, stream=stream, stream_cls=stream_cls))
|
|
||||||
|
|
||||||
def patch(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
*,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
body: Body | None = None,
|
|
||||||
options: UserRequestInput = {},
|
|
||||||
) -> ResponseT:
|
|
||||||
opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options)
|
|
||||||
|
|
||||||
return self.request(
|
|
||||||
cast_type=cast_type,
|
|
||||||
options=opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
def put(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
*,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
body: Body | None = None,
|
|
||||||
options: UserRequestInput = {},
|
|
||||||
files: RequestFiles | None = None,
|
|
||||||
) -> ResponseT | StreamResponse:
|
|
||||||
opts = FinalRequestOptions.construct(
|
|
||||||
method="put", url=path, json_data=body, files=to_httpx_files(files), **options
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.request(
|
|
||||||
cast_type=cast_type,
|
|
||||||
options=opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
*,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
body: Body | None = None,
|
|
||||||
options: UserRequestInput = {},
|
|
||||||
) -> ResponseT | StreamResponse:
|
|
||||||
opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, **options)
|
|
||||||
|
|
||||||
return self.request(
|
|
||||||
cast_type=cast_type,
|
|
||||||
options=opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_api_list(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
*,
|
|
||||||
model: type[object],
|
|
||||||
page: type[SyncPageT],
|
|
||||||
body: Body | None = None,
|
|
||||||
options: UserRequestInput = {},
|
|
||||||
method: str = "get",
|
|
||||||
) -> SyncPageT:
|
|
||||||
opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options)
|
|
||||||
return self._request_api_list(model, page, opts)
|
|
||||||
|
|
||||||
def _make_status_error(self, response) -> APIStatusError:
|
|
||||||
response_text = response.text.strip()
|
|
||||||
status_code = response.status_code
|
|
||||||
error_msg = f"Error code: {status_code}, with error text {response_text}"
|
|
||||||
|
|
||||||
if status_code == 400:
|
|
||||||
return _errors.APIRequestFailedError(message=error_msg, response=response)
|
|
||||||
elif status_code == 401:
|
|
||||||
return _errors.APIAuthenticationError(message=error_msg, response=response)
|
|
||||||
elif status_code == 429:
|
|
||||||
return _errors.APIReachLimitError(message=error_msg, response=response)
|
|
||||||
elif status_code == 500:
|
|
||||||
return _errors.APIInternalError(message=error_msg, response=response)
|
|
||||||
elif status_code == 503:
|
|
||||||
return _errors.APIServerFlowExceedError(message=error_msg, response=response)
|
|
||||||
return APIStatusError(message=error_msg, response=response)
|
|
||||||
|
|
||||||
|
|
||||||
def make_request_options(
|
|
||||||
*,
|
|
||||||
query: Query | None = None,
|
|
||||||
extra_headers: Headers | None = None,
|
|
||||||
extra_query: Query | None = None,
|
|
||||||
extra_body: Body | None = None,
|
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
|
||||||
post_parser: PostParser | NotGiven = NOT_GIVEN,
|
|
||||||
) -> UserRequestInput:
|
|
||||||
"""Create a dict of type RequestOptions without keys of NotGiven values."""
|
|
||||||
options: UserRequestInput = {}
|
|
||||||
if extra_headers is not None:
|
|
||||||
options["headers"] = extra_headers
|
|
||||||
|
|
||||||
if extra_body is not None:
|
|
||||||
options["extra_json"] = cast(AnyMapping, extra_body)
|
|
||||||
|
|
||||||
if query is not None:
|
|
||||||
options["params"] = query
|
|
||||||
|
|
||||||
if extra_query is not None:
|
|
||||||
options["params"] = {**options.get("params", {}), **extra_query}
|
|
||||||
|
|
||||||
if not isinstance(timeout, NotGiven):
|
|
||||||
options["timeout"] = timeout
|
|
||||||
|
|
||||||
if is_given(post_parser):
|
|
||||||
# internal
|
|
||||||
options["post_parser"] = post_parser # type: ignore
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_mappings(
|
|
||||||
obj1: Mapping[_T_co, Union[_T, Omit]],
|
|
||||||
obj2: Mapping[_T_co, Union[_T, Omit]],
|
|
||||||
) -> dict[_T_co, _T]:
|
|
||||||
"""Merge two mappings of the same type, removing any values that are instances of `Omit`.
|
|
||||||
|
|
||||||
In cases with duplicate keys the second mapping takes precedence.
|
|
||||||
"""
|
|
||||||
merged = {**obj1, **obj2}
|
|
||||||
return {key: value for key, value in merged.items() if not isinstance(value, Omit)}
|
|
@ -1,31 +0,0 @@
|
|||||||
import time
|
|
||||||
|
|
||||||
import cachetools.func
|
|
||||||
import jwt
|
|
||||||
|
|
||||||
# 缓存时间 3分钟
|
|
||||||
CACHE_TTL_SECONDS = 3 * 60
|
|
||||||
|
|
||||||
# token 有效期比缓存时间 多30秒
|
|
||||||
API_TOKEN_TTL_SECONDS = CACHE_TTL_SECONDS + 30
|
|
||||||
|
|
||||||
|
|
||||||
@cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS)
|
|
||||||
def generate_token(apikey: str):
|
|
||||||
try:
|
|
||||||
api_key, secret = apikey.split(".")
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception("invalid api_key", e)
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"api_key": api_key,
|
|
||||||
"exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000,
|
|
||||||
"timestamp": int(round(time.time() * 1000)),
|
|
||||||
}
|
|
||||||
ret = jwt.encode(
|
|
||||||
payload,
|
|
||||||
secret,
|
|
||||||
algorithm="HS256",
|
|
||||||
headers={"alg": "HS256", "sign_type": "SIGN"},
|
|
||||||
)
|
|
||||||
return ret
|
|
@ -1,207 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
from collections.abc import AsyncIterator, Iterator
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
|
|
||||||
class HttpxResponseContent:
|
|
||||||
@property
|
|
||||||
def content(self) -> bytes:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def text(self) -> str:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def encoding(self) -> str | None:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def charset_encoding(self) -> str | None:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
def json(self, **kwargs: Any) -> Any:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
def read(self) -> bytes:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
def iter_lines(self) -> Iterator[str]:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
def write_to_file(
|
|
||||||
self,
|
|
||||||
file: str | os.PathLike[str],
|
|
||||||
) -> None:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
def stream_to_file(
|
|
||||||
self,
|
|
||||||
file: str | os.PathLike[str],
|
|
||||||
*,
|
|
||||||
chunk_size: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
async def aread(self) -> bytes:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
async def aiter_lines(self) -> AsyncIterator[str]:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
async def astream_to_file(
|
|
||||||
self,
|
|
||||||
file: str | os.PathLike[str],
|
|
||||||
*,
|
|
||||||
chunk_size: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
|
||||||
raise NotImplementedError("This method is not implemented for this class.")
|
|
||||||
|
|
||||||
|
|
||||||
class HttpxBinaryResponseContent(HttpxResponseContent):
|
|
||||||
response: httpx.Response
|
|
||||||
|
|
||||||
def __init__(self, response: httpx.Response) -> None:
|
|
||||||
self.response = response
|
|
||||||
|
|
||||||
@property
|
|
||||||
def content(self) -> bytes:
|
|
||||||
return self.response.content
|
|
||||||
|
|
||||||
@property
|
|
||||||
def encoding(self) -> str | None:
|
|
||||||
return self.response.encoding
|
|
||||||
|
|
||||||
@property
|
|
||||||
def charset_encoding(self) -> str | None:
|
|
||||||
return self.response.charset_encoding
|
|
||||||
|
|
||||||
def read(self) -> bytes:
|
|
||||||
return self.response.read()
|
|
||||||
|
|
||||||
def text(self) -> str:
|
|
||||||
raise NotImplementedError("Not implemented for binary response content")
|
|
||||||
|
|
||||||
def json(self, **kwargs: Any) -> Any:
|
|
||||||
raise NotImplementedError("Not implemented for binary response content")
|
|
||||||
|
|
||||||
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
|
|
||||||
raise NotImplementedError("Not implemented for binary response content")
|
|
||||||
|
|
||||||
def iter_lines(self) -> Iterator[str]:
|
|
||||||
raise NotImplementedError("Not implemented for binary response content")
|
|
||||||
|
|
||||||
async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
|
|
||||||
raise NotImplementedError("Not implemented for binary response content")
|
|
||||||
|
|
||||||
async def aiter_lines(self) -> AsyncIterator[str]:
|
|
||||||
raise NotImplementedError("Not implemented for binary response content")
|
|
||||||
|
|
||||||
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
|
|
||||||
return self.response.iter_bytes(chunk_size)
|
|
||||||
|
|
||||||
def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]:
|
|
||||||
return self.response.iter_raw(chunk_size)
|
|
||||||
|
|
||||||
def write_to_file(
|
|
||||||
self,
|
|
||||||
file: str | os.PathLike[str],
|
|
||||||
) -> None:
|
|
||||||
"""Write the output to the given file.
|
|
||||||
|
|
||||||
Accepts a filename or any path-like object, e.g. pathlib.Path
|
|
||||||
|
|
||||||
Note: if you want to stream the data to the file instead of writing
|
|
||||||
all at once then you should use `.with_streaming_response` when making
|
|
||||||
the API request, e.g. `client.with_streaming_response.foo().stream_to_file('my_filename.txt')`
|
|
||||||
"""
|
|
||||||
with open(file, mode="wb") as f:
|
|
||||||
for data in self.response.iter_bytes():
|
|
||||||
f.write(data)
|
|
||||||
|
|
||||||
def stream_to_file(
|
|
||||||
self,
|
|
||||||
file: str | os.PathLike[str],
|
|
||||||
*,
|
|
||||||
chunk_size: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
with open(file, mode="wb") as f:
|
|
||||||
for data in self.response.iter_bytes(chunk_size):
|
|
||||||
f.write(data)
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
return self.response.close()
|
|
||||||
|
|
||||||
async def aread(self) -> bytes:
|
|
||||||
return await self.response.aread()
|
|
||||||
|
|
||||||
async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
|
|
||||||
return self.response.aiter_bytes(chunk_size)
|
|
||||||
|
|
||||||
async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
|
|
||||||
return self.response.aiter_raw(chunk_size)
|
|
||||||
|
|
||||||
async def astream_to_file(
|
|
||||||
self,
|
|
||||||
file: str | os.PathLike[str],
|
|
||||||
*,
|
|
||||||
chunk_size: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
path = anyio.Path(file)
|
|
||||||
async with await path.open(mode="wb") as f:
|
|
||||||
async for data in self.response.aiter_bytes(chunk_size):
|
|
||||||
await f.write(data)
|
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
|
||||||
return await self.response.aclose()
|
|
||||||
|
|
||||||
|
|
||||||
class HttpxTextBinaryResponseContent(HttpxBinaryResponseContent):
|
|
||||||
response: httpx.Response
|
|
||||||
|
|
||||||
@property
|
|
||||||
def text(self) -> str:
|
|
||||||
return self.response.text
|
|
||||||
|
|
||||||
def json(self, **kwargs: Any) -> Any:
|
|
||||||
return self.response.json(**kwargs)
|
|
||||||
|
|
||||||
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
|
|
||||||
return self.response.iter_text(chunk_size)
|
|
||||||
|
|
||||||
def iter_lines(self) -> Iterator[str]:
|
|
||||||
return self.response.iter_lines()
|
|
||||||
|
|
||||||
async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
|
|
||||||
return self.response.aiter_text(chunk_size)
|
|
||||||
|
|
||||||
async def aiter_lines(self) -> AsyncIterator[str]:
|
|
||||||
return self.response.aiter_lines()
|
|
@ -1,341 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import functools
|
|
||||||
import inspect
|
|
||||||
import logging
|
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, get_origin, overload
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import pydantic
|
|
||||||
from typing_extensions import ParamSpec, override
|
|
||||||
|
|
||||||
from ._base_models import BaseModel, is_basemodel
|
|
||||||
from ._base_type import NoneType
|
|
||||||
from ._constants import RAW_RESPONSE_HEADER
|
|
||||||
from ._errors import APIResponseValidationError
|
|
||||||
from ._legacy_binary_response import HttpxResponseContent, HttpxTextBinaryResponseContent
|
|
||||||
from ._sse_client import StreamResponse, extract_stream_chunk_type, is_stream_class_type
|
|
||||||
from ._utils import extract_type_arg, is_annotated_type, is_given
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ._http_client import HttpClient
|
|
||||||
from ._request_opt import FinalRequestOptions
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
|
|
||||||
log: logging.Logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LegacyAPIResponse(Generic[R]):
|
|
||||||
"""This is a legacy class as it will be replaced by `APIResponse`
|
|
||||||
and `AsyncAPIResponse` in the `_response.py` file in the next major
|
|
||||||
release.
|
|
||||||
|
|
||||||
For the sync client this will mostly be the same with the exception
|
|
||||||
of `content` & `text` will be methods instead of properties. In the
|
|
||||||
async client, all methods will be async.
|
|
||||||
|
|
||||||
A migration script will be provided & the migration in general should
|
|
||||||
be smooth.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_cast_type: type[R]
|
|
||||||
_client: HttpClient
|
|
||||||
_parsed_by_type: dict[type[Any], Any]
|
|
||||||
_stream: bool
|
|
||||||
_stream_cls: type[StreamResponse[Any]] | None
|
|
||||||
_options: FinalRequestOptions
|
|
||||||
|
|
||||||
http_response: httpx.Response
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
raw: httpx.Response,
|
|
||||||
cast_type: type[R],
|
|
||||||
client: HttpClient,
|
|
||||||
stream: bool,
|
|
||||||
stream_cls: type[StreamResponse[Any]] | None,
|
|
||||||
options: FinalRequestOptions,
|
|
||||||
) -> None:
|
|
||||||
self._cast_type = cast_type
|
|
||||||
self._client = client
|
|
||||||
self._parsed_by_type = {}
|
|
||||||
self._stream = stream
|
|
||||||
self._stream_cls = stream_cls
|
|
||||||
self._options = options
|
|
||||||
self.http_response = raw
|
|
||||||
|
|
||||||
@property
|
|
||||||
def request_id(self) -> str | None:
|
|
||||||
return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return]
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def parse(self, *, to: type[_T]) -> _T: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def parse(self) -> R: ...
|
|
||||||
|
|
||||||
def parse(self, *, to: type[_T] | None = None) -> R | _T:
|
|
||||||
"""Returns the rich python representation of this response's data.
|
|
||||||
|
|
||||||
NOTE: For the async client: this will become a coroutine in the next major version.
|
|
||||||
|
|
||||||
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
|
|
||||||
|
|
||||||
You can customize the type that the response is parsed into through
|
|
||||||
the `to` argument, e.g.
|
|
||||||
|
|
||||||
```py
|
|
||||||
from zhipuai import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class MyModel(BaseModel):
|
|
||||||
foo: str
|
|
||||||
|
|
||||||
|
|
||||||
obj = response.parse(to=MyModel)
|
|
||||||
print(obj.foo)
|
|
||||||
```
|
|
||||||
|
|
||||||
We support parsing:
|
|
||||||
- `BaseModel`
|
|
||||||
- `dict`
|
|
||||||
- `list`
|
|
||||||
- `Union`
|
|
||||||
- `str`
|
|
||||||
- `int`
|
|
||||||
- `float`
|
|
||||||
- `httpx.Response`
|
|
||||||
"""
|
|
||||||
cache_key = to if to is not None else self._cast_type
|
|
||||||
cached = self._parsed_by_type.get(cache_key)
|
|
||||||
if cached is not None:
|
|
||||||
return cached # type: ignore[no-any-return]
|
|
||||||
|
|
||||||
parsed = self._parse(to=to)
|
|
||||||
if is_given(self._options.post_parser):
|
|
||||||
parsed = self._options.post_parser(parsed)
|
|
||||||
|
|
||||||
self._parsed_by_type[cache_key] = parsed
|
|
||||||
return parsed
|
|
||||||
|
|
||||||
@property
|
|
||||||
def headers(self) -> httpx.Headers:
|
|
||||||
return self.http_response.headers
|
|
||||||
|
|
||||||
@property
|
|
||||||
def http_request(self) -> httpx.Request:
|
|
||||||
return self.http_response.request
|
|
||||||
|
|
||||||
@property
|
|
||||||
def status_code(self) -> int:
|
|
||||||
return self.http_response.status_code
|
|
||||||
|
|
||||||
@property
|
|
||||||
def url(self) -> httpx.URL:
|
|
||||||
return self.http_response.url
|
|
||||||
|
|
||||||
@property
|
|
||||||
def method(self) -> str:
|
|
||||||
return self.http_request.method
|
|
||||||
|
|
||||||
@property
|
|
||||||
def content(self) -> bytes:
|
|
||||||
"""Return the binary response content.
|
|
||||||
|
|
||||||
NOTE: this will be removed in favour of `.read()` in the
|
|
||||||
next major version.
|
|
||||||
"""
|
|
||||||
return self.http_response.content
|
|
||||||
|
|
||||||
@property
|
|
||||||
def text(self) -> str:
|
|
||||||
"""Return the decoded response content.
|
|
||||||
|
|
||||||
NOTE: this will be turned into a method in the next major version.
|
|
||||||
"""
|
|
||||||
return self.http_response.text
|
|
||||||
|
|
||||||
@property
|
|
||||||
def http_version(self) -> str:
|
|
||||||
return self.http_response.http_version
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_closed(self) -> bool:
|
|
||||||
return self.http_response.is_closed
|
|
||||||
|
|
||||||
@property
|
|
||||||
def elapsed(self) -> datetime.timedelta:
|
|
||||||
"""The time taken for the complete request/response cycle to complete."""
|
|
||||||
return self.http_response.elapsed
|
|
||||||
|
|
||||||
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
|
|
||||||
# unwrap `Annotated[T, ...]` -> `T`
|
|
||||||
if to and is_annotated_type(to):
|
|
||||||
to = extract_type_arg(to, 0)
|
|
||||||
|
|
||||||
if self._stream:
|
|
||||||
if to:
|
|
||||||
if not is_stream_class_type(to):
|
|
||||||
raise TypeError(f"Expected custom parse type to be a subclass of {StreamResponse}")
|
|
||||||
|
|
||||||
return cast(
|
|
||||||
_T,
|
|
||||||
to(
|
|
||||||
cast_type=extract_stream_chunk_type(
|
|
||||||
to,
|
|
||||||
failure_message="Expected custom stream type to be passed with a type argument, e.g. StreamResponse[ChunkType]", # noqa: E501
|
|
||||||
),
|
|
||||||
response=self.http_response,
|
|
||||||
client=cast(Any, self._client),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._stream_cls:
|
|
||||||
return cast(
|
|
||||||
R,
|
|
||||||
self._stream_cls(
|
|
||||||
cast_type=extract_stream_chunk_type(self._stream_cls),
|
|
||||||
response=self.http_response,
|
|
||||||
client=cast(Any, self._client),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
stream_cls = cast("type[StreamResponse[Any]] | None", self._client._default_stream_cls)
|
|
||||||
if stream_cls is None:
|
|
||||||
raise MissingStreamClassError()
|
|
||||||
|
|
||||||
return cast(
|
|
||||||
R,
|
|
||||||
stream_cls(
|
|
||||||
cast_type=self._cast_type,
|
|
||||||
response=self.http_response,
|
|
||||||
client=cast(Any, self._client),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
cast_type = to if to is not None else self._cast_type
|
|
||||||
|
|
||||||
# unwrap `Annotated[T, ...]` -> `T`
|
|
||||||
if is_annotated_type(cast_type):
|
|
||||||
cast_type = extract_type_arg(cast_type, 0)
|
|
||||||
|
|
||||||
if cast_type is NoneType:
|
|
||||||
return cast(R, None)
|
|
||||||
|
|
||||||
response = self.http_response
|
|
||||||
if cast_type == str:
|
|
||||||
return cast(R, response.text)
|
|
||||||
|
|
||||||
if cast_type == int:
|
|
||||||
return cast(R, int(response.text))
|
|
||||||
|
|
||||||
if cast_type == float:
|
|
||||||
return cast(R, float(response.text))
|
|
||||||
|
|
||||||
origin = get_origin(cast_type) or cast_type
|
|
||||||
|
|
||||||
if inspect.isclass(origin) and issubclass(origin, HttpxResponseContent):
|
|
||||||
# in the response, e.g. mime file
|
|
||||||
*_, filename = response.headers.get("content-disposition", "").split("filename=")
|
|
||||||
# 判断文件类型是jsonl类型的使用HttpxTextBinaryResponseContent
|
|
||||||
if filename and filename.endswith(".jsonl") or filename and filename.endswith(".xlsx"):
|
|
||||||
return cast(R, HttpxTextBinaryResponseContent(response))
|
|
||||||
else:
|
|
||||||
return cast(R, cast_type(response)) # type: ignore
|
|
||||||
|
|
||||||
if origin == LegacyAPIResponse:
|
|
||||||
raise RuntimeError("Unexpected state - cast_type is `APIResponse`")
|
|
||||||
|
|
||||||
if inspect.isclass(origin) and issubclass(origin, httpx.Response):
|
|
||||||
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
|
|
||||||
# and pass that class to our request functions. We cannot change the variance to be either
|
|
||||||
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
|
|
||||||
# the response class ourselves but that is something that should be supported directly in httpx
|
|
||||||
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
|
|
||||||
if cast_type != httpx.Response:
|
|
||||||
raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_type`")
|
|
||||||
return cast(R, response)
|
|
||||||
|
|
||||||
if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):
|
|
||||||
raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
|
|
||||||
|
|
||||||
if (
|
|
||||||
cast_type is not object
|
|
||||||
and origin is not list
|
|
||||||
and origin is not dict
|
|
||||||
and origin is not Union
|
|
||||||
and not issubclass(origin, BaseModel)
|
|
||||||
):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Unsupported type, expected {cast_type} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." # noqa: E501
|
|
||||||
)
|
|
||||||
|
|
||||||
# split is required to handle cases where additional information is included
|
|
||||||
# in the response, e.g. application/json; charset=utf-8
|
|
||||||
content_type, *_ = response.headers.get("content-type", "*").split(";")
|
|
||||||
if content_type != "application/json":
|
|
||||||
if is_basemodel(cast_type):
|
|
||||||
try:
|
|
||||||
data = response.json()
|
|
||||||
except Exception as exc:
|
|
||||||
log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
|
|
||||||
else:
|
|
||||||
return self._client._process_response_data(
|
|
||||||
data=data,
|
|
||||||
cast_type=cast_type, # type: ignore
|
|
||||||
response=response,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._client._strict_response_validation:
|
|
||||||
raise APIResponseValidationError(
|
|
||||||
response=response,
|
|
||||||
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", # noqa: E501
|
|
||||||
json_data=response.text,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If the API responds with content that isn't JSON then we just return
|
|
||||||
# the (decoded) text without performing any parsing so that you can still
|
|
||||||
# handle the response however you need to.
|
|
||||||
return response.text # type: ignore
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
return self._client._process_response_data(
|
|
||||||
data=data,
|
|
||||||
cast_type=cast_type, # type: ignore
|
|
||||||
response=response,
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"<APIResponse [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_type}>"
|
|
||||||
|
|
||||||
|
|
||||||
class MissingStreamClassError(TypeError):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__(
|
|
||||||
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference", # noqa: E501
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]:
|
|
||||||
"""Higher order function that takes one of our bound API methods and wraps it
|
|
||||||
to support returning the raw `APIResponse` object directly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@functools.wraps(func)
|
|
||||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]:
|
|
||||||
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
|
|
||||||
extra_headers[RAW_RESPONSE_HEADER] = "true"
|
|
||||||
|
|
||||||
kwargs["extra_headers"] = extra_headers
|
|
||||||
|
|
||||||
return cast(LegacyAPIResponse[R], func(*args, **kwargs))
|
|
||||||
|
|
||||||
return wrapped
|
|
@ -1,97 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar, Union, cast
|
|
||||||
|
|
||||||
import pydantic.generics
|
|
||||||
from httpx import Timeout
|
|
||||||
from typing_extensions import Required, TypedDict, Unpack, final
|
|
||||||
|
|
||||||
from ._base_compat import PYDANTIC_V2, ConfigDict
|
|
||||||
from ._base_type import AnyMapping, Body, Headers, HttpxRequestFiles, NotGiven, Query
|
|
||||||
from ._constants import RAW_RESPONSE_HEADER
|
|
||||||
from ._utils import is_given, strip_not_given
|
|
||||||
|
|
||||||
|
|
||||||
class UserRequestInput(TypedDict, total=False):
|
|
||||||
headers: Headers
|
|
||||||
max_retries: int
|
|
||||||
timeout: float | Timeout | None
|
|
||||||
params: Query
|
|
||||||
extra_json: AnyMapping
|
|
||||||
|
|
||||||
|
|
||||||
class FinalRequestOptionsInput(TypedDict, total=False):
|
|
||||||
method: Required[str]
|
|
||||||
url: Required[str]
|
|
||||||
params: Query
|
|
||||||
headers: Headers
|
|
||||||
max_retries: int
|
|
||||||
timeout: float | Timeout | None
|
|
||||||
files: HttpxRequestFiles | None
|
|
||||||
json_data: Body
|
|
||||||
extra_json: AnyMapping
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class FinalRequestOptions(pydantic.BaseModel):
|
|
||||||
method: str
|
|
||||||
url: str
|
|
||||||
params: Query = {}
|
|
||||||
headers: Union[Headers, NotGiven] = NotGiven()
|
|
||||||
max_retries: Union[int, NotGiven] = NotGiven()
|
|
||||||
timeout: Union[float, Timeout, None, NotGiven] = NotGiven()
|
|
||||||
files: Union[HttpxRequestFiles, None] = None
|
|
||||||
idempotency_key: Union[str, None] = None
|
|
||||||
post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
|
|
||||||
|
|
||||||
# It should be noted that we cannot use `json` here as that would override
|
|
||||||
# a BaseModel method in an incompatible fashion.
|
|
||||||
json_data: Union[Body, None] = None
|
|
||||||
extra_json: Union[AnyMapping, None] = None
|
|
||||||
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
else:
|
|
||||||
|
|
||||||
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
|
|
||||||
arbitrary_types_allowed: bool = True
|
|
||||||
|
|
||||||
def get_max_retries(self, max_retries: int) -> int:
|
|
||||||
if isinstance(self.max_retries, NotGiven):
|
|
||||||
return max_retries
|
|
||||||
return self.max_retries
|
|
||||||
|
|
||||||
def _strip_raw_response_header(self) -> None:
|
|
||||||
if not is_given(self.headers):
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.headers.get(RAW_RESPONSE_HEADER):
|
|
||||||
self.headers = {**self.headers}
|
|
||||||
self.headers.pop(RAW_RESPONSE_HEADER)
|
|
||||||
|
|
||||||
# override the `construct` method so that we can run custom transformations.
|
|
||||||
# this is necessary as we don't want to do any actual runtime type checking
|
|
||||||
# (which means we can't use validators) but we do want to ensure that `NotGiven`
|
|
||||||
# values are not present
|
|
||||||
#
|
|
||||||
# type ignore required because we're adding explicit types to `**values`
|
|
||||||
@classmethod
|
|
||||||
def construct( # type: ignore
|
|
||||||
cls,
|
|
||||||
_fields_set: set[str] | None = None,
|
|
||||||
**values: Unpack[UserRequestInput],
|
|
||||||
) -> FinalRequestOptions:
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
# we unconditionally call `strip_not_given` on any value
|
|
||||||
# as it will just ignore any non-mapping types
|
|
||||||
key: strip_not_given(value)
|
|
||||||
for key, value in values.items()
|
|
||||||
}
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
return super().model_construct(_fields_set, **kwargs)
|
|
||||||
return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
|
|
||||||
|
|
||||||
if not TYPE_CHECKING:
|
|
||||||
# type checkers incorrectly complain about this assignment
|
|
||||||
model_construct = construct
|
|
@ -1,398 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import inspect
|
|
||||||
import logging
|
|
||||||
from collections.abc import Iterator
|
|
||||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, get_origin, overload
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import pydantic
|
|
||||||
from typing_extensions import ParamSpec, override
|
|
||||||
|
|
||||||
from ._base_models import BaseModel, is_basemodel
|
|
||||||
from ._base_type import NoneType
|
|
||||||
from ._errors import APIResponseValidationError, ZhipuAIError
|
|
||||||
from ._sse_client import StreamResponse, extract_stream_chunk_type, is_stream_class_type
|
|
||||||
from ._utils import extract_type_arg, extract_type_var_from_base, is_annotated_type, is_given
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ._http_client import HttpClient
|
|
||||||
from ._request_opt import FinalRequestOptions
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
_APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]")
|
|
||||||
log: logging.Logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAPIResponse(Generic[R]):
|
|
||||||
_cast_type: type[R]
|
|
||||||
_client: HttpClient
|
|
||||||
_parsed_by_type: dict[type[Any], Any]
|
|
||||||
_is_sse_stream: bool
|
|
||||||
_stream_cls: type[StreamResponse[Any]]
|
|
||||||
_options: FinalRequestOptions
|
|
||||||
http_response: httpx.Response
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
raw: httpx.Response,
|
|
||||||
cast_type: type[R],
|
|
||||||
client: HttpClient,
|
|
||||||
stream: bool,
|
|
||||||
stream_cls: type[StreamResponse[Any]] | None = None,
|
|
||||||
options: FinalRequestOptions,
|
|
||||||
) -> None:
|
|
||||||
self._cast_type = cast_type
|
|
||||||
self._client = client
|
|
||||||
self._parsed_by_type = {}
|
|
||||||
self._is_sse_stream = stream
|
|
||||||
self._stream_cls = stream_cls
|
|
||||||
self._options = options
|
|
||||||
self.http_response = raw
|
|
||||||
|
|
||||||
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
|
|
||||||
# unwrap `Annotated[T, ...]` -> `T`
|
|
||||||
if to and is_annotated_type(to):
|
|
||||||
to = extract_type_arg(to, 0)
|
|
||||||
|
|
||||||
if self._is_sse_stream:
|
|
||||||
if to:
|
|
||||||
if not is_stream_class_type(to):
|
|
||||||
raise TypeError(f"Expected custom parse type to be a subclass of {StreamResponse}")
|
|
||||||
|
|
||||||
return cast(
|
|
||||||
_T,
|
|
||||||
to(
|
|
||||||
cast_type=extract_stream_chunk_type(
|
|
||||||
to,
|
|
||||||
failure_message="Expected custom stream type to be passed with a type argument, e.g. StreamResponse[ChunkType]", # noqa: E501
|
|
||||||
),
|
|
||||||
response=self.http_response,
|
|
||||||
client=cast(Any, self._client),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._stream_cls:
|
|
||||||
return cast(
|
|
||||||
R,
|
|
||||||
self._stream_cls(
|
|
||||||
cast_type=extract_stream_chunk_type(self._stream_cls),
|
|
||||||
response=self.http_response,
|
|
||||||
client=cast(Any, self._client),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
stream_cls = cast("type[Stream[Any]] | None", self._client._default_stream_cls)
|
|
||||||
if stream_cls is None:
|
|
||||||
raise MissingStreamClassError()
|
|
||||||
|
|
||||||
return cast(
|
|
||||||
R,
|
|
||||||
stream_cls(
|
|
||||||
cast_type=self._cast_type,
|
|
||||||
response=self.http_response,
|
|
||||||
client=cast(Any, self._client),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
cast_type = to if to is not None else self._cast_type
|
|
||||||
|
|
||||||
# unwrap `Annotated[T, ...]` -> `T`
|
|
||||||
if is_annotated_type(cast_type):
|
|
||||||
cast_type = extract_type_arg(cast_type, 0)
|
|
||||||
|
|
||||||
if cast_type is NoneType:
|
|
||||||
return cast(R, None)
|
|
||||||
|
|
||||||
response = self.http_response
|
|
||||||
if cast_type == str:
|
|
||||||
return cast(R, response.text)
|
|
||||||
|
|
||||||
if cast_type == bytes:
|
|
||||||
return cast(R, response.content)
|
|
||||||
|
|
||||||
if cast_type == int:
|
|
||||||
return cast(R, int(response.text))
|
|
||||||
|
|
||||||
if cast_type == float:
|
|
||||||
return cast(R, float(response.text))
|
|
||||||
|
|
||||||
origin = get_origin(cast_type) or cast_type
|
|
||||||
|
|
||||||
# handle the legacy binary response case
|
|
||||||
if inspect.isclass(cast_type) and cast_type.__name__ == "HttpxBinaryResponseContent":
|
|
||||||
return cast(R, cast_type(response)) # type: ignore
|
|
||||||
|
|
||||||
if origin == APIResponse:
|
|
||||||
raise RuntimeError("Unexpected state - cast_type is `APIResponse`")
|
|
||||||
|
|
||||||
if inspect.isclass(origin) and issubclass(origin, httpx.Response):
|
|
||||||
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
|
|
||||||
# and pass that class to our request functions. We cannot change the variance to be either
|
|
||||||
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
|
|
||||||
# the response class ourselves but that is something that should be supported directly in httpx
|
|
||||||
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
|
|
||||||
if cast_type != httpx.Response:
|
|
||||||
raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_type`")
|
|
||||||
return cast(R, response)
|
|
||||||
|
|
||||||
if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):
|
|
||||||
raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
|
|
||||||
|
|
||||||
if (
|
|
||||||
cast_type is not object
|
|
||||||
and origin is not list
|
|
||||||
and origin is not dict
|
|
||||||
and origin is not Union
|
|
||||||
and not issubclass(origin, BaseModel)
|
|
||||||
):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Unsupported type, expected {cast_type} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." # noqa: E501
|
|
||||||
)
|
|
||||||
|
|
||||||
# split is required to handle cases where additional information is included
|
|
||||||
# in the response, e.g. application/json; charset=utf-8
|
|
||||||
content_type, *_ = response.headers.get("content-type", "*").split(";")
|
|
||||||
if content_type != "application/json":
|
|
||||||
if is_basemodel(cast_type):
|
|
||||||
try:
|
|
||||||
data = response.json()
|
|
||||||
except Exception as exc:
|
|
||||||
log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
|
|
||||||
else:
|
|
||||||
return self._client._process_response_data(
|
|
||||||
data=data,
|
|
||||||
cast_type=cast_type, # type: ignore
|
|
||||||
response=response,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._client._strict_response_validation:
|
|
||||||
raise APIResponseValidationError(
|
|
||||||
response=response,
|
|
||||||
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", # noqa: E501
|
|
||||||
json_data=response.text,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If the API responds with content that isn't JSON then we just return
|
|
||||||
# the (decoded) text without performing any parsing so that you can still
|
|
||||||
# handle the response however you need to.
|
|
||||||
return response.text # type: ignore
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
return self._client._process_response_data(
|
|
||||||
data=data,
|
|
||||||
cast_type=cast_type, # type: ignore
|
|
||||||
response=response,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def headers(self) -> httpx.Headers:
|
|
||||||
return self.http_response.headers
|
|
||||||
|
|
||||||
@property
|
|
||||||
def http_request(self) -> httpx.Request:
|
|
||||||
"""Returns the httpx Request instance associated with the current response."""
|
|
||||||
return self.http_response.request
|
|
||||||
|
|
||||||
@property
|
|
||||||
def status_code(self) -> int:
|
|
||||||
return self.http_response.status_code
|
|
||||||
|
|
||||||
@property
|
|
||||||
def url(self) -> httpx.URL:
|
|
||||||
"""Returns the URL for which the request was made."""
|
|
||||||
return self.http_response.url
|
|
||||||
|
|
||||||
@property
|
|
||||||
def method(self) -> str:
|
|
||||||
return self.http_request.method
|
|
||||||
|
|
||||||
@property
|
|
||||||
def http_version(self) -> str:
|
|
||||||
return self.http_response.http_version
|
|
||||||
|
|
||||||
@property
|
|
||||||
def elapsed(self) -> datetime.timedelta:
|
|
||||||
"""The time taken for the complete request/response cycle to complete."""
|
|
||||||
return self.http_response.elapsed
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_closed(self) -> bool:
|
|
||||||
"""Whether or not the response body has been closed.
|
|
||||||
|
|
||||||
If this is False then there is response data that has not been read yet.
|
|
||||||
You must either fully consume the response body or call `.close()`
|
|
||||||
before discarding the response to prevent resource leaks.
|
|
||||||
"""
|
|
||||||
return self.http_response.is_closed
|
|
||||||
|
|
||||||
@override
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_type}>" # noqa: E501
|
|
||||||
|
|
||||||
|
|
||||||
class APIResponse(BaseAPIResponse[R]):
|
|
||||||
@property
|
|
||||||
def request_id(self) -> str | None:
|
|
||||||
return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return]
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def parse(self, *, to: type[_T]) -> _T: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def parse(self) -> R: ...
|
|
||||||
|
|
||||||
def parse(self, *, to: type[_T] | None = None) -> R | _T:
|
|
||||||
"""Returns the rich python representation of this response's data.
|
|
||||||
|
|
||||||
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
|
|
||||||
|
|
||||||
You can customize the type that the response is parsed into through
|
|
||||||
the `to` argument, e.g.
|
|
||||||
|
|
||||||
```py
|
|
||||||
from openai import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class MyModel(BaseModel):
|
|
||||||
foo: str
|
|
||||||
|
|
||||||
|
|
||||||
obj = response.parse(to=MyModel)
|
|
||||||
print(obj.foo)
|
|
||||||
```
|
|
||||||
|
|
||||||
We support parsing:
|
|
||||||
- `BaseModel`
|
|
||||||
- `dict`
|
|
||||||
- `list`
|
|
||||||
- `Union`
|
|
||||||
- `str`
|
|
||||||
- `int`
|
|
||||||
- `float`
|
|
||||||
- `httpx.Response`
|
|
||||||
"""
|
|
||||||
cache_key = to if to is not None else self._cast_type
|
|
||||||
cached = self._parsed_by_type.get(cache_key)
|
|
||||||
if cached is not None:
|
|
||||||
return cached # type: ignore[no-any-return]
|
|
||||||
|
|
||||||
if not self._is_sse_stream:
|
|
||||||
self.read()
|
|
||||||
|
|
||||||
parsed = self._parse(to=to)
|
|
||||||
if is_given(self._options.post_parser):
|
|
||||||
parsed = self._options.post_parser(parsed)
|
|
||||||
|
|
||||||
self._parsed_by_type[cache_key] = parsed
|
|
||||||
return parsed
|
|
||||||
|
|
||||||
def read(self) -> bytes:
|
|
||||||
"""Read and return the binary response content."""
|
|
||||||
try:
|
|
||||||
return self.http_response.read()
|
|
||||||
except httpx.StreamConsumed as exc:
|
|
||||||
# The default error raised by httpx isn't very
|
|
||||||
# helpful in our case so we re-raise it with
|
|
||||||
# a different error message.
|
|
||||||
raise StreamAlreadyConsumed() from exc
|
|
||||||
|
|
||||||
def text(self) -> str:
|
|
||||||
"""Read and decode the response content into a string."""
|
|
||||||
self.read()
|
|
||||||
return self.http_response.text
|
|
||||||
|
|
||||||
def json(self) -> object:
|
|
||||||
"""Read and decode the JSON response content."""
|
|
||||||
self.read()
|
|
||||||
return self.http_response.json()
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
"""Close the response and release the connection.
|
|
||||||
|
|
||||||
Automatically called if the response body is read to completion.
|
|
||||||
"""
|
|
||||||
self.http_response.close()
|
|
||||||
|
|
||||||
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
|
|
||||||
"""
|
|
||||||
A byte-iterator over the decoded response content.
|
|
||||||
|
|
||||||
This automatically handles gzip, deflate and brotli encoded responses.
|
|
||||||
"""
|
|
||||||
yield from self.http_response.iter_bytes(chunk_size)
|
|
||||||
|
|
||||||
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
|
|
||||||
"""A str-iterator over the decoded response content
|
|
||||||
that handles both gzip, deflate, etc but also detects the content's
|
|
||||||
string encoding.
|
|
||||||
"""
|
|
||||||
yield from self.http_response.iter_text(chunk_size)
|
|
||||||
|
|
||||||
def iter_lines(self) -> Iterator[str]:
|
|
||||||
"""Like `iter_text()` but will only yield chunks for each line"""
|
|
||||||
yield from self.http_response.iter_lines()
|
|
||||||
|
|
||||||
|
|
||||||
class MissingStreamClassError(TypeError):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__(
|
|
||||||
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference", # noqa: E501
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StreamAlreadyConsumed(ZhipuAIError): # noqa: N818
|
|
||||||
"""
|
|
||||||
Attempted to read or stream content, but the content has already
|
|
||||||
been streamed.
|
|
||||||
|
|
||||||
This can happen if you use a method like `.iter_lines()` and then attempt
|
|
||||||
to read th entire response body afterwards, e.g.
|
|
||||||
|
|
||||||
```py
|
|
||||||
response = await client.post(...)
|
|
||||||
async for line in response.iter_lines():
|
|
||||||
... # do something with `line`
|
|
||||||
|
|
||||||
content = await response.read()
|
|
||||||
# ^ error
|
|
||||||
```
|
|
||||||
|
|
||||||
If you want this behavior you'll need to either manually accumulate the response
|
|
||||||
content or call `await response.read()` before iterating over the stream.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
message = (
|
|
||||||
"Attempted to read or stream some content, but the content has "
|
|
||||||
"already been streamed. "
|
|
||||||
"This could be due to attempting to stream the response "
|
|
||||||
"content more than once."
|
|
||||||
"\n\n"
|
|
||||||
"You can fix this by manually accumulating the response content while streaming "
|
|
||||||
"or by calling `.read()` before starting to stream."
|
|
||||||
)
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type:
|
|
||||||
"""Given a type like `APIResponse[T]`, returns the generic type variable `T`.
|
|
||||||
|
|
||||||
This also handles the case where a concrete subclass is given, e.g.
|
|
||||||
```py
|
|
||||||
class MyResponse(APIResponse[bytes]):
|
|
||||||
...
|
|
||||||
|
|
||||||
extract_response_type(MyResponse) -> bytes
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
return extract_type_var_from_base(
|
|
||||||
typ,
|
|
||||||
generic_bases=cast("tuple[type, ...]", (BaseAPIResponse, APIResponse)),
|
|
||||||
index=0,
|
|
||||||
)
|
|
@ -1,206 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
from collections.abc import Iterator, Mapping
|
|
||||||
from typing import TYPE_CHECKING, Generic, TypeGuard, cast
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from . import get_origin
|
|
||||||
from ._base_type import ResponseT
|
|
||||||
from ._errors import APIResponseError
|
|
||||||
from ._utils import extract_type_var_from_base, is_mapping
|
|
||||||
|
|
||||||
_FIELD_SEPARATOR = ":"
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ._http_client import HttpClient
|
|
||||||
|
|
||||||
|
|
||||||
class StreamResponse(Generic[ResponseT]):
|
|
||||||
response: httpx.Response
|
|
||||||
_cast_type: type[ResponseT]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
cast_type: type[ResponseT],
|
|
||||||
response: httpx.Response,
|
|
||||||
client: HttpClient,
|
|
||||||
) -> None:
|
|
||||||
self.response = response
|
|
||||||
self._cast_type = cast_type
|
|
||||||
self._data_process_func = client._process_response_data
|
|
||||||
self._stream_chunks = self.__stream__()
|
|
||||||
|
|
||||||
def __next__(self) -> ResponseT:
|
|
||||||
return self._stream_chunks.__next__()
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[ResponseT]:
|
|
||||||
yield from self._stream_chunks
|
|
||||||
|
|
||||||
def __stream__(self) -> Iterator[ResponseT]:
|
|
||||||
sse_line_parser = SSELineParser()
|
|
||||||
iterator = sse_line_parser.iter_lines(self.response.iter_lines())
|
|
||||||
|
|
||||||
for sse in iterator:
|
|
||||||
if sse.data.startswith("[DONE]"):
|
|
||||||
break
|
|
||||||
|
|
||||||
if sse.event is None:
|
|
||||||
data = sse.json_data()
|
|
||||||
if isinstance(data, Mapping) and data.get("error"):
|
|
||||||
raise APIResponseError(
|
|
||||||
message="An error occurred during streaming",
|
|
||||||
request=self.response.request,
|
|
||||||
json_data=data["error"],
|
|
||||||
)
|
|
||||||
if sse.event is None:
|
|
||||||
data = sse.json_data()
|
|
||||||
if is_mapping(data) and data.get("error"):
|
|
||||||
message = None
|
|
||||||
error = data.get("error")
|
|
||||||
if is_mapping(error):
|
|
||||||
message = error.get("message")
|
|
||||||
if not message or not isinstance(message, str):
|
|
||||||
message = "An error occurred during streaming"
|
|
||||||
|
|
||||||
raise APIResponseError(
|
|
||||||
message=message,
|
|
||||||
request=self.response.request,
|
|
||||||
json_data=data["error"],
|
|
||||||
)
|
|
||||||
yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response)
|
|
||||||
|
|
||||||
else:
|
|
||||||
data = sse.json_data()
|
|
||||||
|
|
||||||
if sse.event == "error" and is_mapping(data) and data.get("error"):
|
|
||||||
message = None
|
|
||||||
error = data.get("error")
|
|
||||||
if is_mapping(error):
|
|
||||||
message = error.get("message")
|
|
||||||
if not message or not isinstance(message, str):
|
|
||||||
message = "An error occurred during streaming"
|
|
||||||
|
|
||||||
raise APIResponseError(
|
|
||||||
message=message,
|
|
||||||
request=self.response.request,
|
|
||||||
json_data=data["error"],
|
|
||||||
)
|
|
||||||
yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response)
|
|
||||||
|
|
||||||
for sse in iterator:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Event:
|
|
||||||
def __init__(
|
|
||||||
self, event: str | None = None, data: str | None = None, id: str | None = None, retry: int | None = None
|
|
||||||
):
|
|
||||||
self._event = event
|
|
||||||
self._data = data
|
|
||||||
self._id = id
|
|
||||||
self._retry = retry
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
data_len = len(self._data) if self._data else 0
|
|
||||||
return (
|
|
||||||
f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def event(self):
|
|
||||||
return self._event
|
|
||||||
|
|
||||||
@property
|
|
||||||
def data(self):
|
|
||||||
return self._data
|
|
||||||
|
|
||||||
def json_data(self):
|
|
||||||
return json.loads(self._data)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self):
|
|
||||||
return self._id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def retry(self):
|
|
||||||
return self._retry
|
|
||||||
|
|
||||||
|
|
||||||
class SSELineParser:
|
|
||||||
_data: list[str]
|
|
||||||
_event: str | None
|
|
||||||
_retry: int | None
|
|
||||||
_id: str | None
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._event = None
|
|
||||||
self._data = []
|
|
||||||
self._id = None
|
|
||||||
self._retry = None
|
|
||||||
|
|
||||||
def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]:
|
|
||||||
for line in lines:
|
|
||||||
line = line.rstrip("\n")
|
|
||||||
if not line:
|
|
||||||
if self._event is None and not self._data and self._id is None and self._retry is None:
|
|
||||||
continue
|
|
||||||
sse_event = Event(event=self._event, data="\n".join(self._data), id=self._id, retry=self._retry)
|
|
||||||
self._event = None
|
|
||||||
self._data = []
|
|
||||||
self._id = None
|
|
||||||
self._retry = None
|
|
||||||
|
|
||||||
yield sse_event
|
|
||||||
self.decode_line(line)
|
|
||||||
|
|
||||||
def decode_line(self, line: str):
|
|
||||||
if line.startswith(":") or not line:
|
|
||||||
return
|
|
||||||
|
|
||||||
field, _p, value = line.partition(":")
|
|
||||||
|
|
||||||
value = value.removeprefix(" ")
|
|
||||||
if field == "data":
|
|
||||||
self._data.append(value)
|
|
||||||
elif field == "event":
|
|
||||||
self._event = value
|
|
||||||
elif field == "retry":
|
|
||||||
try:
|
|
||||||
self._retry = int(value)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
pass
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def is_stream_class_type(typ: type) -> TypeGuard[type[StreamResponse[object]]]:
|
|
||||||
"""TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
|
|
||||||
origin = get_origin(typ) or typ
|
|
||||||
return inspect.isclass(origin) and issubclass(origin, StreamResponse)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_stream_chunk_type(
|
|
||||||
stream_cls: type,
|
|
||||||
*,
|
|
||||||
failure_message: str | None = None,
|
|
||||||
) -> type:
|
|
||||||
"""Given a type like `StreamResponse[T]`, returns the generic type variable `T`.
|
|
||||||
|
|
||||||
This also handles the case where a concrete subclass is given, e.g.
|
|
||||||
```py
|
|
||||||
class MyStream(StreamResponse[bytes]):
|
|
||||||
...
|
|
||||||
|
|
||||||
extract_stream_chunk_type(MyStream) -> bytes
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
return extract_type_var_from_base(
|
|
||||||
stream_cls,
|
|
||||||
index=0,
|
|
||||||
generic_bases=cast("tuple[type, ...]", (StreamResponse,)),
|
|
||||||
failure_message=failure_message,
|
|
||||||
)
|
|
@ -1,52 +0,0 @@
|
|||||||
from ._utils import ( # noqa: I001
|
|
||||||
remove_notgiven_indict as remove_notgiven_indict, # noqa: PLC0414
|
|
||||||
flatten as flatten, # noqa: PLC0414
|
|
||||||
is_dict as is_dict, # noqa: PLC0414
|
|
||||||
is_list as is_list, # noqa: PLC0414
|
|
||||||
is_given as is_given, # noqa: PLC0414
|
|
||||||
is_tuple as is_tuple, # noqa: PLC0414
|
|
||||||
is_mapping as is_mapping, # noqa: PLC0414
|
|
||||||
is_tuple_t as is_tuple_t, # noqa: PLC0414
|
|
||||||
parse_date as parse_date, # noqa: PLC0414
|
|
||||||
is_iterable as is_iterable, # noqa: PLC0414
|
|
||||||
is_sequence as is_sequence, # noqa: PLC0414
|
|
||||||
coerce_float as coerce_float, # noqa: PLC0414
|
|
||||||
is_mapping_t as is_mapping_t, # noqa: PLC0414
|
|
||||||
removeprefix as removeprefix, # noqa: PLC0414
|
|
||||||
removesuffix as removesuffix, # noqa: PLC0414
|
|
||||||
extract_files as extract_files, # noqa: PLC0414
|
|
||||||
is_sequence_t as is_sequence_t, # noqa: PLC0414
|
|
||||||
required_args as required_args, # noqa: PLC0414
|
|
||||||
coerce_boolean as coerce_boolean, # noqa: PLC0414
|
|
||||||
coerce_integer as coerce_integer, # noqa: PLC0414
|
|
||||||
file_from_path as file_from_path, # noqa: PLC0414
|
|
||||||
parse_datetime as parse_datetime, # noqa: PLC0414
|
|
||||||
strip_not_given as strip_not_given, # noqa: PLC0414
|
|
||||||
deepcopy_minimal as deepcopy_minimal, # noqa: PLC0414
|
|
||||||
get_async_library as get_async_library, # noqa: PLC0414
|
|
||||||
maybe_coerce_float as maybe_coerce_float, # noqa: PLC0414
|
|
||||||
get_required_header as get_required_header, # noqa: PLC0414
|
|
||||||
maybe_coerce_boolean as maybe_coerce_boolean, # noqa: PLC0414
|
|
||||||
maybe_coerce_integer as maybe_coerce_integer, # noqa: PLC0414
|
|
||||||
drop_prefix_image_data as drop_prefix_image_data, # noqa: PLC0414
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
from ._typing import (
|
|
||||||
is_list_type as is_list_type, # noqa: PLC0414
|
|
||||||
is_union_type as is_union_type, # noqa: PLC0414
|
|
||||||
extract_type_arg as extract_type_arg, # noqa: PLC0414
|
|
||||||
is_iterable_type as is_iterable_type, # noqa: PLC0414
|
|
||||||
is_required_type as is_required_type, # noqa: PLC0414
|
|
||||||
is_annotated_type as is_annotated_type, # noqa: PLC0414
|
|
||||||
strip_annotated_type as strip_annotated_type, # noqa: PLC0414
|
|
||||||
extract_type_var_from_base as extract_type_var_from_base, # noqa: PLC0414
|
|
||||||
)
|
|
||||||
|
|
||||||
from ._transform import (
|
|
||||||
PropertyInfo as PropertyInfo, # noqa: PLC0414
|
|
||||||
transform as transform, # noqa: PLC0414
|
|
||||||
async_transform as async_transform, # noqa: PLC0414
|
|
||||||
maybe_transform as maybe_transform, # noqa: PLC0414
|
|
||||||
async_maybe_transform as async_maybe_transform, # noqa: PLC0414
|
|
||||||
)
|
|
@ -1,383 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import io
|
|
||||||
import pathlib
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from datetime import date, datetime
|
|
||||||
from typing import Any, Literal, TypeVar, cast, get_args, get_type_hints
|
|
||||||
|
|
||||||
import anyio
|
|
||||||
import pydantic
|
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from .._base_compat import is_typeddict, model_dump
|
|
||||||
from .._files import is_base64_file_input
|
|
||||||
from ._typing import (
|
|
||||||
extract_type_arg,
|
|
||||||
is_annotated_type,
|
|
||||||
is_iterable_type,
|
|
||||||
is_list_type,
|
|
||||||
is_required_type,
|
|
||||||
is_union_type,
|
|
||||||
strip_annotated_type,
|
|
||||||
)
|
|
||||||
from ._utils import (
|
|
||||||
is_iterable,
|
|
||||||
is_list,
|
|
||||||
is_mapping,
|
|
||||||
)
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: support for drilling globals() and locals()
|
|
||||||
# TODO: ensure works correctly with forward references in all cases
|
|
||||||
|
|
||||||
|
|
||||||
PropertyFormat = Literal["iso8601", "base64", "custom"]
|
|
||||||
|
|
||||||
|
|
||||||
class PropertyInfo:
|
|
||||||
"""Metadata class to be used in Annotated types to provide information about a given type.
|
|
||||||
|
|
||||||
For example:
|
|
||||||
|
|
||||||
class MyParams(TypedDict):
|
|
||||||
account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
|
|
||||||
|
|
||||||
This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
alias: str | None
|
|
||||||
format: PropertyFormat | None
|
|
||||||
format_template: str | None
|
|
||||||
discriminator: str | None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
alias: str | None = None,
|
|
||||||
format: PropertyFormat | None = None,
|
|
||||||
format_template: str | None = None,
|
|
||||||
discriminator: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.alias = alias
|
|
||||||
self.format = format
|
|
||||||
self.format_template = format_template
|
|
||||||
self.discriminator = discriminator
|
|
||||||
|
|
||||||
@override
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')" # noqa: E501
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_transform(
|
|
||||||
data: object,
|
|
||||||
expected_type: object,
|
|
||||||
) -> Any | None:
|
|
||||||
"""Wrapper over `transform()` that allows `None` to be passed.
|
|
||||||
|
|
||||||
See `transform()` for more details.
|
|
||||||
"""
|
|
||||||
if data is None:
|
|
||||||
return None
|
|
||||||
return transform(data, expected_type)
|
|
||||||
|
|
||||||
|
|
||||||
# Wrapper over _transform_recursive providing fake types
|
|
||||||
def transform(
|
|
||||||
data: _T,
|
|
||||||
expected_type: object,
|
|
||||||
) -> _T:
|
|
||||||
"""Transform dictionaries based off of type information from the given type, for example:
|
|
||||||
|
|
||||||
```py
|
|
||||||
class Params(TypedDict, total=False):
|
|
||||||
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
|
|
||||||
|
|
||||||
|
|
||||||
transformed = transform({"card_id": "<my card ID>"}, Params)
|
|
||||||
# {'cardID': '<my card ID>'}
|
|
||||||
```
|
|
||||||
|
|
||||||
Any keys / data that does not have type information given will be included as is.
|
|
||||||
|
|
||||||
It should be noted that the transformations that this function does are not represented in the type system.
|
|
||||||
"""
|
|
||||||
transformed = _transform_recursive(data, annotation=cast(type, expected_type))
|
|
||||||
return cast(_T, transformed)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_annotated_type(type_: type) -> type | None:
|
|
||||||
"""If the given type is an `Annotated` type then it is returned, if not `None` is returned.
|
|
||||||
|
|
||||||
This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
|
|
||||||
"""
|
|
||||||
if is_required_type(type_):
|
|
||||||
# Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
|
|
||||||
type_ = get_args(type_)[0]
|
|
||||||
|
|
||||||
if is_annotated_type(type_):
|
|
||||||
return type_
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_transform_key(key: str, type_: type) -> str:
|
|
||||||
"""Transform the given `data` based on the annotations provided in `type_`.
|
|
||||||
|
|
||||||
Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata.
|
|
||||||
"""
|
|
||||||
annotated_type = _get_annotated_type(type_)
|
|
||||||
if annotated_type is None:
|
|
||||||
# no `Annotated` definition for this type, no transformation needed
|
|
||||||
return key
|
|
||||||
|
|
||||||
# ignore the first argument as it is the actual type
|
|
||||||
annotations = get_args(annotated_type)[1:]
|
|
||||||
for annotation in annotations:
|
|
||||||
if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
|
|
||||||
return annotation.alias
|
|
||||||
|
|
||||||
return key
|
|
||||||
|
|
||||||
|
|
||||||
def _transform_recursive(
|
|
||||||
data: object,
|
|
||||||
*,
|
|
||||||
annotation: type,
|
|
||||||
inner_type: type | None = None,
|
|
||||||
) -> object:
|
|
||||||
"""Transform the given data against the expected type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
annotation: The direct type annotation given to the particular piece of data.
|
|
||||||
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
|
|
||||||
|
|
||||||
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
|
|
||||||
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
|
|
||||||
the list can be transformed using the metadata from the container type.
|
|
||||||
|
|
||||||
Defaults to the same value as the `annotation` argument.
|
|
||||||
"""
|
|
||||||
if inner_type is None:
|
|
||||||
inner_type = annotation
|
|
||||||
|
|
||||||
stripped_type = strip_annotated_type(inner_type)
|
|
||||||
if is_typeddict(stripped_type) and is_mapping(data):
|
|
||||||
return _transform_typeddict(data, stripped_type)
|
|
||||||
|
|
||||||
if (
|
|
||||||
# List[T]
|
|
||||||
(is_list_type(stripped_type) and is_list(data))
|
|
||||||
# Iterable[T]
|
|
||||||
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
|
|
||||||
):
|
|
||||||
inner_type = extract_type_arg(stripped_type, 0)
|
|
||||||
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
|
|
||||||
|
|
||||||
if is_union_type(stripped_type):
|
|
||||||
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
|
|
||||||
#
|
|
||||||
# TODO: there may be edge cases where the same normalized field name will transform to two different names
|
|
||||||
# in different subtypes.
|
|
||||||
for subtype in get_args(stripped_type):
|
|
||||||
data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
|
|
||||||
return data
|
|
||||||
|
|
||||||
if isinstance(data, pydantic.BaseModel):
|
|
||||||
return model_dump(data, exclude_unset=True)
|
|
||||||
|
|
||||||
annotated_type = _get_annotated_type(annotation)
|
|
||||||
if annotated_type is None:
|
|
||||||
return data
|
|
||||||
|
|
||||||
# ignore the first argument as it is the actual type
|
|
||||||
annotations = get_args(annotated_type)[1:]
|
|
||||||
for annotation in annotations:
|
|
||||||
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
|
|
||||||
return _format_data(data, annotation.format, annotation.format_template)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
|
|
||||||
if isinstance(data, date | datetime):
|
|
||||||
if format_ == "iso8601":
|
|
||||||
return data.isoformat()
|
|
||||||
|
|
||||||
if format_ == "custom" and format_template is not None:
|
|
||||||
return data.strftime(format_template)
|
|
||||||
|
|
||||||
if format_ == "base64" and is_base64_file_input(data):
|
|
||||||
binary: str | bytes | None = None
|
|
||||||
|
|
||||||
if isinstance(data, pathlib.Path):
|
|
||||||
binary = data.read_bytes()
|
|
||||||
elif isinstance(data, io.IOBase):
|
|
||||||
binary = data.read()
|
|
||||||
|
|
||||||
if isinstance(binary, str): # type: ignore[unreachable]
|
|
||||||
binary = binary.encode()
|
|
||||||
|
|
||||||
if not isinstance(binary, bytes):
|
|
||||||
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
|
|
||||||
|
|
||||||
return base64.b64encode(binary).decode("ascii")
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def _transform_typeddict(
|
|
||||||
data: Mapping[str, object],
|
|
||||||
expected_type: type,
|
|
||||||
) -> Mapping[str, object]:
|
|
||||||
result: dict[str, object] = {}
|
|
||||||
annotations = get_type_hints(expected_type, include_extras=True)
|
|
||||||
for key, value in data.items():
|
|
||||||
type_ = annotations.get(key)
|
|
||||||
if type_ is None:
|
|
||||||
# we do not have a type annotation for this field, leave it as is
|
|
||||||
result[key] = value
|
|
||||||
else:
|
|
||||||
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def async_maybe_transform(
|
|
||||||
data: object,
|
|
||||||
expected_type: object,
|
|
||||||
) -> Any | None:
|
|
||||||
"""Wrapper over `async_transform()` that allows `None` to be passed.
|
|
||||||
|
|
||||||
See `async_transform()` for more details.
|
|
||||||
"""
|
|
||||||
if data is None:
|
|
||||||
return None
|
|
||||||
return await async_transform(data, expected_type)
|
|
||||||
|
|
||||||
|
|
||||||
async def async_transform(
|
|
||||||
data: _T,
|
|
||||||
expected_type: object,
|
|
||||||
) -> _T:
|
|
||||||
"""Transform dictionaries based off of type information from the given type, for example:
|
|
||||||
|
|
||||||
```py
|
|
||||||
class Params(TypedDict, total=False):
|
|
||||||
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
|
|
||||||
|
|
||||||
|
|
||||||
transformed = transform({"card_id": "<my card ID>"}, Params)
|
|
||||||
# {'cardID': '<my card ID>'}
|
|
||||||
```
|
|
||||||
|
|
||||||
Any keys / data that does not have type information given will be included as is.
|
|
||||||
|
|
||||||
It should be noted that the transformations that this function does are not represented in the type system.
|
|
||||||
"""
|
|
||||||
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
|
|
||||||
return cast(_T, transformed)
|
|
||||||
|
|
||||||
|
|
||||||
async def _async_transform_recursive(
|
|
||||||
data: object,
|
|
||||||
*,
|
|
||||||
annotation: type,
|
|
||||||
inner_type: type | None = None,
|
|
||||||
) -> object:
|
|
||||||
"""Transform the given data against the expected type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
annotation: The direct type annotation given to the particular piece of data.
|
|
||||||
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
|
|
||||||
|
|
||||||
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
|
|
||||||
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
|
|
||||||
the list can be transformed using the metadata from the container type.
|
|
||||||
|
|
||||||
Defaults to the same value as the `annotation` argument.
|
|
||||||
"""
|
|
||||||
if inner_type is None:
|
|
||||||
inner_type = annotation
|
|
||||||
|
|
||||||
stripped_type = strip_annotated_type(inner_type)
|
|
||||||
if is_typeddict(stripped_type) and is_mapping(data):
|
|
||||||
return await _async_transform_typeddict(data, stripped_type)
|
|
||||||
|
|
||||||
if (
|
|
||||||
# List[T]
|
|
||||||
(is_list_type(stripped_type) and is_list(data))
|
|
||||||
# Iterable[T]
|
|
||||||
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
|
|
||||||
):
|
|
||||||
inner_type = extract_type_arg(stripped_type, 0)
|
|
||||||
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
|
|
||||||
|
|
||||||
if is_union_type(stripped_type):
|
|
||||||
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
|
|
||||||
#
|
|
||||||
# TODO: there may be edge cases where the same normalized field name will transform to two different names
|
|
||||||
# in different subtypes.
|
|
||||||
for subtype in get_args(stripped_type):
|
|
||||||
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
|
|
||||||
return data
|
|
||||||
|
|
||||||
if isinstance(data, pydantic.BaseModel):
|
|
||||||
return model_dump(data, exclude_unset=True)
|
|
||||||
|
|
||||||
annotated_type = _get_annotated_type(annotation)
|
|
||||||
if annotated_type is None:
|
|
||||||
return data
|
|
||||||
|
|
||||||
# ignore the first argument as it is the actual type
|
|
||||||
annotations = get_args(annotated_type)[1:]
|
|
||||||
for annotation in annotations:
|
|
||||||
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
|
|
||||||
return await _async_format_data(data, annotation.format, annotation.format_template)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
|
|
||||||
if isinstance(data, date | datetime):
|
|
||||||
if format_ == "iso8601":
|
|
||||||
return data.isoformat()
|
|
||||||
|
|
||||||
if format_ == "custom" and format_template is not None:
|
|
||||||
return data.strftime(format_template)
|
|
||||||
|
|
||||||
if format_ == "base64" and is_base64_file_input(data):
|
|
||||||
binary: str | bytes | None = None
|
|
||||||
|
|
||||||
if isinstance(data, pathlib.Path):
|
|
||||||
binary = await anyio.Path(data).read_bytes()
|
|
||||||
elif isinstance(data, io.IOBase):
|
|
||||||
binary = data.read()
|
|
||||||
|
|
||||||
if isinstance(binary, str): # type: ignore[unreachable]
|
|
||||||
binary = binary.encode()
|
|
||||||
|
|
||||||
if not isinstance(binary, bytes):
|
|
||||||
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
|
|
||||||
|
|
||||||
return base64.b64encode(binary).decode("ascii")
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
async def _async_transform_typeddict(
|
|
||||||
data: Mapping[str, object],
|
|
||||||
expected_type: type,
|
|
||||||
) -> Mapping[str, object]:
|
|
||||||
result: dict[str, object] = {}
|
|
||||||
annotations = get_type_hints(expected_type, include_extras=True)
|
|
||||||
for key, value in data.items():
|
|
||||||
type_ = annotations.get(key)
|
|
||||||
if type_ is None:
|
|
||||||
# we do not have a type annotation for this field, leave it as is
|
|
||||||
result[key] = value
|
|
||||||
else:
|
|
||||||
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
|
|
||||||
return result
|
|
@ -1,122 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections import abc as _c_abc
|
|
||||||
from collections.abc import Iterable
|
|
||||||
from typing import Annotated, Any, TypeVar, cast, get_args, get_origin
|
|
||||||
|
|
||||||
from typing_extensions import Required
|
|
||||||
|
|
||||||
from .._base_compat import is_union as _is_union
|
|
||||||
from .._base_type import InheritsGeneric
|
|
||||||
|
|
||||||
|
|
||||||
def is_annotated_type(typ: type) -> bool:
|
|
||||||
return get_origin(typ) == Annotated
|
|
||||||
|
|
||||||
|
|
||||||
def is_list_type(typ: type) -> bool:
|
|
||||||
return (get_origin(typ) or typ) == list
|
|
||||||
|
|
||||||
|
|
||||||
def is_iterable_type(typ: type) -> bool:
|
|
||||||
"""If the given type is `typing.Iterable[T]`"""
|
|
||||||
origin = get_origin(typ) or typ
|
|
||||||
return origin in {Iterable, _c_abc.Iterable}
|
|
||||||
|
|
||||||
|
|
||||||
def is_union_type(typ: type) -> bool:
|
|
||||||
return _is_union(get_origin(typ))
|
|
||||||
|
|
||||||
|
|
||||||
def is_required_type(typ: type) -> bool:
|
|
||||||
return get_origin(typ) == Required
|
|
||||||
|
|
||||||
|
|
||||||
def is_typevar(typ: type) -> bool:
|
|
||||||
# type ignore is required because type checkers
|
|
||||||
# think this expression will always return False
|
|
||||||
return type(typ) == TypeVar # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
|
|
||||||
def strip_annotated_type(typ: type) -> type:
|
|
||||||
if is_required_type(typ) or is_annotated_type(typ):
|
|
||||||
return strip_annotated_type(cast(type, get_args(typ)[0]))
|
|
||||||
|
|
||||||
return typ
|
|
||||||
|
|
||||||
|
|
||||||
def extract_type_arg(typ: type, index: int) -> type:
|
|
||||||
args = get_args(typ)
|
|
||||||
try:
|
|
||||||
return cast(type, args[index])
|
|
||||||
except IndexError as err:
|
|
||||||
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
|
|
||||||
|
|
||||||
|
|
||||||
def extract_type_var_from_base(
|
|
||||||
typ: type,
|
|
||||||
*,
|
|
||||||
generic_bases: tuple[type, ...],
|
|
||||||
index: int,
|
|
||||||
failure_message: str | None = None,
|
|
||||||
) -> type:
|
|
||||||
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
|
|
||||||
|
|
||||||
This also handles the case where a concrete subclass is given, e.g.
|
|
||||||
```py
|
|
||||||
class MyResponse(Foo[bytes]):
|
|
||||||
...
|
|
||||||
|
|
||||||
extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
|
|
||||||
```
|
|
||||||
|
|
||||||
And where a generic subclass is given:
|
|
||||||
```py
|
|
||||||
_T = TypeVar('_T')
|
|
||||||
class MyResponse(Foo[_T]):
|
|
||||||
...
|
|
||||||
|
|
||||||
extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
cls = cast(object, get_origin(typ) or typ)
|
|
||||||
if cls in generic_bases:
|
|
||||||
# we're given the class directly
|
|
||||||
return extract_type_arg(typ, index)
|
|
||||||
|
|
||||||
# if a subclass is given
|
|
||||||
# ---
|
|
||||||
# this is needed as __orig_bases__ is not present in the typeshed stubs
|
|
||||||
# because it is intended to be for internal use only, however there does
|
|
||||||
# not seem to be a way to resolve generic TypeVars for inherited subclasses
|
|
||||||
# without using it.
|
|
||||||
if isinstance(cls, InheritsGeneric):
|
|
||||||
target_base_class: Any | None = None
|
|
||||||
for base in cls.__orig_bases__:
|
|
||||||
if base.__origin__ in generic_bases:
|
|
||||||
target_base_class = base
|
|
||||||
break
|
|
||||||
|
|
||||||
if target_base_class is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Could not find the generic base class;\n"
|
|
||||||
"This should never happen;\n"
|
|
||||||
f"Does {cls} inherit from one of {generic_bases} ?"
|
|
||||||
)
|
|
||||||
|
|
||||||
extracted = extract_type_arg(target_base_class, index)
|
|
||||||
if is_typevar(extracted):
|
|
||||||
# If the extracted type argument is itself a type variable
|
|
||||||
# then that means the subclass itself is generic, so we have
|
|
||||||
# to resolve the type argument from the class itself, not
|
|
||||||
# the base class.
|
|
||||||
#
|
|
||||||
# Note: if there is more than 1 type argument, the subclass could
|
|
||||||
# change the ordering of the type arguments, this is not currently
|
|
||||||
# supported.
|
|
||||||
return extract_type_arg(typ, index)
|
|
||||||
|
|
||||||
return extracted
|
|
||||||
|
|
||||||
raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")
|
|
@ -1,409 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import inspect
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
TypeGuard,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
cast,
|
|
||||||
overload,
|
|
||||||
)
|
|
||||||
|
|
||||||
import sniffio
|
|
||||||
|
|
||||||
from .._base_compat import parse_date as parse_date # noqa: PLC0414
|
|
||||||
from .._base_compat import parse_datetime as parse_datetime # noqa: PLC0414
|
|
||||||
from .._base_type import FileTypes, Headers, HeadersLike, NotGiven, NotGivenOr
|
|
||||||
|
|
||||||
|
|
||||||
def remove_notgiven_indict(obj):
|
|
||||||
if obj is None or (not isinstance(obj, Mapping)):
|
|
||||||
return obj
|
|
||||||
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
|
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
_TupleT = TypeVar("_TupleT", bound=tuple[object, ...])
|
|
||||||
_MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
|
|
||||||
_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
|
|
||||||
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
|
||||||
|
|
||||||
|
|
||||||
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
|
|
||||||
return [item for sublist in t for item in sublist]
|
|
||||||
|
|
||||||
|
|
||||||
def extract_files(
|
|
||||||
# TODO: this needs to take Dict but variance issues.....
|
|
||||||
# create protocol type ?
|
|
||||||
query: Mapping[str, object],
|
|
||||||
*,
|
|
||||||
paths: Sequence[Sequence[str]],
|
|
||||||
) -> list[tuple[str, FileTypes]]:
|
|
||||||
"""Recursively extract files from the given dictionary based on specified paths.
|
|
||||||
|
|
||||||
A path may look like this ['foo', 'files', '<array>', 'data'].
|
|
||||||
|
|
||||||
Note: this mutates the given dictionary.
|
|
||||||
"""
|
|
||||||
files: list[tuple[str, FileTypes]] = []
|
|
||||||
for path in paths:
|
|
||||||
files.extend(_extract_items(query, path, index=0, flattened_key=None))
|
|
||||||
return files
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_items(
|
|
||||||
obj: object,
|
|
||||||
path: Sequence[str],
|
|
||||||
*,
|
|
||||||
index: int,
|
|
||||||
flattened_key: str | None,
|
|
||||||
) -> list[tuple[str, FileTypes]]:
|
|
||||||
try:
|
|
||||||
key = path[index]
|
|
||||||
except IndexError:
|
|
||||||
if isinstance(obj, NotGiven):
|
|
||||||
# no value was provided - we can safely ignore
|
|
||||||
return []
|
|
||||||
|
|
||||||
# cyclical import
|
|
||||||
from .._files import assert_is_file_content
|
|
||||||
|
|
||||||
# We have exhausted the path, return the entry we found.
|
|
||||||
assert_is_file_content(obj, key=flattened_key)
|
|
||||||
assert flattened_key is not None
|
|
||||||
return [(flattened_key, cast(FileTypes, obj))]
|
|
||||||
|
|
||||||
index += 1
|
|
||||||
if is_dict(obj):
|
|
||||||
try:
|
|
||||||
# We are at the last entry in the path so we must remove the field
|
|
||||||
if (len(path)) == index:
|
|
||||||
item = obj.pop(key)
|
|
||||||
else:
|
|
||||||
item = obj[key]
|
|
||||||
except KeyError:
|
|
||||||
# Key was not present in the dictionary, this is not indicative of an error
|
|
||||||
# as the given path may not point to a required field. We also do not want
|
|
||||||
# to enforce required fields as the API may differ from the spec in some cases.
|
|
||||||
return []
|
|
||||||
if flattened_key is None:
|
|
||||||
flattened_key = key
|
|
||||||
else:
|
|
||||||
flattened_key += f"[{key}]"
|
|
||||||
return _extract_items(
|
|
||||||
item,
|
|
||||||
path,
|
|
||||||
index=index,
|
|
||||||
flattened_key=flattened_key,
|
|
||||||
)
|
|
||||||
elif is_list(obj):
|
|
||||||
if key != "<array>":
|
|
||||||
return []
|
|
||||||
|
|
||||||
return flatten(
|
|
||||||
[
|
|
||||||
_extract_items(
|
|
||||||
item,
|
|
||||||
path,
|
|
||||||
index=index,
|
|
||||||
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
|
|
||||||
)
|
|
||||||
for item in obj
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Something unexpected was passed, just ignore it.
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]:
|
|
||||||
return not isinstance(obj, NotGiven)
|
|
||||||
|
|
||||||
|
|
||||||
# Type safe methods for narrowing types with TypeVars.
|
|
||||||
# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
|
|
||||||
# however this cause Pyright to rightfully report errors. As we know we don't
|
|
||||||
# care about the contained types we can safely use `object` in it's place.
|
|
||||||
#
|
|
||||||
# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
|
|
||||||
# `is_*` is for when you're dealing with an unknown input
|
|
||||||
# `is_*_t` is for when you're narrowing a known union type to a specific subset
|
|
||||||
|
|
||||||
|
|
||||||
def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
|
|
||||||
return isinstance(obj, tuple)
|
|
||||||
|
|
||||||
|
|
||||||
def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
|
|
||||||
return isinstance(obj, tuple)
|
|
||||||
|
|
||||||
|
|
||||||
def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
|
|
||||||
return isinstance(obj, Sequence)
|
|
||||||
|
|
||||||
|
|
||||||
def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
|
|
||||||
return isinstance(obj, Sequence)
|
|
||||||
|
|
||||||
|
|
||||||
def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
|
|
||||||
return isinstance(obj, Mapping)
|
|
||||||
|
|
||||||
|
|
||||||
def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
|
|
||||||
return isinstance(obj, Mapping)
|
|
||||||
|
|
||||||
|
|
||||||
def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
|
|
||||||
return isinstance(obj, dict)
|
|
||||||
|
|
||||||
|
|
||||||
def is_list(obj: object) -> TypeGuard[list[object]]:
|
|
||||||
return isinstance(obj, list)
|
|
||||||
|
|
||||||
|
|
||||||
def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
|
|
||||||
return isinstance(obj, Iterable)
|
|
||||||
|
|
||||||
|
|
||||||
def deepcopy_minimal(item: _T) -> _T:
|
|
||||||
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
|
|
||||||
|
|
||||||
- mappings, e.g. `dict`
|
|
||||||
- list
|
|
||||||
|
|
||||||
This is done for performance reasons.
|
|
||||||
"""
|
|
||||||
if is_mapping(item):
|
|
||||||
return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
|
|
||||||
if is_list(item):
|
|
||||||
return cast(_T, [deepcopy_minimal(entry) for entry in item])
|
|
||||||
return item
|
|
||||||
|
|
||||||
|
|
||||||
# copied from https://github.com/Rapptz/RoboDanny
|
|
||||||
def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
|
|
||||||
size = len(seq)
|
|
||||||
if size == 0:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
if size == 1:
|
|
||||||
return seq[0]
|
|
||||||
|
|
||||||
if size == 2:
|
|
||||||
return f"{seq[0]} {final} {seq[1]}"
|
|
||||||
|
|
||||||
return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
|
|
||||||
|
|
||||||
|
|
||||||
def quote(string: str) -> str:
|
|
||||||
"""Add single quotation marks around the given string. Does *not* do any escaping."""
|
|
||||||
return f"'{string}'"
|
|
||||||
|
|
||||||
|
|
||||||
def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
|
|
||||||
"""Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
|
|
||||||
|
|
||||||
Useful for enforcing runtime validation of overloaded functions.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
```py
|
|
||||||
@overload
|
|
||||||
def foo(*, a: str) -> str:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def foo(*, b: bool) -> str:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
# This enforces the same constraints that a static type checker would
|
|
||||||
# i.e. that either a or b must be passed to the function
|
|
||||||
@required_args(["a"], ["b"])
|
|
||||||
def foo(*, a: str | None = None, b: bool | None = None) -> str:
|
|
||||||
...
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def inner(func: CallableT) -> CallableT:
|
|
||||||
params = inspect.signature(func).parameters
|
|
||||||
positional = [
|
|
||||||
name
|
|
||||||
for name, param in params.items()
|
|
||||||
if param.kind
|
|
||||||
in {
|
|
||||||
param.POSITIONAL_ONLY,
|
|
||||||
param.POSITIONAL_OR_KEYWORD,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
@functools.wraps(func)
|
|
||||||
def wrapper(*args: object, **kwargs: object) -> object:
|
|
||||||
given_params: set[str] = set()
|
|
||||||
for i in range(len(args)):
|
|
||||||
try:
|
|
||||||
given_params.add(positional[i])
|
|
||||||
except IndexError:
|
|
||||||
raise TypeError(
|
|
||||||
f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
|
|
||||||
) from None
|
|
||||||
|
|
||||||
given_params.update(kwargs.keys())
|
|
||||||
|
|
||||||
for variant in variants:
|
|
||||||
matches = all(param in given_params for param in variant)
|
|
||||||
if matches:
|
|
||||||
break
|
|
||||||
else: # no break
|
|
||||||
if len(variants) > 1:
|
|
||||||
variations = human_join(
|
|
||||||
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
|
|
||||||
)
|
|
||||||
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
|
|
||||||
else:
|
|
||||||
# TODO: this error message is not deterministic
|
|
||||||
missing = list(set(variants[0]) - given_params)
|
|
||||||
if len(missing) > 1:
|
|
||||||
msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
|
|
||||||
else:
|
|
||||||
msg = f"Missing required argument: {quote(missing[0])}"
|
|
||||||
raise TypeError(msg)
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper # type: ignore
|
|
||||||
|
|
||||||
return inner
|
|
||||||
|
|
||||||
|
|
||||||
_K = TypeVar("_K")
|
|
||||||
_V = TypeVar("_V")
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def strip_not_given(obj: None) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def strip_not_given(obj: object) -> object: ...
|
|
||||||
|
|
||||||
|
|
||||||
def strip_not_given(obj: object | None) -> object:
|
|
||||||
"""Remove all top-level keys where their values are instances of `NotGiven`"""
|
|
||||||
if obj is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not is_mapping(obj):
|
|
||||||
return obj
|
|
||||||
|
|
||||||
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
|
|
||||||
|
|
||||||
|
|
||||||
def coerce_integer(val: str) -> int:
|
|
||||||
return int(val, base=10)
|
|
||||||
|
|
||||||
|
|
||||||
def coerce_float(val: str) -> float:
|
|
||||||
return float(val)
|
|
||||||
|
|
||||||
|
|
||||||
def coerce_boolean(val: str) -> bool:
|
|
||||||
return val in {"true", "1", "on"}
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_coerce_integer(val: str | None) -> int | None:
|
|
||||||
if val is None:
|
|
||||||
return None
|
|
||||||
return coerce_integer(val)
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_coerce_float(val: str | None) -> float | None:
|
|
||||||
if val is None:
|
|
||||||
return None
|
|
||||||
return coerce_float(val)
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_coerce_boolean(val: str | None) -> bool | None:
|
|
||||||
if val is None:
|
|
||||||
return None
|
|
||||||
return coerce_boolean(val)
|
|
||||||
|
|
||||||
|
|
||||||
def removeprefix(string: str, prefix: str) -> str:
|
|
||||||
"""Remove a prefix from a string.
|
|
||||||
|
|
||||||
Backport of `str.removeprefix` for Python < 3.9
|
|
||||||
"""
|
|
||||||
if string.startswith(prefix):
|
|
||||||
return string[len(prefix) :]
|
|
||||||
return string
|
|
||||||
|
|
||||||
|
|
||||||
def removesuffix(string: str, suffix: str) -> str:
|
|
||||||
"""Remove a suffix from a string.
|
|
||||||
|
|
||||||
Backport of `str.removesuffix` for Python < 3.9
|
|
||||||
"""
|
|
||||||
if string.endswith(suffix):
|
|
||||||
return string[: -len(suffix)]
|
|
||||||
return string
|
|
||||||
|
|
||||||
|
|
||||||
def file_from_path(path: str) -> FileTypes:
|
|
||||||
contents = Path(path).read_bytes()
|
|
||||||
file_name = os.path.basename(path)
|
|
||||||
return (file_name, contents)
|
|
||||||
|
|
||||||
|
|
||||||
def get_required_header(headers: HeadersLike, header: str) -> str:
|
|
||||||
lower_header = header.lower()
|
|
||||||
if isinstance(headers, Mapping):
|
|
||||||
headers = cast(Headers, headers)
|
|
||||||
for k, v in headers.items():
|
|
||||||
if k.lower() == lower_header and isinstance(v, str):
|
|
||||||
return v
|
|
||||||
|
|
||||||
""" to deal with the case where the header looks like Stainless-Event-Id """
|
|
||||||
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
|
|
||||||
|
|
||||||
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
|
|
||||||
value = headers.get(normalized_header)
|
|
||||||
if value:
|
|
||||||
return value
|
|
||||||
|
|
||||||
raise ValueError(f"Could not find {header} header")
|
|
||||||
|
|
||||||
|
|
||||||
def get_async_library() -> str:
|
|
||||||
try:
|
|
||||||
return sniffio.current_async_library()
|
|
||||||
except Exception:
|
|
||||||
return "false"
|
|
||||||
|
|
||||||
|
|
||||||
def drop_prefix_image_data(content: Union[str, list[dict]]) -> Union[str, list[dict]]:
|
|
||||||
"""
|
|
||||||
删除 ;base64, 前缀
|
|
||||||
:param image_data:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if isinstance(content, list):
|
|
||||||
for data in content:
|
|
||||||
if data.get("type") == "image_url":
|
|
||||||
image_data = data.get("image_url").get("url")
|
|
||||||
if image_data.startswith("data:image/"):
|
|
||||||
image_data = image_data.split("base64,")[-1]
|
|
||||||
data["image_url"]["url"] = image_data
|
|
||||||
|
|
||||||
return content
|
|
@ -1,78 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LoggerNameFilter(logging.Filter):
|
|
||||||
def filter(self, record):
|
|
||||||
# return record.name.startswith("loom_core") or record.name in "ERROR" or (
|
|
||||||
# record.name.startswith("uvicorn.error")
|
|
||||||
# and record.getMessage().startswith("Uvicorn running on")
|
|
||||||
# )
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def get_log_file(log_path: str, sub_dir: str):
|
|
||||||
"""
|
|
||||||
sub_dir should contain a timestamp.
|
|
||||||
"""
|
|
||||||
log_dir = os.path.join(log_path, sub_dir)
|
|
||||||
# Here should be creating a new directory each time, so `exist_ok=False`
|
|
||||||
os.makedirs(log_dir, exist_ok=False)
|
|
||||||
return os.path.join(log_dir, "zhipuai.log")
|
|
||||||
|
|
||||||
|
|
||||||
def get_config_dict(log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int) -> dict:
|
|
||||||
# for windows, the path should be a raw string.
|
|
||||||
log_file_path = log_file_path.encode("unicode-escape").decode() if os.name == "nt" else log_file_path
|
|
||||||
log_level = log_level.upper()
|
|
||||||
config_dict = {
|
|
||||||
"version": 1,
|
|
||||||
"disable_existing_loggers": False,
|
|
||||||
"formatters": {
|
|
||||||
"formatter": {"format": ("%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s")},
|
|
||||||
},
|
|
||||||
"filters": {
|
|
||||||
"logger_name_filter": {
|
|
||||||
"()": __name__ + ".LoggerNameFilter",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"handlers": {
|
|
||||||
"stream_handler": {
|
|
||||||
"class": "logging.StreamHandler",
|
|
||||||
"formatter": "formatter",
|
|
||||||
"level": log_level,
|
|
||||||
# "stream": "ext://sys.stdout",
|
|
||||||
# "filters": ["logger_name_filter"],
|
|
||||||
},
|
|
||||||
"file_handler": {
|
|
||||||
"class": "logging.handlers.RotatingFileHandler",
|
|
||||||
"formatter": "formatter",
|
|
||||||
"level": log_level,
|
|
||||||
"filename": log_file_path,
|
|
||||||
"mode": "a",
|
|
||||||
"maxBytes": log_max_bytes,
|
|
||||||
"backupCount": log_backup_count,
|
|
||||||
"encoding": "utf8",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"loggers": {
|
|
||||||
"loom_core": {
|
|
||||||
"handlers": ["stream_handler", "file_handler"],
|
|
||||||
"level": log_level,
|
|
||||||
"propagate": False,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"root": {
|
|
||||||
"level": log_level,
|
|
||||||
"handlers": ["stream_handler", "file_handler"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
|
|
||||||
def get_timestamp_ms():
|
|
||||||
t = time.time()
|
|
||||||
return int(round(t * 1000))
|
|
@ -1,62 +0,0 @@
|
|||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Any, Generic, Optional, TypeVar, cast
|
|
||||||
|
|
||||||
from typing_extensions import Protocol, override, runtime_checkable
|
|
||||||
|
|
||||||
from ._http_client import BasePage, BaseSyncPage, PageInfo
|
|
||||||
|
|
||||||
__all__ = ["SyncPage", "SyncCursorPage"]
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class CursorPageItem(Protocol):
|
|
||||||
id: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
class SyncPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
|
|
||||||
"""Note: no pagination actually occurs yet, this is for forwards-compatibility."""
|
|
||||||
|
|
||||||
data: list[_T]
|
|
||||||
object: str
|
|
||||||
|
|
||||||
@override
|
|
||||||
def _get_page_items(self) -> list[_T]:
|
|
||||||
data = self.data
|
|
||||||
if not data:
|
|
||||||
return []
|
|
||||||
return data
|
|
||||||
|
|
||||||
@override
|
|
||||||
def next_page_info(self) -> None:
|
|
||||||
"""
|
|
||||||
This page represents a response that isn't actually paginated at the API level
|
|
||||||
so there will never be a next page.
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class SyncCursorPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
|
|
||||||
data: list[_T]
|
|
||||||
|
|
||||||
@override
|
|
||||||
def _get_page_items(self) -> list[_T]:
|
|
||||||
data = self.data
|
|
||||||
if not data:
|
|
||||||
return []
|
|
||||||
return data
|
|
||||||
|
|
||||||
@override
|
|
||||||
def next_page_info(self) -> Optional[PageInfo]:
|
|
||||||
data = self.data
|
|
||||||
if not data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
item = cast(Any, data[-1])
|
|
||||||
if not isinstance(item, CursorPageItem) or item.id is None:
|
|
||||||
# TODO emit warning log
|
|
||||||
return None
|
|
||||||
|
|
||||||
return PageInfo(params={"after": item.id})
|
|
@ -1,5 +0,0 @@
|
|||||||
from .assistant_completion import AssistantCompletion
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AssistantCompletion",
|
|
||||||
]
|
|
@ -1,40 +0,0 @@
|
|||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from ...core import BaseModel
|
|
||||||
from .message import MessageContent
|
|
||||||
|
|
||||||
__all__ = ["AssistantCompletion", "CompletionUsage"]
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorInfo(BaseModel):
|
|
||||||
code: str # 错误码
|
|
||||||
message: str # 错误信息
|
|
||||||
|
|
||||||
|
|
||||||
class AssistantChoice(BaseModel):
|
|
||||||
index: int # 结果下标
|
|
||||||
delta: MessageContent # 当前会话输出消息体
|
|
||||||
finish_reason: str
|
|
||||||
"""
|
|
||||||
# 推理结束原因 stop代表推理自然结束或触发停止词。 sensitive 代表模型推理内容被安全审核接口拦截。请注意,针对此类内容,请用户自行判断并决定是否撤回已公开的内容。
|
|
||||||
# network_error 代表模型推理服务异常。
|
|
||||||
""" # noqa: E501
|
|
||||||
metadata: dict # 元信息,拓展字段
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionUsage(BaseModel):
|
|
||||||
prompt_tokens: int # 输入的 tokens 数量
|
|
||||||
completion_tokens: int # 输出的 tokens 数量
|
|
||||||
total_tokens: int # 总 tokens 数量
|
|
||||||
|
|
||||||
|
|
||||||
class AssistantCompletion(BaseModel):
|
|
||||||
id: str # 请求 ID
|
|
||||||
conversation_id: str # 会话 ID
|
|
||||||
assistant_id: str # 智能体 ID
|
|
||||||
created: int # 请求创建时间,Unix 时间戳
|
|
||||||
status: str # 返回状态,包括:`completed` 表示生成结束`in_progress`表示生成中 `failed` 表示生成异常
|
|
||||||
last_error: Optional[ErrorInfo] # 异常信息
|
|
||||||
choices: list[AssistantChoice] # 增量返回的信息
|
|
||||||
metadata: Optional[dict[str, Any]] # 元信息,拓展字段
|
|
||||||
usage: Optional[CompletionUsage] # tokens 数量统计
|
|
@ -1,7 +0,0 @@
|
|||||||
from typing import TypedDict
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationParameters(TypedDict, total=False):
|
|
||||||
assistant_id: str # 智能体 ID
|
|
||||||
page: int # 当前分页
|
|
||||||
page_size: int # 分页数量
|
|
@ -1,29 +0,0 @@
|
|||||||
from ...core import BaseModel
|
|
||||||
|
|
||||||
__all__ = ["ConversationUsageListResp"]
|
|
||||||
|
|
||||||
|
|
||||||
class Usage(BaseModel):
|
|
||||||
prompt_tokens: int # 用户输入的 tokens 数量
|
|
||||||
completion_tokens: int # 模型输入的 tokens 数量
|
|
||||||
total_tokens: int # 总 tokens 数量
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationUsage(BaseModel):
|
|
||||||
id: str # 会话 id
|
|
||||||
assistant_id: str # 智能体Assistant id
|
|
||||||
create_time: int # 创建时间
|
|
||||||
update_time: int # 更新时间
|
|
||||||
usage: Usage # 会话中 tokens 数量统计
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationUsageList(BaseModel):
|
|
||||||
assistant_id: str # 智能体id
|
|
||||||
has_more: bool # 是否还有更多页
|
|
||||||
conversation_list: list[ConversationUsage] # 返回的
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationUsageListResp(BaseModel):
|
|
||||||
code: int
|
|
||||||
msg: str
|
|
||||||
data: ConversationUsageList
|
|
@ -1,32 +0,0 @@
|
|||||||
from typing import Optional, TypedDict, Union
|
|
||||||
|
|
||||||
|
|
||||||
class AssistantAttachments:
|
|
||||||
file_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class MessageTextContent:
|
|
||||||
type: str # 目前支持 type = text
|
|
||||||
text: str
|
|
||||||
|
|
||||||
|
|
||||||
MessageContent = Union[MessageTextContent]
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationMessage(TypedDict):
|
|
||||||
"""会话消息体"""
|
|
||||||
|
|
||||||
role: str # 用户的输入角色,例如 'user'
|
|
||||||
content: list[MessageContent] # 会话消息体的内容
|
|
||||||
|
|
||||||
|
|
||||||
class AssistantParameters(TypedDict, total=False):
|
|
||||||
"""智能体参数类"""
|
|
||||||
|
|
||||||
assistant_id: str # 智能体 ID
|
|
||||||
conversation_id: Optional[str] # 会话 ID,不传则创建新会话
|
|
||||||
model: str # 模型名称,默认为 'GLM-4-Assistant'
|
|
||||||
stream: bool # 是否支持流式 SSE,需要传入 True
|
|
||||||
messages: list[ConversationMessage] # 会话消息体
|
|
||||||
attachments: Optional[list[AssistantAttachments]] # 会话指定的文件,非必填
|
|
||||||
metadata: Optional[dict] # 元信息,拓展字段,非必填
|
|
@ -1,21 +0,0 @@
|
|||||||
from ...core import BaseModel
|
|
||||||
|
|
||||||
__all__ = ["AssistantSupportResp"]
|
|
||||||
|
|
||||||
|
|
||||||
class AssistantSupport(BaseModel):
|
|
||||||
assistant_id: str # 智能体的 Assistant id,用于智能体会话
|
|
||||||
created_at: int # 创建时间
|
|
||||||
updated_at: int # 更新时间
|
|
||||||
name: str # 智能体名称
|
|
||||||
avatar: str # 智能体头像
|
|
||||||
description: str # 智能体描述
|
|
||||||
status: str # 智能体状态,目前只有 publish
|
|
||||||
tools: list[str] # 智能体支持的工具名
|
|
||||||
starter_prompts: list[str] # 智能体启动推荐的 prompt
|
|
||||||
|
|
||||||
|
|
||||||
class AssistantSupportResp(BaseModel):
|
|
||||||
code: int
|
|
||||||
msg: str
|
|
||||||
data: list[AssistantSupport] # 智能体列表
|
|
@ -1,3 +0,0 @@
|
|||||||
from .message_content import MessageContent
|
|
||||||
|
|
||||||
__all__ = ["MessageContent"]
|
|
@ -1,13 +0,0 @@
|
|||||||
from typing import Annotated, TypeAlias, Union
|
|
||||||
|
|
||||||
from ....core._utils import PropertyInfo
|
|
||||||
from .text_content_block import TextContentBlock
|
|
||||||
from .tools_delta_block import ToolsDeltaBlock
|
|
||||||
|
|
||||||
__all__ = ["MessageContent"]
|
|
||||||
|
|
||||||
|
|
||||||
MessageContent: TypeAlias = Annotated[
|
|
||||||
Union[ToolsDeltaBlock, TextContentBlock],
|
|
||||||
PropertyInfo(discriminator="type"),
|
|
||||||
]
|
|
@ -1,14 +0,0 @@
|
|||||||
from typing import Literal
|
|
||||||
|
|
||||||
from ....core import BaseModel
|
|
||||||
|
|
||||||
__all__ = ["TextContentBlock"]
|
|
||||||
|
|
||||||
|
|
||||||
class TextContentBlock(BaseModel):
|
|
||||||
content: str
|
|
||||||
|
|
||||||
role: str = "assistant"
|
|
||||||
|
|
||||||
type: Literal["content"] = "content"
|
|
||||||
"""Always `content`."""
|
|
@ -1,27 +0,0 @@
|
|||||||
from typing import Literal
|
|
||||||
|
|
||||||
__all__ = ["CodeInterpreterToolBlock"]
|
|
||||||
|
|
||||||
from .....core import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class CodeInterpreterToolOutput(BaseModel):
|
|
||||||
"""代码工具输出结果"""
|
|
||||||
|
|
||||||
type: str # 代码执行日志,目前只有 logs
|
|
||||||
logs: str # 代码执行的日志结果
|
|
||||||
error_msg: str # 错误信息
|
|
||||||
|
|
||||||
|
|
||||||
class CodeInterpreter(BaseModel):
|
|
||||||
"""代码解释器"""
|
|
||||||
|
|
||||||
input: str # 生成的代码片段,输入给代码沙盒
|
|
||||||
outputs: list[CodeInterpreterToolOutput] # 代码执行后的输出结果
|
|
||||||
|
|
||||||
|
|
||||||
class CodeInterpreterToolBlock(BaseModel):
|
|
||||||
"""代码工具块"""
|
|
||||||
|
|
||||||
code_interpreter: CodeInterpreter # 代码解释器对象
|
|
||||||
type: Literal["code_interpreter"] # 调用工具的类型,始终为 `code_interpreter`
|
|
@ -1,21 +0,0 @@
|
|||||||
from typing import Literal
|
|
||||||
|
|
||||||
from .....core import BaseModel
|
|
||||||
|
|
||||||
__all__ = ["DrawingToolBlock"]
|
|
||||||
|
|
||||||
|
|
||||||
class DrawingToolOutput(BaseModel):
|
|
||||||
image: str
|
|
||||||
|
|
||||||
|
|
||||||
class DrawingTool(BaseModel):
|
|
||||||
input: str
|
|
||||||
outputs: list[DrawingToolOutput]
|
|
||||||
|
|
||||||
|
|
||||||
class DrawingToolBlock(BaseModel):
|
|
||||||
drawing_tool: DrawingTool
|
|
||||||
|
|
||||||
type: Literal["drawing_tool"]
|
|
||||||
"""Always `drawing_tool`."""
|
|
@ -1,22 +0,0 @@
|
|||||||
from typing import Literal, Union
|
|
||||||
|
|
||||||
__all__ = ["FunctionToolBlock"]
|
|
||||||
|
|
||||||
from .....core import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionToolOutput(BaseModel):
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionTool(BaseModel):
|
|
||||||
name: str
|
|
||||||
arguments: Union[str, dict]
|
|
||||||
outputs: list[FunctionToolOutput]
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionToolBlock(BaseModel):
|
|
||||||
function: FunctionTool
|
|
||||||
|
|
||||||
type: Literal["function"]
|
|
||||||
"""Always `drawing_tool`."""
|
|
@ -1,41 +0,0 @@
|
|||||||
from typing import Literal
|
|
||||||
|
|
||||||
from .....core import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class RetrievalToolOutput(BaseModel):
|
|
||||||
"""
|
|
||||||
This class represents the output of a retrieval tool.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
- text (str): The text snippet retrieved from the knowledge base.
|
|
||||||
- document (str): The name of the document from which the text snippet was retrieved, returned only in intelligent configuration.
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
text: str
|
|
||||||
document: str
|
|
||||||
|
|
||||||
|
|
||||||
class RetrievalTool(BaseModel):
|
|
||||||
"""
|
|
||||||
This class represents the outputs of a retrieval tool.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
- outputs (List[RetrievalToolOutput]): A list of text snippets and their respective document names retrieved from the knowledge base.
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
outputs: list[RetrievalToolOutput]
|
|
||||||
|
|
||||||
|
|
||||||
class RetrievalToolBlock(BaseModel):
|
|
||||||
"""
|
|
||||||
This class represents a block for invoking the retrieval tool.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
- retrieval (RetrievalTool): An instance of the RetrievalTool class containing the retrieval outputs.
|
|
||||||
- type (Literal["retrieval"]): The type of tool being used, always set to "retrieval".
|
|
||||||
"""
|
|
||||||
|
|
||||||
retrieval: RetrievalTool
|
|
||||||
type: Literal["retrieval"]
|
|
||||||
"""Always `retrieval`."""
|
|
@ -1,16 +0,0 @@
|
|||||||
from typing import Annotated, TypeAlias, Union
|
|
||||||
|
|
||||||
from .....core._utils import PropertyInfo
|
|
||||||
from .code_interpreter_delta_block import CodeInterpreterToolBlock
|
|
||||||
from .drawing_tool_delta_block import DrawingToolBlock
|
|
||||||
from .function_delta_block import FunctionToolBlock
|
|
||||||
from .retrieval_delta_black import RetrievalToolBlock
|
|
||||||
from .web_browser_delta_block import WebBrowserToolBlock
|
|
||||||
|
|
||||||
__all__ = ["ToolsType"]
|
|
||||||
|
|
||||||
|
|
||||||
ToolsType: TypeAlias = Annotated[
|
|
||||||
Union[DrawingToolBlock, CodeInterpreterToolBlock, WebBrowserToolBlock, RetrievalToolBlock, FunctionToolBlock],
|
|
||||||
PropertyInfo(discriminator="type"),
|
|
||||||
]
|
|
@ -1,48 +0,0 @@
|
|||||||
from typing import Literal
|
|
||||||
|
|
||||||
from .....core import BaseModel
|
|
||||||
|
|
||||||
__all__ = ["WebBrowserToolBlock"]
|
|
||||||
|
|
||||||
|
|
||||||
class WebBrowserOutput(BaseModel):
|
|
||||||
"""
|
|
||||||
This class represents the output of a web browser search result.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
- title (str): The title of the search result.
|
|
||||||
- link (str): The URL link to the search result's webpage.
|
|
||||||
- content (str): The textual content extracted from the search result.
|
|
||||||
- error_msg (str): Any error message encountered during the search or retrieval process.
|
|
||||||
"""
|
|
||||||
|
|
||||||
title: str
|
|
||||||
link: str
|
|
||||||
content: str
|
|
||||||
error_msg: str
|
|
||||||
|
|
||||||
|
|
||||||
class WebBrowser(BaseModel):
|
|
||||||
"""
|
|
||||||
This class represents the input and outputs of a web browser search.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
- input (str): The input query for the web browser search.
|
|
||||||
- outputs (List[WebBrowserOutput]): A list of search results returned by the web browser.
|
|
||||||
"""
|
|
||||||
|
|
||||||
input: str
|
|
||||||
outputs: list[WebBrowserOutput]
|
|
||||||
|
|
||||||
|
|
||||||
class WebBrowserToolBlock(BaseModel):
|
|
||||||
"""
|
|
||||||
This class represents a block for invoking the web browser tool.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
- web_browser (WebBrowser): An instance of the WebBrowser class containing the search input and outputs.
|
|
||||||
- type (Literal["web_browser"]): The type of tool being used, always set to "web_browser".
|
|
||||||
"""
|
|
||||||
|
|
||||||
web_browser: WebBrowser
|
|
||||||
type: Literal["web_browser"]
|
|
@ -1,16 +0,0 @@
|
|||||||
from typing import Literal
|
|
||||||
|
|
||||||
from ....core import BaseModel
|
|
||||||
from .tools.tools_type import ToolsType
|
|
||||||
|
|
||||||
__all__ = ["ToolsDeltaBlock"]
|
|
||||||
|
|
||||||
|
|
||||||
class ToolsDeltaBlock(BaseModel):
|
|
||||||
tool_calls: list[ToolsType]
|
|
||||||
"""The index of the content part in the message."""
|
|
||||||
|
|
||||||
role: str = "tool"
|
|
||||||
|
|
||||||
type: Literal["tool_calls"] = "tool_calls"
|
|
||||||
"""Always `tool_calls`."""
|
|
@ -1,82 +0,0 @@
|
|||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
import builtins
|
|
||||||
from typing import Literal, Optional
|
|
||||||
|
|
||||||
from ..core import BaseModel
|
|
||||||
from .batch_error import BatchError
|
|
||||||
from .batch_request_counts import BatchRequestCounts
|
|
||||||
|
|
||||||
__all__ = ["Batch", "Errors"]
|
|
||||||
|
|
||||||
|
|
||||||
class Errors(BaseModel):
|
|
||||||
data: Optional[list[BatchError]] = None
|
|
||||||
|
|
||||||
object: Optional[str] = None
|
|
||||||
"""这个类型,一直是`list`。"""
|
|
||||||
|
|
||||||
|
|
||||||
class Batch(BaseModel):
|
|
||||||
id: str
|
|
||||||
|
|
||||||
completion_window: str
|
|
||||||
"""用于执行请求的地址信息。"""
|
|
||||||
|
|
||||||
created_at: int
|
|
||||||
"""这是 Unix timestamp (in seconds) 表示的创建时间。"""
|
|
||||||
|
|
||||||
endpoint: str
|
|
||||||
"""这是ZhipuAI endpoint的地址。"""
|
|
||||||
|
|
||||||
input_file_id: str
|
|
||||||
"""标记为batch的输入文件的ID。"""
|
|
||||||
|
|
||||||
object: Literal["batch"]
|
|
||||||
"""这个类型,一直是`batch`."""
|
|
||||||
|
|
||||||
status: Literal[
|
|
||||||
"validating", "failed", "in_progress", "finalizing", "completed", "expired", "cancelling", "cancelled"
|
|
||||||
]
|
|
||||||
"""batch 的状态。"""
|
|
||||||
|
|
||||||
cancelled_at: Optional[int] = None
|
|
||||||
"""Unix timestamp (in seconds) 表示的取消时间。"""
|
|
||||||
|
|
||||||
cancelling_at: Optional[int] = None
|
|
||||||
"""Unix timestamp (in seconds) 表示发起取消的请求时间 """
|
|
||||||
|
|
||||||
completed_at: Optional[int] = None
|
|
||||||
"""Unix timestamp (in seconds) 表示的完成时间。"""
|
|
||||||
|
|
||||||
error_file_id: Optional[str] = None
|
|
||||||
"""这个文件id包含了执行请求失败的请求的输出。"""
|
|
||||||
|
|
||||||
errors: Optional[Errors] = None
|
|
||||||
|
|
||||||
expired_at: Optional[int] = None
|
|
||||||
"""Unix timestamp (in seconds) 表示的将在过期时间。"""
|
|
||||||
|
|
||||||
expires_at: Optional[int] = None
|
|
||||||
"""Unix timestamp (in seconds) 触发过期"""
|
|
||||||
|
|
||||||
failed_at: Optional[int] = None
|
|
||||||
"""Unix timestamp (in seconds) 表示的失败时间。"""
|
|
||||||
|
|
||||||
finalizing_at: Optional[int] = None
|
|
||||||
"""Unix timestamp (in seconds) 表示的最终时间。"""
|
|
||||||
|
|
||||||
in_progress_at: Optional[int] = None
|
|
||||||
"""Unix timestamp (in seconds) 表示的开始处理时间。"""
|
|
||||||
|
|
||||||
metadata: Optional[builtins.object] = None
|
|
||||||
"""
|
|
||||||
key:value形式的元数据,以便将信息存储
|
|
||||||
结构化格式。键的长度是64个字符,值最长512个字符
|
|
||||||
"""
|
|
||||||
|
|
||||||
output_file_id: Optional[str] = None
|
|
||||||
"""完成请求的输出文件的ID。"""
|
|
||||||
|
|
||||||
request_counts: Optional[BatchRequestCounts] = None
|
|
||||||
"""批次中不同状态的请求计数"""
|
|
@ -1,37 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Literal, Optional
|
|
||||||
|
|
||||||
from typing_extensions import Required, TypedDict
|
|
||||||
|
|
||||||
__all__ = ["BatchCreateParams"]
|
|
||||||
|
|
||||||
|
|
||||||
class BatchCreateParams(TypedDict, total=False):
|
|
||||||
completion_window: Required[str]
|
|
||||||
"""The time frame within which the batch should be processed.
|
|
||||||
|
|
||||||
Currently only `24h` is supported.
|
|
||||||
"""
|
|
||||||
|
|
||||||
endpoint: Required[Literal["/v1/chat/completions", "/v1/embeddings"]]
|
|
||||||
"""The endpoint to be used for all requests in the batch.
|
|
||||||
|
|
||||||
Currently `/v1/chat/completions` and `/v1/embeddings` are supported.
|
|
||||||
"""
|
|
||||||
|
|
||||||
input_file_id: Required[str]
|
|
||||||
"""The ID of an uploaded file that contains requests for the new batch.
|
|
||||||
|
|
||||||
See [upload file](https://platform.openai.com/docs/api-reference/files/create)
|
|
||||||
for how to upload a file.
|
|
||||||
|
|
||||||
Your input file must be formatted as a
|
|
||||||
[JSONL file](https://platform.openai.com/docs/api-reference/batch/requestInput),
|
|
||||||
and must be uploaded with the purpose `batch`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
metadata: Optional[dict[str, str]]
|
|
||||||
"""Optional custom metadata for the batch."""
|
|
||||||
|
|
||||||
auto_delete_input_file: Optional[bool]
|
|
@ -1,21 +0,0 @@
|
|||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from ..core import BaseModel
|
|
||||||
|
|
||||||
__all__ = ["BatchError"]
|
|
||||||
|
|
||||||
|
|
||||||
class BatchError(BaseModel):
|
|
||||||
code: Optional[str] = None
|
|
||||||
"""定义的业务错误码"""
|
|
||||||
|
|
||||||
line: Optional[int] = None
|
|
||||||
"""文件中的行号"""
|
|
||||||
|
|
||||||
message: Optional[str] = None
|
|
||||||
"""关于对话文件中的错误的描述"""
|
|
||||||
|
|
||||||
param: Optional[str] = None
|
|
||||||
"""参数名称,如果有的话"""
|
|
@ -1,20 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
__all__ = ["BatchListParams"]
|
|
||||||
|
|
||||||
|
|
||||||
class BatchListParams(TypedDict, total=False):
|
|
||||||
after: str
|
|
||||||
"""分页的游标,用于获取下一页的数据。
|
|
||||||
|
|
||||||
`after` 是一个指向当前页面的游标,用于获取下一页的数据。如果没有提供 `after`,则返回第一页的数据。
|
|
||||||
list.
|
|
||||||
"""
|
|
||||||
|
|
||||||
limit: int
|
|
||||||
"""这个参数用于限制返回的结果数量。
|
|
||||||
|
|
||||||
Limit 用于限制返回的结果数量。默认值为 10
|
|
||||||
"""
|
|
@ -1,14 +0,0 @@
|
|||||||
from ..core import BaseModel
|
|
||||||
|
|
||||||
__all__ = ["BatchRequestCounts"]
|
|
||||||
|
|
||||||
|
|
||||||
class BatchRequestCounts(BaseModel):
|
|
||||||
completed: int
|
|
||||||
"""这个数字表示已经完成的请求。"""
|
|
||||||
|
|
||||||
failed: int
|
|
||||||
"""这个数字表示失败的请求。"""
|
|
||||||
|
|
||||||
total: int
|
|
||||||
"""这个数字表示总的请求。"""
|
|
@ -1,22 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from ...core import BaseModel
|
|
||||||
from .chat_completion import CompletionChoice, CompletionUsage
|
|
||||||
|
|
||||||
__all__ = ["AsyncTaskStatus", "AsyncCompletion"]
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncTaskStatus(BaseModel):
|
|
||||||
id: Optional[str] = None
|
|
||||||
request_id: Optional[str] = None
|
|
||||||
model: Optional[str] = None
|
|
||||||
task_status: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncCompletion(BaseModel):
|
|
||||||
id: Optional[str] = None
|
|
||||||
request_id: Optional[str] = None
|
|
||||||
model: Optional[str] = None
|
|
||||||
task_status: str
|
|
||||||
choices: list[CompletionChoice]
|
|
||||||
usage: CompletionUsage
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user