diff --git a/wol.py b/wol.py index 835e4aa..4465059 100755 --- a/wol.py +++ b/wol.py @@ -1,12 +1,15 @@ #!/usr/bin/env python3 +from pathlib import Path from pyparsing import * from typing import Dict, Optional import socketserver import argparse import logging import psutil +import stat import socket +import subprocess import sys import os import re @@ -74,9 +77,16 @@ def sendMagicPacket(macAddress: str, iface: str) -> bool: logging.exception('Sending magic packet to {} (on {}) failed'.format(macAddress, iface)) return False +SPECIAL_WAKEUP_DIR = Path("/etc/wake-on-lan/special-wakeup.d") + def wake(hostname: str) -> bool: global hosts - if hostname in hosts: + special_wakeup_file = SPECIAL_WAKEUP_DIR / hostname + if special_wakeup_file.is_file() and os.access(special_wakeup_file, os.X_OK): + with subprocess.Popen(special_wakeup_file) as proc: + returncode = proc.wait() + return returncode == 0 + elif hostname in hosts: logging.info('Waking up {}...'.format(hostname)) host = hosts[hostname] if not 'iface' in host: @@ -89,6 +99,8 @@ def wake(hostname: str) -> bool: logging.warning('Unknown host "{}"'.format(hostname)) return False +LDH_RE = re.compile(r"[0-9A-Za-z\-]+") + class WakeRequestHandler(socketserver.StreamRequestHandler): def handle(self): self.connection.settimeout(6) @@ -96,13 +108,21 @@ def handle(self): logging.debug('Connected {}'.format(client)) try: while self.rfile: - hostname = self.rfile.readline().strip() - if hostname: + hostname = self.rfile.readline().decode('ascii').strip() + if not hostname: + break + + # Check if nostname matches the letter-digits-hyphen rule of DNS. + # This also prevents remote code execution in case hostname is + # e.g., "../../../bin/sh". + if LDH_RE.match(hostname): logging.info('Request WoL at "{}" from {}'.format(hostname, client)) success = wake(hostname) - self.wfile.write(b"success\n" if success else b"failed\n") else: - break + success = False + logging.warning(f"Invalid hostname: {hostname}") + self.wfile.write(b"success\n" if success else b"failed\n") + except socket.timeout: logging.debug('Timeout of {}'.format(client))