refac: code intepreter

This commit is contained in:
Timothy Jaeryang Baek 2025-02-10 13:12:05 -08:00
parent 610f9d039a
commit a273cba0fb
2 changed files with 50 additions and 23 deletions

View File

@ -18,6 +18,7 @@ async def execute_code_jupyter(
:param password: Jupyter password (optional) :param password: Jupyter password (optional)
:param timeout: WebSocket timeout in seconds (default: 10s) :param timeout: WebSocket timeout in seconds (default: 10s)
:return: Dictionary with stdout, stderr, and result :return: Dictionary with stdout, stderr, and result
- Images are prefixed with "base64:image/png," and separated by newlines if multiple.
""" """
session = requests.Session() # Maintain cookies session = requests.Session() # Maintain cookies
headers = {} # Headers for requests headers = {} # Headers for requests
@ -28,20 +29,15 @@ async def execute_code_jupyter(
login_url = urljoin(jupyter_url, "/login") login_url = urljoin(jupyter_url, "/login")
response = session.get(login_url) response = session.get(login_url)
response.raise_for_status() response.raise_for_status()
# Retrieve `_xsrf` token
xsrf_token = session.cookies.get("_xsrf") xsrf_token = session.cookies.get("_xsrf")
if not xsrf_token: if not xsrf_token:
raise ValueError("Failed to fetch _xsrf token") raise ValueError("Failed to fetch _xsrf token")
# Send login request
login_data = {"_xsrf": xsrf_token, "password": password} login_data = {"_xsrf": xsrf_token, "password": password}
login_response = session.post( login_response = session.post(
login_url, data=login_data, cookies=session.cookies login_url, data=login_data, cookies=session.cookies
) )
login_response.raise_for_status() login_response.raise_for_status()
# Update headers with `_xsrf`
headers["X-XSRFToken"] = xsrf_token headers["X-XSRFToken"] = xsrf_token
except Exception as e: except Exception as e:
return { return {
@ -55,18 +51,15 @@ async def execute_code_jupyter(
kernel_url = urljoin(jupyter_url, f"/api/kernels{params}") kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
try: try:
# Include cookies if authenticating with password
response = session.post(kernel_url, headers=headers, cookies=session.cookies) response = session.post(kernel_url, headers=headers, cookies=session.cookies)
response.raise_for_status() response.raise_for_status()
kernel_id = response.json()["id"] kernel_id = response.json()["id"]
# Construct WebSocket URL
websocket_url = urljoin( websocket_url = urljoin(
jupyter_url.replace("http", "ws"), jupyter_url.replace("http", "ws"),
f"/api/kernels/{kernel_id}/channels{params}", f"/api/kernels/{kernel_id}/channels{params}",
) )
# **IMPORTANT:** Include authentication cookies for WebSockets
ws_headers = {} ws_headers = {}
if password and not token: if password and not token:
ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf") ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf")
@ -75,13 +68,10 @@ async def execute_code_jupyter(
[f"{name}={value}" for name, value in cookies.items()] [f"{name}={value}" for name, value in cookies.items()]
) )
# Connect to the WebSocket
async with websockets.connect( async with websockets.connect(
websocket_url, additional_headers=ws_headers websocket_url, additional_headers=ws_headers
) as ws: ) as ws:
msg_id = str(uuid.uuid4()) msg_id = str(uuid.uuid4())
# Send execution request
execute_request = { execute_request = {
"header": { "header": {
"msg_id": msg_id, "msg_id": msg_id,
@ -105,37 +95,47 @@ async def execute_code_jupyter(
} }
await ws.send(json.dumps(execute_request)) await ws.send(json.dumps(execute_request))
# Collect execution results stdout, stderr, result = "", "", []
stdout, stderr, result = "", "", None
while True: while True:
try: try:
message = await asyncio.wait_for(ws.recv(), timeout) message = await asyncio.wait_for(ws.recv(), timeout)
message_data = json.loads(message) message_data = json.loads(message)
if message_data.get("parent_header", {}).get("msg_id") == msg_id: if message_data.get("parent_header", {}).get("msg_id") == msg_id:
msg_type = message_data.get("msg_type") msg_type = message_data.get("msg_type")
if msg_type == "stream": if msg_type == "stream":
if message_data["content"]["name"] == "stdout": if message_data["content"]["name"] == "stdout":
stdout += message_data["content"]["text"] stdout += message_data["content"]["text"]
elif message_data["content"]["name"] == "stderr": elif message_data["content"]["name"] == "stderr":
stderr += message_data["content"]["text"] stderr += message_data["content"]["text"]
elif msg_type in ("execute_result", "display_data"): elif msg_type in ("execute_result", "display_data"):
result = message_data["content"]["data"].get( data = message_data["content"]["data"]
"text/plain", "" if "image/png" in data:
result.append(
f"data:image/png;base64,{data['image/png']}"
) )
elif "text/plain" in data:
result.append(data["text/plain"])
elif msg_type == "error": elif msg_type == "error":
stderr += "\n".join(message_data["content"]["traceback"]) stderr += "\n".join(message_data["content"]["traceback"])
elif ( elif (
msg_type == "status" msg_type == "status"
and message_data["content"]["execution_state"] == "idle" and message_data["content"]["execution_state"] == "idle"
): ):
break break
except asyncio.TimeoutError: except asyncio.TimeoutError:
stderr += "\nExecution timed out." stderr += "\nExecution timed out."
break break
except Exception as e: except Exception as e:
return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""} return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""}
finally: finally:
# Shutdown the kernel
if kernel_id: if kernel_id:
requests.delete( requests.delete(
f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies
@ -144,10 +144,5 @@ async def execute_code_jupyter(
return { return {
"stdout": stdout.strip(), "stdout": stdout.strip(),
"stderr": stderr.strip(), "stderr": stderr.strip(),
"result": result.strip() if result else "", "result": "\n".join(result).strip() if result else "",
} }
# Example Usage
# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", token="your-token"))
# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", password="your-password"))

View File

@ -1723,6 +1723,38 @@ async def process_chat_response(
) )
output["stdout"] = "\n".join(stdoutLines) output["stdout"] = "\n".join(stdoutLines)
result = output.get("result", "")
if result:
resultLines = result.split("\n")
for idx, line in enumerate(resultLines):
if "data:image/png;base64" in line:
id = str(uuid4())
# ensure the path exists
os.makedirs(
os.path.join(CACHE_DIR, "images"),
exist_ok=True,
)
image_path = os.path.join(
CACHE_DIR,
f"images/{id}.png",
)
with open(image_path, "wb") as f:
f.write(
base64.b64decode(
line.split(",")[1]
)
)
resultLines[idx] = (
f"![Output Image {idx}](/cache/images/{id}.png)"
)
output["result"] = "\n".join(resultLines)
except Exception as e: except Exception as e:
output = str(e) output = str(e)