1
0

formatter pass

This commit is contained in:
2025-10-12 18:20:08 -04:00
parent ca021408ca
commit f0bd973461

View File

@@ -1,5 +1,11 @@
from typing import List from typing import List
from cvmlib import guac_decode, guac_encode, CollabVMRank, CollabVMState, CollabVMClientRenameStatus from cvmlib import (
guac_decode,
guac_encode,
CollabVMRank,
CollabVMState,
CollabVMClientRenameStatus,
)
import config import config
import os, random, websockets, asyncio import os, random, websockets, asyncio
from websockets import Subprotocol, Origin from websockets import Subprotocol, Origin
@@ -30,6 +36,7 @@ vms = {}
vm_botuser = {} vm_botuser = {}
STATE = CollabVMState.WS_DISCONNECTED STATE = CollabVMState.WS_DISCONNECTED
def get_origin_from_ws_url(ws_url: str) -> str: def get_origin_from_ws_url(ws_url: str) -> str:
domain = ( domain = (
ws_url.removeprefix("ws:") ws_url.removeprefix("ws:")
@@ -41,61 +48,75 @@ def get_origin_from_ws_url(ws_url: str) -> str:
is_wss = ws_url.startswith("wss:") is_wss = ws_url.startswith("wss:")
return f"http{'s' if is_wss else ''}://{domain}/" return f"http{'s' if is_wss else ''}://{domain}/"
async def send_chat_message(websocket, message: str): async def send_chat_message(websocket, message: str):
log.debug(f"Sending chat message: {message}") log.debug(f"Sending chat message: {message}")
await websocket.send(guac_encode(["chat", message])) await websocket.send(guac_encode(["chat", message]))
async def send_guac(websocket, *args: str): async def send_guac(websocket, *args: str):
await websocket.send(guac_encode(list(args))) await websocket.send(guac_encode(list(args)))
async def periodic_snapshot_task(): async def periodic_snapshot_task():
"""Background task that saves VM framebuffers as snapshots in WEBP format.""" """Background task that saves VM framebuffers as snapshots in WEBP format."""
log.info("Starting periodic snapshot task") log.info("Starting periodic snapshot task")
while True: while True:
try: try:
await asyncio.sleep(10) # Wait 10 seconds await asyncio.sleep(10) # Wait 10 seconds
log.debug("Running periodic framebuffer snapshot capture...") log.debug("Running periodic framebuffer snapshot capture...")
# Save framebuffers for all VMs # Save framebuffers for all VMs
timestamp = int(datetime.now(timezone.utc).timestamp()) # Use Unix timestamp for filenames timestamp = int(
datetime.now(timezone.utc).timestamp()
) # Use Unix timestamp for filenames
current_day = datetime.now(timezone.utc).strftime("%b-%d-%Y") current_day = datetime.now(timezone.utc).strftime("%b-%d-%Y")
for vm_name, vm_data in vms.items(): for vm_name, vm_data in vms.items():
# Create snapshots directory for each VM, then by current day # Create snapshots directory for each VM, then by current day
snapshot_dir = os.path.join("logs", "webp", vm_name, current_day) snapshot_dir = os.path.join("logs", "webp", vm_name, current_day)
os.makedirs(snapshot_dir, exist_ok=True) os.makedirs(snapshot_dir, exist_ok=True)
framebuffer = vm_data.get("framebuffer") framebuffer = vm_data.get("framebuffer")
if framebuffer: if framebuffer:
framebuffer.seek(0) framebuffer.seek(0)
# Calculate hash of the framebuffer to detect duplicates # Calculate hash of the framebuffer to detect duplicates
framebuffer_data = framebuffer.getvalue() framebuffer_data = framebuffer.getvalue()
current_hash = hashlib.md5(framebuffer_data).hexdigest() current_hash = hashlib.md5(framebuffer_data).hexdigest()
# Check if this frame is the same as the last one # Check if this frame is the same as the last one
if vm_data.get("last_frame_hash") == current_hash: if vm_data.get("last_frame_hash") == current_hash:
log.debug(f"Skipping duplicate frame for VM '{vm_name}'") log.debug(f"Skipping duplicate frame for VM '{vm_name}'")
continue continue
# Save the new frame # Save the new frame
framebuffer.seek(0) framebuffer.seek(0)
image = Image.open(framebuffer) image = Image.open(framebuffer)
snapshot_path = os.path.join(snapshot_dir, f"{vm_name}_{timestamp}.webp") snapshot_path = os.path.join(
image.save(snapshot_path, format="WEBP", quality=65, optimize=True, method=6) snapshot_dir, f"{vm_name}_{timestamp}.webp"
)
image.save(
snapshot_path,
format="WEBP",
quality=65,
optimize=True,
method=6,
)
# Update the hash for this VM # Update the hash for this VM
vm_data["last_frame_hash"] = current_hash vm_data["last_frame_hash"] = current_hash
log.info(f"Saved snapshot for VM '{vm_name}' to {snapshot_path}") log.info(f"Saved snapshot for VM '{vm_name}' to {snapshot_path}")
else: else:
log.warning(f"No framebuffer available for VM '{vm_name}'") log.warning(f"No framebuffer available for VM '{vm_name}'")
except Exception as e: except Exception as e:
log.error(f"Error in periodic snapshot task: {e}") log.error(f"Error in periodic snapshot task: {e}")
# Continue running even if there's an error # Continue running even if there's an error
async def connect(vm_name: str): async def connect(vm_name: str):
global STATE global STATE
global vms global vms
@@ -103,9 +124,17 @@ async def connect(vm_name: str):
if vm_name not in config.vms: if vm_name not in config.vms:
log.error(f"VM '{vm_name}' not found in configuration.") log.error(f"VM '{vm_name}' not found in configuration.")
return return
vms[vm_name] = {"turn_queue": [], "active_turn_user": None, "users": {}, "framebuffer": None, "last_frame_hash": None} vms[vm_name] = {
"turn_queue": [],
"active_turn_user": None,
"users": {},
"framebuffer": None,
"last_frame_hash": None,
}
uri = config.vms[vm_name] uri = config.vms[vm_name]
log_file_path = os.path.join(getattr(config, "log_directory", "logs"), f"{vm_name}.json") log_file_path = os.path.join(
getattr(config, "log_directory", "logs"), f"{vm_name}.json"
)
if not os.path.exists(log_file_path): if not os.path.exists(log_file_path):
with open(log_file_path, "w") as log_file: with open(log_file_path, "w") as log_file:
log_file.write("{}") log_file.write("{}")
@@ -113,7 +142,7 @@ async def connect(vm_name: str):
uri=uri, uri=uri,
subprotocols=[Subprotocol("guacamole")], subprotocols=[Subprotocol("guacamole")],
origin=Origin(get_origin_from_ws_url(uri)), origin=Origin(get_origin_from_ws_url(uri)),
user_agent_header="cvmsentry/1 (https://git.nixlabs.dev/clair/cvmsentry)" user_agent_header="cvmsentry/1 (https://git.nixlabs.dev/clair/cvmsentry)",
) as websocket: ) as websocket:
STATE = CollabVMState.WS_CONNECTED STATE = CollabVMState.WS_CONNECTED
log.info(f"Connected to VM '{vm_name}' at {uri}") log.info(f"Connected to VM '{vm_name}' at {uri}")
@@ -129,33 +158,58 @@ async def connect(vm_name: str):
await send_guac(websocket, "nop") await send_guac(websocket, "nop")
case ["auth", config.auth_server]: case ["auth", config.auth_server]:
await asyncio.sleep(1) await asyncio.sleep(1)
await send_guac(websocket, "login", config.credentials["session_auth"]) await send_guac(
case ["connect", connection_status, turns_enabled, votes_enabled, uploads_enabled]: websocket, "login", config.credentials["session_auth"]
)
case [
"connect",
connection_status,
turns_enabled,
votes_enabled,
uploads_enabled,
]:
if connection_status == "1": if connection_status == "1":
STATE = CollabVMState.VM_CONNECTED STATE = CollabVMState.VM_CONNECTED
log.info(f"Connected to VM '{vm_name}' successfully. Turns enabled: {bool(int(turns_enabled))}, Votes enabled: {bool(int(votes_enabled))}, Uploads enabled: {bool(int(uploads_enabled))}") log.info(
f"Connected to VM '{vm_name}' successfully. Turns enabled: {bool(int(turns_enabled))}, Votes enabled: {bool(int(votes_enabled))}, Uploads enabled: {bool(int(uploads_enabled))}"
)
else: else:
log.error(f"Failed to connect to VM '{vm_name}'. Connection status: {connection_status}") log.error(
f"Failed to connect to VM '{vm_name}'. Connection status: {connection_status}"
)
STATE = CollabVMState.WS_DISCONNECTED STATE = CollabVMState.WS_DISCONNECTED
await websocket.close() await websocket.close()
case ["rename", *instructions]: case ["rename", *instructions]:
match instructions: match instructions:
case ["0", status, new_name]: case ["0", status, new_name]:
if CollabVMClientRenameStatus(int(status)) == CollabVMClientRenameStatus.SUCCEEDED: if (
log.debug(f"({STATE.name} - {vm_name}) Bot rename on VM {vm_name}: {vm_botuser[vm_name]} -> {new_name}") CollabVMClientRenameStatus(int(status))
== CollabVMClientRenameStatus.SUCCEEDED
):
log.debug(
f"({STATE.name} - {vm_name}) Bot rename on VM {vm_name}: {vm_botuser[vm_name]} -> {new_name}"
)
vm_botuser[vm_name] = new_name vm_botuser[vm_name] = new_name
else: else:
log.debug(f"({STATE.name} - {vm_name}) Bot rename on VM {vm_name} failed with status {CollabVMClientRenameStatus(int(status)).name}") log.debug(
f"({STATE.name} - {vm_name}) Bot rename on VM {vm_name} failed with status {CollabVMClientRenameStatus(int(status)).name}"
)
case ["1", old_name, new_name]: case ["1", old_name, new_name]:
if old_name in vms[vm_name]["users"]: if old_name in vms[vm_name]["users"]:
log.debug(f"({STATE.name} - {vm_name}) User rename on VM {vm_name}: {old_name} -> {new_name}") log.debug(
vms[vm_name]["users"][new_name] = vms[vm_name]["users"].pop(old_name) f"({STATE.name} - {vm_name}) User rename on VM {vm_name}: {old_name} -> {new_name}"
)
vms[vm_name]["users"][new_name] = vms[vm_name][
"users"
].pop(old_name)
case ["login", "1"]: case ["login", "1"]:
STATE = CollabVMState.LOGGED_IN STATE = CollabVMState.LOGGED_IN
if config.send_autostart and config.autostart_messages: if config.send_autostart and config.autostart_messages:
await send_chat_message(websocket, random.choice(config.autostart_messages)) await send_chat_message(
websocket, random.choice(config.autostart_messages)
)
case ["chat", user, message, *backlog]: case ["chat", user, message, *backlog]:
system_message = (user == "") system_message = user == ""
if system_message: if system_message:
continue continue
if not backlog: if not backlog:
@@ -163,14 +217,17 @@ async def connect(vm_name: str):
def get_rank(username: str) -> CollabVMRank: def get_rank(username: str) -> CollabVMRank:
return vms[vm_name]["users"].get(username, {}).get("rank") return vms[vm_name]["users"].get(username, {}).get("rank")
def admin_check(username: str) -> bool: def admin_check(username: str) -> bool:
return username in config.admins and get_rank(username) > CollabVMRank.Unregistered return (
username in config.admins
and get_rank(username) > CollabVMRank.Unregistered
)
utc_now = datetime.now(timezone.utc) utc_now = datetime.now(timezone.utc)
utc_day = utc_now.strftime("%Y-%m-%d") utc_day = utc_now.strftime("%Y-%m-%d")
timestamp = utc_now.isoformat() timestamp = utc_now.isoformat()
with open(log_file_path, "r+") as log_file: with open(log_file_path, "r+") as log_file:
try: try:
log_data = json.load(log_file) log_data = json.load(log_file)
@@ -193,48 +250,75 @@ async def connect(vm_name: str):
# "message": backlog_message # "message": backlog_message
# }) # })
log_data[utc_day].append({ log_data[utc_day].append(
"type": "chat", {
"timestamp": timestamp, "type": "chat",
"username": user, "timestamp": timestamp,
"message": message "username": user,
}) "message": message,
}
)
log_file.seek(0) log_file.seek(0)
json.dump(log_data, log_file, indent=4) json.dump(log_data, log_file, indent=4)
log_file.truncate() log_file.truncate()
if config.commands["enabled"] and message.startswith(config.commands["prefix"]): if config.commands["enabled"] and message.startswith(
command = message[len(config.commands["prefix"]):].strip().lower() config.commands["prefix"]
):
command = (
message[len(config.commands["prefix"]) :].strip().lower()
)
match command: match command:
case "whoami": case "whoami":
await send_chat_message(websocket, f"You are {user} with rank {get_rank(user).name}.") await send_chat_message(
websocket,
f"You are {user} with rank {get_rank(user).name}.",
)
case "about": case "about":
await send_chat_message(websocket, config.responses.get("about", "CVM-Sentry (NO RESPONSE CONFIGURED)")) await send_chat_message(
websocket,
config.responses.get(
"about", "CVM-Sentry (NO RESPONSE CONFIGURED)"
),
)
case "dump": case "dump":
if not admin_check(user): if not admin_check(user):
continue continue
log.debug(f"({STATE.name} - {vm_name}) Dumping user list for VM {vm_name}: {vms[vm_name]['users']}") log.debug(
await send_chat_message(websocket, f"Dumped user list to console.") f"({STATE.name} - {vm_name}) Dumping user list for VM {vm_name}: {vms[vm_name]['users']}"
)
await send_chat_message(
websocket, f"Dumped user list to console."
)
case ["adduser", count, *list]: case ["adduser", count, *list]:
for i in range(int(count)): for i in range(int(count)):
user = list[i * 2] user = list[i * 2]
rank = CollabVMRank(int(list[i * 2 + 1])) rank = CollabVMRank(int(list[i * 2 + 1]))
if user in vms[vm_name]["users"]: if user in vms[vm_name]["users"]:
vms[vm_name]["users"][user]["rank"] = rank vms[vm_name]["users"][user]["rank"] = rank
log.info(f"[{vm_name}] User '{user}' rank updated to {rank.name}.") log.info(
f"[{vm_name}] User '{user}' rank updated to {rank.name}."
)
else: else:
vms[vm_name]["users"][user] = {"rank": rank} vms[vm_name]["users"][user] = {"rank": rank}
log.info(f"[{vm_name}] User '{user}' connected with rank {rank.name}.") log.info(
f"[{vm_name}] User '{user}' connected with rank {rank.name}."
)
case ["turn", _, "0"]: case ["turn", _, "0"]:
if STATE < CollabVMState.LOGGED_IN: if STATE < CollabVMState.LOGGED_IN:
continue continue
if vms[vm_name]["active_turn_user"] is None and not vms[vm_name]["turn_queue"]: if (
#log.debug(f"({STATE.name} - {vm_name}) Incoming queue exhaustion matches the VM's state. Dropping update.") vms[vm_name]["active_turn_user"] is None
and not vms[vm_name]["turn_queue"]
):
# log.debug(f"({STATE.name} - {vm_name}) Incoming queue exhaustion matches the VM's state. Dropping update.")
continue continue
vms[vm_name]["active_turn_user"] = None vms[vm_name]["active_turn_user"] = None
vms[vm_name]["turn_queue"] = [] vms[vm_name]["turn_queue"] = []
log.debug(f"({STATE.name} - {vm_name}) Turn queue is naturally exhausted.") log.debug(
f"({STATE.name} - {vm_name}) Turn queue is naturally exhausted."
)
case ["png", "0", "0", "0", "0", initial_frame_b64]: case ["png", "0", "0", "0", "0", initial_frame_b64]:
# Decode the base64 image data # Decode the base64 image data
initial_frame_data = base64.b64decode(initial_frame_b64) initial_frame_data = base64.b64decode(initial_frame_b64)
@@ -250,7 +334,9 @@ async def connect(vm_name: str):
# Assign the in-memory framebuffer to the VM's dictionary # Assign the in-memory framebuffer to the VM's dictionary
vms[vm_name]["framebuffer"] = framebuffer vms[vm_name]["framebuffer"] = framebuffer
framebuffer_size = framebuffer.getbuffer().nbytes framebuffer_size = framebuffer.getbuffer().nbytes
log.info(f"({STATE.name} - {vm_name}) !!! WHOLE FRAME UPDATE !!! ({framebuffer_size} bytes)") log.info(
f"({STATE.name} - {vm_name}) !!! WHOLE FRAME UPDATE !!! ({framebuffer_size} bytes)"
)
case ["png", "0", "0", x, y, rect_b64]: case ["png", "0", "0", x, y, rect_b64]:
# Decode the base64 image data for the rectangle # Decode the base64 image data for the rectangle
rect_data = base64.b64decode(rect_b64) rect_data = base64.b64decode(rect_b64)
@@ -277,54 +363,69 @@ async def connect(vm_name: str):
# Log the updated framebuffer size # Log the updated framebuffer size
framebuffer_size = framebuffer.getbuffer().nbytes framebuffer_size = framebuffer.getbuffer().nbytes
log.debug(f"({STATE.name} - {vm_name}) Updated framebuffer size: {framebuffer_size} bytes") log.debug(
f"({STATE.name} - {vm_name}) Updated framebuffer size: {framebuffer_size} bytes"
)
else: else:
continue continue
case ["turn", turn_time, count, current_turn, *queue]: case ["turn", turn_time, count, current_turn, *queue]:
if queue == vms[vm_name]["turn_queue"] and current_turn == vms[vm_name]["active_turn_user"]: if (
#log.debug(f"({STATE.name} - {vm_name}) Incoming turn update matches the VM's state. Dropping update.") queue == vms[vm_name]["turn_queue"]
continue and current_turn == vms[vm_name]["active_turn_user"]
for user in vms[vm_name]["users"]: ):
vms[vm_name]["turn_queue"] = queue # log.debug(f"({STATE.name} - {vm_name}) Incoming turn update matches the VM's state. Dropping update.")
vms[vm_name]["active_turn_user"] = current_turn if current_turn != "" else None continue
if current_turn: for user in vms[vm_name]["users"]:
utc_now = datetime.now(timezone.utc) vms[vm_name]["turn_queue"] = queue
utc_day = utc_now.strftime("%Y-%m-%d") vms[vm_name]["active_turn_user"] = (
timestamp = utc_now.isoformat() current_turn if current_turn != "" else None
)
if current_turn:
utc_now = datetime.now(timezone.utc)
utc_day = utc_now.strftime("%Y-%m-%d")
timestamp = utc_now.isoformat()
with open(log_file_path, "r+") as log_file: with open(log_file_path, "r+") as log_file:
try: try:
log_data = json.load(log_file) log_data = json.load(log_file)
except json.JSONDecodeError: except json.JSONDecodeError:
log_data = {} log_data = {}
if utc_day not in log_data: if utc_day not in log_data:
log_data[utc_day] = [] log_data[utc_day] = []
log_data[utc_day].append({ log_data[utc_day].append(
{
"type": "turn", "type": "turn",
"timestamp": timestamp, "timestamp": timestamp,
"active_turn_user": current_turn, "active_turn_user": current_turn,
"queue": queue "queue": queue,
}) }
)
log_file.seek(0) log_file.seek(0)
json.dump(log_data, log_file, indent=4) json.dump(log_data, log_file, indent=4)
log_file.truncate() log_file.truncate()
log.debug(f"({STATE.name} - {vm_name}) Turn update: turn_time={turn_time}, count={count}, current_turn={current_turn}, queue={queue}") log.debug(
f"({STATE.name} - {vm_name}) Turn update: turn_time={turn_time}, count={count}, current_turn={current_turn}, queue={queue}"
)
case ["remuser", count, *list]: case ["remuser", count, *list]:
for i in range(int(count)): for i in range(int(count)):
username = list[i] username = list[i]
if username in vms[vm_name]["users"]: if username in vms[vm_name]["users"]:
del vms[vm_name]["users"][username] del vms[vm_name]["users"][username]
log.info(f"[{vm_name}] User '{username}' left.") log.info(f"[{vm_name}] User '{username}' left.")
case ["flag", *args] | ["size", *args] | ["png", *args] | ["sync", *args]: case (
["flag", *args] | ["size", *args] | ["png", *args] | ["sync", *args]
):
continue continue
case _: case _:
if decoded is not None: if decoded is not None:
log.debug(f"({STATE.name} - {vm_name}) Unhandled message: {decoded}") log.debug(
f"({STATE.name} - {vm_name}) Unhandled message: {decoded}"
)
log.info(f"({STATE.name}) CVM-Sentry started") log.info(f"({STATE.name}) CVM-Sentry started")
@@ -334,35 +435,45 @@ for vm in config.vms.keys():
asyncio.run(connect(vm_name)) asyncio.run(connect(vm_name))
async def main(): async def main():
async def connect_with_reconnect(vm_name: str): async def connect_with_reconnect(vm_name: str):
while True: while True:
try: try:
await connect(vm_name) await connect(vm_name)
except websockets.exceptions.ConnectionClosedError as e: except websockets.exceptions.ConnectionClosedError as e:
log.warning(f"Connection to VM '{vm_name}' closed with error: {e}. Reconnecting...") log.warning(
f"Connection to VM '{vm_name}' closed with error: {e}. Reconnecting..."
)
await asyncio.sleep(5) # Wait before attempting to reconnect await asyncio.sleep(5) # Wait before attempting to reconnect
except websockets.exceptions.ConnectionClosedOK: except websockets.exceptions.ConnectionClosedOK:
log.warning(f"Connection to VM '{vm_name}' closed cleanly (code 1005). Reconnecting...") log.warning(
f"Connection to VM '{vm_name}' closed cleanly (code 1005). Reconnecting..."
)
await asyncio.sleep(5) # Wait before attempting to reconnect await asyncio.sleep(5) # Wait before attempting to reconnect
except websockets.exceptions.InvalidStatus as e: except websockets.exceptions.InvalidStatus as e:
log.error(f"Failed to connect to VM '{vm_name}' with status code: {e}. Reconnecting...") log.error(
f"Failed to connect to VM '{vm_name}' with status code: {e}. Reconnecting..."
)
await asyncio.sleep(10) # Wait longer for HTTP errors await asyncio.sleep(10) # Wait longer for HTTP errors
except websockets.exceptions.WebSocketException as e: except websockets.exceptions.WebSocketException as e:
log.error(f"WebSocket error connecting to VM '{vm_name}': {e}. Reconnecting...") log.error(
f"WebSocket error connecting to VM '{vm_name}': {e}. Reconnecting..."
)
await asyncio.sleep(5) await asyncio.sleep(5)
except Exception as e: except Exception as e:
log.error(f"Unexpected error connecting to VM '{vm_name}': {e}. Reconnecting...") log.error(
f"Unexpected error connecting to VM '{vm_name}': {e}. Reconnecting..."
)
await asyncio.sleep(10) # Wait longer for unexpected errors await asyncio.sleep(10) # Wait longer for unexpected errors
# Create tasks for VM connections # Create tasks for VM connections
vm_tasks = [connect_with_reconnect(vm) for vm in config.vms.keys()] vm_tasks = [connect_with_reconnect(vm) for vm in config.vms.keys()]
# Add periodic snapshot task # Add periodic snapshot task
snapshot_task = periodic_snapshot_task() snapshot_task = periodic_snapshot_task()
# Run all tasks concurrently # Run all tasks concurrently
all_tasks = [snapshot_task] + vm_tasks all_tasks = [snapshot_task] + vm_tasks
await asyncio.gather(*all_tasks) await asyncio.gather(*all_tasks)
asyncio.run(main()) asyncio.run(main())