Merge branch 'main' into feat/attachments

This commit is contained in:
StyleZhang 2024-10-14 16:09:05 +08:00
commit 8fe5028f74
266 changed files with 4108 additions and 9061 deletions

View File

@ -20,6 +20,9 @@ FILES_URL=http://127.0.0.1:5001
# The time in seconds after the signature is rejected
FILES_ACCESS_TIMEOUT=300
# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
@ -39,7 +42,7 @@ DB_DATABASE=dify
# Storage configuration
# use for store upload files, private keys...
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos, baidu-obs
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos, baidu-obs, supabase
STORAGE_TYPE=local
STORAGE_LOCAL_PATH=storage
S3_USE_AWS_MANAGED_IAM=false
@ -99,11 +102,16 @@ VOLCENGINE_TOS_ACCESS_KEY=your-access-key
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
VOLCENGINE_TOS_REGION=your-region
# Supabase Storage Configuration
SUPABASE_BUCKET_NAME=your-bucket-name
SUPABASE_API_KEY=your-access-key
SUPABASE_URL=your-server-url
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb
VECTOR_STORE=weaviate
# Weaviate configuration
@ -203,6 +211,24 @@ OPENSEARCH_USER=admin
OPENSEARCH_PASSWORD=admin
OPENSEARCH_SECURE=true
# Baidu configuration
BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000
BAIDU_VECTOR_DB_ACCOUNT=root
BAIDU_VECTOR_DB_API_KEY=dify
BAIDU_VECTOR_DB_DATABASE=dify
BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3
# ViKingDB configuration
VIKINGDB_ACCESS_KEY=your-ak
VIKINGDB_SECRET_KEY=your-sk
VIKINGDB_REGION=cn-shanghai
VIKINGDB_HOST=api-vikingdb.xxx.volces.com
VIKINGDB_SCHEMA=http
VIKINGDB_CONNECTION_TIMEOUT=30
VIKINGDB_SOCKET_TIMEOUT=30
# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5

View File

@ -183,7 +183,7 @@ def load_user_from_request(request_from_flask_login):
decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id")
logged_in_account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
if logged_in_account:
contexts.tenant_id.set(logged_in_account.current_tenant_id)
return logged_in_account

View File

@ -347,6 +347,14 @@ def migrate_knowledge_vector_database():
index_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.BAIDU:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.BAIDU,
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -360,9 +360,9 @@ class WorkflowConfig(BaseSettings):
)
class OAuthConfig(BaseSettings):
class AuthConfig(BaseSettings):
"""
Configuration for OAuth authentication
Configuration for authentication and OAuth
"""
OAUTH_REDIRECT_PATH: str = Field(
@ -371,7 +371,7 @@ class OAuthConfig(BaseSettings):
)
GITHUB_CLIENT_ID: Optional[str] = Field(
description="GitHub OAuth client secret",
description="GitHub OAuth client ID",
default=None,
)
@ -390,6 +390,11 @@ class OAuthConfig(BaseSettings):
default=None,
)
ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field(
description="Expiration time for access tokens in minutes",
default=60,
)
class ModerationConfig(BaseSettings):
"""
@ -607,6 +612,7 @@ class PositionConfig(BaseSettings):
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
DataSetConfig,
@ -621,14 +627,13 @@ class FeatureConfig(
MailConfig,
ModelLoadBalanceConfig,
ModerationConfig,
OAuthConfig,
PositionConfig,
RagEtlConfig,
SecurityConfig,
ToolConfig,
UpdateConfig,
WorkflowConfig,
WorkspaceConfig,
PositionConfig,
# hosted services config
HostedServiceConfig,
CeleryBeatConfig,

View File

@ -12,6 +12,7 @@ from configs.middleware.storage.baidu_obs_storage_config import BaiduOBSStorageC
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.supabase_storage_config import SupabaseStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
@ -27,6 +28,7 @@ from configs.middleware.vdb.qdrant_config import QdrantConfig
from configs.middleware.vdb.relyt_config import RelytConfig
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
from configs.middleware.vdb.weaviate_config import WeaviateConfig
@ -191,6 +193,22 @@ class CeleryConfig(DatabaseConfig):
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False
class InternalTestConfig(BaseSettings):
"""
Configuration settings for Internal Test
"""
AWS_SECRET_ACCESS_KEY: Optional[str] = Field(
description="Internal test AWS secret access key",
default=None,
)
AWS_ACCESS_KEY_ID: Optional[str] = Field(
description="Internal test AWS access key ID",
default=None,
)
class MiddlewareConfig(
# place the configs in alphabet order
CeleryConfig,
@ -206,6 +224,7 @@ class MiddlewareConfig(
HuaweiCloudOBSStorageConfig,
OCIStorageConfig,
S3StorageConfig,
SupabaseStorageConfig,
TencentCloudCOSStorageConfig,
VolcengineTOSStorageConfig,
# configs of vdb and vdb providers
@ -224,5 +243,7 @@ class MiddlewareConfig(
TiDBVectorConfig,
WeaviateConfig,
ElasticsearchConfig,
InternalTestConfig,
VikingDBConfig,
):
pass

View File

@ -0,0 +1,24 @@
from typing import Optional
from pydantic import BaseModel, Field
class SupabaseStorageConfig(BaseModel):
"""
Configuration settings for Supabase Object Storage Service
"""
SUPABASE_BUCKET_NAME: Optional[str] = Field(
description="Name of the Supabase bucket to store and retrieve objects (e.g., 'dify-bucket')",
default=None,
)
SUPABASE_API_KEY: Optional[str] = Field(
description="API KEY for authenticating with Supabase",
default=None,
)
SUPABASE_URL: Optional[str] = Field(
description="URL of the Supabase",
default=None,
)

View File

@ -0,0 +1,45 @@
from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class BaiduVectorDBConfig(BaseSettings):
"""
Configuration settings for Baidu Vector Database
"""
BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field(
description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')",
default=None,
)
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field(
description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)",
default=30000,
)
BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field(
description="Account for authenticating with the Baidu Vector Database",
default=None,
)
BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field(
description="API key for authenticating with the Baidu Vector Database service",
default=None,
)
BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field(
description="Name of the specific Baidu Vector Database to connect to",
default=None,
)
BAIDU_VECTOR_DB_SHARD: PositiveInt = Field(
description="Number of shards for the Baidu Vector Database (default is 1)",
default=1,
)
BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
description="Number of replicas for the Baidu Vector Database (default is 3)",
default=3,
)

View File

@ -0,0 +1,37 @@
from typing import Optional
from pydantic import BaseModel, Field
class VikingDBConfig(BaseModel):
"""
Configuration for connecting to Volcengine VikingDB.
Refer to the following documentation for details on obtaining credentials:
https://www.volcengine.com/docs/6291/65568
"""
VIKINGDB_ACCESS_KEY: Optional[str] = Field(
default=None, description="The Access Key provided by Volcengine VikingDB for API authentication."
)
VIKINGDB_SECRET_KEY: Optional[str] = Field(
default=None, description="The Secret Key provided by Volcengine VikingDB for API authentication."
)
VIKINGDB_REGION: Optional[str] = Field(
default="cn-shanghai",
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
)
VIKINGDB_HOST: Optional[str] = Field(
default="api-vikingdb.mlp.cn-shanghai.volces.com",
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
'api-vikingdb.mlp.cn-shanghai.volces.com')",
)
VIKINGDB_SCHEME: Optional[str] = Field(
default="http",
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
)
VIKINGDB_CONNECTION_TIMEOUT: Optional[int] = Field(
default=30, description="The connection timeout of the Volcengine VikingDB service."
)
VIKINGDB_SOCKET_TIMEOUT: Optional[int] = Field(
default=30, description="The socket timeout of the Volcengine VikingDB service."
)

View File

@ -7,7 +7,7 @@ from flask_restful import Resource, reqparse
import services
from controllers.console import api
from controllers.console.setup import setup_required
from libs.helper import email, get_remote_ip
from libs.helper import email, extract_remote_ip
from libs.password import valid_password
from models.account import Account
from services.account_service import AccountService, TenantService
@ -40,17 +40,16 @@ class LoginApi(Resource):
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
}
token = AccountService.login(account, ip_address=get_remote_ip(request))
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
return {"result": "success", "data": token}
return {"result": "success", "data": token_pair.model_dump()}
class LogoutApi(Resource):
@setup_required
def get(self):
account = cast(Account, flask_login.current_user)
token = request.headers.get("Authorization", "").split(" ")[1]
AccountService.logout(account=account, token=token)
AccountService.logout(account=account)
flask_login.logout_user()
return {"result": "success"}
@ -106,5 +105,19 @@ class ResetPasswordApi(Resource):
return {"result": "success"}
class RefreshTokenApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("refresh_token", type=str, required=True, location="json")
args = parser.parse_args()
try:
new_token_pair = AccountService.refresh_token(args["refresh_token"])
return {"result": "success", "data": new_token_pair.model_dump()}
except Exception as e:
return {"result": "fail", "data": str(e)}, 401
api.add_resource(LoginApi, "/login")
api.add_resource(LogoutApi, "/logout")
api.add_resource(RefreshTokenApi, "/refresh-token")

View File

@ -9,7 +9,7 @@ from flask_restful import Resource
from configs import dify_config
from constants.languages import languages
from extensions.ext_database import db
from libs.helper import get_remote_ip
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models.account import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
@ -81,9 +81,14 @@ class OAuthCallback(Resource):
TenantService.create_owner_tenant_if_not_exist(account)
token = AccountService.login(account, ip_address=get_remote_ip(request))
token_pair = AccountService.login(
account=account,
ip_address=extract_remote_ip(request),
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
return redirect(
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
)
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:

View File

@ -617,6 +617,8 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
@ -653,6 +655,8 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (

View File

@ -13,6 +13,7 @@ from libs.login import login_required
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
def _validate_name(name):
@ -232,8 +233,31 @@ class ExternalKnowledgeHitTestingApi(Resource):
raise InternalServerError(str(e))
class BedrockRetrievalApi(Resource):
# this api is only for internal testing
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
parser.add_argument(
"query",
nullable=False,
required=True,
type=str,
)
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
args = parser.parse_args()
# Call the knowledge retrieval service
result = ExternalDatasetTestService.knowledge_retrieval(
args["retrieval_setting"], args["query"], args["knowledge_id"]
)
return result, 200
api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets/<uuid:dataset_id>/external-hit-testing")
api.add_resource(ExternalDatasetCreateApi, "/datasets/external")
api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api")
api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
# this api is only for internal test
api.add_resource(BedrockRetrievalApi, "/test/retrieval")

View File

@ -4,7 +4,7 @@ from flask import request
from flask_restful import Resource, reqparse
from configs import dify_config
from libs.helper import StrLen, email, get_remote_ip
from libs.helper import StrLen, email, extract_remote_ip
from libs.password import valid_password
from models.model import DifySetup
from services.account_service import RegisterService, TenantService
@ -46,7 +46,7 @@ class SetupApi(Resource):
# setup
RegisterService.setup(
email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request)
email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
)
return {"result": "success"}, 201

View File

@ -126,13 +126,12 @@ class ModelProviderIconApi(Resource):
Get model provider icon
"""
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str, icon_type: str, lang: str):
model_provider_service = ModelProviderService()
icon, mimetype = model_provider_service.get_model_provider_icon(
provider=provider, icon_type=icon_type, lang=lang
provider=provider,
icon_type=icon_type,
lang=lang,
)
return send_file(io.BytesIO(icon), mimetype=mimetype)

View File

@ -56,6 +56,7 @@ from models.account import Account
from models.model import Conversation, EndUser, Message
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowRunStatus,
)
@ -72,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow: Workflow
_user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
def __init__(
self,
@ -115,6 +117,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
}
self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}
self._conversation_name_generate_thread = None

View File

@ -52,6 +52,7 @@ from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun,
WorkflowRunStatus,
)
@ -69,6 +70,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
def __init__(
self,
@ -103,6 +105,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
}
self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""

View File

@ -1,8 +1,10 @@
import logging
from threading import Thread
from typing import Optional, Union
from flask import Flask, current_app
from configs import dify_config
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
@ -83,7 +85,9 @@ class MessageCycleManage:
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
conversation.name = name
except Exception as e:
logging.exception(f"generate conversation name failed: {e}")
if dify_config.DEBUG:
logging.exception(f"generate conversation name failed: {e}")
pass
db.session.merge(conversation)
db.session.commit()

View File

