Compare commits
31 Commits
48ebbd4f60
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 04b289861a | |||
| 925afed7a5 | |||
| 0959d17403 | |||
| bc5bb81330 | |||
| cdfbc7e55f | |||
| c479a86f29 | |||
| 5e1fcf37d0 | |||
| a48ebd4b72 | |||
| 63dc3600d5 | |||
| 2154c29515 | |||
| 81ba086b39 | |||
| 78b57f10c4 | |||
| 8070f79164 | |||
| 918fae093f | |||
| 195e2799a5 | |||
| 359e366fe0 | |||
| b1f608d14b | |||
| 09d1a77ea5 | |||
| 299aa4e0b1 | |||
| 167ac1858b | |||
| e80b3f764f | |||
| f846d55d44 | |||
| c7ae849d0e | |||
| 9727eba274 | |||
| 906f26c220 | |||
| c5d7c7d24c | |||
| fd40af02bd | |||
| 8b64179f41 | |||
| f0bd973461 | |||
| ca021408ca | |||
| 49725095cb |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -175,4 +175,6 @@ cython_debug/
|
||||
.pypirc
|
||||
|
||||
config.py
|
||||
logs/
|
||||
|
||||
logs/
|
||||
old_logs/
|
||||
@@ -1,3 +1,5 @@
|
||||
# CVM-Sentry
|
||||
Python application for taking screenshots and logging chat from a CollabVM instance.
|
||||
|
||||
Python application for logging chat (Actual usefulness... EVENTUALLY)
|
||||
# HEAVY DISCLAIMER
|
||||
A lot of the code was written by the geriatric Claude by Anthropic, in a mix of laziness and inability to write good code. Some of it has been cleaned up, and the bot is considered to be in a stable state. Pull requests in the form of patches sent to `clair@nixlabs.dev` are welcome.
|
||||
529
cvmsentry.py
529
cvmsentry.py
@@ -1,5 +1,12 @@
|
||||
from typing import List
|
||||
from cvmlib import guac_decode, guac_encode, CollabVMRank, CollabVMState, CollabVMClientRenameStatus
|
||||
from urllib.parse import urlparse
|
||||
from cvmlib import (
|
||||
guac_decode,
|
||||
guac_encode,
|
||||
CollabVMRank,
|
||||
CollabVMState,
|
||||
CollabVMClientRenameStatus,
|
||||
)
|
||||
import config
|
||||
import os, random, websockets, asyncio
|
||||
from websockets import Subprotocol, Origin
|
||||
@@ -7,16 +14,14 @@ import logging
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
from snapper import snap_all_vms
|
||||
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import base64
|
||||
import imagehash
|
||||
LOG_LEVEL = getattr(config, "log_level", "INFO")
|
||||
|
||||
# Prepare logs
|
||||
if not os.path.exists("logs"):
|
||||
os.makedirs("logs")
|
||||
log_format = logging.Formatter(
|
||||
"[%(asctime)s:%(name)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
)
|
||||
log_format = logging.Formatter("[%(asctime)s:%(name)s] %(levelname)s - %(message)s")
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
stdout_handler.setFormatter(log_format)
|
||||
log = logging.getLogger("CVMSentry")
|
||||
@@ -25,7 +30,6 @@ log.addHandler(stdout_handler)
|
||||
|
||||
vms = {}
|
||||
vm_botuser = {}
|
||||
STATE = CollabVMState.WS_DISCONNECTED
|
||||
|
||||
def get_origin_from_ws_url(ws_url: str) -> str:
|
||||
domain = (
|
||||
@@ -38,255 +42,464 @@ def get_origin_from_ws_url(ws_url: str) -> str:
|
||||
is_wss = ws_url.startswith("wss:")
|
||||
return f"http{'s' if is_wss else ''}://{domain}/"
|
||||
|
||||
|
||||
async def send_chat_message(websocket, message: str):
|
||||
log.debug(f"Sending chat message: {message}")
|
||||
await websocket.send(guac_encode(["chat", message]))
|
||||
|
||||
|
||||
async def send_guac(websocket, *args: str):
|
||||
await websocket.send(guac_encode(list(args)))
|
||||
|
||||
|
||||
async def periodic_snapshot_task():
|
||||
"""Background task that captures VM snapshots."""
|
||||
"""Background task that saves VM framebuffers as snapshots in WEBP format."""
|
||||
log.info("Starting periodic snapshot task")
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(30) # Wait 30 seconds
|
||||
log.debug("Running periodic snapshot capture...")
|
||||
|
||||
# Create snapshots directory with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
snapshot_dir = os.path.join("logs", timestamp)
|
||||
|
||||
# Capture all VMs
|
||||
await snap_all_vms(snapshot_dir)
|
||||
|
||||
await asyncio.sleep(config.snapshot_cadence)
|
||||
log.debug("Running periodic framebuffer snapshot capture...")
|
||||
|
||||
save_tasks = []
|
||||
for vm_name, vm_data in vms.items():
|
||||
# Skip if VM doesn't have a framebuffer
|
||||
if not vm_data.get("framebuffer"):
|
||||
continue
|
||||
|
||||
# Create directory structure if it doesn't exist - [date]/[vm] structure in UTC
|
||||
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
snapshot_dir = os.path.join(config.log_directory, "webp", date_str, vm_name)
|
||||
os.makedirs(snapshot_dir, exist_ok=True)
|
||||
|
||||
# Generate formatted timestamp in UTC
|
||||
timestamp = datetime.now(timezone.utc).strftime("%H-%M-%S")
|
||||
filename = f"{timestamp}.webp"
|
||||
filepath = os.path.join(snapshot_dir, filename)
|
||||
|
||||
# Get framebuffer reference (no copy needed)
|
||||
framebuffer = vm_data["framebuffer"]
|
||||
if not framebuffer:
|
||||
continue
|
||||
|
||||
# Calculate difference hash asynchronously to avoid blocking
|
||||
current_hash = await asyncio.to_thread(
|
||||
lambda: str(imagehash.dhash(framebuffer))
|
||||
)
|
||||
|
||||
# Only save if the framebuffer has changed since last snapshot
|
||||
if current_hash != vm_data.get("last_frame_hash"):
|
||||
# Pass framebuffer directly without copying
|
||||
save_tasks.append(
|
||||
asyncio.create_task(
|
||||
save_image_async(
|
||||
framebuffer, filepath, vm_name, vm_data, current_hash
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for all save tasks to complete
|
||||
if save_tasks:
|
||||
await asyncio.gather(*save_tasks)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error in periodic snapshot task: {e}")
|
||||
# Continue running even if there's an error
|
||||
|
||||
async def connect(vm_name: str):
|
||||
global STATE
|
||||
|
||||
async def save_image_async(image, filepath, vm_name, vm_data, current_hash):
|
||||
"""Save an image to disk asynchronously."""
|
||||
try:
|
||||
# Run the image saving in a thread pool to avoid blocking
|
||||
await asyncio.to_thread(
|
||||
image.save, filepath, format="WEBP", quality=65, method=6, minimize_size=True
|
||||
)
|
||||
vm_data["last_frame_hash"] = current_hash
|
||||
log.info(f"Saved snapshot of {vm_name} ({datetime.now(timezone.utc).strftime('%H:%M:%S')} UTC)")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to save snapshot for {vm_name}: {e}")
|
||||
|
||||
|
||||
async def connect(vm_obj: dict):
|
||||
log.debug(f"Connecting to VM at {vm_obj['ws_url']} with origin {get_origin_from_ws_url(vm_obj['ws_url'])}")
|
||||
global vms
|
||||
global vm_botuser
|
||||
if vm_name not in config.vms:
|
||||
log.error(f"VM '{vm_name}' not found in configuration.")
|
||||
return
|
||||
vms[vm_name] = {"turn_queue": [], "active_turn_user": None, "users": {}}
|
||||
uri = config.vms[vm_name]
|
||||
log_file_path = os.path.join(getattr(config, "log_directory", "logs"), f"{vm_name}.json")
|
||||
if not os.path.exists(log_file_path):
|
||||
with open(log_file_path, "w") as log_file:
|
||||
log_file.write("{}")
|
||||
fqdn = urlparse(vm_obj["ws_url"]).netloc
|
||||
STATE = CollabVMState.WS_DISCONNECTED
|
||||
log_label = vm_obj.get("log_label") or f"{fqdn}-{vm_obj.get('node', '')}"
|
||||
vms[log_label] = {
|
||||
"turn_queue": [],
|
||||
"active_turn_user": None,
|
||||
"users": {},
|
||||
"framebuffer": None,
|
||||
"last_frame_hash": None,
|
||||
"size": (0, 0),
|
||||
}
|
||||
ws_url = vm_obj["ws_url"]
|
||||
log_directory = getattr(config, "log_directory", "./logs")
|
||||
# Create VM-specific log directory
|
||||
vm_log_directory = os.path.join(log_directory, log_label)
|
||||
os.makedirs(vm_log_directory, exist_ok=True)
|
||||
|
||||
origin = Origin(vm_obj.get("origin_override", get_origin_from_ws_url(ws_url)))
|
||||
|
||||
async with websockets.connect(
|
||||
uri=uri,
|
||||
uri=ws_url,
|
||||
subprotocols=[Subprotocol("guacamole")],
|
||||
origin=Origin(get_origin_from_ws_url(uri)),
|
||||
user_agent_header="cvmsentry/1 (https://git.nixlabs.dev/clair/cvmsentry)"
|
||||
origin=Origin(origin),
|
||||
user_agent_header="cvmsentry/1 (https://git.nixlabs.dev/clair/cvmsentry)",
|
||||
) as websocket:
|
||||
STATE = CollabVMState.WS_CONNECTED
|
||||
log.info(f"Connected to VM '{vm_name}' at {uri}")
|
||||
await send_guac(websocket, "rename", "")
|
||||
await send_guac(websocket, "connect", vm_name)
|
||||
if vm_name not in vm_botuser:
|
||||
vm_botuser[vm_name] = ""
|
||||
# response = await websocket.recv()
|
||||
log.info(f"Connected to VM '{log_label}' at {ws_url}")
|
||||
await send_guac(websocket, "rename", config.unauth_name)
|
||||
await send_guac(websocket, "connect", vm_obj["node"])
|
||||
if log_label not in vm_botuser:
|
||||
vm_botuser[log_label] = ""
|
||||
async for message in websocket:
|
||||
decoded: List[str] = guac_decode(str(message))
|
||||
match decoded:
|
||||
case ["nop"]:
|
||||
await send_guac(websocket, "nop")
|
||||
case ["auth", config.auth_server]:
|
||||
case ["auth", auth_server]:
|
||||
await asyncio.sleep(1)
|
||||
await send_guac(websocket, "login", config.credentials["session_auth"])
|
||||
case ["connect", connection_status, turns_enabled, votes_enabled, uploads_enabled]:
|
||||
if vm_obj.get("auth"):
|
||||
await send_guac(
|
||||
websocket,
|
||||
"login",
|
||||
vm_obj["auth"]["session_auth"],
|
||||
)
|
||||
else:
|
||||
log.error(
|
||||
f"Auth server '{auth_server}' not recognized for VM '{log_label}'"
|
||||
)
|
||||
case [
|
||||
"connect",
|
||||
connection_status,
|
||||
turns_enabled,
|
||||
votes_enabled,
|
||||
uploads_enabled,
|
||||
]:
|
||||
if connection_status == "1":
|
||||
STATE = CollabVMState.VM_CONNECTED
|
||||
log.info(f"Connected to VM '{vm_name}' successfully. Turns enabled: {bool(int(turns_enabled))}, Votes enabled: {bool(int(votes_enabled))}, Uploads enabled: {bool(int(uploads_enabled))}")
|
||||
log.info(
|
||||
f"Connected to VM '{log_label}' successfully. Turns enabled: {bool(int(turns_enabled))}, Votes enabled: {bool(int(votes_enabled))}, Uploads enabled: {bool(int(uploads_enabled))}"
|
||||
)
|
||||
else:
|
||||
log.error(f"Failed to connect to VM '{vm_name}'. Connection status: {connection_status}")
|
||||
log.debug(
|
||||
f"Failed to connect to VM '{log_label}'. Connection status: {connection_status}"
|
||||
)
|
||||
STATE = CollabVMState.WS_DISCONNECTED
|
||||
await websocket.close()
|
||||
case ["rename", *instructions]:
|
||||
match instructions:
|
||||
case ["0", status, new_name]:
|
||||
if CollabVMClientRenameStatus(int(status)) == CollabVMClientRenameStatus.SUCCEEDED:
|
||||
log.debug(f"({STATE.name} - {vm_name}) Bot rename on VM {vm_name}: {vm_botuser[vm_name]} -> {new_name}")
|
||||
vm_botuser[vm_name] = new_name
|
||||
if (
|
||||
CollabVMClientRenameStatus(int(status))
|
||||
== CollabVMClientRenameStatus.SUCCEEDED
|
||||
):
|
||||
log.debug(
|
||||
f"({STATE.name} - {log_label}) Bot rename on VM {log_label}: {vm_botuser[log_label]} -> {new_name}"
|
||||
)
|
||||
vm_botuser[log_label] = new_name
|
||||
else:
|
||||
log.debug(f"({STATE.name} - {vm_name}) Bot rename on VM {vm_name} failed with status {CollabVMClientRenameStatus(int(status)).name}")
|
||||
log.debug(
|
||||
f"({STATE.name} - {log_label}) Bot rename on VM {log_label} failed with status {CollabVMClientRenameStatus(int(status)).name}"
|
||||
)
|
||||
case ["1", old_name, new_name]:
|
||||
if old_name in vms[vm_name]["users"]:
|
||||
log.debug(f"({STATE.name} - {vm_name}) User rename on VM {vm_name}: {old_name} -> {new_name}")
|
||||
vms[vm_name]["users"][new_name] = vms[vm_name]["users"].pop(old_name)
|
||||
if old_name in vms[log_label]["users"]:
|
||||
log.debug(
|
||||
f"({STATE.name} - {log_label}) User rename on VM {log_label}: {old_name} -> {new_name}"
|
||||
)
|
||||
vms[log_label]["users"][new_name] = vms[log_label][
|
||||
"users"
|
||||
].pop(old_name)
|
||||
case ["login", "1"]:
|
||||
STATE = CollabVMState.LOGGED_IN
|
||||
if config.send_autostart and config.autostart_messages:
|
||||
await send_chat_message(websocket, random.choice(config.autostart_messages))
|
||||
await send_chat_message(
|
||||
websocket, random.choice(config.autostart_messages)
|
||||
)
|
||||
case ["chat", user, message, *backlog]:
|
||||
system_message = (user == "")
|
||||
if system_message:
|
||||
system_message = user == ""
|
||||
if system_message or backlog:
|
||||
continue
|
||||
if not backlog:
|
||||
log.info(f"[{vm_name} - {user}]: {message}")
|
||||
log.info(f"[{log_label} - {user}]: {message}")
|
||||
|
||||
def get_rank(username: str) -> CollabVMRank:
|
||||
return vms[vm_name]["users"].get(username, {}).get("rank")
|
||||
|
||||
return vms[log_label]["users"].get(username, {}).get("rank")
|
||||
|
||||
def admin_check(username: str) -> bool:
|
||||
return username in config.admins and get_rank(username) > CollabVMRank.Unregistered
|
||||
|
||||
return (
|
||||
username in config.admins
|
||||
and get_rank(username) > CollabVMRank.Unregistered
|
||||
)
|
||||
|
||||
utc_now = datetime.now(timezone.utc)
|
||||
utc_day = utc_now.strftime("%Y-%m-%d")
|
||||
timestamp = utc_now.isoformat()
|
||||
|
||||
# Get daily log file path
|
||||
daily_log_path = os.path.join(vm_log_directory, f"{utc_day}.json")
|
||||
|
||||
with open(log_file_path, "r+") as log_file:
|
||||
try:
|
||||
log_data = json.load(log_file)
|
||||
except json.JSONDecodeError:
|
||||
log_data = {}
|
||||
# Load existing log data or create new
|
||||
if os.path.exists(daily_log_path):
|
||||
with open(daily_log_path, "r") as log_file:
|
||||
try:
|
||||
log_data = json.load(log_file)
|
||||
except json.JSONDecodeError:
|
||||
log_data = []
|
||||
else:
|
||||
log_data = []
|
||||
|
||||
if utc_day not in log_data:
|
||||
log_data[utc_day] = []
|
||||
|
||||
if backlog:
|
||||
pass
|
||||
# for i in range(0, len(backlog), 2):
|
||||
# backlog_user = backlog[i]
|
||||
# backlog_message = backlog[i + 1]
|
||||
# if not any(entry["message"] == backlog_message and entry["username"] == backlog_user for entry in log_data[utc_day]):
|
||||
# log.info(f"[{vm_name} - {backlog_user} (backlog)]: {backlog_message}")
|
||||
# log_data[utc_day].append({
|
||||
# "timestamp": timestamp,
|
||||
# "username": backlog_user,
|
||||
# "message": backlog_message
|
||||
# })
|
||||
|
||||
log_data[utc_day].append({
|
||||
log_data.append(
|
||||
{
|
||||
"type": "chat",
|
||||
"timestamp": timestamp,
|
||||
"username": user,
|
||||
"message": message
|
||||
})
|
||||
|
||||
log_file.seek(0)
|
||||
"message": message,
|
||||
}
|
||||
)
|
||||
|
||||
with open(daily_log_path, "w") as log_file:
|
||||
json.dump(log_data, log_file, indent=4)
|
||||
log_file.truncate()
|
||||
|
||||
if config.commands["enabled"] and message.startswith(config.commands["prefix"]):
|
||||
command = message[len(config.commands["prefix"]):].strip().lower()
|
||||
if config.commands["enabled"] and message.startswith(
|
||||
config.commands["prefix"]
|
||||
):
|
||||
command_full = message[len(config.commands["prefix"]):].strip().lower()
|
||||
command = command_full.split(" ")[0] if " " in command_full else command_full
|
||||
match command:
|
||||
case "whoami":
|
||||
await send_chat_message(websocket, f"You are {user} with rank {get_rank(user).name}.")
|
||||
await send_chat_message(
|
||||
websocket,
|
||||
f"You are {user} with rank {get_rank(user).name}.",
|
||||
)
|
||||
case "about":
|
||||
await send_chat_message(websocket, config.responses.get("about", "CVM-Sentry (NO RESPONSE CONFIGURED)"))
|
||||
await send_chat_message(
|
||||
websocket,
|
||||
config.responses.get(
|
||||
"about", "CVM-Sentry (NO RESPONSE CONFIGURED)"
|
||||
),
|
||||
)
|
||||
case "dump":
|
||||
if not admin_check(user):
|
||||
continue
|
||||
log.debug(f"({STATE.name} - {vm_name}) Dumping user list for VM {vm_name}: {vms[vm_name]['users']}")
|
||||
await send_chat_message(websocket, f"Dumped user list to console.")
|
||||
log.info(
|
||||
f"({STATE.name} - {log_label}) Dumping user list for VM {log_label}: {vms[log_label]['users']}"
|
||||
)
|
||||
await send_chat_message(
|
||||
websocket, f"Dumped user list to console."
|
||||
)
|
||||
case ["adduser", count, *list]:
|
||||
for i in range(int(count)):
|
||||
user = list[i * 2]
|
||||
rank = CollabVMRank(int(list[i * 2 + 1]))
|
||||
if user in vms[vm_name]["users"]:
|
||||
vms[vm_name]["users"][user]["rank"] = rank
|
||||
log.info(f"[{vm_name}] User '{user}' rank updated to {rank.name}.")
|
||||
|
||||
if user in vms[log_label]["users"]:
|
||||
vms[log_label]["users"][user]["rank"] = rank
|
||||
log.info(
|
||||
f"[{log_label}] User '{user}' rank updated to {rank.name}."
|
||||
)
|
||||
else:
|
||||
vms[vm_name]["users"][user] = {"rank": rank}
|
||||
log.info(f"[{vm_name}] User '{user}' connected with rank {rank.name}.")
|
||||
vms[log_label]["users"][user] = {"rank": rank}
|
||||
log.info(
|
||||
f"[{log_label}] User '{user}' connected with rank {rank.name}."
|
||||
)
|
||||
case ["turn", _, "0"]:
|
||||
if STATE < CollabVMState.LOGGED_IN:
|
||||
continue
|
||||
if vms[vm_name]["active_turn_user"] is None and not vms[vm_name]["turn_queue"]:
|
||||
#log.debug(f"({STATE.name} - {vm_name}) Incoming queue exhaustion matches the VM's state. Dropping update.")
|
||||
if (
|
||||
vms[log_label]["active_turn_user"] is None
|
||||
and not vms[log_label]["turn_queue"]
|
||||
):
|
||||
# log.debug(f"({STATE.name} - {log_label}) Incoming queue exhaustion matches the VM's state. Dropping update.")
|
||||
continue
|
||||
vms[vm_name]["active_turn_user"] = None
|
||||
vms[vm_name]["turn_queue"] = []
|
||||
log.debug(f"({STATE.name} - {vm_name}) Turn queue is naturally exhausted.")
|
||||
case ["turn", turn_time, count, current_turn, *queue]:
|
||||
if queue == vms[vm_name]["turn_queue"] and current_turn == vms[vm_name]["active_turn_user"]:
|
||||
#log.debug(f"({STATE.name} - {vm_name}) Incoming turn update matches the VM's state. Dropping update.")
|
||||
continue
|
||||
for user in vms[vm_name]["users"]:
|
||||
vms[vm_name]["turn_queue"] = queue
|
||||
vms[vm_name]["active_turn_user"] = current_turn if current_turn != "" else None
|
||||
if current_turn:
|
||||
utc_now = datetime.now(timezone.utc)
|
||||
utc_day = utc_now.strftime("%Y-%m-%d")
|
||||
timestamp = utc_now.isoformat()
|
||||
vms[log_label]["active_turn_user"] = None
|
||||
vms[log_label]["turn_queue"] = []
|
||||
log.debug(
|
||||
f"({STATE.name} - {log_label}) Turn queue is naturally exhausted."
|
||||
)
|
||||
case ["size", "0", width, height]:
|
||||
log.debug(
|
||||
f"({STATE.name} - {log_label}) !!! Framebuffer size update: {width}x{height} !!!"
|
||||
)
|
||||
vms[log_label]["size"] = (int(width), int(height))
|
||||
case ["png", "0", "0", "0", "0", full_frame_b64]:
|
||||
try:
|
||||
log.debug(
|
||||
f"({STATE.name} - {log_label}) !!! Received full framebuffer update !!!"
|
||||
)
|
||||
expected_width, expected_height = vms[log_label]["size"]
|
||||
|
||||
with open(log_file_path, "r+") as log_file:
|
||||
# Decode the base64 data to get the PNG image
|
||||
frame_data = base64.b64decode(full_frame_b64)
|
||||
frame_img = Image.open(BytesIO(frame_data))
|
||||
|
||||
# Validate image size and handle partial frames
|
||||
if expected_width > 0 and expected_height > 0:
|
||||
if frame_img.size != (expected_width, expected_height):
|
||||
log.debug(
|
||||
f"({STATE.name} - {log_label}) Partial framebuffer update: "
|
||||
f"expected {expected_width}x{expected_height}, got {frame_img.size}"
|
||||
)
|
||||
|
||||
# Create a new image of expected size if no framebuffer exists
|
||||
if vms[log_label]["framebuffer"] is None:
|
||||
vms[log_label]["framebuffer"] = Image.new(
|
||||
"RGB", (expected_width, expected_height)
|
||||
)
|
||||
|
||||
# Only update the portion that was received - modify in place
|
||||
if vms[log_label]["framebuffer"]:
|
||||
# Paste directly onto existing framebuffer
|
||||
vms[log_label]["framebuffer"].paste(frame_img, (0, 0))
|
||||
frame_img = vms[log_label]["framebuffer"]
|
||||
|
||||
# Update the framebuffer with the new image
|
||||
vms[log_label]["framebuffer"] = frame_img
|
||||
log.debug(
|
||||
f"({STATE.name} - {log_label}) Framebuffer updated with full frame, size: {frame_img.size}"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"({STATE.name} - {log_label}) Failed to process full framebuffer update: {e}"
|
||||
)
|
||||
case ["png", "0", "0", x, y, rect_b64]:
|
||||
try:
|
||||
log.debug(
|
||||
f"({STATE.name} - {log_label}) Received partial framebuffer update at position ({x}, {y})"
|
||||
)
|
||||
x, y = int(x), int(y)
|
||||
|
||||
# Decode the base64 data to get the PNG image fragment
|
||||
frame_data = base64.b64decode(rect_b64)
|
||||
fragment_img = Image.open(BytesIO(frame_data))
|
||||
|
||||
# If we don't have a framebuffer yet or it's incompatible, create one
|
||||
if vms[log_label]["framebuffer"] is None:
|
||||
# drop
|
||||
continue
|
||||
|
||||
# If we have a valid framebuffer, update it with the fragment
|
||||
if vms[log_label]["framebuffer"]:
|
||||
# Paste directly onto existing framebuffer (no copy needed)
|
||||
vms[log_label]["framebuffer"].paste(fragment_img, (x, y))
|
||||
log.debug(
|
||||
f"({STATE.name} - {log_label}) Updated framebuffer with fragment at ({x}, {y}), fragment size: {fragment_img.size}"
|
||||
)
|
||||
else:
|
||||
log.warning(
|
||||
f"({STATE.name} - {log_label}) Cannot update framebuffer - no base framebuffer exists"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"({STATE.name} - {log_label}) Failed to process partial framebuffer update: {e}"
|
||||
)
|
||||
case ["turn", turn_time, count, current_turn, *queue]:
|
||||
if (
|
||||
queue == vms[log_label]["turn_queue"]
|
||||
and current_turn == vms[log_label]["active_turn_user"]
|
||||
):
|
||||
continue
|
||||
for user in vms[log_label]["users"]:
|
||||
vms[log_label]["turn_queue"] = queue
|
||||
vms[log_label]["active_turn_user"] = (
|
||||
current_turn if current_turn != "" else None
|
||||
)
|
||||
if current_turn:
|
||||
log.info(
|
||||
f"[{log_label}] It's now {current_turn}'s turn. Queue: {queue}"
|
||||
)
|
||||
|
||||
utc_now = datetime.now(timezone.utc)
|
||||
utc_day = utc_now.strftime("%Y-%m-%d")
|
||||
timestamp = utc_now.isoformat()
|
||||
|
||||
# Get daily log file path
|
||||
daily_log_path = os.path.join(vm_log_directory, f"{utc_day}.json")
|
||||
|
||||
# Load existing log data or create new
|
||||
if os.path.exists(daily_log_path):
|
||||
with open(daily_log_path, "r") as log_file:
|
||||
try:
|
||||
log_data = json.load(log_file)
|
||||
except json.JSONDecodeError:
|
||||
log_data = {}
|
||||
|
||||
if utc_day not in log_data:
|
||||
log_data[utc_day] = []
|
||||
|
||||
log_data[utc_day].append({
|
||||
"type": "turn",
|
||||
"timestamp": timestamp,
|
||||
"active_turn_user": current_turn,
|
||||
"queue": queue
|
||||
})
|
||||
|
||||
log_file.seek(0)
|
||||
json.dump(log_data, log_file, indent=4)
|
||||
log_file.truncate()
|
||||
log.debug(f"({STATE.name} - {vm_name}) Turn update: turn_time={turn_time}, count={count}, current_turn={current_turn}, queue={queue}")
|
||||
log_data = []
|
||||
else:
|
||||
log_data = []
|
||||
|
||||
log_data.append(
|
||||
{
|
||||
"type": "turn",
|
||||
"timestamp": timestamp,
|
||||
"active_turn_user": current_turn,
|
||||
"queue": queue,
|
||||
}
|
||||
)
|
||||
|
||||
with open(daily_log_path, "w") as log_file:
|
||||
json.dump(log_data, log_file, indent=4)
|
||||
|
||||
case ["remuser", count, *list]:
|
||||
for i in range(int(count)):
|
||||
username = list[i]
|
||||
if username in vms[vm_name]["users"]:
|
||||
del vms[vm_name]["users"][username]
|
||||
log.info(f"[{vm_name}] User '{username}' left.")
|
||||
case ["flag", *args] | ["size", *args] | ["png", *args] | ["sync", *args]:
|
||||
if username in vms[log_label]["users"]:
|
||||
del vms[log_label]["users"][username]
|
||||
log.info(f"[{log_label}] User '{username}' left.")
|
||||
case ["flag", *args] | ["png", *args] | ["sync", *args]:
|
||||
continue
|
||||
case _:
|
||||
if decoded is not None:
|
||||
log.debug(f"({STATE.name} - {vm_name}) Unhandled message: {decoded}")
|
||||
log.debug(
|
||||
f"({STATE.name} - {log_label}) Unhandled message: {decoded}"
|
||||
)
|
||||
|
||||
log.info(f"({STATE.name}) CVM-Sentry started")
|
||||
log.info(f"CVM-Sentry started")
|
||||
|
||||
for vm in config.vms.keys():
|
||||
for vm_dict_label, vm_obj in config.vms.items():
|
||||
|
||||
def start_vm_thread(vm_name: str):
|
||||
asyncio.run(connect(vm_name))
|
||||
def start_vm_thread(vm_obj: dict):
|
||||
asyncio.run(connect(vm_obj))
|
||||
|
||||
async def main():
|
||||
|
||||
async def connect_with_reconnect(vm_name: str):
|
||||
|
||||
async def connect_with_reconnect(vm_obj: dict):
|
||||
while True:
|
||||
try:
|
||||
await connect(vm_name)
|
||||
await connect(vm_obj)
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
log.warning(f"Connection to VM '{vm_name}' closed with error: {e}. Reconnecting...")
|
||||
await asyncio.sleep(5) # Wait before attempting to reconnect
|
||||
log.error(
|
||||
f"Connection to VM '{vm_obj['ws_url']}' closed with error: {e}. Reconnecting..."
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
log.warning(f"Connection to VM '{vm_name}' closed cleanly (code 1005). Reconnecting...")
|
||||
await asyncio.sleep(5) # Wait before attempting to reconnect
|
||||
log.warning(
|
||||
f"Connection to VM '{vm_obj['ws_url']}' closed cleanly (code 1005). Reconnecting..."
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
except websockets.exceptions.InvalidStatus as e:
|
||||
log.error(f"Failed to connect to VM '{vm_name}' with status code: {e}. Reconnecting...")
|
||||
await asyncio.sleep(10) # Wait longer for HTTP errors
|
||||
log.debug(
|
||||
f"Failed to connect to VM '{vm_obj['ws_url']}' with status code: {e}. Reconnecting..."
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
except websockets.exceptions.WebSocketException as e:
|
||||
log.error(f"WebSocket error connecting to VM '{vm_name}': {e}. Reconnecting...")
|
||||
log.error(
|
||||
f"WebSocket error connecting to VM '{vm_obj['ws_url']}': {e}. Reconnecting..."
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
except Exception as e:
|
||||
log.error(f"Unexpected error connecting to VM '{vm_name}': {e}. Reconnecting...")
|
||||
await asyncio.sleep(10) # Wait longer for unexpected errors
|
||||
log.error(
|
||||
f"Unexpected error connecting to VM '{vm_obj['ws_url']}': {e}. Reconnecting..."
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Create tasks for VM connections
|
||||
vm_tasks = [connect_with_reconnect(vm) for vm in config.vms.keys()]
|
||||
|
||||
vm_tasks = [connect_with_reconnect(vm) for vm in config.vms.values()]
|
||||
|
||||
# Add periodic snapshot task
|
||||
snapshot_task = periodic_snapshot_task()
|
||||
|
||||
|
||||
# Run all tasks concurrently
|
||||
all_tasks = [snapshot_task] + vm_tasks
|
||||
await asyncio.gather(*all_tasks)
|
||||
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
2
poetry.lock
generated
2
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "imagehash"
|
||||
|
||||
@@ -1,241 +0,0 @@
|
||||
from typing import List
|
||||
from cvmlib import guac_decode, guac_encode, CollabVMRank, CollabVMState, CollabVMClientRenameStatus
|
||||
import config
|
||||
import os, websockets, asyncio, base64
|
||||
from websockets import Subprotocol, Origin
|
||||
import logging
|
||||
import sys
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from datetime import datetime
|
||||
import imagehash
|
||||
import glob
|
||||
|
||||
LOG_LEVEL = getattr(config, "log_level", "INFO")
|
||||
|
||||
# Setup logging
|
||||
log_format = logging.Formatter(
|
||||
"[%(asctime)s:%(name)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
stdout_handler.setFormatter(log_format)
|
||||
log = logging.getLogger("CVMSnapper")
|
||||
log.setLevel(LOG_LEVEL)
|
||||
log.addHandler(stdout_handler)
|
||||
|
||||
STATE = CollabVMState.WS_DISCONNECTED
|
||||
|
||||
def get_origin_from_ws_url(ws_url: str) -> str:
|
||||
domain = (
|
||||
ws_url.removeprefix("ws:")
|
||||
.removeprefix("wss:")
|
||||
.removeprefix("/")
|
||||
.removeprefix("/")
|
||||
.split("/", 1)[0]
|
||||
)
|
||||
is_wss = ws_url.startswith("wss:")
|
||||
return f"http{'s' if is_wss else ''}://{domain}/"
|
||||
|
||||
def get_file_dhash(file_path: str) -> str:
|
||||
"""Get dhash (difference hash) of an image file."""
|
||||
try:
|
||||
with Image.open(file_path) as img:
|
||||
hash_obj = imagehash.dhash(img)
|
||||
return str(hash_obj)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to get dhash for file {file_path}: {e}")
|
||||
return ""
|
||||
|
||||
def get_image_dhash_from_data(image_data: bytes) -> str:
|
||||
"""Get dhash (difference hash) of image data."""
|
||||
try:
|
||||
with Image.open(BytesIO(image_data)) as img:
|
||||
hash_obj = imagehash.dhash(img)
|
||||
return str(hash_obj)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to get dhash for image data: {e}")
|
||||
return ""
|
||||
|
||||
def get_latest_snapshot_path(vm_name: str) -> str:
|
||||
"""Get the path of the most recent snapshot for a VM."""
|
||||
try:
|
||||
snapshot_dir = os.path.join(config.log_directory, "webp", vm_name)
|
||||
if not os.path.exists(snapshot_dir):
|
||||
return ""
|
||||
|
||||
# Get all .webp files in the directory
|
||||
pattern = os.path.join(snapshot_dir, "snapshot_*.webp")
|
||||
files = glob.glob(pattern)
|
||||
|
||||
if not files:
|
||||
return ""
|
||||
|
||||
# Sort by modification time and return the most recent
|
||||
latest_file = max(files, key=os.path.getmtime)
|
||||
return latest_file
|
||||
except Exception as e:
|
||||
log.error(f"Failed to get latest snapshot for VM {vm_name}: {e}")
|
||||
return ""
|
||||
|
||||
def images_are_identical(image_data: bytes, existing_file_path: str) -> bool:
|
||||
"""Compare image data with an existing file using dhash to check if they're visually similar."""
|
||||
try:
|
||||
if not os.path.exists(existing_file_path):
|
||||
return False
|
||||
|
||||
# Get dhash of new image data
|
||||
new_hash = get_image_dhash_from_data(image_data)
|
||||
if not new_hash:
|
||||
return False
|
||||
|
||||
# Get dhash of existing file
|
||||
existing_hash = get_file_dhash(existing_file_path)
|
||||
if not existing_hash:
|
||||
return False
|
||||
|
||||
# Compare dhashes - they should be identical for very similar images
|
||||
# dhash is more forgiving than SHA256 and will detect visually identical images
|
||||
return new_hash == existing_hash
|
||||
except Exception as e:
|
||||
log.error(f"Failed to compare images: {e}")
|
||||
return False
|
||||
|
||||
async def send_guac(websocket, *args: str):
|
||||
await websocket.send(guac_encode(list(args)))
|
||||
|
||||
def convert_png_to_webp(b64_png_data: str, output_path: str, vm_name: str) -> bool:
|
||||
"""Convert base64 PNG data to WebP format and save to file, checking for duplicates."""
|
||||
try:
|
||||
# Decode base64 PNG data
|
||||
png_data = base64.b64decode(b64_png_data)
|
||||
|
||||
# Check if this image is identical to the latest snapshot
|
||||
latest_snapshot = get_latest_snapshot_path(vm_name)
|
||||
if latest_snapshot:
|
||||
# Convert PNG to WebP in memory for comparison
|
||||
with Image.open(BytesIO(png_data)) as img:
|
||||
webp_buffer = BytesIO()
|
||||
img.save(webp_buffer, "WEBP", quality=55, method=6, minimize_size=True)
|
||||
webp_data = webp_buffer.getvalue()
|
||||
|
||||
if images_are_identical(webp_data, latest_snapshot):
|
||||
log.debug(f"Snapshot for VM '{vm_name}' is identical to the previous one, skipping save to avoid duplicate")
|
||||
return True # Return True because the operation was successful (no error, just no need to save)
|
||||
|
||||
# Open PNG image from bytes and save as WebP
|
||||
with Image.open(BytesIO(png_data)) as img:
|
||||
# Convert and save as WebP
|
||||
img.save(output_path, "WEBP", quality=55, method=6, minimize_size=True)
|
||||
log.debug(f"Successfully converted and saved WebP image to: {output_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Failed to convert PNG to WebP: {e}")
|
||||
return False
|
||||
|
||||
async def snap_vm(vm_name: str, output_filename: str = "snapshots"):
|
||||
"""Connect to a VM and capture the initial frame as WebP."""
|
||||
global STATE
|
||||
|
||||
if vm_name not in config.vms:
|
||||
log.error(f"VM '{vm_name}' not found in configuration.")
|
||||
return False
|
||||
|
||||
# Ensure output directory exists
|
||||
|
||||
uri = config.vms[vm_name]
|
||||
|
||||
try:
|
||||
async with websockets.connect(
|
||||
uri=uri,
|
||||
subprotocols=[Subprotocol("guacamole")],
|
||||
origin=Origin(get_origin_from_ws_url(uri)),
|
||||
user_agent_header="cvmsnapper/1 (https://git.nixlabs.dev/clair/cvmsentry)",
|
||||
close_timeout=5, # Wait max 5 seconds for close handshake
|
||||
ping_interval=None # Disable ping for short-lived connections
|
||||
) as websocket:
|
||||
STATE = CollabVMState.WS_CONNECTED
|
||||
log.debug(f"Connected to VM '{vm_name}' at {uri}")
|
||||
|
||||
# Send connection commands
|
||||
await send_guac(websocket, "rename", "")
|
||||
await send_guac(websocket, "connect", vm_name)
|
||||
|
||||
# Wait for messages
|
||||
async for message in websocket:
|
||||
decoded: List[str] = guac_decode(str(message))
|
||||
match decoded:
|
||||
case ["nop"]:
|
||||
await send_guac(websocket, "nop")
|
||||
case ["auth", config.auth_server]:
|
||||
#await send_guac(websocket, "login", config.credentials["scrotter_auth"])
|
||||
continue
|
||||
case ["connect", connection_status, turns_enabled, votes_enabled, uploads_enabled]:
|
||||
if connection_status == "1":
|
||||
STATE = CollabVMState.VM_CONNECTED
|
||||
log.debug(f"Connected to VM '{vm_name}' successfully. Waiting for initial frame...")
|
||||
else:
|
||||
log.error(f"Failed to connect to VM '{vm_name}'. Connection status: {connection_status}")
|
||||
STATE = CollabVMState.WS_DISCONNECTED
|
||||
await websocket.close()
|
||||
return False
|
||||
case ["login", status, error]:
|
||||
if status == "0":
|
||||
log.debug(f"Authentication successful for VM '{vm_name}'")
|
||||
STATE = CollabVMState.LOGGED_IN
|
||||
else:
|
||||
log.error(f"Authentication failed for VM '{vm_name}'. Error: {error}")
|
||||
STATE = CollabVMState.WS_DISCONNECTED
|
||||
continue
|
||||
case ["png", "0", "0", "0", "0", b64_rect]:
|
||||
# This is the initial full frame
|
||||
log.debug(f"Received initial frame from VM '{vm_name}' ({len(b64_rect)} bytes)")
|
||||
|
||||
# Ensure the output directory exists
|
||||
os.makedirs(config.log_directory + f"/webp/{vm_name}", exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
output_filename = f"snapshot_{timestamp}.webp"
|
||||
output_path = os.path.join(config.log_directory, "webp", vm_name, output_filename)
|
||||
|
||||
# Convert PNG to WebP
|
||||
if convert_png_to_webp(b64_rect, output_path, vm_name):
|
||||
# Give a small delay before closing to ensure proper handshake
|
||||
await asyncio.sleep(0.1)
|
||||
try:
|
||||
await websocket.close(code=1000, reason="Screenshot captured")
|
||||
except Exception as close_error:
|
||||
log.debug(f"Error during close handshake for VM '{vm_name}': {close_error}")
|
||||
return True
|
||||
case _:
|
||||
#log.debug(f"Received unhandled message from VM '{vm_name}': {decoded}")
|
||||
continue
|
||||
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
log.debug(f"Connection to VM '{vm_name}' closed during snapshot capture (code {e.code}): {e.reason}")
|
||||
# This is expected when we close after getting the screenshot
|
||||
return True
|
||||
except websockets.exceptions.ConnectionClosedOK as e:
|
||||
log.debug(f"Connection to VM '{vm_name}' closed cleanly during snapshot capture")
|
||||
return True
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
log.debug(f"Connection to VM '{vm_name}' closed during snapshot capture (code {e.code}): {e.reason}")
|
||||
# This catches the "1000 no close frame received" errors
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Unexpected error while capturing VM '{vm_name}': {e}")
|
||||
return False
|
||||
|
||||
async def snap_all_vms(output_dir: str = "snapshots"):
|
||||
"""Capture snapshots of all configured VMs."""
|
||||
log.info("Starting snapshot capture for all VMs...")
|
||||
|
||||
# Create tasks for all VMs to run concurrently
|
||||
tasks = []
|
||||
vm_names = list(config.vms.keys())
|
||||
|
||||
for vm_name in vm_names:
|
||||
log.debug(f"Starting snapshot capture for VM: {vm_name}")
|
||||
tasks.append(snap_vm(vm_name, output_dir))
|
||||
|
||||
# Run tasks consecutively
|
||||
for task in tasks:
|
||||
await task
|
||||
Reference in New Issue
Block a user