Skip to content

Commit

Permalink
Update some checks
Browse files Browse the repository at this point in the history
  • Loading branch information
tommaso-ascani committed Nov 4, 2024
1 parent fe919a3 commit 8a24475
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 34 deletions.
29 changes: 1 addition & 28 deletions core/imageroot/usr/local/agent/pypkg/node/ports_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#

import sqlite3
import agent

class PortError(Exception):
"""Base class for all port-related exceptions."""
Expand Down Expand Up @@ -37,12 +36,6 @@ def __init__(self, message="The number of required ports must be at least 1."):
self.message = message
super().__init__(self.message)

class ExceededPortsDemand(PortError):
"""Exception raised when the requested number of ports is higher than the maxium assigned to the image."""
def __init__(self, message="The number of required ports is higher than the maxium assigned to the image."):
self.message = message
super().__init__(self.message)

def create_tables(cursor: sqlite3.Cursor):
# Create TCP table if it doesn't exist
cursor.execute("""
Expand Down Expand Up @@ -108,23 +101,6 @@ def allocate_ports(required_ports: int, module_name: str, protocol: str, keep_ex
cursor.execute("SELECT start,end,module FROM UDP_PORTS ORDER BY start;")
ports_used = cursor.fetchall()

# Ensure number of ports required
rdb = agent.redis_connect(privileged=False)
if protocol == 'tcp':
ports_demand = rdb.hgetall('cluster/tcp_ports_demand')
elif protocol == 'udp':
ports_demand = rdb.hgetall('cluster/udp_ports_demand')

total_ports_required = required_ports

if ports_demand:
for port in ports_used:
if port[2] == module_name:
total_ports_required += (port[1] - port[0] + 1)

if total_ports_required > int(ports_demand.get(module_name)):
raise ExceededPortsDemand()

if len(ports_used) == 0:
write_range(range_start, range_start + required_ports - 1, module_name, protocol, database)
return (range_start, range_start + required_ports - 1)
Expand Down Expand Up @@ -229,10 +205,7 @@ def get_ports_by_module(module_name: str):

ports = tcp_ports + udp_ports

if ports:
return [(port[0], port[1], port[2]) for port in ports]
else:
raise ModuleNotFoundError(module_name) # Raise error if the module is not found
return [(port[0], port[1], port[2]) for port in ports]

except sqlite3.Error as e:
raise StorageError(f"Database error: {e}") from e # Raise custom database error
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,22 @@ request = json.load(sys.stdin)
module_env = os.getenv("AGENT_TASK_USER")

if module_env != "" and module_env != f"module/{request['module_id']}":
print(agent.SD_ERR + f" Agent {module_env} does not have permission to change the port allocation for {request['module_id']}.", file=sys.stderr)
sys.exit(1)

raise Exception(f"Agent {module_env} does not have permission to change the port allocation for {request['module_id']}.")

# Ensure number of ports required
rdb = agent.redis_connect(privileged=False)

current_node_id = os.getenv("NODE_ID")
module_node_id = rdb.hget('cluster/module_node', request['module_id'])

# Verify that the module exists
if not module_node_id:
raise Exception(f"Error: Module {request['module_id']} does not exist.")

# Verify that the module is present on the current node
if module_node_id != current_node_id:
raise Exception(f"Error: Module {request['module_id']} is not located on the current node {current_node_id}.")

if request['protocol'] == 'tcp':
ports_demand = rdb.hgetall('cluster/tcp_ports_demand')
elif request['protocol'] == 'udp':
Expand All @@ -34,9 +45,8 @@ if ports_demand:
if port[2] == request['module_id']:
total_ports_required += (port[1] - port[0] + 1)

if total_ports_required > int(ports_demand.get(request['module_id'])):
print(agent.SD_ERR + " Error: Exceeded the allowed number of ports.", file=sys.stderr)
sys.exit(1)
if total_ports_required > int(ports_demand.get(request['module_id'], 0)):
raise Exception("Error: Exceeded the allowed number of ports.")

range = node.ports_manager.allocate_ports(int(request['ports']), request['module_id'], request['protocol'], request['keep_existing'])

Expand Down

0 comments on commit 8a24475

Please sign in to comment.