1
0

formatter pass

This commit is contained in:
2025-10-12 18:20:08 -04:00
parent ca021408ca
commit f0bd973461

View File

@@ -1,5 +1,11 @@
from typing import List
from cvmlib import guac_decode, guac_encode, CollabVMRank, CollabVMState, CollabVMClientRenameStatus
from cvmlib import (
guac_decode,
guac_encode,
CollabVMRank,
CollabVMState,
CollabVMClientRenameStatus,
)
import config
import os, random, websockets, asyncio
from websockets import Subprotocol, Origin
@@ -30,6 +36,7 @@ vms = {}
vm_botuser = {}
STATE = CollabVMState.WS_DISCONNECTED
def get_origin_from_ws_url(ws_url: str) -> str:
domain = (
ws_url.removeprefix("ws:")
@@ -41,61 +48,75 @@ def get_origin_from_ws_url(ws_url: str) -> str:
is_wss = ws_url.startswith("wss:")
return f"http{'s' if is_wss else ''}://{domain}/"
async def send_chat_message(websocket, message: str):
log.debug(f"Sending chat message: {message}")
await websocket.send(guac_encode(["chat", message]))
async def send_guac(websocket, *args: str):
await websocket.send(guac_encode(list(args)))
async def periodic_snapshot_task():
"""Background task that saves VM framebuffers as snapshots in WEBP format."""
log.info("Starting periodic snapshot task")
while True:
try:
await asyncio.sleep(10) # Wait 10 seconds
log.debug("Running periodic framebuffer snapshot capture...")
# Save framebuffers for all VMs
timestamp = int(datetime.now(timezone.utc).timestamp()) # Use Unix timestamp for filenames
timestamp = int(
datetime.now(timezone.utc).timestamp()
) # Use Unix timestamp for filenames
current_day = datetime.now(timezone.utc).strftime("%b-%d-%Y")
for vm_name, vm_data in vms.items():
# Create snapshots directory for each VM, then by current day
snapshot_dir = os.path.join("logs", "webp", vm_name, current_day)
os.makedirs(snapshot_dir, exist_ok=True)
framebuffer = vm_data.get("framebuffer")
if framebuffer:
framebuffer.seek(0)
# Calculate hash of the framebuffer to detect duplicates
framebuffer_data = framebuffer.getvalue()
current_hash = hashlib.md5(framebuffer_data).hexdigest()
# Check if this frame is the same as the last one
if vm_data.get("last_frame_hash") == current_hash:
log.debug(f"Skipping duplicate frame for VM '{vm_name}'")
continue
# Save the new frame
framebuffer.seek(0)
image = Image.open(framebuffer)
snapshot_path = os.path.join(snapshot_dir, f"{vm_name}_{timestamp}.webp")
image.save(snapshot_path, format="WEBP", quality=65, optimize=True, method=6)
snapshot_path = os.path.join(
snapshot_dir, f"{vm_name}_{timestamp}.webp"
)
image.save(
snapshot_path,
format="WEBP",
quality=65,
optimize=True,
method=6,
)
# Update the hash for this VM
vm_data["last_frame_hash"] = current_hash
log.info(f"Saved snapshot for VM '{vm_name}' to {snapshot_path}")
else:
log.warning(f"No framebuffer available for VM '{vm_name}'")
except Exception as e:
log.error(f"Error in periodic snapshot task: {e}")
# Continue running even if there's an error
async def connect(vm_name: str):
global STATE
global vms
@@ -103,9 +124,17 @@ async def connect(vm_name: str):
if vm_name not in config.vms:
log.error(f"VM '{vm_name}' not found in configuration.")
return
vms[vm_name] = {"turn_queue": [], "active_turn_user": None, "users": {}, "framebuffer": None, "last_frame_hash": None}
vms[vm_name] = {
"turn_queue": [],
"active_turn_user": None,
"users": {},
"framebuffer": None,
"last_frame_hash": None,
}
uri = config.vms[vm_name]
log_file_path = os.path.join(getattr(config, "log_directory", "logs"), f"{vm_name}.json")
log_file_path = os.path.join(
getattr(config, "log_directory", "logs"), f"{vm_name}.json"
)
if not os.path.exists(log_file_path):
with open(log_file_path, "w") as log_file:
log_file.write("{}")
@@ -113,7 +142,7 @@ async def connect(vm_name: str):
uri=uri,
subprotocols=[Subprotocol("guacamole")],
origin=Origin(get_origin_from_ws_url(uri)),
user_agent_header="cvmsentry/1 (https://git.nixlabs.dev/clair/cvmsentry)"
user_agent_header="cvmsentry/1 (https://git.nixlabs.dev/clair/cvmsentry)",
) as websocket:
STATE = CollabVMState.WS_CONNECTED
log.info(f"Connected to VM '{vm_name}' at {uri}")
@@ -129,33 +158,58 @@ async def connect(vm_name: str):
await send_guac(websocket, "nop")
case ["auth", config.auth_server]:
await asyncio.sleep(1)
await send_guac(websocket, "login", config.credentials["session_auth"])
case ["connect", connection_status, turns_enabled, votes_enabled, uploads_enabled]:
await send_guac(
websocket, "login", config.credentials["session_auth"]
)
case [
"connect",
connection_status,
turns_enabled,
votes_enabled,
uploads_enabled,
]:
if connection_status == "1":
STATE = CollabVMState.VM_CONNECTED
log.info(f"Connected to VM '{vm_name}' successfully. Turns enabled: {bool(int(turns_enabled))}, Votes enabled: {bool(int(votes_enabled))}, Uploads enabled: {bool(int(uploads_enabled))}")
log.info(
f"Connected to VM '{vm_name}' successfully. Turns enabled: {bool(int(turns_enabled))}, Votes enabled: {bool(int(votes_enabled))}, Uploads enabled: {bool(int(uploads_enabled))}"
)
else:
log.error(f"Failed to connect to VM '{vm_name}'. Connection status: {connection_status}")
log.error(
f"Failed to connect to VM '{vm_name}'. Connection status: {connection_status}"
)
STATE = CollabVMState.WS_DISCONNECTED
await websocket.close()
case ["rename", *instructions]:
match instructions:
case ["0", status, new_name]:
if CollabVMClientRenameStatus(int(status)) == CollabVMClientRenameStatus.SUCCEEDED:
log.debug(f"({STATE.name} - {vm_name}) Bot rename on VM {vm_name}: {vm_botuser[vm_name]} -> {new_name}")
if (
CollabVMClientRenameStatus(int(status))
== CollabVMClientRenameStatus.SUCCEEDED
):
log.debug(
f"({STATE.name} - {vm_name}) Bot rename on VM {vm_name}: {vm_botuser[vm_name]} -> {new_name}"
)
vm_botuser[vm_name] = new_name
else:
log.debug(f"({STATE.name} - {vm_name}) Bot rename on VM {vm_name} failed with status {CollabVMClientRenameStatus(int(status)).name}")
log.debug(
f"({STATE.name} - {vm_name}) Bot rename on VM {vm_name} failed with status {CollabVMClientRenameStatus(int(status)).name}"
)
case ["1", old_name, new_name]:
if old_name in vms[vm_name]["users"]:
log.debug(f"({STATE.name} - {vm_name}) User rename on VM {vm_name}: {old_name} -> {new_name}")
vms[vm_name]["users"][new_name] = vms[vm_name]["users"].pop(old_name)
log.debug(
f"({STATE.name} - {vm_name}) User rename on VM {vm_name}: {old_name} -> {new_name}"
)
vms[vm_name]["users"][new_name] = vms[vm_name][
"users"
].pop(old_name)
case ["login", "1"]:
STATE = CollabVMState.LOGGED_IN
if config.send_autostart and config.autostart_messages:
await send_chat_message(websocket, random.choice(config.autostart_messages))
await send_chat_message(
websocket, random.choice(config.autostart_messages)
)
case ["chat", user, message, *backlog]:
system_message = (user == "")
system_message = user == ""
if system_message:
continue
if not backlog:
@@ -163,14 +217,17 @@ async def connect(vm_name: str):
def get_rank(username: str) -> CollabVMRank:
return vms[vm_name]["users"].get(username, {}).get("rank")
def admin_check(username: str) -> bool:
return username in config.admins and get_rank(username) > CollabVMRank.Unregistered
return (
username in config.admins
and get_rank(username) > CollabVMRank.Unregistered
)
utc_now = datetime.now(timezone.utc)
utc_day = utc_now.strftime("%Y-%m-%d")
timestamp = utc_now.isoformat()
with open(log_file_path, "r+") as log_file:
try:
log_data = json.load(log_file)
@@ -193,48 +250,75 @@ async def connect(vm_name: str):
# "message": backlog_message
# })
log_data[utc_day].append({
"type": "chat",
"timestamp": timestamp,
"username": user,
"message": message
})
log_data[utc_day].append(
{
"type": "chat",
"timestamp": timestamp,
"username": user,
"message": message,
}
)
log_file.seek(0)
json.dump(log_data, log_file, indent=4)
log_file.truncate()
if config.commands["enabled"] and message.startswith(config.commands["prefix"]):
command = message[len(config.commands["prefix"]):].strip().lower()
if config.commands["enabled"] and message.startswith(
config.commands["prefix"]
):
command = (
message[len(config.commands["prefix"]) :].strip().lower()
)
match command:
case "whoami":
await send_chat_message(websocket, f"You are {user} with rank {get_rank(user).name}.")
await send_chat_message(
websocket,
f"You are {user} with rank {get_rank(user).name}.",
)
case "about":
await send_chat_message(websocket, config.responses.get("about", "CVM-Sentry (NO RESPONSE CONFIGURED)"))
await send_chat_message(
websocket,
config.responses.get(
"about", "CVM-Sentry (NO RESPONSE CONFIGURED)"
),
)
case "dump":
if not admin_check(user):
continue
log.debug(f"({STATE.name} - {vm_name}) Dumping user list for VM {vm_name}: {vms[vm_name]['users']}")
await send_chat_message(websocket, f"Dumped user list to console.")
log.debug(
f"({STATE.name} - {vm_name}) Dumping user list for VM {vm_name}: {vms[vm_name]['users']}"
)
await send_chat_message(
websocket, f"Dumped user list to console."
)
case ["adduser", count, *list]:
for i in range(int(count)):
user = list[i * 2]
rank = CollabVMRank(int(list[i * 2 + 1]))
if user in vms[vm_name]["users"]:
vms[vm_name]["users"][user]["rank"] = rank
log.info(f"[{vm_name}] User '{user}' rank updated to {rank.name}.")
log.info(
f"[{vm_name}] User '{user}' rank updated to {rank.name}."
)
else:
vms[vm_name]["users"][user] = {"rank": rank}
log.info(f"[{vm_name}] User '{user}' connected with rank {rank.name}.")
log.info(
f"[{vm_name}] User '{user}' connected with rank {rank.name}."
)
case ["turn", _, "0"]:
if STATE < CollabVMState.LOGGED_IN:
continue
if vms[vm_name]["active_turn_user"] is None and not vms[vm_name]["turn_queue"]:
#log.debug(f"({STATE.name} - {vm_name}) Incoming queue exhaustion matches the VM's state. Dropping update.")
if (
vms[vm_name]["active_turn_user"] is None
and not vms[vm_name]["turn_queue"]
):
# log.debug(f"({STATE.name} - {vm_name}) Incoming queue exhaustion matches the VM's state. Dropping update.")
continue
vms[vm_name]["active_turn_user"] = None
vms[vm_name]["turn_queue"] = []
log.debug(f"({STATE.name} - {vm_name}) Turn queue is naturally exhausted.")
log.debug(
f"({STATE.name} - {vm_name}) Turn queue is naturally exhausted."
)
case ["png", "0", "0", "0", "0", initial_frame_b64]:
# Decode the base64 image data
initial_frame_data = base64.b64decode(initial_frame_b64)
@@ -250,7 +334,9 @@ async def connect(vm_name: str):
# Assign the in-memory framebuffer to the VM's dictionary
vms[vm_name]["framebuffer"] = framebuffer
framebuffer_size = framebuffer.getbuffer().nbytes
log.info(f"({STATE.name} - {vm_name}) !!! WHOLE FRAME UPDATE !!! ({framebuffer_size} bytes)")
log.info(
f"({STATE.name} - {vm_name}) !!! WHOLE FRAME UPDATE !!! ({framebuffer_size} bytes)"
)
case ["png", "0", "0", x, y, rect_b64]:
# Decode the base64 image data for the rectangle
rect_data = base64.b64decode(rect_b64)
@@ -277,54 +363,69 @@ async def connect(vm_name: str):
# Log the updated framebuffer size
framebuffer_size = framebuffer.getbuffer().nbytes
log.debug(f"({STATE.name} - {vm_name}) Updated framebuffer size: {framebuffer_size} bytes")
log.debug(
f"({STATE.name} - {vm_name}) Updated framebuffer size: {framebuffer_size} bytes"
)
else:
continue
case ["turn", turn_time, count, current_turn, *queue]:
if queue == vms[vm_name]["turn_queue"] and current_turn == vms[vm_name]["active_turn_user"]:
#log.debug(f"({STATE.name} - {vm_name}) Incoming turn update matches the VM's state. Dropping update.")
continue
for user in vms[vm_name]["users"]:
vms[vm_name]["turn_queue"] = queue
vms[vm_name]["active_turn_user"] = current_turn if current_turn != "" else None
if current_turn:
utc_now = datetime.now(timezone.utc)
utc_day = utc_now.strftime("%Y-%m-%d")
timestamp = utc_now.isoformat()
if (
queue == vms[vm_name]["turn_queue"]
and current_turn == vms[vm_name]["active_turn_user"]
):
# log.debug(f"({STATE.name} - {vm_name}) Incoming turn update matches the VM's state. Dropping update.")
continue
for user in vms[vm_name]["users"]:
vms[vm_name]["turn_queue"] = queue
vms[vm_name]["active_turn_user"] = (
current_turn if current_turn != "" else None
)
if current_turn:
utc_now = datetime.now(timezone.utc)
utc_day = utc_now.strftime("%Y-%m-%d")
timestamp = utc_now.isoformat()
with open(log_file_path, "r+") as log_file:
try:
log_data = json.load(log_file)
except json.JSONDecodeError:
log_data = {}
with open(log_file_path, "r+") as log_file:
try:
log_data = json.load(log_file)
except json.JSONDecodeError:
log_data = {}
if utc_day not in log_data:
log_data[utc_day] = []
if utc_day not in log_data:
log_data[utc_day] = []
log_data[utc_day].append({
log_data[utc_day].append(
{
"type": "turn",
"timestamp": timestamp,
"active_turn_user": current_turn,
"queue": queue
})
"queue": queue,
}
)
log_file.seek(0)
json.dump(log_data, log_file, indent=4)
log_file.truncate()
log.debug(f"({STATE.name} - {vm_name}) Turn update: turn_time={turn_time}, count={count}, current_turn={current_turn}, queue={queue}")
log_file.seek(0)
json.dump(log_data, log_file, indent=4)
log_file.truncate()
log.debug(
f"({STATE.name} - {vm_name}) Turn update: turn_time={turn_time}, count={count}, current_turn={current_turn}, queue={queue}"
)
case ["remuser", count, *list]:
for i in range(int(count)):
username = list[i]
if username in vms[vm_name]["users"]:
del vms[vm_name]["users"][username]
log.info(f"[{vm_name}] User '{username}' left.")
case ["flag", *args] | ["size", *args] | ["png", *args] | ["sync", *args]:
case (
["flag", *args] | ["size", *args] | ["png", *args] | ["sync", *args]
):
continue
case _:
if decoded is not None:
log.debug(f"({STATE.name} - {vm_name}) Unhandled message: {decoded}")
log.debug(
f"({STATE.name} - {vm_name}) Unhandled message: {decoded}"
)
log.info(f"({STATE.name}) CVM-Sentry started")
@@ -334,35 +435,45 @@ for vm in config.vms.keys():
asyncio.run(connect(vm_name))
async def main():
async def connect_with_reconnect(vm_name: str):
while True:
try:
await connect(vm_name)
except websockets.exceptions.ConnectionClosedError as e:
log.warning(f"Connection to VM '{vm_name}' closed with error: {e}. Reconnecting...")
log.warning(
f"Connection to VM '{vm_name}' closed with error: {e}. Reconnecting..."
)
await asyncio.sleep(5) # Wait before attempting to reconnect
except websockets.exceptions.ConnectionClosedOK:
log.warning(f"Connection to VM '{vm_name}' closed cleanly (code 1005). Reconnecting...")
log.warning(
f"Connection to VM '{vm_name}' closed cleanly (code 1005). Reconnecting..."
)
await asyncio.sleep(5) # Wait before attempting to reconnect
except websockets.exceptions.InvalidStatus as e:
log.error(f"Failed to connect to VM '{vm_name}' with status code: {e}. Reconnecting...")
log.error(
f"Failed to connect to VM '{vm_name}' with status code: {e}. Reconnecting..."
)
await asyncio.sleep(10) # Wait longer for HTTP errors
except websockets.exceptions.WebSocketException as e:
log.error(f"WebSocket error connecting to VM '{vm_name}': {e}. Reconnecting...")
log.error(
f"WebSocket error connecting to VM '{vm_name}': {e}. Reconnecting..."
)
await asyncio.sleep(5)
except Exception as e:
log.error(f"Unexpected error connecting to VM '{vm_name}': {e}. Reconnecting...")
log.error(
f"Unexpected error connecting to VM '{vm_name}': {e}. Reconnecting..."
)
await asyncio.sleep(10) # Wait longer for unexpected errors
# Create tasks for VM connections
vm_tasks = [connect_with_reconnect(vm) for vm in config.vms.keys()]
# Add periodic snapshot task
snapshot_task = periodic_snapshot_task()
# Run all tasks concurrently
all_tasks = [snapshot_task] + vm_tasks
await asyncio.gather(*all_tasks)
asyncio.run(main())
asyncio.run(main())