@ -57,6 +57,7 @@ class WorkflowCycleManage:
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
def _handle_workflow_run_start(self) -> WorkflowRun:
max_sequence = (
@ -251,6 +252,8 @@ class WorkflowCycleManage:
db.session.refresh(workflow_node_execution)
db.session.close()
self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution
return workflow_node_execution
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
@ -263,20 +266,36 @@ class WorkflowCycleManage:
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.execution_metadata: execution_metadata,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,
}
)
db.session.commit()
db.session.close()
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close()
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
return workflow_node_execution
@ -290,18 +309,33 @@ class WorkflowCycleManage:
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,
}
)
db.session.commit()
db.session.close()
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = event.error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close()
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
return workflow_node_execution
@ -678,17 +712,7 @@ class WorkflowCycleManage:
:param node_execution_id: workflow node execution id
:return:
"""
workflow_node_execution = (
db.session.query(WorkflowNodeExecution)
.filter(
WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id,
WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id,
WorkflowNodeExecution.workflow_id == self._workflow.id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.node_execution_id == node_execution_id,
)
.first()
)
workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id)
if not workflow_node_execution:
raise Exception(f"Workflow node execution not found: {node_execution_id}")

View File

@ -5,6 +5,7 @@ from typing import Optional, cast
import numpy as np
from sqlalchemy.exc import IntegrityError
from configs import dify_config
from core.embedding.embedding_constant import EmbeddingInputType
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelPropertyKey
@ -110,6 +111,8 @@ class CacheEmbedding(Embeddings):
embedding_results = embedding_result.embeddings[0]
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
except Exception as ex:
if dify_config.DEBUG:
logging.exception(f"Failed to embed query text: {ex}")
raise ex
try:
@ -122,6 +125,8 @@ class CacheEmbedding(Embeddings):
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 600, encoded_str)
except Exception as ex:
logging.exception("Failed to add embedding to redis %s", ex)
if dify_config.DEBUG:
logging.exception("Failed to add embedding to redis %s", ex)
raise ex
return embedding_results

View File

@ -60,8 +60,8 @@ class TokenBufferMemory:
thread_messages = extract_thread_messages(messages)
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
if thread_messages and not thread_messages[-1].answer:
thread_messages.pop()
if thread_messages and not thread_messages[0].answer:
thread_messages.pop(0)
messages = list(reversed(thread_messages))

View File

@ -1,8 +1,18 @@
from collections.abc import Generator
from typing import Optional, Union
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
@ -29,3 +39,53 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _add_custom_parameters(cls, credentials: dict) -> None:
credentials["mode"] = "chat"
credentials["endpoint_url"] = "https://api.siliconflow.cn/v1"
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
return AIModelEntity(
model=model,
label=I18nObject(en_US=model, zh_Hans=model),
model_type=ModelType.LLM,
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
if credentials.get("function_calling_type") == "tool_call"
else [],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)),
ModelPropertyKey.MODE: LLMMode.CHAT.value,
},
parameter_rules=[
ParameterRule(
name="temperature",
use_template="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=ParameterType.FLOAT,
),
ParameterRule(
name="max_tokens",
use_template="max_tokens",
default=512,
min=1,
max=int(credentials.get("max_tokens", 1024)),
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
type=ParameterType.INT,
),
ParameterRule(
name="top_p",
use_template="top_p",
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
type=ParameterType.FLOAT,
),
ParameterRule(
name="top_k",
use_template="top_k",
label=I18nObject(en_US="Top K", zh_Hans="Top K"),
type=ParameterType.FLOAT,
),
ParameterRule(
name="frequency_penalty",
use_template="frequency_penalty",
label=I18nObject(en_US="Frequency Penalty", zh_Hans="重复惩罚"),
type=ParameterType.FLOAT,
),
],
)

View File

@ -20,6 +20,7 @@ supported_model_types:
- speech2text
configurate_methods:
- predefined-model
- customizable-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
@ -30,3 +31,57 @@ provider_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: '4096'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
default: '4096'
type: text-input
show_on:
- variable: __model_type
value: llm
- variable: function_calling_type
label:
en_US: Function calling
type: select
required: false
default: no_call
options:
- value: no_call
label:
en_US: Not Support
zh_Hans: 不支持
- value: function_call
label:
en_US: Support
zh_Hans: 支持
show_on:
- variable: __model_type
value: llm

View File

@ -0,0 +1 @@
- gte-rerank

View File

@ -0,0 +1,4 @@
model: gte-rerank
model_type: rerank
model_properties:
context_size: 4000

View File

@ -0,0 +1,136 @@
from typing import Optional
import dashscope
from dashscope.common.error import (
AuthenticationError,
InvalidParameter,
RequestFailure,
ServiceUnavailableError,
UnsupportedHTTPMethod,
UnsupportedModel,
)
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
class GTERerankModel(RerankModel):
"""
Model class for GTE rerank model.
"""
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=docs)
# initialize client
dashscope.api_key = credentials["dashscope_api_key"]
response = dashscope.TextReRank.call(
query=query,
documents=docs,
model=model,
top_n=top_n,
return_documents=True,
)
rerank_documents = []
for _, result in enumerate(response.output.results):
# format document
rerank_document = RerankDocument(
index=result.index,
score=result.relevance_score,
text=result["document"]["text"],
)
# score threshold check
if score_threshold is not None:
if result.relevance_score >= score_threshold:
rerank_documents.append(rerank_document)
else:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self.invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
except Exception as ex:
print(ex)
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
RequestFailure,
],
InvokeServerUnavailableError: [
ServiceUnavailableError,
],
InvokeRateLimitError: [],
InvokeAuthorizationError: [
AuthenticationError,
],
InvokeBadRequestError: [
InvalidParameter,
UnsupportedModel,
UnsupportedHTTPMethod,
],
}

View File

@ -18,6 +18,7 @@ supported_model_types:
- llm
- tts
- text-embedding
- rerank
configurate_methods:
- predefined-model
- customizable-model

View File

@ -1,6 +1,10 @@
from collections.abc import Generator
from typing import Optional, Union
from zhipuai import ZhipuAI
from zhipuai.types.chat.chat_completion import Completion
from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@ -16,9 +20,6 @@ from core.model_runtime.entities.message_entities import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
from core.model_runtime.utils import helper
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.

View File

@ -1,13 +1,14 @@
import time
from typing import Optional
from zhipuai import ZhipuAI
from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):

View File

@ -1,15 +0,0 @@
from .__version__ import __version__
from ._client import ZhipuAI
from .core import (
APIAuthenticationError,
APIConnectionError,
APIInternalError,
APIReachLimitError,
APIRequestFailedError,
APIResponseError,
APIResponseValidationError,
APIServerFlowExceedError,
APIStatusError,
APITimeoutError,
ZhipuAIError,
)

View File

@ -1 +0,0 @@
__version__ = "v2.1.0"

View File

@ -1,82 +0,0 @@
from __future__ import annotations
import os
from collections.abc import Mapping
from typing import Union
import httpx
from httpx import Timeout
from typing_extensions import override
from . import api_resource
from .core import NOT_GIVEN, ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient, NotGiven, ZhipuAIError, _jwt_token
class ZhipuAI(HttpClient):
chat: api_resource.chat.Chat
api_key: str
_disable_token_cache: bool = True
def __init__(
self,
*,
api_key: str | None = None,
base_url: str | httpx.URL | None = None,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
http_client: httpx.Client | None = None,
custom_headers: Mapping[str, str] | None = None,
disable_token_cache: bool = True,
_strict_response_validation: bool = False,
) -> None:
if api_key is None:
api_key = os.environ.get("ZHIPUAI_API_KEY")
if api_key is None:
raise ZhipuAIError("未提供api_key请通过参数或环境变量提供")
self.api_key = api_key
self._disable_token_cache = disable_token_cache
if base_url is None:
base_url = os.environ.get("ZHIPUAI_BASE_URL")
if base_url is None:
base_url = "https://open.bigmodel.cn/api/paas/v4"
from .__version__ import __version__
super().__init__(
version=__version__,
base_url=base_url,
max_retries=max_retries,
timeout=timeout,
custom_httpx_client=http_client,
custom_headers=custom_headers,
_strict_response_validation=_strict_response_validation,
)
self.chat = api_resource.chat.Chat(self)
self.images = api_resource.images.Images(self)
self.embeddings = api_resource.embeddings.Embeddings(self)
self.files = api_resource.files.Files(self)
self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
self.batches = api_resource.Batches(self)
self.knowledge = api_resource.Knowledge(self)
self.tools = api_resource.Tools(self)
self.videos = api_resource.Videos(self)
self.assistant = api_resource.Assistant(self)
@property
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
if self._disable_token_cache:
return {"Authorization": f"Bearer {api_key}"}
else:
return {"Authorization": f"Bearer {_jwt_token.generate_token(api_key)}"}
def __del__(self) -> None:
if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close") or not hasattr(self, "_client"):
# if the '__init__' method raised an error, self would not have client attr
return
if self._has_custom_http_client:
return
self.close()

View File

@ -1,34 +0,0 @@
from .assistant import (
Assistant,
)
from .batches import Batches
from .chat import (
AsyncCompletions,
Chat,
Completions,
)
from .embeddings import Embeddings
from .files import Files, FilesWithRawResponse
from .fine_tuning import FineTuning
from .images import Images
from .knowledge import Knowledge
from .tools import Tools
from .videos import (
Videos,
)
__all__ = [
"Videos",
"AsyncCompletions",
"Chat",
"Completions",
"Images",
"Embeddings",
"Files",
"FilesWithRawResponse",
"FineTuning",
"Batches",
"Knowledge",
"Tools",
"Assistant",
]

View File

@ -1,3 +0,0 @@
from .assistant import Assistant
__all__ = ["Assistant"]

View File

@ -1,122 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
StreamResponse,
deepcopy_minimal,
make_request_options,
maybe_transform,
)
from ...types.assistant import AssistantCompletion
from ...types.assistant.assistant_conversation_resp import ConversationUsageListResp
from ...types.assistant.assistant_support_resp import AssistantSupportResp
if TYPE_CHECKING:
from ..._client import ZhipuAI
from ...types.assistant import assistant_conversation_params, assistant_create_params
__all__ = ["Assistant"]
class Assistant(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def conversation(
self,
assistant_id: str,
model: str,
messages: list[assistant_create_params.ConversationMessage],
*,
stream: bool = True,
conversation_id: Optional[str] = None,
attachments: Optional[list[assistant_create_params.AssistantAttachments]] = None,
metadata: dict | None = None,
request_id: Optional[str] = None,
user_id: Optional[str] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> StreamResponse[AssistantCompletion]:
body = deepcopy_minimal(
{
"assistant_id": assistant_id,
"model": model,
"messages": messages,
"stream": stream,
"conversation_id": conversation_id,
"attachments": attachments,
"metadata": metadata,
"request_id": request_id,
"user_id": user_id,
}
)
return self._post(
"/assistant",
body=maybe_transform(body, assistant_create_params.AssistantParameters),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=AssistantCompletion,
stream=stream or True,
stream_cls=StreamResponse[AssistantCompletion],
)
def query_support(
self,
*,
assistant_id_list: Optional[list[str]] = None,
request_id: Optional[str] = None,
user_id: Optional[str] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AssistantSupportResp:
body = deepcopy_minimal(
{
"assistant_id_list": assistant_id_list,
"request_id": request_id,
"user_id": user_id,
}
)
return self._post(
"/assistant/list",
body=body,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=AssistantSupportResp,
)
def query_conversation_usage(
self,
assistant_id: str,
page: int = 1,
page_size: int = 10,
*,
request_id: Optional[str] = None,
user_id: Optional[str] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ConversationUsageListResp:
body = deepcopy_minimal(
{
"assistant_id": assistant_id,
"page": page,
"page_size": page_size,
"request_id": request_id,
"user_id": user_id,
}
)
return self._post(
"/assistant/conversation/list",
body=maybe_transform(body, assistant_conversation_params.ConversationParameters),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=ConversationUsageListResp,
)

View File

@ -1,146 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Literal, Optional
import httpx
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options, maybe_transform
from ..core.pagination import SyncCursorPage
from ..types import batch_create_params, batch_list_params
from ..types.batch import Batch
if TYPE_CHECKING:
from .._client import ZhipuAI
class Batches(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
completion_window: str | None = None,
endpoint: Literal["/v1/chat/completions", "/v1/embeddings"],
input_file_id: str,
metadata: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
auto_delete_input_file: bool = True,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Batch:
return self._post(
"/batches",
body=maybe_transform(
{
"completion_window": completion_window,
"endpoint": endpoint,
"input_file_id": input_file_id,
"metadata": metadata,
"auto_delete_input_file": auto_delete_input_file,
},
batch_create_params.BatchCreateParams,
),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=Batch,
)
def retrieve(
self,
batch_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Batch:
"""
Retrieves a batch.
Args:
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not batch_id:
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
return self._get(
f"/batches/{batch_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=Batch,
)
def list(
self,
*,
after: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> SyncCursorPage[Batch]:
"""List your organization's batches.
Args:
after: A cursor for use in pagination.
`after` is an object ID that defines your place
in the list. For instance, if you make a list request and receive 100 objects,
ending with obj_foo, your subsequent call can include after=obj_foo in order to
fetch the next page of the list.
limit: A limit on the number of objects to be returned. Limit can range between 1 and
100, and the default is 20.
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
return self._get_api_list(
"/batches",
page=SyncCursorPage[Batch],
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"after": after,
"limit": limit,
},
batch_list_params.BatchListParams,
),
),
model=Batch,
)
def cancel(
self,
batch_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Batch:
"""
Cancels an in-progress batch.
Args:
batch_id: The ID of the batch to cancel.
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not batch_id:
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
return self._post(
f"/batches/{batch_id}/cancel",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=Batch,
)

View File

@ -1,5 +0,0 @@
from .async_completions import AsyncCompletions
from .chat import Chat
from .completions import Completions
__all__ = ["AsyncCompletions", "Chat", "Completions"]

View File

@ -1,115 +0,0 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Literal, Optional, Union
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
drop_prefix_image_data,
make_request_options,
maybe_transform,
)
from ...types.chat.async_chat_completion import AsyncCompletion, AsyncTaskStatus
from ...types.chat.code_geex import code_geex_params
from ...types.sensitive_word_check import SensitiveWordCheckRequest
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from ..._client import ZhipuAI
class AsyncCompletions(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
model: str,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
user_id: Optional[str] | NotGiven = NOT_GIVEN,
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
messages: Union[str, list[str], list[int], list[list[int]], None],
stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
tools: Optional[object] | NotGiven = NOT_GIVEN,
tool_choice: str | NotGiven = NOT_GIVEN,
meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AsyncTaskStatus:
_cast_type = AsyncTaskStatus
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
if temperature is not None and temperature != NOT_GIVEN:
if temperature <= 0:
do_sample = False
temperature = 0.01
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间do_sample重写为:false参数top_p temperature不生效") # noqa: E501
if temperature >= 1:
temperature = 0.99
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
if top_p is not None and top_p != NOT_GIVEN:
if top_p >= 1:
top_p = 0.99
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
if top_p <= 0:
top_p = 0.01
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
if isinstance(messages, list):
for item in messages:
if item.get("content"):
item["content"] = drop_prefix_image_data(item["content"])
body = {
"model": model,
"request_id": request_id,
"user_id": user_id,
"temperature": temperature,
"top_p": top_p,
"do_sample": do_sample,
"max_tokens": max_tokens,
"seed": seed,
"messages": messages,
"stop": stop,
"sensitive_word_check": sensitive_word_check,
"tools": tools,
"tool_choice": tool_choice,
"meta": meta,
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
}
return self._post(
"/async/chat/completions",
body=body,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=_cast_type,
stream=False,
)
def retrieve_completion_result(
self,
id: str,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Union[AsyncCompletion, AsyncTaskStatus]:
_cast_type = Union[AsyncCompletion, AsyncTaskStatus]
return self._get(
path=f"/async-result/{id}",
cast_type=_cast_type,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
)

View File

@ -1,18 +0,0 @@
from typing import TYPE_CHECKING
from ...core import BaseAPI, cached_property
from .async_completions import AsyncCompletions
from .completions import Completions
if TYPE_CHECKING:
pass
class Chat(BaseAPI):
@cached_property
def completions(self) -> Completions:
return Completions(self._client)
@cached_property
def asyncCompletions(self) -> AsyncCompletions: # noqa: N802
return AsyncCompletions(self._client)

View File

@ -1,108 +0,0 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Literal, Optional, Union
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
StreamResponse,
deepcopy_minimal,
drop_prefix_image_data,
make_request_options,
maybe_transform,
)
from ...types.chat.chat_completion import Completion
from ...types.chat.chat_completion_chunk import ChatCompletionChunk
from ...types.chat.code_geex import code_geex_params
from ...types.sensitive_word_check import SensitiveWordCheckRequest
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from ..._client import ZhipuAI
class Completions(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
model: str,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
user_id: Optional[str] | NotGiven = NOT_GIVEN,
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
messages: Union[str, list[str], list[int], object, None],
stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
tools: Optional[object] | NotGiven = NOT_GIVEN,
tool_choice: str | NotGiven = NOT_GIVEN,
meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Completion | StreamResponse[ChatCompletionChunk]:
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
if temperature is not None and temperature != NOT_GIVEN:
if temperature <= 0:
do_sample = False
temperature = 0.01
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间do_sample重写为:false参数top_p temperature不生效") # noqa: E501
if temperature >= 1:
temperature = 0.99
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
if top_p is not None and top_p != NOT_GIVEN:
if top_p >= 1:
top_p = 0.99
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
if top_p <= 0:
top_p = 0.01
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
if isinstance(messages, list):
for item in messages:
if item.get("content"):
item["content"] = drop_prefix_image_data(item["content"])
body = deepcopy_minimal(
{
"model": model,
"request_id": request_id,
"user_id": user_id,
"temperature": temperature,
"top_p": top_p,
"do_sample": do_sample,
"max_tokens": max_tokens,
"seed": seed,
"messages": messages,
"stop": stop,
"sensitive_word_check": sensitive_word_check,
"stream": stream,
"tools": tools,
"tool_choice": tool_choice,
"meta": meta,
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
}
)
return self._post(
"/chat/completions",
body=body,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=Completion,
stream=stream or False,
stream_cls=StreamResponse[ChatCompletionChunk],
)

View File

@ -1,50 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Union
import httpx
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options
from ..types.embeddings import EmbeddingsResponded
if TYPE_CHECKING:
from .._client import ZhipuAI
class Embeddings(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
input: Union[str, list[str], list[int], list[list[int]]],
model: Union[str],
dimensions: Union[int] | NotGiven = NOT_GIVEN,
encoding_format: str | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> EmbeddingsResponded:
_cast_type = EmbeddingsResponded
if disable_strict_validation:
_cast_type = object
return self._post(
"/embeddings",
body={
"input": input,
"model": model,
"dimensions": dimensions,
"encoding_format": encoding_format,
"user": user,
"request_id": request_id,
"sensitive_word_check": sensitive_word_check,
},
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=_cast_type,
stream=False,
)

View File

@ -1,194 +0,0 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import TYPE_CHECKING, Literal, Optional, cast
import httpx
from ..core import (
NOT_GIVEN,
BaseAPI,
Body,
FileTypes,
Headers,
NotGiven,
_legacy_binary_response,
_legacy_response,
deepcopy_minimal,
extract_files,
make_request_options,
maybe_transform,
)
from ..types.files import FileDeleted, FileObject, ListOfFileObject, UploadDetail, file_create_params
if TYPE_CHECKING:
from .._client import ZhipuAI
__all__ = ["Files", "FilesWithRawResponse"]
class Files(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
file: Optional[FileTypes] = None,
upload_detail: Optional[list[UploadDetail]] = None,
purpose: Literal["fine-tune", "retrieval", "batch"],
knowledge_id: Optional[str] = None,
sentence_size: Optional[int] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FileObject:
if not file and not upload_detail:
raise ValueError("At least one of `file` and `upload_detail` must be provided.")
body = deepcopy_minimal(
{
"file": file,
"upload_detail": upload_detail,
"purpose": purpose,
"knowledge_id": knowledge_id,
"sentence_size": sentence_size,
}
)
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
if files:
# It should be noted that the actual Content-Type header that will be
# sent to the server will contain a `boundary` parameter, e.g.
# multipart/form-data; boundary=---abc--
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return self._post(
"/files",
body=maybe_transform(body, file_create_params.FileCreateParams),
files=files,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FileObject,
)
# def retrieve(
# self,
# file_id: str,
# *,
# extra_headers: Headers | None = None,
# extra_body: Body | None = None,
# timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
# ) -> FileObject:
# """
# Returns information about a specific file.
#
# Args:
# file_id: The ID of the file to retrieve information about
# extra_headers: Send extra headers
#
# extra_body: Add additional JSON properties to the request
#
# timeout: Override the client-level default timeout for this request, in seconds
# """
# if not file_id:
# raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
# return self._get(
# f"/files/{file_id}",
# options=make_request_options(
# extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
# ),
# cast_type=FileObject,
# )
def list(
self,
*,
purpose: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
after: str | NotGiven = NOT_GIVEN,
order: str | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ListOfFileObject:
return self._get(
"/files",
cast_type=ListOfFileObject,
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query={
"purpose": purpose,
"limit": limit,
"after": after,
"order": order,
},
),
)
def delete(
self,
file_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FileDeleted:
"""
Delete a file.
Args:
file_id: The ID of the file to delete
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not file_id:
raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
return self._delete(
f"/files/{file_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FileDeleted,
)
def content(
self,
file_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> _legacy_response.HttpxBinaryResponseContent:
"""
Returns the contents of the specified file.
Args:
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not file_id:
raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
extra_headers = {"Accept": "application/binary", **(extra_headers or {})}
return self._get(
f"/files/{file_id}/content",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=_legacy_binary_response.HttpxBinaryResponseContent,
)
class FilesWithRawResponse:
def __init__(self, files: Files) -> None:
self._files = files
self.create = _legacy_response.to_raw_response_wrapper(
files.create,
)
self.list = _legacy_response.to_raw_response_wrapper(
files.list,
)
self.content = _legacy_response.to_raw_response_wrapper(
files.content,
)

View File

@ -1,5 +0,0 @@
from .fine_tuning import FineTuning
from .jobs import Jobs
from .models import FineTunedModels
__all__ = ["Jobs", "FineTunedModels", "FineTuning"]

View File

@ -1,18 +0,0 @@
from typing import TYPE_CHECKING
from ...core import BaseAPI, cached_property
from .jobs import Jobs
from .models import FineTunedModels
if TYPE_CHECKING:
pass
class FineTuning(BaseAPI):
@cached_property
def jobs(self) -> Jobs:
return Jobs(self._client)
@cached_property
def models(self) -> FineTunedModels:
return FineTunedModels(self._client)

View File

@ -1,3 +0,0 @@
from .jobs import Jobs
__all__ = ["Jobs"]

View File

@ -1,152 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import httpx
from ....core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
make_request_options,
)
from ....types.fine_tuning import (
FineTuningJob,
FineTuningJobEvent,
ListOfFineTuningJob,
job_create_params,
)
if TYPE_CHECKING:
from ...._client import ZhipuAI
__all__ = ["Jobs"]
class Jobs(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
model: str,
training_file: str,
hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
suffix: Optional[str] | NotGiven = NOT_GIVEN,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
validation_file: Optional[str] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
return self._post(
"/fine_tuning/jobs",
body={
"model": model,
"training_file": training_file,
"hyperparameters": hyperparameters,
"suffix": suffix,
"validation_file": validation_file,
"request_id": request_id,
},
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FineTuningJob,
)
def retrieve(
self,
fine_tuning_job_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
return self._get(
f"/fine_tuning/jobs/{fine_tuning_job_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FineTuningJob,
)
def list(
self,
*,
after: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ListOfFineTuningJob:
return self._get(
"/fine_tuning/jobs",
cast_type=ListOfFineTuningJob,
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query={
"after": after,
"limit": limit,
},
),
)
def cancel(
self,
fine_tuning_job_id: str,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # noqa: E501
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
if not fine_tuning_job_id:
raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}")
return self._post(
f"/fine_tuning/jobs/{fine_tuning_job_id}/cancel",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FineTuningJob,
)
def list_events(
self,
fine_tuning_job_id: str,
*,
after: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJobEvent:
return self._get(
f"/fine_tuning/jobs/{fine_tuning_job_id}/events",
cast_type=FineTuningJobEvent,
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query={
"after": after,
"limit": limit,
},
),
)
def delete(
self,
fine_tuning_job_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
if not fine_tuning_job_id:
raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}")
return self._delete(
f"/fine_tuning/jobs/{fine_tuning_job_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FineTuningJob,
)

View File

@ -1,3 +0,0 @@
from .fine_tuned_models import FineTunedModels
__all__ = ["FineTunedModels"]

View File

@ -1,41 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import httpx
from ....core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
make_request_options,
)
from ....types.fine_tuning.models import FineTunedModelsStatus
if TYPE_CHECKING:
from ...._client import ZhipuAI
__all__ = ["FineTunedModels"]
class FineTunedModels(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def delete(
self,
fine_tuned_model: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTunedModelsStatus:
if not fine_tuned_model:
raise ValueError(f"Expected a non-empty value for `fine_tuned_model` but received {fine_tuned_model!r}")
return self._delete(
f"fine_tuning/fine_tuned_models/{fine_tuned_model}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FineTunedModelsStatus,
)

View File

@ -1,59 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import httpx
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options
from ..types.image import ImagesResponded
from ..types.sensitive_word_check import SensitiveWordCheckRequest
if TYPE_CHECKING:
from .._client import ZhipuAI
class Images(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def generations(
self,
*,
prompt: str,
model: str | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
quality: Optional[str] | NotGiven = NOT_GIVEN,
response_format: Optional[str] | NotGiven = NOT_GIVEN,
size: Optional[str] | NotGiven = NOT_GIVEN,
style: Optional[str] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
user_id: Optional[str] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ImagesResponded:
_cast_type = ImagesResponded
if disable_strict_validation:
_cast_type = object
return self._post(
"/images/generations",
body={
"prompt": prompt,
"model": model,
"n": n,
"quality": quality,
"response_format": response_format,
"sensitive_word_check": sensitive_word_check,
"size": size,
"style": style,
"user": user,
"user_id": user_id,
"request_id": request_id,
},
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=_cast_type,
stream=False,
)

View File

@ -1,3 +0,0 @@
from .knowledge import Knowledge
__all__ = ["Knowledge"]

View File

@ -1,3 +0,0 @@
from .document import Document
__all__ = ["Document"]

View File

@ -1,217 +0,0 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import TYPE_CHECKING, Literal, Optional, cast
import httpx
from ....core import (
NOT_GIVEN,
BaseAPI,
Body,
FileTypes,
Headers,
NotGiven,
deepcopy_minimal,
extract_files,
make_request_options,
maybe_transform,
)
from ....types.files import UploadDetail, file_create_params
from ....types.knowledge.document import DocumentData, DocumentObject, document_edit_params, document_list_params
from ....types.knowledge.document.document_list_resp import DocumentPage
if TYPE_CHECKING:
from ...._client import ZhipuAI
__all__ = ["Document"]
class Document(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
file: Optional[FileTypes] = None,
custom_separator: Optional[list[str]] = None,
upload_detail: Optional[list[UploadDetail]] = None,
purpose: Literal["retrieval"],
knowledge_id: Optional[str] = None,
sentence_size: Optional[int] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> DocumentObject:
if not file and not upload_detail:
raise ValueError("At least one of `file` and `upload_detail` must be provided.")
body = deepcopy_minimal(
{
"file": file,
"upload_detail": upload_detail,
"purpose": purpose,
"custom_separator": custom_separator,
"knowledge_id": knowledge_id,
"sentence_size": sentence_size,
}
)
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
if files:
# It should be noted that the actual Content-Type header that will be
# sent to the server will contain a `boundary` parameter, e.g.
# multipart/form-data; boundary=---abc--
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return self._post(
"/files",
body=maybe_transform(body, file_create_params.FileCreateParams),
files=files,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=DocumentObject,
)
def edit(
self,
document_id: str,
knowledge_type: str,
*,
custom_separator: Optional[list[str]] = None,
sentence_size: Optional[int] = None,
callback_url: Optional[str] = None,
callback_header: Optional[dict[str, str]] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> httpx.Response:
"""
Args:
document_id: 知识id
knowledge_type: 知识类型:
1:文章知识: 支持pdf,url,docx
2.问答知识-文档: 支持pdf,url,docx
3.问答知识-表格: 支持xlsx
4.商品库-表格: 支持xlsx
5.自定义: 支持pdf,url,docx
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
:param knowledge_type:
:param document_id:
:param timeout:
:param extra_body:
:param callback_header:
:param sentence_size:
:param extra_headers:
:param callback_url:
:param custom_separator:
"""
if not document_id:
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
body = deepcopy_minimal(
{
"id": document_id,
"knowledge_type": knowledge_type,
"custom_separator": custom_separator,
"sentence_size": sentence_size,
"callback_url": callback_url,
"callback_header": callback_header,
}
)
return self._put(
f"/document/{document_id}",
body=maybe_transform(body, document_edit_params.DocumentEditParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=httpx.Response,
)
def list(
self,
knowledge_id: str,
*,
purpose: str | NotGiven = NOT_GIVEN,
page: str | NotGiven = NOT_GIVEN,
limit: str | NotGiven = NOT_GIVEN,
order: Literal["desc", "asc"] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> DocumentPage:
return self._get(
"/files",
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"knowledge_id": knowledge_id,
"purpose": purpose,
"page": page,
"limit": limit,
"order": order,
},
document_list_params.DocumentListParams,
),
),
cast_type=DocumentPage,
)
def delete(
self,
document_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> httpx.Response:
"""
Delete a file.
Args:
document_id: 知识id
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not document_id:
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
return self._delete(
f"/document/{document_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=httpx.Response,
)
def retrieve(
self,
document_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> DocumentData:
"""
Args:
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not document_id:
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
return self._get(
f"/document/{document_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=DocumentData,
)

View File

@ -1,173 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Literal, Optional
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
cached_property,
deepcopy_minimal,
make_request_options,
maybe_transform,
)
from ...types.knowledge import KnowledgeInfo, KnowledgeUsed, knowledge_create_params, knowledge_list_params
from ...types.knowledge.knowledge_list_resp import KnowledgePage
from .document import Document
if TYPE_CHECKING:
from ..._client import ZhipuAI
__all__ = ["Knowledge"]
class Knowledge(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
@cached_property
def document(self) -> Document:
return Document(self._client)
def create(
self,
embedding_id: int,
name: str,
*,
customer_identifier: Optional[str] = None,
description: Optional[str] = None,
background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None,
icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None,
bucket_id: Optional[str] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> KnowledgeInfo:
body = deepcopy_minimal(
{
"embedding_id": embedding_id,
"name": name,
"customer_identifier": customer_identifier,
"description": description,
"background": background,
"icon": icon,
"bucket_id": bucket_id,
}
)
return self._post(
"/knowledge",
body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=KnowledgeInfo,
)
def modify(
self,
knowledge_id: str,
embedding_id: int,
*,
name: str,
description: Optional[str] = None,
background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None,
icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> httpx.Response:
body = deepcopy_minimal(
{
"id": knowledge_id,
"embedding_id": embedding_id,
"name": name,
"description": description,
"background": background,
"icon": icon,
}
)
return self._put(
f"/knowledge/{knowledge_id}",
body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=httpx.Response,
)
def query(
self,
*,
page: int | NotGiven = 1,
size: int | NotGiven = 10,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> KnowledgePage:
return self._get(
"/knowledge",
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"page": page,
"size": size,
},
knowledge_list_params.KnowledgeListParams,
),
),
cast_type=KnowledgePage,
)
def delete(
self,
knowledge_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> httpx.Response:
"""
Delete a file.
Args:
knowledge_id: 知识库ID
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not knowledge_id:
raise ValueError("Expected a non-empty value for `knowledge_id`")
return self._delete(
f"/knowledge/{knowledge_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=httpx.Response,
)
def used(
self,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> KnowledgeUsed:
"""
Returns the contents of the specified file.
Args:
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
return self._get(
"/knowledge/capacity",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=KnowledgeUsed,
)

View File

@ -1,3 +0,0 @@
from .tools import Tools
__all__ = ["Tools"]

View File

@ -1,65 +0,0 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Literal, Optional, Union
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
StreamResponse,
deepcopy_minimal,
make_request_options,
maybe_transform,
)
from ...types.tools import WebSearch, WebSearchChunk, tools_web_search_params
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from ..._client import ZhipuAI
__all__ = ["Tools"]
class Tools(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def web_search(
self,
*,
model: str,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
messages: Union[str, list[str], list[int], object, None],
scope: Optional[str] | NotGiven = NOT_GIVEN,
location: Optional[str] | NotGiven = NOT_GIVEN,
recent_days: Optional[int] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> WebSearch | StreamResponse[WebSearchChunk]:
body = deepcopy_minimal(
{
"model": model,
"request_id": request_id,
"messages": messages,
"stream": stream,
"scope": scope,
"location": location,
"recent_days": recent_days,
}
)
return self._post(
"/tools",
body=maybe_transform(body, tools_web_search_params.WebSearchParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=WebSearch,
stream=stream or False,
stream_cls=StreamResponse[WebSearchChunk],
)

View File

@ -1,7 +0,0 @@
from .videos import (
Videos,
)
__all__ = [
"Videos",
]

View File

@ -1,77 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
deepcopy_minimal,
make_request_options,
maybe_transform,
)
from ...types.sensitive_word_check import SensitiveWordCheckRequest
from ...types.video import VideoObject, video_create_params
if TYPE_CHECKING:
from ..._client import ZhipuAI
__all__ = ["Videos"]
class Videos(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def generations(
self,
model: str,
*,
prompt: Optional[str] = None,
image_url: Optional[str] = None,
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
request_id: Optional[str] = None,
user_id: Optional[str] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> VideoObject:
if not model and not model:
raise ValueError("At least one of `model` and `prompt` must be provided.")
body = deepcopy_minimal(
{
"model": model,
"prompt": prompt,
"image_url": image_url,
"sensitive_word_check": sensitive_word_check,
"request_id": request_id,
"user_id": user_id,
}
)
return self._post(
"/videos/generations",
body=maybe_transform(body, video_create_params.VideoCreateParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=VideoObject,
)
def retrieve_videos_result(
self,
id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> VideoObject:
if not id:
raise ValueError("At least one of `id` must be provided.")
return self._get(
f"/async-result/{id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=VideoObject,
)

View File

@ -1,108 +0,0 @@
from ._base_api import BaseAPI
from ._base_compat import (
PYDANTIC_V2,
ConfigDict,
GenericModel,
cached_property,
field_get_default,
get_args,
get_model_config,
get_model_fields,
get_origin,
is_literal_type,
is_union,
parse_obj,
)
from ._base_models import BaseModel, construct_type
from ._base_type import (
NOT_GIVEN,
Body,
FileTypes,
Headers,
IncEx,
ModelT,
NotGiven,
Query,
)
from ._constants import (
ZHIPUAI_DEFAULT_LIMITS,
ZHIPUAI_DEFAULT_MAX_RETRIES,
ZHIPUAI_DEFAULT_TIMEOUT,
)
from ._errors import (
APIAuthenticationError,
APIConnectionError,
APIInternalError,
APIReachLimitError,
APIRequestFailedError,
APIResponseError,
APIResponseValidationError,
APIServerFlowExceedError,
APIStatusError,
APITimeoutError,
ZhipuAIError,
)
from ._files import is_file_content
from ._http_client import HttpClient, make_request_options
from ._sse_client import StreamResponse
from ._utils import (
deepcopy_minimal,
drop_prefix_image_data,
extract_files,
is_given,
is_list,
is_mapping,
maybe_transform,
parse_date,
parse_datetime,
)
__all__ = [
"BaseModel",
"construct_type",
"BaseAPI",
"NOT_GIVEN",
"Headers",
"NotGiven",
"Body",
"IncEx",
"ModelT",
"Query",
"FileTypes",
"PYDANTIC_V2",
"ConfigDict",
"GenericModel",
"get_args",
"is_union",
"parse_obj",
"get_origin",
"is_literal_type",
"get_model_config",
"get_model_fields",
"field_get_default",
"is_file_content",
"ZhipuAIError",
"APIStatusError",
"APIRequestFailedError",
"APIAuthenticationError",
"APIReachLimitError",
"APIInternalError",
"APIServerFlowExceedError",
"APIResponseError",
"APIResponseValidationError",
"APITimeoutError",
"make_request_options",
"HttpClient",
"ZHIPUAI_DEFAULT_TIMEOUT",
"ZHIPUAI_DEFAULT_MAX_RETRIES",
"ZHIPUAI_DEFAULT_LIMITS",
"is_list",
"is_mapping",
"parse_date",
"parse_datetime",
"is_given",
"maybe_transform",
"deepcopy_minimal",
"extract_files",
"StreamResponse",
]

View File

@ -1,19 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .._client import ZhipuAI
class BaseAPI:
_client: ZhipuAI
def __init__(self, client: ZhipuAI) -> None:
self._client = client
self._delete = client.delete
self._get = client.get
self._post = client.post
self._put = client.put
self._patch = client.patch
self._get_api_list = client.get_api_list

View File

@ -1,209 +0,0 @@
from __future__ import annotations
from collections.abc import Callable
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, overload
import pydantic
from pydantic.fields import FieldInfo
from typing_extensions import Self
from ._base_type import StrBytesIntFloat
_T = TypeVar("_T")
_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
# --------------- Pydantic v2 compatibility ---------------
# Pyright incorrectly reports some of our functions as overriding a method when they don't
# pyright: reportIncompatibleMethodOverride=false
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
# v1 re-exports
if TYPE_CHECKING:
def parse_date(value: date | StrBytesIntFloat) -> date: ...
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: ...
def get_args(t: type[Any]) -> tuple[Any, ...]: ...
def is_union(tp: type[Any] | None) -> bool: ...
def get_origin(t: type[Any]) -> type[Any] | None: ...
def is_literal_type(type_: type[Any]) -> bool: ...
def is_typeddict(type_: type[Any]) -> bool: ...
else:
if PYDANTIC_V2:
from pydantic.v1.typing import ( # noqa: I001
get_args as get_args, # noqa: PLC0414
is_union as is_union, # noqa: PLC0414
get_origin as get_origin, # noqa: PLC0414
is_typeddict as is_typeddict, # noqa: PLC0414
is_literal_type as is_literal_type, # noqa: PLC0414
)
from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414
else:
from pydantic.typing import ( # noqa: I001
get_args as get_args, # noqa: PLC0414
is_union as is_union, # noqa: PLC0414
get_origin as get_origin, # noqa: PLC0414
is_typeddict as is_typeddict, # noqa: PLC0414
is_literal_type as is_literal_type, # noqa: PLC0414
)
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414
# refactored config
if TYPE_CHECKING:
from pydantic import ConfigDict
else:
if PYDANTIC_V2:
from pydantic import ConfigDict
else:
# TODO: provide an error message here?
ConfigDict = None
# renamed methods / properties
def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
if PYDANTIC_V2:
return model.model_validate(value)
else:
# pyright: ignore[reportDeprecated, reportUnnecessaryCast]
return cast(_ModelT, model.parse_obj(value))
def field_is_required(field: FieldInfo) -> bool:
if PYDANTIC_V2:
return field.is_required()
return field.required # type: ignore
def field_get_default(field: FieldInfo) -> Any:
value = field.get_default()
if PYDANTIC_V2:
from pydantic_core import PydanticUndefined
if value == PydanticUndefined:
return None
return value
return value
def field_outer_type(field: FieldInfo) -> Any:
if PYDANTIC_V2:
return field.annotation
return field.outer_type_ # type: ignore
def get_model_config(model: type[pydantic.BaseModel]) -> Any:
if PYDANTIC_V2:
return model.model_config
return model.__config__ # type: ignore
def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
if PYDANTIC_V2:
return model.model_fields
return model.__fields__ # type: ignore
def model_copy(model: _ModelT) -> _ModelT:
if PYDANTIC_V2:
return model.model_copy()
return model.copy() # type: ignore
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
if PYDANTIC_V2:
return model.model_dump_json(indent=indent)
return model.json(indent=indent) # type: ignore
def model_dump(
model: pydantic.BaseModel,
*,
exclude_unset: bool = False,
exclude_defaults: bool = False,
) -> dict[str, Any]:
if PYDANTIC_V2:
return model.model_dump(
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
)
return cast(
"dict[str, Any]",
model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
),
)
def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
if PYDANTIC_V2:
return model.model_validate(data)
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
# generic models
if TYPE_CHECKING:
class GenericModel(pydantic.BaseModel): ...
else:
if PYDANTIC_V2:
# there no longer needs to be a distinction in v2 but
# we still have to create our own subclass to avoid
# inconsistent MRO ordering errors
class GenericModel(pydantic.BaseModel): ...
else:
import pydantic.generics
class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
# cached properties
if TYPE_CHECKING:
cached_property = property
# we define a separate type (copied from typeshed)
# that represents that `cached_property` is `set`able
# at runtime, which differs from `@property`.
#
# this is a separate type as editors likely special case
# `@property` and we don't want to cause issues just to have
# more helpful internal types.
class typed_cached_property(Generic[_T]): # noqa: N801
func: Callable[[Any], _T]
attrname: str | None
def __init__(self, func: Callable[[Any], _T]) -> None: ...
@overload
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ...
@overload
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ...
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
raise NotImplementedError()
def __set_name__(self, owner: type[Any], name: str) -> None: ...
# __set__ is not defined at runtime, but @cached_property is designed to be settable
def __set__(self, instance: object, value: _T) -> None: ...
else:
try:
from functools import cached_property
except ImportError:
from cached_property import cached_property
typed_cached_property = cached_property

View File

@ -1,670 +0,0 @@
from __future__ import annotations
import inspect
import os
from collections.abc import Callable
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeGuard, TypeVar, cast
import pydantic
import pydantic.generics
from pydantic.fields import FieldInfo
from typing_extensions import (
ParamSpec,
Protocol,
override,
runtime_checkable,
)
from ._base_compat import (
PYDANTIC_V2,
ConfigDict,
field_get_default,
get_args,
get_model_config,
get_model_fields,
get_origin,
is_literal_type,
is_union,
parse_obj,
)
from ._base_compat import (
GenericModel as BaseGenericModel,
)
from ._base_type import (
IncEx,
ModelT,
)
from ._utils import (
PropertyInfo,
coerce_boolean,
extract_type_arg,
is_annotated_type,
is_list,
is_mapping,
parse_date,
parse_datetime,
strip_annotated_type,
)
if TYPE_CHECKING:
from pydantic_core.core_schema import ModelField
__all__ = ["BaseModel", "GenericModel"]
_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")
_T = TypeVar("_T")
P = ParamSpec("P")
@runtime_checkable
class _ConfigProtocol(Protocol):
allow_population_by_field_name: bool
class BaseModel(pydantic.BaseModel):
if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
)
else:
@property
@override
def model_fields_set(self) -> set[str]:
# a forwards-compat shim for pydantic v2
return self.__fields_set__ # type: ignore
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
extra: Any = pydantic.Extra.allow # type: ignore
def to_dict(
self,
*,
mode: Literal["json", "python"] = "python",
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> dict[str, object]:
"""Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
mode:
If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`.
If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
exclude_none: Whether to exclude fields that have a value of `None` from the output.
warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2.
""" # noqa: E501
return self.model_dump(
mode=mode,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)
def to_json(
self,
*,
indent: int | None = 2,
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> str:
"""Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that have the default value.
exclude_none: Whether to exclude fields that have a value of `None`.
warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2.
""" # noqa: E501
return self.model_dump_json(
indent=indent,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)
@override
def __str__(self) -> str:
# mypy complains about an invalid self arg
return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc]
# Override the 'construct' method in a way that supports recursive parsing without validation.
# Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.
@classmethod
@override
def construct(
cls: type[ModelT],
_fields_set: set[str] | None = None,
**values: object,
) -> ModelT:
m = cls.__new__(cls)
fields_values: dict[str, object] = {}
config = get_model_config(cls)
populate_by_name = (
config.allow_population_by_field_name
if isinstance(config, _ConfigProtocol)
else config.get("populate_by_name")
)
if _fields_set is None:
_fields_set = set()
model_fields = get_model_fields(cls)
for name, field in model_fields.items():
key = field.alias
if key is None or (key not in values and populate_by_name):
key = name
if key in values:
fields_values[name] = _construct_field(value=values[key], field=field, key=key)
_fields_set.add(name)
else:
fields_values[name] = field_get_default(field)
_extra = {}
for key, value in values.items():
if key not in model_fields:
if PYDANTIC_V2:
_extra[key] = value
else:
_fields_set.add(key)
fields_values[key] = value
object.__setattr__(m, "__dict__", fields_values) # noqa: PLC2801
if PYDANTIC_V2:
# these properties are copied from Pydantic's `model_construct()` method
object.__setattr__(m, "__pydantic_private__", None) # noqa: PLC2801
object.__setattr__(m, "__pydantic_extra__", _extra) # noqa: PLC2801
object.__setattr__(m, "__pydantic_fields_set__", _fields_set) # noqa: PLC2801
else:
# init_private_attributes() does not exist in v2
m._init_private_attributes() # type: ignore
# copied from Pydantic v1's `construct()` method
object.__setattr__(m, "__fields_set__", _fields_set) # noqa: PLC2801
return m
if not TYPE_CHECKING:
# type checkers incorrectly complain about this assignment
# because the type signatures are technically different
# although not in practice
model_construct = construct
if not PYDANTIC_V2:
# we define aliases for some of the new pydantic v2 methods so
# that we can just document these methods without having to specify
# a specific pydantic version as some users may not know which
# pydantic version they are currently using
@override
def model_dump(
self,
*,
mode: Literal["json", "python"] | str = "python",
include: IncEx = None,
exclude: IncEx = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
context: dict[str, Any] | None = None,
serialize_as_any: bool = False,
) -> dict[str, Any]:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
Args:
mode: The mode in which `to_python` should run.
If mode is 'json', the dictionary will only contain JSON serializable types.
If mode is 'python', the dictionary may contain any Python objects.
include: A list of fields to include in the output.
exclude: A list of fields to exclude from the output.
by_alias: Whether to use the field's alias in the dictionary key if defined.
exclude_unset: Whether to exclude fields that are unset or None from the output.
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
exclude_none: Whether to exclude fields that have a value of `None` from the output.
round_trip: Whether to enable serialization and deserialization round-trip support.
warnings: Whether to log warnings when invalid fields are encountered.
Returns:
A dictionary representation of the model.
"""
if mode != "python":
raise ValueError("mode is only supported in Pydantic v2")
if round_trip != False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings != True:
raise ValueError("warnings is only supported in Pydantic v2")
if context is not None:
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
return super().dict( # pyright: ignore[reportDeprecated]
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
@override
def model_dump_json(
self,
*,
indent: int | None = None,
include: IncEx = None,
exclude: IncEx = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
context: dict[str, Any] | None = None,
serialize_as_any: bool = False,
) -> str:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json
Generates a JSON representation of the model using Pydantic's `to_json` method.
Args:
indent: Indentation to use in the JSON output. If None is passed, the output will be compact.
include: Field(s) to include in the JSON output. Can take either a string or set of strings.
exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.
by_alias: Whether to serialize using field aliases.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that have the default value.
exclude_none: Whether to exclude fields that have a value of `None`.
round_trip: Whether to use serialization/deserialization between JSON and class instance.
warnings: Whether to show any warnings that occurred during serialization.
Returns:
A JSON string representation of the model.
"""
if round_trip != False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings != True:
raise ValueError("warnings is only supported in Pydantic v2")
if context is not None:
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
return super().json( # type: ignore[reportDeprecated]
indent=indent,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
def _construct_field(value: object, field: FieldInfo, key: str) -> object:
if value is None:
return field_get_default(field)
if PYDANTIC_V2:
type_ = field.annotation
else:
type_ = cast(type, field.outer_type_) # type: ignore
if type_ is None:
raise RuntimeError(f"Unexpected field type is None for {key}")
return construct_type(value=value, type_=type_)
def is_basemodel(type_: type) -> bool:
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
if is_union(type_):
return any(is_basemodel(variant) for variant in get_args(type_))
return is_basemodel_type(type_)
def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
origin = get_origin(type_) or type_
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
def build(
base_model_cls: Callable[P, _BaseModelT],
*args: P.args,
**kwargs: P.kwargs,
) -> _BaseModelT:
"""Construct a BaseModel class without validation.
This is useful for cases where you need to instantiate a `BaseModel`
from an API response as this provides type-safe params which isn't supported
by helpers like `construct_type()`.
```py
build(MyModel, my_field_a="foo", my_field_b=123)
```
"""
if args:
raise TypeError(
"Received positional arguments which are not supported; Keyword arguments must be used instead",
)
return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs))
def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T:
"""Loose coercion to the expected type with construction of nested values.
Note: the returned value from this function is not guaranteed to match the
given type.
"""
return cast(_T, construct_type(value=value, type_=type_))
def construct_type(*, value: object, type_: type) -> object:
"""Loose coercion to the expected type with construction of nested values.
If the given value does not match the expected type then it is returned as-is.
"""
# we allow `object` as the input type because otherwise, passing things like
# `Literal['value']` will be reported as a type error by type checkers
type_ = cast("type[object]", type_)
# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(type_):
meta: tuple[Any, ...] = get_args(type_)[1:]
type_ = extract_type_arg(type_, 0)
else:
meta = ()
# we need to use the origin class for any types that are subscripted generics
# e.g. Dict[str, object]
origin = get_origin(type_) or type_
args = get_args(type_)
if is_union(origin):
try:
return validate_type(type_=cast("type[object]", type_), value=value)
except Exception:
pass
# if the type is a discriminated union then we want to construct the right variant
# in the union, even if the data doesn't match exactly, otherwise we'd break code
# that relies on the constructed class types, e.g.
#
# class FooType:
# kind: Literal['foo']
# value: str
#
# class BarType:
# kind: Literal['bar']
# value: int
#
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
# we'd end up constructing `FooType` when it should be `BarType`.
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
if discriminator and is_mapping(value):
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
if variant_value and isinstance(variant_value, str):
variant_type = discriminator.mapping.get(variant_value)
if variant_type:
return construct_type(type_=variant_type, value=value)
# if the data is not valid, use the first variant that doesn't fail while deserializing
for variant in args:
try:
return construct_type(value=value, type_=variant)
except Exception:
continue
raise RuntimeError(f"Could not convert data into a valid instance of {type_}")
if origin == dict:
if not is_mapping(value):
return value
_, items_type = get_args(type_) # Dict[_, items_type]
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):
if is_list(value):
return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
if is_mapping(value):
if issubclass(type_, BaseModel):
return type_.construct(**value) # type: ignore[arg-type]
return cast(Any, type_).construct(**value)
if origin == list:
if not is_list(value):
return value
inner_type = args[0] # List[inner_type]
return [construct_type(value=entry, type_=inner_type) for entry in value]
if origin == float:
if isinstance(value, int):
coerced = float(value)
if coerced != value:
return value
return coerced
return value
if type_ == datetime:
try:
return parse_datetime(value) # type: ignore
except Exception:
return value
if type_ == date:
try:
return parse_date(value) # type: ignore
except Exception:
return value
return value
@runtime_checkable
class CachedDiscriminatorType(Protocol):
__discriminator__: DiscriminatorDetails
class DiscriminatorDetails:
field_name: str
"""The name of the discriminator field in the variant class, e.g.
```py
class Foo(BaseModel):
type: Literal['foo']
```
Will result in field_name='type'
"""
field_alias_from: str | None
"""The name of the discriminator field in the API response, e.g.
```py
class Foo(BaseModel):
type: Literal['foo'] = Field(alias='type_from_api')
```
Will result in field_alias_from='type_from_api'
"""
mapping: dict[str, type]
"""Mapping of discriminator value to variant type, e.g.
{'foo': FooVariant, 'bar': BarVariant}
"""
def __init__(
self,
*,
mapping: dict[str, type],
discriminator_field: str,
discriminator_alias: str | None,
) -> None:
self.mapping = mapping
self.field_name = discriminator_field
self.field_alias_from = discriminator_alias
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
if isinstance(union, CachedDiscriminatorType):
return union.__discriminator__
discriminator_field_name: str | None = None
for annotation in meta_annotations:
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
discriminator_field_name = annotation.discriminator
break
if not discriminator_field_name:
return None
mapping: dict[str, type] = {}
discriminator_alias: str | None = None
for variant in get_args(union):
variant = strip_annotated_type(variant)
if is_basemodel_type(variant):
if PYDANTIC_V2:
field = _extract_field_schema_pv2(variant, discriminator_field_name)
if not field:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field.get("serialization_alias")
field_schema = field["schema"]
if field_schema["type"] == "literal":
for entry in cast("LiteralSchema", field_schema)["expected"]:
if isinstance(entry, str):
mapping[entry] = variant
else:
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
if not field_info:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field_info.alias
if field_info.annotation and is_literal_type(field_info.annotation):
for entry in get_args(field_info.annotation):
if isinstance(entry, str):
mapping[entry] = variant
if not mapping:
return None
details = DiscriminatorDetails(
mapping=mapping,
discriminator_field=discriminator_field_name,
discriminator_alias=discriminator_alias,
)
cast(CachedDiscriminatorType, union).__discriminator__ = details
return details
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
schema = model.__pydantic_core_schema__
if schema["type"] != "model":
return None
fields_schema = schema["schema"]
if fields_schema["type"] != "model-fields":
return None
fields_schema = cast("ModelFieldsSchema", fields_schema)
field = fields_schema["fields"].get(field_name)
if not field:
return None
return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
def validate_type(*, type_: type[_T], value: object) -> _T:
"""Strict validation that the given value matches the expected type"""
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
return cast(_T, parse_obj(type_, value))
return cast(_T, _validate_non_model_type(type_=type_, value=value))
# Subclassing here confuses type checkers, so we treat this class as non-inheriting.
if TYPE_CHECKING:
GenericModel = BaseModel
else:
class GenericModel(BaseGenericModel, BaseModel):
pass
if PYDANTIC_V2:
from pydantic import TypeAdapter
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
return TypeAdapter(type_).validate_python(value)
elif not TYPE_CHECKING:
class TypeAdapter(Generic[_T]):
"""Used as a placeholder to easily convert runtime types to a Pydantic format
to provide validation.
For example:
```py
validated = RootModel[int](__root__="5").__root__
# validated: 5
```
"""
def __init__(self, type_: type[_T]):
self.type_ = type_
def validate_python(self, value: Any) -> _T:
if not isinstance(value, self.type_):
raise ValueError(f"Invalid type: {value} is not of type {self.type_}")
return value
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
return TypeAdapter(type_).validate_python(value)

View File

@ -1,170 +0,0 @@
from __future__ import annotations
from collections.abc import Callable, Mapping, Sequence
from os import PathLike
from typing import (
IO,
TYPE_CHECKING,
Any,
Literal,
Optional,
TypeAlias,
TypeVar,
Union,
)
import pydantic
from httpx import Response
from typing_extensions import Protocol, TypedDict, override, runtime_checkable
Query = Mapping[str, object]
Body = object
AnyMapping = Mapping[str, object]
PrimitiveData = Union[str, int, float, bool, None]
Data = Union[PrimitiveData, list[Any], tuple[Any], "Mapping[str, Any]"]
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
_T = TypeVar("_T")
if TYPE_CHECKING:
NoneType: type[None]
else:
NoneType = type(None)
# Sentinel class used until PEP 0661 is accepted
class NotGiven:
"""
A sentinel singleton class used to distinguish omitted keyword arguments
from those passed in with the value None (which may have different behavior).
For example:
```py
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
get(timeout=1) # 1s timeout
get(timeout=None) # No timeout
get() # Default timeout behavior, which may not be statically known at the method definition.
```
"""
def __bool__(self) -> Literal[False]:
return False
@override
def __repr__(self) -> str:
return "NOT_GIVEN"
NotGivenOr = Union[_T, NotGiven]
NOT_GIVEN = NotGiven()
class Omit:
"""In certain situations you need to be able to represent a case where a default value has
to be explicitly removed and `None` is not an appropriate substitute, for example:
```py
# as the default `Content-Type` header is `application/json` that will be sent
client.post('/upload/files', files={'file': b'my raw file content'})
# you can't explicitly override the header as it has to be dynamically generated
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
client.post(..., headers={'Content-Type': 'multipart/form-data'})
# instead you can remove the default `application/json` header by passing Omit
client.post(..., headers={'Content-Type': Omit()})
```
"""
def __bool__(self) -> Literal[False]:
return False
@runtime_checkable
class ModelBuilderProtocol(Protocol):
@classmethod
def build(
cls: type[_T],
*,
response: Response,
data: object,
) -> _T: ...
Headers = Mapping[str, Union[str, Omit]]
class HeadersLikeProtocol(Protocol):
def get(self, __key: str) -> str | None: ...
HeadersLike = Union[Headers, HeadersLikeProtocol]
ResponseT = TypeVar(
"ResponseT",
bound="Union[str, None, BaseModel, list[Any], dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", # noqa: E501
)
StrBytesIntFloat = Union[str, bytes, int, float]
# Note: copied from Pydantic
# https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49
IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None"
PostParser = Callable[[Any], Any]
@runtime_checkable
class InheritsGeneric(Protocol):
"""Represents a type that has inherited from `Generic`
The `__orig_bases__` property can be used to determine the resolved
type variable for a given base class.
"""
__orig_bases__: tuple[_GenericAlias]
class _GenericAlias(Protocol):
__origin__: type[object]
class HttpxSendArgs(TypedDict, total=False):
auth: httpx.Auth
# for user input files
if TYPE_CHECKING:
Base64FileInput = Union[IO[bytes], PathLike[str]]
FileContent = Union[IO[bytes], bytes, PathLike[str]]
else:
Base64FileInput = Union[IO[bytes], PathLike]
FileContent = Union[IO[bytes], bytes, PathLike]
FileTypes = Union[
# file (or bytes)
FileContent,
# (filename, file (or bytes))
tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
tuple[Optional[str], FileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
]
RequestFiles = Union[Mapping[str, FileTypes], Sequence[tuple[str, FileTypes]]]
# duplicate of the above but without our custom file support
HttpxFileContent = Union[bytes, IO[bytes]]
HttpxFileTypes = Union[
# file (or bytes)
HttpxFileContent,
# (filename, file (or bytes))
tuple[Optional[str], HttpxFileContent],
# (filename, file (or bytes), content_type)
tuple[Optional[str], HttpxFileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
]
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[tuple[str, HttpxFileTypes]]]

View File

@ -1,12 +0,0 @@
import httpx
RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
# 通过 `Timeout` 控制接口`connect` 和 `read` 超时时间,默认为`timeout=300.0, connect=8.0`
ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
# 通过 `retry` 参数控制重试次数默认为3次
ZHIPUAI_DEFAULT_MAX_RETRIES = 3
# 通过 `Limits` 控制最大连接数和保持连接数,默认为`max_connections=50, max_keepalive_connections=10`
ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10)
INITIAL_RETRY_DELAY = 0.5
MAX_RETRY_DELAY = 8.0

View File

@ -1,86 +0,0 @@
from __future__ import annotations
import httpx
__all__ = [
"ZhipuAIError",
"APIStatusError",
"APIRequestFailedError",
"APIAuthenticationError",
"APIReachLimitError",
"APIInternalError",
"APIServerFlowExceedError",
"APIResponseError",
"APIResponseValidationError",
"APITimeoutError",
"APIConnectionError",
]
class ZhipuAIError(Exception):
def __init__(
self,
message: str,
) -> None:
super().__init__(message)
class APIStatusError(ZhipuAIError):
response: httpx.Response
status_code: int
def __init__(self, message: str, *, response: httpx.Response) -> None:
super().__init__(message)
self.response = response
self.status_code = response.status_code
class APIRequestFailedError(APIStatusError): ...
class APIAuthenticationError(APIStatusError): ...
class APIReachLimitError(APIStatusError): ...
class APIInternalError(APIStatusError): ...
class APIServerFlowExceedError(APIStatusError): ...
class APIResponseError(ZhipuAIError):
message: str
request: httpx.Request
json_data: object
def __init__(self, message: str, request: httpx.Request, json_data: object):
self.message = message
self.request = request
self.json_data = json_data
super().__init__(message)
class APIResponseValidationError(APIResponseError):
status_code: int
response: httpx.Response
def __init__(self, response: httpx.Response, json_data: object | None, *, message: str | None = None) -> None:
super().__init__(
message=message or "Data returned by API invalid for expected schema.",
request=response.request,
json_data=json_data,
)
self.response = response
self.status_code = response.status_code
class APIConnectionError(APIResponseError):
def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
super().__init__(message, request, json_data=None)
class APITimeoutError(APIConnectionError):
def __init__(self, request: httpx.Request) -> None:
super().__init__(message="Request timed out.", request=request)

View File

@ -1,75 +0,0 @@
from __future__ import annotations
import io
import os
import pathlib
from typing import TypeGuard, overload
from ._base_type import (
Base64FileInput,
FileContent,
FileTypes,
HttpxFileContent,
HttpxFileTypes,
HttpxRequestFiles,
RequestFiles,
)
from ._utils import is_mapping_t, is_sequence_t, is_tuple_t
def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
return isinstance(obj, io.IOBase | os.PathLike)
def is_file_content(obj: object) -> TypeGuard[FileContent]:
return isinstance(obj, bytes | tuple | io.IOBase | os.PathLike)
def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
if not is_file_content(obj):
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
raise RuntimeError(
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads"
) from None
@overload
def to_httpx_files(files: None) -> None: ...
@overload
def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ...
def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
if files is None:
return None
if is_mapping_t(files):
files = {key: _transform_file(file) for key, file in files.items()}
elif is_sequence_t(files):
files = [(key, _transform_file(file)) for key, file in files]
else:
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
return files
def _transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = pathlib.Path(file)
return (path.name, path.read_bytes())
return file
if is_tuple_t(file):
return (file[0], _read_file_content(file[1]), *file[2:])
raise TypeError("Expected file types input to be a FileContent type or to be a tuple")
def _read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return pathlib.Path(file).read_bytes()
return file

View File

@ -1,910 +0,0 @@
from __future__ import annotations
import inspect
import logging
import time
import warnings
from collections.abc import Iterator, Mapping
from itertools import starmap
from random import random
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
import httpx
import pydantic
from httpx import URL, Timeout
from . import _errors, get_origin
from ._base_compat import model_copy
from ._base_models import GenericModel, construct_type, validate_type
from ._base_type import (
NOT_GIVEN,
AnyMapping,
Body,
Data,
Headers,
HttpxSendArgs,
ModelBuilderProtocol,
NotGiven,
Omit,
PostParser,
Query,
RequestFiles,
ResponseT,
)
from ._constants import (
INITIAL_RETRY_DELAY,
MAX_RETRY_DELAY,
RAW_RESPONSE_HEADER,
ZHIPUAI_DEFAULT_LIMITS,
ZHIPUAI_DEFAULT_MAX_RETRIES,
ZHIPUAI_DEFAULT_TIMEOUT,
)
from ._errors import APIConnectionError, APIResponseValidationError, APIStatusError, APITimeoutError
from ._files import to_httpx_files
from ._legacy_response import LegacyAPIResponse
from ._request_opt import FinalRequestOptions, UserRequestInput
from ._response import APIResponse, BaseAPIResponse, extract_response_type
from ._sse_client import StreamResponse
from ._utils import flatten, is_given, is_mapping
log: logging.Logger = logging.getLogger(__name__)
# TODO: make base page type vars covariant
SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]")
# AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]")
_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
if TYPE_CHECKING:
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
else:
try:
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
except ImportError:
# taken from https://github.com/encode/httpx/blob/3ba5fe0d7ac70222590e759c31442b1cab263791/httpx/_config.py#L366
HTTPX_DEFAULT_TIMEOUT = Timeout(5.0)
headers = {
"Accept": "application/json",
"Content-Type": "application/json; charset=UTF-8",
}
class PageInfo:
"""Stores the necessary information to build the request to retrieve the next page.
Either `url` or `params` must be set.
"""
url: URL | NotGiven
params: Query | NotGiven
@overload
def __init__(
self,
*,
url: URL,
) -> None: ...
@overload
def __init__(
self,
*,
params: Query,
) -> None: ...
def __init__(
self,
*,
url: URL | NotGiven = NOT_GIVEN,
params: Query | NotGiven = NOT_GIVEN,
) -> None:
self.url = url
self.params = params
class BasePage(GenericModel, Generic[_T]):
"""
Defines the core interface for pagination.
Type Args:
ModelT: The pydantic model that represents an item in the response.
Methods:
has_next_page(): Check if there is another page available
next_page_info(): Get the necessary information to make a request for the next page
"""
_options: FinalRequestOptions = pydantic.PrivateAttr()
_model: type[_T] = pydantic.PrivateAttr()
def has_next_page(self) -> bool:
items = self._get_page_items()
if not items:
return False
return self.next_page_info() is not None
def next_page_info(self) -> Optional[PageInfo]: ...
def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body]
...
def _params_from_url(self, url: URL) -> httpx.QueryParams:
# TODO: do we have to preprocess params here?
return httpx.QueryParams(cast(Any, self._options.params)).merge(url.params)
def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
options = model_copy(self._options)
options._strip_raw_response_header()
if not isinstance(info.params, NotGiven):
options.params = {**options.params, **info.params}
return options
if not isinstance(info.url, NotGiven):
params = self._params_from_url(info.url)
url = info.url.copy_with(params=params)
options.params = dict(url.params)
options.url = str(url)
return options
raise ValueError("Unexpected PageInfo state")
class BaseSyncPage(BasePage[_T], Generic[_T]):
_client: HttpClient = pydantic.PrivateAttr()
def _set_private_attributes(
self,
client: HttpClient,
model: type[_T],
options: FinalRequestOptions,
) -> None:
self._model = model
self._client = client
self._options = options
# Pydantic uses a custom `__iter__` method to support casting BaseModels
# to dictionaries. e.g. dict(model).
# As we want to support `for item in page`, this is inherently incompatible
# with the default pydantic behavior. It is not possible to support both
# use cases at once. Fortunately, this is not a big deal as all other pydantic
# methods should continue to work as expected as there is an alternative method
# to cast a model to a dictionary, model.dict(), which is used internally
# by pydantic.
def __iter__(self) -> Iterator[_T]: # type: ignore
for page in self.iter_pages():
yield from page._get_page_items()
def iter_pages(self: SyncPageT) -> Iterator[SyncPageT]:
page = self
while True:
yield page
if page.has_next_page():
page = page.get_next_page()
else:
return
def get_next_page(self: SyncPageT) -> SyncPageT:
info = self.next_page_info()
if not info:
raise RuntimeError(
"No next page expected; please check `.has_next_page()` before calling `.get_next_page()`."
)
options = self._info_to_options(info)
return self._client._request_api_list(self._model, page=self.__class__, options=options)
class HttpClient:
_client: httpx.Client
_version: str
_base_url: URL
max_retries: int
timeout: Union[float, Timeout, None]
_limits: httpx.Limits
_has_custom_http_client: bool
_default_stream_cls: type[StreamResponse[Any]] | None = None
_strict_response_validation: bool
def __init__(
self,
*,
version: str,
base_url: URL,
_strict_response_validation: bool,
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
timeout: Union[float, Timeout, None],
limits: httpx.Limits | None = None,
custom_httpx_client: httpx.Client | None = None,
custom_headers: Mapping[str, str] | None = None,
) -> None:
if limits is not None:
warnings.warn(
"The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead", # noqa: E501
category=DeprecationWarning,
stacklevel=3,
)
if custom_httpx_client is not None:
raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`")
else:
limits = ZHIPUAI_DEFAULT_LIMITS
if not is_given(timeout):
if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT:
timeout = custom_httpx_client.timeout
else:
timeout = ZHIPUAI_DEFAULT_TIMEOUT
self.max_retries = max_retries
self.timeout = timeout
self._limits = limits
self._has_custom_http_client = bool(custom_httpx_client)
self._client = custom_httpx_client or httpx.Client(
base_url=base_url,
timeout=self.timeout,
limits=limits,
)
self._version = version
url = URL(url=base_url)
if not url.raw_path.endswith(b"/"):
url = url.copy_with(raw_path=url.raw_path + b"/")
self._base_url = url
self._custom_headers = custom_headers or {}
self._strict_response_validation = _strict_response_validation
def _prepare_url(self, url: str) -> URL:
sub_url = URL(url)
if sub_url.is_relative_url:
request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/")
return self._base_url.copy_with(raw_path=request_raw_url)
return sub_url
@property
def _default_headers(self):
return {
"Accept": "application/json",
"Content-Type": "application/json; charset=UTF-8",
"ZhipuAI-SDK-Ver": self._version,
"source_type": "zhipu-sdk-python",
"x-request-sdk": "zhipu-sdk-python",
**self.auth_headers,
**self._custom_headers,
}
@property
def custom_auth(self) -> httpx.Auth | None:
return None
@property
def auth_headers(self):
return {}
def _prepare_headers(self, options: FinalRequestOptions) -> httpx.Headers:
custom_headers = options.headers or {}
headers_dict = _merge_mappings(self._default_headers, custom_headers)
httpx_headers = httpx.Headers(headers_dict)
return httpx_headers
def _remaining_retries(
self,
remaining_retries: Optional[int],
options: FinalRequestOptions,
) -> int:
return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries)
def _calculate_retry_timeout(
self,
remaining_retries: int,
options: FinalRequestOptions,
response_headers: Optional[httpx.Headers] = None,
) -> float:
max_retries = options.get_max_retries(self.max_retries)
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
# retry_after = self._parse_retry_after_header(response_headers)
# if retry_after is not None and 0 < retry_after <= 60:
# return retry_after
nb_retries = max_retries - remaining_retries
# Apply exponential backoff, but not more than the max.
sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY)
# Apply some jitter, plus-or-minus half a second.
jitter = 1 - 0.25 * random()
timeout = sleep_seconds * jitter
return max(timeout, 0)
def _build_request(self, options: FinalRequestOptions) -> httpx.Request:
kwargs: dict[str, Any] = {}
headers = self._prepare_headers(options)
url = self._prepare_url(options.url)
json_data = options.json_data
if options.extra_json is not None:
if json_data is None:
json_data = cast(Body, options.extra_json)
elif is_mapping(json_data):
json_data = _merge_mappings(json_data, options.extra_json)
else:
raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")
content_type = headers.get("Content-Type")
# multipart/form-data; boundary=---abc--
if headers.get("Content-Type") == "multipart/form-data":
if "boundary" not in content_type:
# only remove the header if the boundary hasn't been explicitly set
# as the caller doesn't want httpx to come up with their own boundary
headers.pop("Content-Type")
if json_data:
kwargs["data"] = self._make_multipartform(json_data)
return self._client.build_request(
headers=headers,
timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout,
method=options.method,
url=url,
json=json_data,
files=options.files,
params=options.params,
**kwargs,
)
def _object_to_formdata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]:
items = []
if isinstance(value, Mapping):
for k, v in value.items():
items.extend(self._object_to_formdata(f"{key}[{k}]", v))
return items
if isinstance(value, list | tuple):
for v in value:
items.extend(self._object_to_formdata(key + "[]", v))
return items
def _primitive_value_to_str(val) -> str:
# copied from httpx
if val is True:
return "true"
elif val is False:
return "false"
elif val is None:
return ""
return str(val)
str_data = _primitive_value_to_str(value)
if not str_data:
return []
return [(key, str_data)]
def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
items = flatten(list(starmap(self._object_to_formdata, data.items())))
serialized: dict[str, object] = {}
for key, value in items:
if key in serialized:
raise ValueError(f"存在重复的键: {key};")
serialized[key] = value
return serialized
def _process_response_data(
self,
*,
data: object,
cast_type: type[ResponseT],
response: httpx.Response,
) -> ResponseT:
if data is None:
return cast(ResponseT, None)
if cast_type is object:
return cast(ResponseT, data)
try:
if inspect.isclass(cast_type) and issubclass(cast_type, ModelBuilderProtocol):
return cast(ResponseT, cast_type.build(response=response, data=data))
if self._strict_response_validation:
return cast(ResponseT, validate_type(type_=cast_type, value=data))
return cast(ResponseT, construct_type(type_=cast_type, value=data))
except pydantic.ValidationError as err:
raise APIResponseValidationError(response=response, json_data=data) from err
def _should_stream_response_body(self, request: httpx.Request) -> bool:
return request.headers.get(RAW_RESPONSE_HEADER) == "stream" # type: ignore[no-any-return]
def _should_retry(self, response: httpx.Response) -> bool:
# Note: this is not a standard header
should_retry_header = response.headers.get("x-should-retry")
# If the server explicitly says whether or not to retry, obey.
if should_retry_header == "true":
log.debug("Retrying as header `x-should-retry` is set to `true`")
return True
if should_retry_header == "false":
log.debug("Not retrying as header `x-should-retry` is set to `false`")
return False
# Retry on request timeouts.
if response.status_code == 408:
log.debug("Retrying due to status code %i", response.status_code)
return True
# Retry on lock timeouts.
if response.status_code == 409:
log.debug("Retrying due to status code %i", response.status_code)
return True
# Retry on rate limits.
if response.status_code == 429:
log.debug("Retrying due to status code %i", response.status_code)
return True
# Retry internal errors.
if response.status_code >= 500:
log.debug("Retrying due to status code %i", response.status_code)
return True
log.debug("Not retrying")
return False
def is_closed(self) -> bool:
return self._client.is_closed
def close(self):
self._client.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def request(
self,
cast_type: type[ResponseT],
options: FinalRequestOptions,
remaining_retries: Optional[int] = None,
*,
stream: bool = False,
stream_cls: type[StreamResponse] | None = None,
) -> ResponseT | StreamResponse:
return self._request(
cast_type=cast_type,
options=options,
stream=stream,
stream_cls=stream_cls,
remaining_retries=remaining_retries,
)
def _request(
self,
*,
cast_type: type[ResponseT],
options: FinalRequestOptions,
remaining_retries: int | None,
stream: bool,
stream_cls: type[StreamResponse] | None,
) -> ResponseT | StreamResponse:
retries = self._remaining_retries(remaining_retries, options)
request = self._build_request(options)
kwargs: HttpxSendArgs = {}
if self.custom_auth is not None:
kwargs["auth"] = self.custom_auth
try:
response = self._client.send(
request,
stream=stream or self._should_stream_response_body(request=request),
**kwargs,
)
except httpx.TimeoutException as err:
log.debug("Encountered httpx.TimeoutException", exc_info=True)
if retries > 0:
return self._retry_request(
options,
cast_type,
retries,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
)
log.debug("Raising timeout error")
raise APITimeoutError(request=request) from err
except Exception as err:
log.debug("Encountered Exception", exc_info=True)
if retries > 0:
return self._retry_request(
options,
cast_type,
retries,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
)
log.debug("Raising connection error")
raise APIConnectionError(request=request) from err
log.debug(
'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
)
try:
response.raise_for_status()
except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
if retries > 0 and self._should_retry(err.response):
err.response.close()
return self._retry_request(
options,
cast_type,
retries,
err.response.headers,
stream=stream,
stream_cls=stream_cls,
)
# If the response is streamed then we need to explicitly read the response
# to completion before attempting to access the response text.
if not err.response.is_closed:
err.response.read()
log.debug("Re-raising status error")
raise self._make_status_error(err.response) from None
# return self._parse_response(
# cast_type=cast_type,
# options=options,
# response=response,
# stream=stream,
# stream_cls=stream_cls,
# )
return self._process_response(
cast_type=cast_type,
options=options,
response=response,
stream=stream,
stream_cls=stream_cls,
)
def _retry_request(
self,
options: FinalRequestOptions,
cast_type: type[ResponseT],
remaining_retries: int,
response_headers: httpx.Headers | None,
*,
stream: bool,
stream_cls: type[StreamResponse] | None,
) -> ResponseT | StreamResponse:
remaining = remaining_retries - 1
if remaining == 1:
log.debug("1 retry left")
else:
log.debug("%i retries left", remaining)
timeout = self._calculate_retry_timeout(remaining, options, response_headers)
log.info("Retrying request to %s in %f seconds", options.url, timeout)
# In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a
# different thread if necessary.
time.sleep(timeout)
return self._request(
options=options,
cast_type=cast_type,
remaining_retries=remaining,
stream=stream,
stream_cls=stream_cls,
)
def _process_response(
self,
*,
cast_type: type[ResponseT],
options: FinalRequestOptions,
response: httpx.Response,
stream: bool,
stream_cls: type[StreamResponse] | None,
) -> ResponseT:
# _legacy_response with raw_response_header to parser method
if response.request.headers.get(RAW_RESPONSE_HEADER) == "true":
return cast(
ResponseT,
LegacyAPIResponse(
raw=response,
client=self,
cast_type=cast_type,
stream=stream,
stream_cls=stream_cls,
options=options,
),
)
origin = get_origin(cast_type) or cast_type
if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse):
if not issubclass(origin, APIResponse):
raise TypeError(f"API Response types must subclass {APIResponse}; Received {origin}")
response_cls = cast("type[BaseAPIResponse[Any]]", cast_type)
return cast(
ResponseT,
response_cls(
raw=response,
client=self,
cast_type=extract_response_type(response_cls),
stream=stream,
stream_cls=stream_cls,
options=options,
),
)
if cast_type == httpx.Response:
return cast(ResponseT, response)
api_response = APIResponse(
raw=response,
client=self,
cast_type=cast("type[ResponseT]", cast_type), # pyright: ignore[reportUnnecessaryCast]
stream=stream,
stream_cls=stream_cls,
options=options,
)
if bool(response.request.headers.get(RAW_RESPONSE_HEADER)):
return cast(ResponseT, api_response)
return api_response.parse()
def _request_api_list(
self,
model: type[object],
page: type[SyncPageT],
options: FinalRequestOptions,
) -> SyncPageT:
def _parser(resp: SyncPageT) -> SyncPageT:
resp._set_private_attributes(
client=self,
model=model,
options=options,
)
return resp
options.post_parser = _parser
return self.request(page, options, stream=False)
@overload
def get(
self,
path: str,
*,
cast_type: type[ResponseT],
options: UserRequestInput = {},
stream: Literal[False] = False,
) -> ResponseT: ...
@overload
def get(
self,
path: str,
*,
cast_type: type[ResponseT],
options: UserRequestInput = {},
stream: Literal[True],
stream_cls: type[StreamResponse],
) -> StreamResponse: ...
@overload
def get(
self,
path: str,
*,
cast_type: type[ResponseT],
options: UserRequestInput = {},
stream: bool,
stream_cls: type[StreamResponse] | None = None,
) -> ResponseT | StreamResponse: ...
def get(
self,
path: str,
*,
cast_type: type[ResponseT],
options: UserRequestInput = {},
stream: bool = False,
stream_cls: type[StreamResponse] | None = None,
) -> ResponseT:
opts = FinalRequestOptions.construct(method="get", url=path, **options)
return cast(ResponseT, self.request(cast_type, opts, stream=stream, stream_cls=stream_cls))
@overload
def post(
self,
path: str,
*,
cast_type: type[ResponseT],
body: Body | None = None,
options: UserRequestInput = {},
files: RequestFiles | None = None,
stream: Literal[False] = False,
) -> ResponseT: ...
@overload
def post(
self,
path: str,
*,
cast_type: type[ResponseT],
body: Body | None = None,
options: UserRequestInput = {},
files: RequestFiles | None = None,
stream: Literal[True],
stream_cls: type[StreamResponse],
) -> StreamResponse: ...
@overload
def post(
self,
path: str,
*,
cast_type: type[ResponseT],
body: Body | None = None,
options: UserRequestInput = {},
files: RequestFiles | None = None,
stream: bool,
stream_cls: type[StreamResponse] | None = None,
) -> ResponseT | StreamResponse: ...
def post(
self,
path: str,
*,
cast_type: type[ResponseT],
body: Body | None = None,
options: UserRequestInput = {},
files: RequestFiles | None = None,
stream: bool = False,
stream_cls: type[StreamResponse[Any]] | None = None,
) -> ResponseT | StreamResponse:
opts = FinalRequestOptions.construct(
method="post", url=path, json_data=body, files=to_httpx_files(files), **options
)
return cast(ResponseT, self.request(cast_type, opts, stream=stream, stream_cls=stream_cls))
def patch(
self,
path: str,
*,
cast_type: type[ResponseT],
body: Body | None = None,
options: UserRequestInput = {},
) -> ResponseT:
opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options)
return self.request(
cast_type=cast_type,
options=opts,
)
def put(
self,
path: str,
*,
cast_type: type[ResponseT],
body: Body | None = None,
options: UserRequestInput = {},
files: RequestFiles | None = None,
) -> ResponseT | StreamResponse:
opts = FinalRequestOptions.construct(
method="put", url=path, json_data=body, files=to_httpx_files(files), **options
)
return self.request(
cast_type=cast_type,
options=opts,
)
def delete(
self,
path: str,
*,
cast_type: type[ResponseT],
body: Body | None = None,
options: UserRequestInput = {},
) -> ResponseT | StreamResponse:
opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, **options)
return self.request(
cast_type=cast_type,
options=opts,
)
def get_api_list(
self,
path: str,
*,
model: type[object],
page: type[SyncPageT],
body: Body | None = None,
options: UserRequestInput = {},
method: str = "get",
) -> SyncPageT:
opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options)
return self._request_api_list(model, page, opts)
def _make_status_error(self, response) -> APIStatusError:
response_text = response.text.strip()
status_code = response.status_code
error_msg = f"Error code: {status_code}, with error text {response_text}"
if status_code == 400:
return _errors.APIRequestFailedError(message=error_msg, response=response)
elif status_code == 401:
return _errors.APIAuthenticationError(message=error_msg, response=response)
elif status_code == 429:
return _errors.APIReachLimitError(message=error_msg, response=response)
elif status_code == 500:
return _errors.APIInternalError(message=error_msg, response=response)
elif status_code == 503:
return _errors.APIServerFlowExceedError(message=error_msg, response=response)
return APIStatusError(message=error_msg, response=response)
def make_request_options(
*,
query: Query | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
post_parser: PostParser | NotGiven = NOT_GIVEN,
) -> UserRequestInput:
"""Create a dict of type RequestOptions without keys of NotGiven values."""
options: UserRequestInput = {}
if extra_headers is not None:
options["headers"] = extra_headers
if extra_body is not None:
options["extra_json"] = cast(AnyMapping, extra_body)
if query is not None:
options["params"] = query
if extra_query is not None:
options["params"] = {**options.get("params", {}), **extra_query}
if not isinstance(timeout, NotGiven):
options["timeout"] = timeout
if is_given(post_parser):
# internal
options["post_parser"] = post_parser # type: ignore
return options
def _merge_mappings(
obj1: Mapping[_T_co, Union[_T, Omit]],
obj2: Mapping[_T_co, Union[_T, Omit]],
) -> dict[_T_co, _T]:
"""Merge two mappings of the same type, removing any values that are instances of `Omit`.
In cases with duplicate keys the second mapping takes precedence.
"""
merged = {**obj1, **obj2}
return {key: value for key, value in merged.items() if not isinstance(value, Omit)}

View File

@ -1,31 +0,0 @@
import time
import cachetools.func
import jwt
# 缓存时间 3分钟
CACHE_TTL_SECONDS = 3 * 60
# token 有效期比缓存时间 多30秒
API_TOKEN_TTL_SECONDS = CACHE_TTL_SECONDS + 30
@cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS)
def generate_token(apikey: str):
try:
api_key, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid api_key", e)
payload = {
"api_key": api_key,
"exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000,
"timestamp": int(round(time.time() * 1000)),
}
ret = jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
return ret

View File

@ -1,207 +0,0 @@
from __future__ import annotations
import os
from collections.abc import AsyncIterator, Iterator
from typing import Any
import httpx
class HttpxResponseContent:
@property
def content(self) -> bytes:
raise NotImplementedError("This method is not implemented for this class.")
@property
def text(self) -> str:
raise NotImplementedError("This method is not implemented for this class.")
@property
def encoding(self) -> str | None:
raise NotImplementedError("This method is not implemented for this class.")
@property
def charset_encoding(self) -> str | None:
raise NotImplementedError("This method is not implemented for this class.")
def json(self, **kwargs: Any) -> Any:
raise NotImplementedError("This method is not implemented for this class.")
def read(self) -> bytes:
raise NotImplementedError("This method is not implemented for this class.")
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
raise NotImplementedError("This method is not implemented for this class.")
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
raise NotImplementedError("This method is not implemented for this class.")
def iter_lines(self) -> Iterator[str]:
raise NotImplementedError("This method is not implemented for this class.")
def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]:
raise NotImplementedError("This method is not implemented for this class.")
def write_to_file(
self,
file: str | os.PathLike[str],
) -> None:
raise NotImplementedError("This method is not implemented for this class.")
def stream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
raise NotImplementedError("This method is not implemented for this class.")
def close(self) -> None:
raise NotImplementedError("This method is not implemented for this class.")
async def aread(self) -> bytes:
raise NotImplementedError("This method is not implemented for this class.")
async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
raise NotImplementedError("This method is not implemented for this class.")
async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
raise NotImplementedError("This method is not implemented for this class.")
async def aiter_lines(self) -> AsyncIterator[str]:
raise NotImplementedError("This method is not implemented for this class.")
async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
raise NotImplementedError("This method is not implemented for this class.")
async def astream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
raise NotImplementedError("This method is not implemented for this class.")
async def aclose(self) -> None:
raise NotImplementedError("This method is not implemented for this class.")
class HttpxBinaryResponseContent(HttpxResponseContent):
response: httpx.Response
def __init__(self, response: httpx.Response) -> None:
self.response = response
@property
def content(self) -> bytes:
return self.response.content
@property
def encoding(self) -> str | None:
return self.response.encoding
@property
def charset_encoding(self) -> str | None:
return self.response.charset_encoding
def read(self) -> bytes:
return self.response.read()
def text(self) -> str:
raise NotImplementedError("Not implemented for binary response content")
def json(self, **kwargs: Any) -> Any:
raise NotImplementedError("Not implemented for binary response content")
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
raise NotImplementedError("Not implemented for binary response content")
def iter_lines(self) -> Iterator[str]:
raise NotImplementedError("Not implemented for binary response content")
async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
raise NotImplementedError("Not implemented for binary response content")
async def aiter_lines(self) -> AsyncIterator[str]:
raise NotImplementedError("Not implemented for binary response content")
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
return self.response.iter_bytes(chunk_size)
def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]:
return self.response.iter_raw(chunk_size)
def write_to_file(
self,
file: str | os.PathLike[str],
) -> None:
"""Write the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
Note: if you want to stream the data to the file instead of writing
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `client.with_streaming_response.foo().stream_to_file('my_filename.txt')`
"""
with open(file, mode="wb") as f:
for data in self.response.iter_bytes():
f.write(data)
def stream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
with open(file, mode="wb") as f:
for data in self.response.iter_bytes(chunk_size):
f.write(data)
def close(self) -> None:
return self.response.close()
async def aread(self) -> bytes:
return await self.response.aread()
async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
return self.response.aiter_bytes(chunk_size)
async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
return self.response.aiter_raw(chunk_size)
async def astream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
path = anyio.Path(file)
async with await path.open(mode="wb") as f:
async for data in self.response.aiter_bytes(chunk_size):
await f.write(data)
async def aclose(self) -> None:
return await self.response.aclose()
class HttpxTextBinaryResponseContent(HttpxBinaryResponseContent):
response: httpx.Response
@property
def text(self) -> str:
return self.response.text
def json(self, **kwargs: Any) -> Any:
return self.response.json(**kwargs)
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
return self.response.iter_text(chunk_size)
def iter_lines(self) -> Iterator[str]:
return self.response.iter_lines()
async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
return self.response.aiter_text(chunk_size)
async def aiter_lines(self) -> AsyncIterator[str]:
return self.response.aiter_lines()

View File

@ -1,341 +0,0 @@
from __future__ import annotations
import datetime
import functools
import inspect
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, get_origin, overload
import httpx
import pydantic
from typing_extensions import ParamSpec, override
from ._base_models import BaseModel, is_basemodel
from ._base_type import NoneType
from ._constants import RAW_RESPONSE_HEADER
from ._errors import APIResponseValidationError
from ._legacy_binary_response import HttpxResponseContent, HttpxTextBinaryResponseContent
from ._sse_client import StreamResponse, extract_stream_chunk_type, is_stream_class_type
from ._utils import extract_type_arg, is_annotated_type, is_given
if TYPE_CHECKING:
from ._http_client import HttpClient
from ._request_opt import FinalRequestOptions
P = ParamSpec("P")
R = TypeVar("R")
_T = TypeVar("_T")
log: logging.Logger = logging.getLogger(__name__)
class LegacyAPIResponse(Generic[R]):
"""This is a legacy class as it will be replaced by `APIResponse`
and `AsyncAPIResponse` in the `_response.py` file in the next major
release.
For the sync client this will mostly be the same with the exception
of `content` & `text` will be methods instead of properties. In the
async client, all methods will be async.
A migration script will be provided & the migration in general should
be smooth.
"""
_cast_type: type[R]
_client: HttpClient
_parsed_by_type: dict[type[Any], Any]
_stream: bool
_stream_cls: type[StreamResponse[Any]] | None
_options: FinalRequestOptions
http_response: httpx.Response
def __init__(
self,
*,
raw: httpx.Response,
cast_type: type[R],
client: HttpClient,
stream: bool,
stream_cls: type[StreamResponse[Any]] | None,
options: FinalRequestOptions,
) -> None:
self._cast_type = cast_type
self._client = client
self._parsed_by_type = {}
self._stream = stream
self._stream_cls = stream_cls
self._options = options
self.http_response = raw
@property
def request_id(self) -> str | None:
return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return]
@overload
def parse(self, *, to: type[_T]) -> _T: ...
@overload
def parse(self) -> R: ...
def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
NOTE: For the async client: this will become a coroutine in the next major version.
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
You can customize the type that the response is parsed into through
the `to` argument, e.g.
```py
from zhipuai import BaseModel
class MyModel(BaseModel):
foo: str
obj = response.parse(to=MyModel)
print(obj.foo)
```
We support parsing:
- `BaseModel`
- `dict`
- `list`
- `Union`
- `str`
- `int`
- `float`
- `httpx.Response`
"""
cache_key = to if to is not None else self._cast_type
cached = self._parsed_by_type.get(cache_key)
if cached is not None:
return cached # type: ignore[no-any-return]
parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)
self._parsed_by_type[cache_key] = parsed
return parsed
@property
def headers(self) -> httpx.Headers:
return self.http_response.headers
@property
def http_request(self) -> httpx.Request:
return self.http_response.request
@property
def status_code(self) -> int:
return self.http_response.status_code
@property
def url(self) -> httpx.URL:
return self.http_response.url
@property
def method(self) -> str:
return self.http_request.method
@property
def content(self) -> bytes:
"""Return the binary response content.
NOTE: this will be removed in favour of `.read()` in the
next major version.
"""
return self.http_response.content
@property
def text(self) -> str:
"""Return the decoded response content.
NOTE: this will be turned into a method in the next major version.
"""
return self.http_response.text
@property
def http_version(self) -> str:
return self.http_response.http_version
@property
def is_closed(self) -> bool:
return self.http_response.is_closed
@property
def elapsed(self) -> datetime.timedelta:
"""The time taken for the complete request/response cycle to complete."""
return self.http_response.elapsed
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
# unwrap `Annotated[T, ...]` -> `T`
if to and is_annotated_type(to):
to = extract_type_arg(to, 0)
if self._stream:
if to:
if not is_stream_class_type(to):
raise TypeError(f"Expected custom parse type to be a subclass of {StreamResponse}")
return cast(
_T,
to(
cast_type=extract_stream_chunk_type(
to,
failure_message="Expected custom stream type to be passed with a type argument, e.g. StreamResponse[ChunkType]", # noqa: E501
),
response=self.http_response,
client=cast(Any, self._client),
),
)
if self._stream_cls:
return cast(
R,
self._stream_cls(
cast_type=extract_stream_chunk_type(self._stream_cls),
response=self.http_response,
client=cast(Any, self._client),
),
)
stream_cls = cast("type[StreamResponse[Any]] | None", self._client._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return cast(
R,
stream_cls(
cast_type=self._cast_type,
response=self.http_response,
client=cast(Any, self._client),
),
)
cast_type = to if to is not None else self._cast_type
# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(cast_type):
cast_type = extract_type_arg(cast_type, 0)
if cast_type is NoneType:
return cast(R, None)
response = self.http_response
if cast_type == str:
return cast(R, response.text)
if cast_type == int:
return cast(R, int(response.text))
if cast_type == float:
return cast(R, float(response.text))
origin = get_origin(cast_type) or cast_type
if inspect.isclass(origin) and issubclass(origin, HttpxResponseContent):
# in the response, e.g. mime file
*_, filename = response.headers.get("content-disposition", "").split("filename=")
# 判断文件类型是jsonl类型的使用HttpxTextBinaryResponseContent
if filename and filename.endswith(".jsonl") or filename and filename.endswith(".xlsx"):
return cast(R, HttpxTextBinaryResponseContent(response))
else:
return cast(R, cast_type(response)) # type: ignore
if origin == LegacyAPIResponse:
raise RuntimeError("Unexpected state - cast_type is `APIResponse`")
if inspect.isclass(origin) and issubclass(origin, httpx.Response):
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
# and pass that class to our request functions. We cannot change the variance to be either
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
# the response class ourselves but that is something that should be supported directly in httpx
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
if cast_type != httpx.Response:
raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_type`")
return cast(R, response)
if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):
raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
if (
cast_type is not object
and origin is not list
and origin is not dict
and origin is not Union
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
f"Unsupported type, expected {cast_type} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." # noqa: E501
)
# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type", "*").split(";")
if content_type != "application/json":
if is_basemodel(cast_type):
try:
data = response.json()
except Exception as exc:
log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
else:
return self._client._process_response_data(
data=data,
cast_type=cast_type, # type: ignore
response=response,
)
if self._client._strict_response_validation:
raise APIResponseValidationError(
response=response,
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", # noqa: E501
json_data=response.text,
)
# If the API responds with content that isn't JSON then we just return
# the (decoded) text without performing any parsing so that you can still
# handle the response however you need to.
return response.text # type: ignore
data = response.json()
return self._client._process_response_data(
data=data,
cast_type=cast_type, # type: ignore
response=response,
)
@override
def __repr__(self) -> str:
return f"<APIResponse [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_type}>"
class MissingStreamClassError(TypeError):
def __init__(self) -> None:
super().__init__(
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference", # noqa: E501
)
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "true"
kwargs["extra_headers"] = extra_headers
return cast(LegacyAPIResponse[R], func(*args, **kwargs))
return wrapped

View File

@ -1,97 +0,0 @@
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, ClassVar, Union, cast
import pydantic.generics
from httpx import Timeout
from typing_extensions import Required, TypedDict, Unpack, final
from ._base_compat import PYDANTIC_V2, ConfigDict
from ._base_type import AnyMapping, Body, Headers, HttpxRequestFiles, NotGiven, Query
from ._constants import RAW_RESPONSE_HEADER
from ._utils import is_given, strip_not_given
class UserRequestInput(TypedDict, total=False):
headers: Headers
max_retries: int
timeout: float | Timeout | None
params: Query
extra_json: AnyMapping
class FinalRequestOptionsInput(TypedDict, total=False):
method: Required[str]
url: Required[str]
params: Query
headers: Headers
max_retries: int
timeout: float | Timeout | None
files: HttpxRequestFiles | None
json_data: Body
extra_json: AnyMapping
@final
class FinalRequestOptions(pydantic.BaseModel):
method: str
url: str
params: Query = {}
headers: Union[Headers, NotGiven] = NotGiven()
max_retries: Union[int, NotGiven] = NotGiven()
timeout: Union[float, Timeout, None, NotGiven] = NotGiven()
files: Union[HttpxRequestFiles, None] = None
idempotency_key: Union[str, None] = None
post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
# It should be noted that we cannot use `json` here as that would override
# a BaseModel method in an incompatible fashion.
json_data: Union[Body, None] = None
extra_json: Union[AnyMapping, None] = None
if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
else:
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
arbitrary_types_allowed: bool = True
def get_max_retries(self, max_retries: int) -> int:
if isinstance(self.max_retries, NotGiven):
return max_retries
return self.max_retries
def _strip_raw_response_header(self) -> None:
if not is_given(self.headers):
return
if self.headers.get(RAW_RESPONSE_HEADER):
self.headers = {**self.headers}
self.headers.pop(RAW_RESPONSE_HEADER)
# override the `construct` method so that we can run custom transformations.
# this is necessary as we don't want to do any actual runtime type checking
# (which means we can't use validators) but we do want to ensure that `NotGiven`
# values are not present
#
# type ignore required because we're adding explicit types to `**values`
@classmethod
def construct( # type: ignore
cls,
_fields_set: set[str] | None = None,
**values: Unpack[UserRequestInput],
) -> FinalRequestOptions:
kwargs: dict[str, Any] = {
# we unconditionally call `strip_not_given` on any value
# as it will just ignore any non-mapping types
key: strip_not_given(value)
for key, value in values.items()
}
if PYDANTIC_V2:
return super().model_construct(_fields_set, **kwargs)
return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
if not TYPE_CHECKING:
# type checkers incorrectly complain about this assignment
model_construct = construct

View File

@ -1,398 +0,0 @@
from __future__ import annotations
import datetime
import inspect
import logging
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, get_origin, overload
import httpx
import pydantic
from typing_extensions import ParamSpec, override
from ._base_models import BaseModel, is_basemodel
from ._base_type import NoneType
from ._errors import APIResponseValidationError, ZhipuAIError
from ._sse_client import StreamResponse, extract_stream_chunk_type, is_stream_class_type
from ._utils import extract_type_arg, extract_type_var_from_base, is_annotated_type, is_given
if TYPE_CHECKING:
from ._http_client import HttpClient
from ._request_opt import FinalRequestOptions
P = ParamSpec("P")
R = TypeVar("R")
_T = TypeVar("_T")
_APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]")
log: logging.Logger = logging.getLogger(__name__)
class BaseAPIResponse(Generic[R]):
_cast_type: type[R]
_client: HttpClient
_parsed_by_type: dict[type[Any], Any]
_is_sse_stream: bool
_stream_cls: type[StreamResponse[Any]]
_options: FinalRequestOptions
http_response: httpx.Response
def __init__(
self,
*,
raw: httpx.Response,
cast_type: type[R],
client: HttpClient,
stream: bool,
stream_cls: type[StreamResponse[Any]] | None = None,
options: FinalRequestOptions,
) -> None:
self._cast_type = cast_type
self._client = client
self._parsed_by_type = {}
self._is_sse_stream = stream
self._stream_cls = stream_cls
self._options = options
self.http_response = raw
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
# unwrap `Annotated[T, ...]` -> `T`
if to and is_annotated_type(to):
to = extract_type_arg(to, 0)
if self._is_sse_stream:
if to:
if not is_stream_class_type(to):
raise TypeError(f"Expected custom parse type to be a subclass of {StreamResponse}")
return cast(
_T,
to(
cast_type=extract_stream_chunk_type(
to,
failure_message="Expected custom stream type to be passed with a type argument, e.g. StreamResponse[ChunkType]", # noqa: E501
),
response=self.http_response,
client=cast(Any, self._client),
),
)
if self._stream_cls:
return cast(
R,
self._stream_cls(
cast_type=extract_stream_chunk_type(self._stream_cls),
response=self.http_response,
client=cast(Any, self._client),
),
)
stream_cls = cast("type[Stream[Any]] | None", self._client._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return cast(
R,
stream_cls(
cast_type=self._cast_type,
response=self.http_response,
client=cast(Any, self._client),
),
)
cast_type = to if to is not None else self._cast_type
# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(cast_type):
cast_type = extract_type_arg(cast_type, 0)
if cast_type is NoneType:
return cast(R, None)
response = self.http_response
if cast_type == str:
return cast(R, response.text)
if cast_type == bytes:
return cast(R, response.content)
if cast_type == int:
return cast(R, int(response.text))
if cast_type == float:
return cast(R, float(response.text))
origin = get_origin(cast_type) or cast_type
# handle the legacy binary response case
if inspect.isclass(cast_type) and cast_type.__name__ == "HttpxBinaryResponseContent":
return cast(R, cast_type(response)) # type: ignore
if origin == APIResponse:
raise RuntimeError("Unexpected state - cast_type is `APIResponse`")
if inspect.isclass(origin) and issubclass(origin, httpx.Response):
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
# and pass that class to our request functions. We cannot change the variance to be either
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
# the response class ourselves but that is something that should be supported directly in httpx
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
if cast_type != httpx.Response:
raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_type`")
return cast(R, response)
if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):
raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
if (
cast_type is not object
and origin is not list
and origin is not dict
and origin is not Union
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
f"Unsupported type, expected {cast_type} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." # noqa: E501
)
# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type", "*").split(";")
if content_type != "application/json":
if is_basemodel(cast_type):
try:
data = response.json()
except Exception as exc:
log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
else:
return self._client._process_response_data(
data=data,
cast_type=cast_type, # type: ignore
response=response,
)
if self._client._strict_response_validation:
raise APIResponseValidationError(
response=response,
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", # noqa: E501
json_data=response.text,
)
# If the API responds with content that isn't JSON then we just return
# the (decoded) text without performing any parsing so that you can still
# handle the response however you need to.
return response.text # type: ignore
data = response.json()
return self._client._process_response_data(
data=data,
cast_type=cast_type, # type: ignore
response=response,
)
@property
def headers(self) -> httpx.Headers:
return self.http_response.headers
@property
def http_request(self) -> httpx.Request:
"""Returns the httpx Request instance associated with the current response."""
return self.http_response.request
@property
def status_code(self) -> int:
return self.http_response.status_code
@property
def url(self) -> httpx.URL:
"""Returns the URL for which the request was made."""
return self.http_response.url
@property
def method(self) -> str:
return self.http_request.method
@property
def http_version(self) -> str:
return self.http_response.http_version
@property
def elapsed(self) -> datetime.timedelta:
"""The time taken for the complete request/response cycle to complete."""
return self.http_response.elapsed
@property
def is_closed(self) -> bool:
"""Whether or not the response body has been closed.
If this is False then there is response data that has not been read yet.
You must either fully consume the response body or call `.close()`
before discarding the response to prevent resource leaks.
"""
return self.http_response.is_closed
@override
def __repr__(self) -> str:
return f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_type}>" # noqa: E501
class APIResponse(BaseAPIResponse[R]):
@property
def request_id(self) -> str | None:
return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return]
@overload
def parse(self, *, to: type[_T]) -> _T: ...
@overload
def parse(self) -> R: ...
def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
You can customize the type that the response is parsed into through
the `to` argument, e.g.
```py
from openai import BaseModel
class MyModel(BaseModel):
foo: str
obj = response.parse(to=MyModel)
print(obj.foo)
```
We support parsing:
- `BaseModel`
- `dict`
- `list`
- `Union`
- `str`
- `int`
- `float`
- `httpx.Response`
"""
cache_key = to if to is not None else self._cast_type
cached = self._parsed_by_type.get(cache_key)
if cached is not None:
return cached # type: ignore[no-any-return]
if not self._is_sse_stream:
self.read()
parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)
self._parsed_by_type[cache_key] = parsed
return parsed
def read(self) -> bytes:
"""Read and return the binary response content."""
try:
return self.http_response.read()
except httpx.StreamConsumed as exc:
# The default error raised by httpx isn't very
# helpful in our case so we re-raise it with
# a different error message.
raise StreamAlreadyConsumed() from exc
def text(self) -> str:
"""Read and decode the response content into a string."""
self.read()
return self.http_response.text
def json(self) -> object:
"""Read and decode the JSON response content."""
self.read()
return self.http_response.json()
def close(self) -> None:
"""Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self.http_response.close()
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
"""
A byte-iterator over the decoded response content.
This automatically handles gzip, deflate and brotli encoded responses.
"""
yield from self.http_response.iter_bytes(chunk_size)
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
"""A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
yield from self.http_response.iter_text(chunk_size)
def iter_lines(self) -> Iterator[str]:
"""Like `iter_text()` but will only yield chunks for each line"""
yield from self.http_response.iter_lines()
class MissingStreamClassError(TypeError):
def __init__(self) -> None:
super().__init__(
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference", # noqa: E501
)
class StreamAlreadyConsumed(ZhipuAIError): # noqa: N818
"""
Attempted to read or stream content, but the content has already
been streamed.
This can happen if you use a method like `.iter_lines()` and then attempt
to read th entire response body afterwards, e.g.
```py
response = await client.post(...)
async for line in response.iter_lines():
... # do something with `line`
content = await response.read()
# ^ error
```
If you want this behavior you'll need to either manually accumulate the response
content or call `await response.read()` before iterating over the stream.
"""
def __init__(self) -> None:
message = (
"Attempted to read or stream some content, but the content has "
"already been streamed. "
"This could be due to attempting to stream the response "
"content more than once."
"\n\n"
"You can fix this by manually accumulating the response content while streaming "
"or by calling `.read()` before starting to stream."
)
super().__init__(message)
def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type:
"""Given a type like `APIResponse[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyResponse(APIResponse[bytes]):
...
extract_response_type(MyResponse) -> bytes
```
"""
return extract_type_var_from_base(
typ,
generic_bases=cast("tuple[type, ...]", (BaseAPIResponse, APIResponse)),
index=0,
)

View File

@ -1,206 +0,0 @@
from __future__ import annotations
import inspect
import json
from collections.abc import Iterator, Mapping
from typing import TYPE_CHECKING, Generic, TypeGuard, cast
import httpx
from . import get_origin
from ._base_type import ResponseT
from ._errors import APIResponseError
from ._utils import extract_type_var_from_base, is_mapping
_FIELD_SEPARATOR = ":"
if TYPE_CHECKING:
from ._http_client import HttpClient
class StreamResponse(Generic[ResponseT]):
response: httpx.Response
_cast_type: type[ResponseT]
def __init__(
self,
*,
cast_type: type[ResponseT],
response: httpx.Response,
client: HttpClient,
) -> None:
self.response = response
self._cast_type = cast_type
self._data_process_func = client._process_response_data
self._stream_chunks = self.__stream__()
def __next__(self) -> ResponseT:
return self._stream_chunks.__next__()
def __iter__(self) -> Iterator[ResponseT]:
yield from self._stream_chunks
def __stream__(self) -> Iterator[ResponseT]:
sse_line_parser = SSELineParser()
iterator = sse_line_parser.iter_lines(self.response.iter_lines())
for sse in iterator:
if sse.data.startswith("[DONE]"):
break
if sse.event is None:
data = sse.json_data()
if isinstance(data, Mapping) and data.get("error"):
raise APIResponseError(
message="An error occurred during streaming",
request=self.response.request,
json_data=data["error"],
)
if sse.event is None:
data = sse.json_data()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"
raise APIResponseError(
message=message,
request=self.response.request,
json_data=data["error"],
)
yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response)
else:
data = sse.json_data()
if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"
raise APIResponseError(
message=message,
request=self.response.request,
json_data=data["error"],
)
yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response)
for sse in iterator:
pass
class Event:
def __init__(
self, event: str | None = None, data: str | None = None, id: str | None = None, retry: int | None = None
):
self._event = event
self._data = data
self._id = id
self._retry = retry
def __repr__(self):
data_len = len(self._data) if self._data else 0
return (
f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}"
)
@property
def event(self):
return self._event
@property
def data(self):
return self._data
def json_data(self):
return json.loads(self._data)
@property
def id(self):
return self._id
@property
def retry(self):
return self._retry
class SSELineParser:
_data: list[str]
_event: str | None
_retry: int | None
_id: str | None
def __init__(self):
self._event = None
self._data = []
self._id = None
self._retry = None
def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]:
for line in lines:
line = line.rstrip("\n")
if not line:
if self._event is None and not self._data and self._id is None and self._retry is None:
continue
sse_event = Event(event=self._event, data="\n".join(self._data), id=self._id, retry=self._retry)
self._event = None
self._data = []
self._id = None
self._retry = None
yield sse_event
self.decode_line(line)
def decode_line(self, line: str):
if line.startswith(":") or not line:
return
field, _p, value = line.partition(":")
value = value.removeprefix(" ")
if field == "data":
self._data.append(value)
elif field == "event":
self._event = value
elif field == "retry":
try:
self._retry = int(value)
except (TypeError, ValueError):
pass
return
def is_stream_class_type(typ: type) -> TypeGuard[type[StreamResponse[object]]]:
"""TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
origin = get_origin(typ) or typ
return inspect.isclass(origin) and issubclass(origin, StreamResponse)
def extract_stream_chunk_type(
stream_cls: type,
*,
failure_message: str | None = None,
) -> type:
"""Given a type like `StreamResponse[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyStream(StreamResponse[bytes]):
...
extract_stream_chunk_type(MyStream) -> bytes
```
"""
return extract_type_var_from_base(
stream_cls,
index=0,
generic_bases=cast("tuple[type, ...]", (StreamResponse,)),
failure_message=failure_message,
)

View File

@ -1,52 +0,0 @@
from ._utils import ( # noqa: I001
remove_notgiven_indict as remove_notgiven_indict, # noqa: PLC0414
flatten as flatten, # noqa: PLC0414
is_dict as is_dict, # noqa: PLC0414
is_list as is_list, # noqa: PLC0414
is_given as is_given, # noqa: PLC0414
is_tuple as is_tuple, # noqa: PLC0414
is_mapping as is_mapping, # noqa: PLC0414
is_tuple_t as is_tuple_t, # noqa: PLC0414
parse_date as parse_date, # noqa: PLC0414
is_iterable as is_iterable, # noqa: PLC0414
is_sequence as is_sequence, # noqa: PLC0414
coerce_float as coerce_float, # noqa: PLC0414
is_mapping_t as is_mapping_t, # noqa: PLC0414
removeprefix as removeprefix, # noqa: PLC0414
removesuffix as removesuffix, # noqa: PLC0414
extract_files as extract_files, # noqa: PLC0414
is_sequence_t as is_sequence_t, # noqa: PLC0414
required_args as required_args, # noqa: PLC0414
coerce_boolean as coerce_boolean, # noqa: PLC0414
coerce_integer as coerce_integer, # noqa: PLC0414
file_from_path as file_from_path, # noqa: PLC0414
parse_datetime as parse_datetime, # noqa: PLC0414
strip_not_given as strip_not_given, # noqa: PLC0414
deepcopy_minimal as deepcopy_minimal, # noqa: PLC0414
get_async_library as get_async_library, # noqa: PLC0414
maybe_coerce_float as maybe_coerce_float, # noqa: PLC0414
get_required_header as get_required_header, # noqa: PLC0414
maybe_coerce_boolean as maybe_coerce_boolean, # noqa: PLC0414
maybe_coerce_integer as maybe_coerce_integer, # noqa: PLC0414
drop_prefix_image_data as drop_prefix_image_data, # noqa: PLC0414
)
from ._typing import (
is_list_type as is_list_type, # noqa: PLC0414
is_union_type as is_union_type, # noqa: PLC0414
extract_type_arg as extract_type_arg, # noqa: PLC0414
is_iterable_type as is_iterable_type, # noqa: PLC0414
is_required_type as is_required_type, # noqa: PLC0414
is_annotated_type as is_annotated_type, # noqa: PLC0414
strip_annotated_type as strip_annotated_type, # noqa: PLC0414
extract_type_var_from_base as extract_type_var_from_base, # noqa: PLC0414
)
from ._transform import (
PropertyInfo as PropertyInfo, # noqa: PLC0414
transform as transform, # noqa: PLC0414
async_transform as async_transform, # noqa: PLC0414
maybe_transform as maybe_transform, # noqa: PLC0414
async_maybe_transform as async_maybe_transform, # noqa: PLC0414
)

View File

@ -1,383 +0,0 @@
from __future__ import annotations
import base64
import io
import pathlib
from collections.abc import Mapping
from datetime import date, datetime
from typing import Any, Literal, TypeVar, cast, get_args, get_type_hints
import anyio
import pydantic
from typing_extensions import override
from .._base_compat import is_typeddict, model_dump
from .._files import is_base64_file_input
from ._typing import (
extract_type_arg,
is_annotated_type,
is_iterable_type,
is_list_type,
is_required_type,
is_union_type,
strip_annotated_type,
)
from ._utils import (
is_iterable,
is_list,
is_mapping,
)
_T = TypeVar("_T")
# TODO: support for drilling globals() and locals()
# TODO: ensure works correctly with forward references in all cases
PropertyFormat = Literal["iso8601", "base64", "custom"]
class PropertyInfo:
"""Metadata class to be used in Annotated types to provide information about a given type.
For example:
class MyParams(TypedDict):
account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
""" # noqa: E501
alias: str | None
format: PropertyFormat | None
format_template: str | None
discriminator: str | None
def __init__(
self,
*,
alias: str | None = None,
format: PropertyFormat | None = None,
format_template: str | None = None,
discriminator: str | None = None,
) -> None:
self.alias = alias
self.format = format
self.format_template = format_template
self.discriminator = discriminator
@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')" # noqa: E501
def maybe_transform(
data: object,
expected_type: object,
) -> Any | None:
"""Wrapper over `transform()` that allows `None` to be passed.
See `transform()` for more details.
"""
if data is None:
return None
return transform(data, expected_type)
# Wrapper over _transform_recursive providing fake types
def transform(
data: _T,
expected_type: object,
) -> _T:
"""Transform dictionaries based off of type information from the given type, for example:
```py
class Params(TypedDict, total=False):
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
transformed = transform({"card_id": "<my card ID>"}, Params)
# {'cardID': '<my card ID>'}
```
Any keys / data that does not have type information given will be included as is.
It should be noted that the transformations that this function does are not represented in the type system.
"""
transformed = _transform_recursive(data, annotation=cast(type, expected_type))
return cast(_T, transformed)
def _get_annotated_type(type_: type) -> type | None:
"""If the given type is an `Annotated` type then it is returned, if not `None` is returned.
This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
"""
if is_required_type(type_):
# Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
type_ = get_args(type_)[0]
if is_annotated_type(type_):
return type_
return None
def _maybe_transform_key(key: str, type_: type) -> str:
"""Transform the given `data` based on the annotations provided in `type_`.
Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata.
"""
annotated_type = _get_annotated_type(type_)
if annotated_type is None:
# no `Annotated` definition for this type, no transformation needed
return key
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
return annotation.alias
return key
def _transform_recursive(
data: object,
*,
annotation: type,
inner_type: type | None = None,
) -> object:
"""Transform the given data against the expected type.
Args:
annotation: The direct type annotation given to the particular piece of data.
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
the list can be transformed using the metadata from the container type.
Defaults to the same value as the `annotation` argument.
"""
if inner_type is None:
inner_type = annotation
stripped_type = strip_annotated_type(inner_type)
if is_typeddict(stripped_type) and is_mapping(data):
return _transform_typeddict(data, stripped_type)
if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
):
inner_type = extract_type_arg(stripped_type, 0)
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
return data
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True)
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return _format_data(data, annotation.format, annotation.format_template)
return data
def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
if isinstance(data, date | datetime):
if format_ == "iso8601":
return data.isoformat()
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)
if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None
if isinstance(data, pathlib.Path):
binary = data.read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()
if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()
if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
return base64.b64encode(binary).decode("ascii")
return data
def _transform_typeddict(
data: Mapping[str, object],
expected_type: type,
) -> Mapping[str, object]:
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
return result
async def async_maybe_transform(
data: object,
expected_type: object,
) -> Any | None:
"""Wrapper over `async_transform()` that allows `None` to be passed.
See `async_transform()` for more details.
"""
if data is None:
return None
return await async_transform(data, expected_type)
async def async_transform(
data: _T,
expected_type: object,
) -> _T:
"""Transform dictionaries based off of type information from the given type, for example:
```py
class Params(TypedDict, total=False):
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
transformed = transform({"card_id": "<my card ID>"}, Params)
# {'cardID': '<my card ID>'}
```
Any keys / data that does not have type information given will be included as is.
It should be noted that the transformations that this function does are not represented in the type system.
"""
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
return cast(_T, transformed)
async def _async_transform_recursive(
data: object,
*,
annotation: type,
inner_type: type | None = None,
) -> object:
"""Transform the given data against the expected type.
Args:
annotation: The direct type annotation given to the particular piece of data.
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
the list can be transformed using the metadata from the container type.
Defaults to the same value as the `annotation` argument.
"""
if inner_type is None:
inner_type = annotation
stripped_type = strip_annotated_type(inner_type)
if is_typeddict(stripped_type) and is_mapping(data):
return await _async_transform_typeddict(data, stripped_type)
if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
):
inner_type = extract_type_arg(stripped_type, 0)
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
return data
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True)
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return await _async_format_data(data, annotation.format, annotation.format_template)
return data
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
if isinstance(data, date | datetime):
if format_ == "iso8601":
return data.isoformat()
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)
if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None
if isinstance(data, pathlib.Path):
binary = await anyio.Path(data).read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()
if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()
if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
return base64.b64encode(binary).decode("ascii")
return data
async def _async_transform_typeddict(
data: Mapping[str, object],
expected_type: type,
) -> Mapping[str, object]:
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
return result

View File

@ -1,122 +0,0 @@
from __future__ import annotations
from collections import abc as _c_abc
from collections.abc import Iterable
from typing import Annotated, Any, TypeVar, cast, get_args, get_origin
from typing_extensions import Required
from .._base_compat import is_union as _is_union
from .._base_type import InheritsGeneric
def is_annotated_type(typ: type) -> bool:
return get_origin(typ) == Annotated
def is_list_type(typ: type) -> bool:
return (get_origin(typ) or typ) == list
def is_iterable_type(typ: type) -> bool:
"""If the given type is `typing.Iterable[T]`"""
origin = get_origin(typ) or typ
return origin in {Iterable, _c_abc.Iterable}
def is_union_type(typ: type) -> bool:
return _is_union(get_origin(typ))
def is_required_type(typ: type) -> bool:
return get_origin(typ) == Required
def is_typevar(typ: type) -> bool:
# type ignore is required because type checkers
# think this expression will always return False
return type(typ) == TypeVar # type: ignore
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
return strip_annotated_type(cast(type, get_args(typ)[0]))
return typ
def extract_type_arg(typ: type, index: int) -> type:
args = get_args(typ)
try:
return cast(type, args[index])
except IndexError as err:
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
def extract_type_var_from_base(
typ: type,
*,
generic_bases: tuple[type, ...],
index: int,
failure_message: str | None = None,
) -> type:
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyResponse(Foo[bytes]):
...
extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
```
And where a generic subclass is given:
```py
_T = TypeVar('_T')
class MyResponse(Foo[_T]):
...
extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes
```
"""
cls = cast(object, get_origin(typ) or typ)
if cls in generic_bases:
# we're given the class directly
return extract_type_arg(typ, index)
# if a subclass is given
# ---
# this is needed as __orig_bases__ is not present in the typeshed stubs
# because it is intended to be for internal use only, however there does
# not seem to be a way to resolve generic TypeVars for inherited subclasses
# without using it.
if isinstance(cls, InheritsGeneric):
target_base_class: Any | None = None
for base in cls.__orig_bases__:
if base.__origin__ in generic_bases:
target_base_class = base
break
if target_base_class is None:
raise RuntimeError(
"Could not find the generic base class;\n"
"This should never happen;\n"
f"Does {cls} inherit from one of {generic_bases} ?"
)
extracted = extract_type_arg(target_base_class, index)
if is_typevar(extracted):
# If the extracted type argument is itself a type variable
# then that means the subclass itself is generic, so we have
# to resolve the type argument from the class itself, not
# the base class.
#
# Note: if there is more than 1 type argument, the subclass could
# change the ordering of the type arguments, this is not currently
# supported.
return extract_type_arg(typ, index)
return extracted
raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")

View File

@ -1,409 +0,0 @@
from __future__ import annotations
import functools
import inspect
import os
import re
from collections.abc import Callable, Iterable, Mapping, Sequence
from pathlib import Path
from typing import (
Any,
TypeGuard,
TypeVar,
Union,
cast,
overload,
)
import sniffio
from .._base_compat import parse_date as parse_date # noqa: PLC0414
from .._base_compat import parse_datetime as parse_datetime # noqa: PLC0414
from .._base_type import FileTypes, Headers, HeadersLike, NotGiven, NotGivenOr
def remove_notgiven_indict(obj):
if obj is None or (not isinstance(obj, Mapping)):
return obj
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
_T = TypeVar("_T")
_TupleT = TypeVar("_TupleT", bound=tuple[object, ...])
_MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
return [item for sublist in t for item in sublist]
def extract_files(
# TODO: this needs to take Dict but variance issues.....
# create protocol type ?
query: Mapping[str, object],
*,
paths: Sequence[Sequence[str]],
) -> list[tuple[str, FileTypes]]:
"""Recursively extract files from the given dictionary based on specified paths.
A path may look like this ['foo', 'files', '<array>', 'data'].
Note: this mutates the given dictionary.
"""
files: list[tuple[str, FileTypes]] = []
for path in paths:
files.extend(_extract_items(query, path, index=0, flattened_key=None))
return files
def _extract_items(
obj: object,
path: Sequence[str],
*,
index: int,
flattened_key: str | None,
) -> list[tuple[str, FileTypes]]:
try:
key = path[index]
except IndexError:
if isinstance(obj, NotGiven):
# no value was provided - we can safely ignore
return []
# cyclical import
from .._files import assert_is_file_content
# We have exhausted the path, return the entry we found.
assert_is_file_content(obj, key=flattened_key)
assert flattened_key is not None
return [(flattened_key, cast(FileTypes, obj))]
index += 1
if is_dict(obj):
try:
# We are at the last entry in the path so we must remove the field
if (len(path)) == index:
item = obj.pop(key)
else:
item = obj[key]
except KeyError:
# Key was not present in the dictionary, this is not indicative of an error
# as the given path may not point to a required field. We also do not want
# to enforce required fields as the API may differ from the spec in some cases.
return []
if flattened_key is None:
flattened_key = key
else:
flattened_key += f"[{key}]"
return _extract_items(
item,
path,
index=index,
flattened_key=flattened_key,
)
elif is_list(obj):
if key != "<array>":
return []
return flatten(
[
_extract_items(
item,
path,
index=index,
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
)
for item in obj
]
)
# Something unexpected was passed, just ignore it.
return []
def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]:
return not isinstance(obj, NotGiven)
# Type safe methods for narrowing types with TypeVars.
# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
# however this cause Pyright to rightfully report errors. As we know we don't
# care about the contained types we can safely use `object` in it's place.
#
# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
# `is_*` is for when you're dealing with an unknown input
# `is_*_t` is for when you're narrowing a known union type to a specific subset
def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
return isinstance(obj, tuple)
def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
return isinstance(obj, tuple)
def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
return isinstance(obj, Sequence)
def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
return isinstance(obj, Sequence)
def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
return isinstance(obj, Mapping)
def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
return isinstance(obj, Mapping)
def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
return isinstance(obj, dict)
def is_list(obj: object) -> TypeGuard[list[object]]:
return isinstance(obj, list)
def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
return isinstance(obj, Iterable)
def deepcopy_minimal(item: _T) -> _T:
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
- mappings, e.g. `dict`
- list
This is done for performance reasons.
"""
if is_mapping(item):
return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
if is_list(item):
return cast(_T, [deepcopy_minimal(entry) for entry in item])
return item
# copied from https://github.com/Rapptz/RoboDanny
def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
size = len(seq)
if size == 0:
return ""
if size == 1:
return seq[0]
if size == 2:
return f"{seq[0]} {final} {seq[1]}"
return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
def quote(string: str) -> str:
"""Add single quotation marks around the given string. Does *not* do any escaping."""
return f"'{string}'"
def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
"""Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
Useful for enforcing runtime validation of overloaded functions.
Example usage:
```py
@overload
def foo(*, a: str) -> str:
...
@overload
def foo(*, b: bool) -> str:
...
# This enforces the same constraints that a static type checker would
# i.e. that either a or b must be passed to the function
@required_args(["a"], ["b"])
def foo(*, a: str | None = None, b: bool | None = None) -> str:
...
```
"""
def inner(func: CallableT) -> CallableT:
params = inspect.signature(func).parameters
positional = [
name
for name, param in params.items()
if param.kind
in {
param.POSITIONAL_ONLY,
param.POSITIONAL_OR_KEYWORD,
}
]
@functools.wraps(func)
def wrapper(*args: object, **kwargs: object) -> object:
given_params: set[str] = set()
for i in range(len(args)):
try:
given_params.add(positional[i])
except IndexError:
raise TypeError(
f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
) from None
given_params.update(kwargs.keys())
for variant in variants:
matches = all(param in given_params for param in variant)
if matches:
break
else: # no break
if len(variants) > 1:
variations = human_join(
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
)
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
else:
# TODO: this error message is not deterministic
missing = list(set(variants[0]) - given_params)
if len(missing) > 1:
msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
else:
msg = f"Missing required argument: {quote(missing[0])}"
raise TypeError(msg)
return func(*args, **kwargs)
return wrapper # type: ignore
return inner
_K = TypeVar("_K")
_V = TypeVar("_V")
@overload
def strip_not_given(obj: None) -> None: ...
@overload
def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ...
@overload
def strip_not_given(obj: object) -> object: ...
def strip_not_given(obj: object | None) -> object:
"""Remove all top-level keys where their values are instances of `NotGiven`"""
if obj is None:
return None
if not is_mapping(obj):
return obj
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
def coerce_integer(val: str) -> int:
return int(val, base=10)
def coerce_float(val: str) -> float:
return float(val)
def coerce_boolean(val: str) -> bool:
return val in {"true", "1", "on"}
def maybe_coerce_integer(val: str | None) -> int | None:
if val is None:
return None
return coerce_integer(val)
def maybe_coerce_float(val: str | None) -> float | None:
if val is None:
return None
return coerce_float(val)
def maybe_coerce_boolean(val: str | None) -> bool | None:
if val is None:
return None
return coerce_boolean(val)
def removeprefix(string: str, prefix: str) -> str:
"""Remove a prefix from a string.
Backport of `str.removeprefix` for Python < 3.9
"""
if string.startswith(prefix):
return string[len(prefix) :]
return string
def removesuffix(string: str, suffix: str) -> str:
"""Remove a suffix from a string.
Backport of `str.removesuffix` for Python < 3.9
"""
if string.endswith(suffix):
return string[: -len(suffix)]
return string
def file_from_path(path: str) -> FileTypes:
contents = Path(path).read_bytes()
file_name = os.path.basename(path)
return (file_name, contents)
def get_required_header(headers: HeadersLike, header: str) -> str:
lower_header = header.lower()
if isinstance(headers, Mapping):
headers = cast(Headers, headers)
for k, v in headers.items():
if k.lower() == lower_header and isinstance(v, str):
return v
""" to deal with the case where the header looks like Stainless-Event-Id """
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
value = headers.get(normalized_header)
if value:
return value
raise ValueError(f"Could not find {header} header")
def get_async_library() -> str:
try:
return sniffio.current_async_library()
except Exception:
return "false"
def drop_prefix_image_data(content: Union[str, list[dict]]) -> Union[str, list[dict]]:
"""
删除 ;base64, 前缀
:param image_data:
:return:
"""
if isinstance(content, list):
for data in content:
if data.get("type") == "image_url":
image_data = data.get("image_url").get("url")
if image_data.startswith("data:image/"):
image_data = image_data.split("base64,")[-1]
data["image_url"]["url"] = image_data
return content

View File

@ -1,78 +0,0 @@
import logging
import os
import time
logger = logging.getLogger(__name__)
class LoggerNameFilter(logging.Filter):
def filter(self, record):
# return record.name.startswith("loom_core") or record.name in "ERROR" or (
# record.name.startswith("uvicorn.error")
# and record.getMessage().startswith("Uvicorn running on")
# )
return True
def get_log_file(log_path: str, sub_dir: str):
"""
sub_dir should contain a timestamp.
"""
log_dir = os.path.join(log_path, sub_dir)
# Here should be creating a new directory each time, so `exist_ok=False`
os.makedirs(log_dir, exist_ok=False)
return os.path.join(log_dir, "zhipuai.log")
def get_config_dict(log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int) -> dict:
# for windows, the path should be a raw string.
log_file_path = log_file_path.encode("unicode-escape").decode() if os.name == "nt" else log_file_path
log_level = log_level.upper()
config_dict = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"formatter": {"format": ("%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s")},
},
"filters": {
"logger_name_filter": {
"()": __name__ + ".LoggerNameFilter",
},
},
"handlers": {
"stream_handler": {
"class": "logging.StreamHandler",
"formatter": "formatter",
"level": log_level,
# "stream": "ext://sys.stdout",
# "filters": ["logger_name_filter"],
},
"file_handler": {
"class": "logging.handlers.RotatingFileHandler",
"formatter": "formatter",
"level": log_level,
"filename": log_file_path,
"mode": "a",
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf8",
},
},
"loggers": {
"loom_core": {
"handlers": ["stream_handler", "file_handler"],
"level": log_level,
"propagate": False,
}
},
"root": {
"level": log_level,
"handlers": ["stream_handler", "file_handler"],
},
}
return config_dict
def get_timestamp_ms():
t = time.time()
return int(round(t * 1000))

View File

@ -1,62 +0,0 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Any, Generic, Optional, TypeVar, cast
from typing_extensions import Protocol, override, runtime_checkable
from ._http_client import BasePage, BaseSyncPage, PageInfo
__all__ = ["SyncPage", "SyncCursorPage"]
_T = TypeVar("_T")
@runtime_checkable
class CursorPageItem(Protocol):
id: Optional[str]
class SyncPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
"""Note: no pagination actually occurs yet, this is for forwards-compatibility."""
data: list[_T]
object: str
@override
def _get_page_items(self) -> list[_T]:
data = self.data
if not data:
return []
return data
@override
def next_page_info(self) -> None:
"""
This page represents a response that isn't actually paginated at the API level
so there will never be a next page.
"""
return None
class SyncCursorPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
data: list[_T]
@override
def _get_page_items(self) -> list[_T]:
data = self.data
if not data:
return []
return data
@override
def next_page_info(self) -> Optional[PageInfo]:
data = self.data
if not data:
return None
item = cast(Any, data[-1])
if not isinstance(item, CursorPageItem) or item.id is None:
# TODO emit warning log
return None
return PageInfo(params={"after": item.id})

View File

@ -1,5 +0,0 @@
from .assistant_completion import AssistantCompletion
__all__ = [
"AssistantCompletion",
]

View File

@ -1,40 +0,0 @@
from typing import Any, Optional
from ...core import BaseModel
from .message import MessageContent
__all__ = ["AssistantCompletion", "CompletionUsage"]
class ErrorInfo(BaseModel):
code: str # 错误码
message: str # 错误信息
class AssistantChoice(BaseModel):
index: int # 结果下标
delta: MessageContent # 当前会话输出消息体
finish_reason: str
"""
# 推理结束原因 stop代表推理自然结束或触发停止词。 sensitive 代表模型推理内容被安全审核接口拦截。请注意,针对此类内容,请用户自行判断并决定是否撤回已公开的内容。
# network_error 代表模型推理服务异常。
""" # noqa: E501
metadata: dict # 元信息,拓展字段
class CompletionUsage(BaseModel):
prompt_tokens: int # 输入的 tokens 数量
completion_tokens: int # 输出的 tokens 数量
total_tokens: int # 总 tokens 数量
class AssistantCompletion(BaseModel):
id: str # 请求 ID
conversation_id: str # 会话 ID
assistant_id: str # 智能体 ID
created: int # 请求创建时间Unix 时间戳
status: str # 返回状态,包括:`completed` 表示生成结束`in_progress`表示生成中 `failed` 表示生成异常
last_error: Optional[ErrorInfo] # 异常信息
choices: list[AssistantChoice] # 增量返回的信息
metadata: Optional[dict[str, Any]] # 元信息,拓展字段
usage: Optional[CompletionUsage] # tokens 数量统计

View File

@ -1,7 +0,0 @@
from typing import TypedDict
class ConversationParameters(TypedDict, total=False):
assistant_id: str # 智能体 ID
page: int # 当前分页
page_size: int # 分页数量

View File

@ -1,29 +0,0 @@
from ...core import BaseModel
__all__ = ["ConversationUsageListResp"]
class Usage(BaseModel):
prompt_tokens: int # 用户输入的 tokens 数量
completion_tokens: int # 模型输入的 tokens 数量
total_tokens: int # 总 tokens 数量
class ConversationUsage(BaseModel):
id: str # 会话 id
assistant_id: str # 智能体Assistant id
create_time: int # 创建时间
update_time: int # 更新时间
usage: Usage # 会话中 tokens 数量统计
class ConversationUsageList(BaseModel):
assistant_id: str # 智能体id
has_more: bool # 是否还有更多页
conversation_list: list[ConversationUsage] # 返回的
class ConversationUsageListResp(BaseModel):
code: int
msg: str
data: ConversationUsageList

View File

@ -1,32 +0,0 @@
from typing import Optional, TypedDict, Union
class AssistantAttachments:
file_id: str
class MessageTextContent:
type: str # 目前支持 type = text
text: str
MessageContent = Union[MessageTextContent]
class ConversationMessage(TypedDict):
"""会话消息体"""
role: str # 用户的输入角色,例如 'user'
content: list[MessageContent] # 会话消息体的内容
class AssistantParameters(TypedDict, total=False):
"""智能体参数类"""
assistant_id: str # 智能体 ID
conversation_id: Optional[str] # 会话 ID不传则创建新会话
model: str # 模型名称,默认为 'GLM-4-Assistant'
stream: bool # 是否支持流式 SSE需要传入 True
messages: list[ConversationMessage] # 会话消息体
attachments: Optional[list[AssistantAttachments]] # 会话指定的文件,非必填
metadata: Optional[dict] # 元信息,拓展字段,非必填

View File

@ -1,21 +0,0 @@
from ...core import BaseModel
__all__ = ["AssistantSupportResp"]
class AssistantSupport(BaseModel):
assistant_id: str # 智能体的 Assistant id用于智能体会话
created_at: int # 创建时间
updated_at: int # 更新时间
name: str # 智能体名称
avatar: str # 智能体头像
description: str # 智能体描述
status: str # 智能体状态,目前只有 publish
tools: list[str] # 智能体支持的工具名
starter_prompts: list[str] # 智能体启动推荐的 prompt
class AssistantSupportResp(BaseModel):
code: int
msg: str
data: list[AssistantSupport] # 智能体列表

View File

@ -1,3 +0,0 @@
from .message_content import MessageContent
__all__ = ["MessageContent"]

View File

@ -1,13 +0,0 @@
from typing import Annotated, TypeAlias, Union
from ....core._utils import PropertyInfo
from .text_content_block import TextContentBlock
from .tools_delta_block import ToolsDeltaBlock
__all__ = ["MessageContent"]
MessageContent: TypeAlias = Annotated[
Union[ToolsDeltaBlock, TextContentBlock],
PropertyInfo(discriminator="type"),
]

View File

@ -1,14 +0,0 @@
from typing import Literal
from ....core import BaseModel
__all__ = ["TextContentBlock"]
class TextContentBlock(BaseModel):
content: str
role: str = "assistant"
type: Literal["content"] = "content"
"""Always `content`."""

View File

@ -1,27 +0,0 @@
from typing import Literal
__all__ = ["CodeInterpreterToolBlock"]
from .....core import BaseModel
class CodeInterpreterToolOutput(BaseModel):
"""代码工具输出结果"""
type: str # 代码执行日志,目前只有 logs
logs: str # 代码执行的日志结果
error_msg: str # 错误信息
class CodeInterpreter(BaseModel):
"""代码解释器"""
input: str # 生成的代码片段,输入给代码沙盒
outputs: list[CodeInterpreterToolOutput] # 代码执行后的输出结果
class CodeInterpreterToolBlock(BaseModel):
"""代码工具块"""
code_interpreter: CodeInterpreter # 代码解释器对象
type: Literal["code_interpreter"] # 调用工具的类型,始终为 `code_interpreter`

View File

@ -1,21 +0,0 @@
from typing import Literal
from .....core import BaseModel
__all__ = ["DrawingToolBlock"]
class DrawingToolOutput(BaseModel):
image: str
class DrawingTool(BaseModel):
input: str
outputs: list[DrawingToolOutput]
class DrawingToolBlock(BaseModel):
drawing_tool: DrawingTool
type: Literal["drawing_tool"]
"""Always `drawing_tool`."""

View File

@ -1,22 +0,0 @@
from typing import Literal, Union
__all__ = ["FunctionToolBlock"]
from .....core import BaseModel
class FunctionToolOutput(BaseModel):
content: str
class FunctionTool(BaseModel):
name: str
arguments: Union[str, dict]
outputs: list[FunctionToolOutput]
class FunctionToolBlock(BaseModel):
function: FunctionTool
type: Literal["function"]
"""Always `drawing_tool`."""

View File

@ -1,41 +0,0 @@
from typing import Literal
from .....core import BaseModel
class RetrievalToolOutput(BaseModel):
"""
This class represents the output of a retrieval tool.
Attributes:
- text (str): The text snippet retrieved from the knowledge base.
- document (str): The name of the document from which the text snippet was retrieved, returned only in intelligent configuration.
""" # noqa: E501
text: str
document: str
class RetrievalTool(BaseModel):
"""
This class represents the outputs of a retrieval tool.
Attributes:
- outputs (List[RetrievalToolOutput]): A list of text snippets and their respective document names retrieved from the knowledge base.
""" # noqa: E501
outputs: list[RetrievalToolOutput]
class RetrievalToolBlock(BaseModel):
"""
This class represents a block for invoking the retrieval tool.
Attributes:
- retrieval (RetrievalTool): An instance of the RetrievalTool class containing the retrieval outputs.
- type (Literal["retrieval"]): The type of tool being used, always set to "retrieval".
"""
retrieval: RetrievalTool
type: Literal["retrieval"]
"""Always `retrieval`."""

View File

@ -1,16 +0,0 @@
from typing import Annotated, TypeAlias, Union
from .....core._utils import PropertyInfo
from .code_interpreter_delta_block import CodeInterpreterToolBlock
from .drawing_tool_delta_block import DrawingToolBlock
from .function_delta_block import FunctionToolBlock
from .retrieval_delta_black import RetrievalToolBlock
from .web_browser_delta_block import WebBrowserToolBlock
__all__ = ["ToolsType"]
ToolsType: TypeAlias = Annotated[
Union[DrawingToolBlock, CodeInterpreterToolBlock, WebBrowserToolBlock, RetrievalToolBlock, FunctionToolBlock],
PropertyInfo(discriminator="type"),
]

View File

@ -1,48 +0,0 @@
from typing import Literal
from .....core import BaseModel
__all__ = ["WebBrowserToolBlock"]
class WebBrowserOutput(BaseModel):
"""
This class represents the output of a web browser search result.
Attributes:
- title (str): The title of the search result.
- link (str): The URL link to the search result's webpage.
- content (str): The textual content extracted from the search result.
- error_msg (str): Any error message encountered during the search or retrieval process.
"""
title: str
link: str
content: str
error_msg: str
class WebBrowser(BaseModel):
"""
This class represents the input and outputs of a web browser search.
Attributes:
- input (str): The input query for the web browser search.
- outputs (List[WebBrowserOutput]): A list of search results returned by the web browser.
"""
input: str
outputs: list[WebBrowserOutput]
class WebBrowserToolBlock(BaseModel):
"""
This class represents a block for invoking the web browser tool.
Attributes:
- web_browser (WebBrowser): An instance of the WebBrowser class containing the search input and outputs.
- type (Literal["web_browser"]): The type of tool being used, always set to "web_browser".
"""
web_browser: WebBrowser
type: Literal["web_browser"]

View File

@ -1,16 +0,0 @@
from typing import Literal
from ....core import BaseModel
from .tools.tools_type import ToolsType
__all__ = ["ToolsDeltaBlock"]
class ToolsDeltaBlock(BaseModel):
tool_calls: list[ToolsType]
"""The index of the content part in the message."""
role: str = "tool"
type: Literal["tool_calls"] = "tool_calls"
"""Always `tool_calls`."""

View File

@ -1,82 +0,0 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
import builtins
from typing import Literal, Optional
from ..core import BaseModel
from .batch_error import BatchError
from .batch_request_counts import BatchRequestCounts
__all__ = ["Batch", "Errors"]
class Errors(BaseModel):
data: Optional[list[BatchError]] = None
object: Optional[str] = None
"""这个类型,一直是`list`。"""
class Batch(BaseModel):
id: str
completion_window: str
"""用于执行请求的地址信息。"""
created_at: int
"""这是 Unix timestamp (in seconds) 表示的创建时间。"""
endpoint: str
"""这是ZhipuAI endpoint的地址。"""
input_file_id: str
"""标记为batch的输入文件的ID。"""
object: Literal["batch"]
"""这个类型,一直是`batch`."""
status: Literal[
"validating", "failed", "in_progress", "finalizing", "completed", "expired", "cancelling", "cancelled"
]
"""batch 的状态。"""
cancelled_at: Optional[int] = None
"""Unix timestamp (in seconds) 表示的取消时间。"""
cancelling_at: Optional[int] = None
"""Unix timestamp (in seconds) 表示发起取消的请求时间 """
completed_at: Optional[int] = None
"""Unix timestamp (in seconds) 表示的完成时间。"""
error_file_id: Optional[str] = None
"""这个文件id包含了执行请求失败的请求的输出。"""
errors: Optional[Errors] = None
expired_at: Optional[int] = None
"""Unix timestamp (in seconds) 表示的将在过期时间。"""
expires_at: Optional[int] = None
"""Unix timestamp (in seconds) 触发过期"""
failed_at: Optional[int] = None
"""Unix timestamp (in seconds) 表示的失败时间。"""
finalizing_at: Optional[int] = None
"""Unix timestamp (in seconds) 表示的最终时间。"""
in_progress_at: Optional[int] = None
"""Unix timestamp (in seconds) 表示的开始处理时间。"""
metadata: Optional[builtins.object] = None
"""
key:value形式的元数据以便将信息存储
结构化格式键的长度是64个字符值最长512个字符
"""
output_file_id: Optional[str] = None
"""完成请求的输出文件的ID。"""
request_counts: Optional[BatchRequestCounts] = None
"""批次中不同状态的请求计数"""

View File

@ -1,37 +0,0 @@
from __future__ import annotations
from typing import Literal, Optional
from typing_extensions import Required, TypedDict
__all__ = ["BatchCreateParams"]
class BatchCreateParams(TypedDict, total=False):
completion_window: Required[str]
"""The time frame within which the batch should be processed.
Currently only `24h` is supported.
"""
endpoint: Required[Literal["/v1/chat/completions", "/v1/embeddings"]]
"""The endpoint to be used for all requests in the batch.
Currently `/v1/chat/completions` and `/v1/embeddings` are supported.
"""
input_file_id: Required[str]
"""The ID of an uploaded file that contains requests for the new batch.
See [upload file](https://platform.openai.com/docs/api-reference/files/create)
for how to upload a file.
Your input file must be formatted as a
[JSONL file](https://platform.openai.com/docs/api-reference/batch/requestInput),
and must be uploaded with the purpose `batch`.
"""
metadata: Optional[dict[str, str]]
"""Optional custom metadata for the batch."""
auto_delete_input_file: Optional[bool]

View File

@ -1,21 +0,0 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from ..core import BaseModel
__all__ = ["BatchError"]
class BatchError(BaseModel):
code: Optional[str] = None
"""定义的业务错误码"""
line: Optional[int] = None
"""文件中的行号"""
message: Optional[str] = None
"""关于对话文件中的错误的描述"""
param: Optional[str] = None
"""参数名称,如果有的话"""

View File

@ -1,20 +0,0 @@
from __future__ import annotations
from typing_extensions import TypedDict
__all__ = ["BatchListParams"]
class BatchListParams(TypedDict, total=False):
after: str
"""分页的游标,用于获取下一页的数据。
`after` 是一个指向当前页面的游标用于获取下一页的数据如果没有提供 `after`则返回第一页的数据
list.
"""
limit: int
"""这个参数用于限制返回的结果数量。
Limit 用于限制返回的结果数量默认值为 10
"""

View File

@ -1,14 +0,0 @@
from ..core import BaseModel
__all__ = ["BatchRequestCounts"]
class BatchRequestCounts(BaseModel):
completed: int
"""这个数字表示已经完成的请求。"""
failed: int
"""这个数字表示失败的请求。"""
total: int
"""这个数字表示总的请求。"""

View File

@ -1,22 +0,0 @@
from typing import Optional
from ...core import BaseModel
from .chat_completion import CompletionChoice, CompletionUsage
__all__ = ["AsyncTaskStatus", "AsyncCompletion"]
class AsyncTaskStatus(BaseModel):
id: Optional[str] = None
request_id: Optional[str] = None
model: Optional[str] = None
task_status: Optional[str] = None
class AsyncCompletion(BaseModel):
id: Optional[str] = None
request_id: Optional[str] = None
model: Optional[str] = None
task_status: str
choices: list[CompletionChoice]
usage: CompletionUsage

Some files were not shown because too many files have changed in this diff Show More