Skip to content

Commit

Permalink
Discv5 Protocol: Add support for banning nodes (#769)
Browse files Browse the repository at this point in the history
* Add banned nodes to routing table.

* Filter out banned nodes in lookups and cleanup expired bans in refreshLoop.

* Don't respond to messages from banned nodes.

* Prevent sending messages to banned nodes.
  • Loading branch information
bhartnett authored Jan 30, 2025
1 parent e589cc0 commit c640d3c
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 18 deletions.
87 changes: 72 additions & 15 deletions eth/p2p/discoveryv5/protocol.nim
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ const
defaultResponseTimeout* = 4.seconds ## timeout for the response of a request-response
## call

## Ban durations for banned nodes in the routing table
NodeBanDurationInvalidResponse = 15.minutes

type
OptAddress* = object
ip*: Opt[IpAddress]
Expand All @@ -142,6 +145,7 @@ type
bindAddress: OptAddress ## UDP binding address
pendingRequests: Table[AESGCMNonce, PendingRequest]
routingTable*: RoutingTable
banNodes: bool
codec*: Codec
awaitedMessages: Table[(NodeId, RequestId), Future[Opt[Message]]]
refreshLoop: Future[void]
Expand All @@ -157,6 +161,7 @@ type
responseTimeout: Duration
rng*: ref HmacDrbgContext


PendingRequest = object
node: Node
message: seq[byte]
Expand Down Expand Up @@ -192,10 +197,13 @@ proc addNode*(d: Protocol, node: Node): bool =
##
## Returns true only when `Node` was added as a new entry to a bucket in the
## routing table.
if d.routingTable.addNode(node) == Added:
let r = d.routingTable.addNode(node)
if r == Added:
return true
else:
return false

if r == Banned:
debug "Banned node not added to routing table", nodeId = node.id
return false

proc addNode*(d: Protocol, r: Record): bool =
## Add `Node` from a `Record` to discovery routing table.
Expand Down Expand Up @@ -429,6 +437,30 @@ proc sendWhoareyou(d: Protocol, toId: NodeId, a: Address,
else:
debug "Node with this id already has ongoing handshake, ignoring packet"

proc replaceNode(d: Protocol, n: Node) =
if n.record notin d.bootstrapRecords:
d.routingTable.replaceNode(n)
else:
# For now we never remove bootstrap nodes. It might make sense to actually
# do so and to retry them only in case we drop to a really low amount of
# peers in the routing table.
debug "Message request to bootstrap node failed", enr = toURI(n.record)

proc banNode*(d: Protocol, n: Node, banPeriod: chronos.Duration) =
if n.record notin d.bootstrapRecords:
if d.banNodes:
d.routingTable.banNode(n.id, banPeriod) # banNode also replaces the node
else:
d.routingTable.replaceNode(n)
else:
# For now we never remove bootstrap nodes. It might make sense to actually
# do so and to retry them only in case we drop to a really low amount of
# peers in the routing table.
debug "Message request to bootstrap node failed", enr = toURI(n.record)

proc isBanned*(d: Protocol, nodeId: NodeId): bool =
d.banNodes and d.routingTable.isBanned(nodeId)

proc receive*(d: Protocol, a: Address, packet: openArray[byte]) =
discv5_network_bytes.inc(packet.len.int64, labelValues = [$Direction.In])

Expand All @@ -437,6 +469,10 @@ proc receive*(d: Protocol, a: Address, packet: openArray[byte]) =
let packet = decoded[]
case packet.flag
of OrdinaryMessage:
if d.isBanned(packet.srcId):
trace "Ignoring received OrdinaryMessage from banned node", nodeId = packet.srcId
return

if packet.messageOpt.isSome():
let message = packet.messageOpt.get()
trace "Received message packet", srcId = packet.srcId, address = a,
Expand Down Expand Up @@ -464,6 +500,10 @@ proc receive*(d: Protocol, a: Address, packet: openArray[byte]) =
else:
debug "Timed out or unrequested whoareyou packet", address = a
of HandshakeMessage:
if d.isBanned(packet.srcIdHs):
trace "Ignoring received HandshakeMessage from banned node", nodeId = packet.srcIdHs
return

trace "Received handshake message packet", srcId = packet.srcIdHs,
address = a, kind = packet.message.kind
d.handleMessage(packet.srcIdHs, a, packet.message, packet.node)
Expand Down Expand Up @@ -494,14 +534,7 @@ proc processClient(transp: DatagramTransport, raddr: TransportAddress):

proto.receive(Address(ip: raddr.toIpAddress(), port: raddr.port), buf)

proc replaceNode(d: Protocol, n: Node) =
if n.record notin d.bootstrapRecords:
d.routingTable.replaceNode(n)
else:
# For now we never remove bootstrap nodes. It might make sense to actually
# do so and to retry them only in case we drop to a really low amount of
# peers in the routing table.
debug "Message request to bootstrap node failed", enr = toURI(n.record)


# TODO: This could be improved to do the clean-up immediately in case a non
# whoareyou response does arrive, but we would need to store the AuthTag
Expand Down Expand Up @@ -546,9 +579,11 @@ proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId):
break
return ok(res)
else:
d.banNode(fromNode, NodeBanDurationInvalidResponse)
discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"])
return err("Invalid response to find node message")
else:
d.replaceNode(fromNode)
discovery_message_requests_outgoing.inc(labelValues = ["no_response"])
return err("Nodes message not received in time")

Expand All @@ -574,6 +609,10 @@ proc ping*(d: Protocol, toNode: Node):
## Send a discovery ping message.
##
## Returns the received pong message or an error.

if d.isBanned(toNode.id):
return err("toNode is banned")

let reqId = d.sendMessage(toNode,
PingMessage(enrSeq: d.localNode.record.seqNum))
let resp = await d.waitMessage(toNode, reqId)
Expand All @@ -583,7 +622,7 @@ proc ping*(d: Protocol, toNode: Node):
d.routingTable.setJustSeen(toNode)
return ok(resp.get().pong)
else:
d.replaceNode(toNode)
d.banNode(toNode, NodeBanDurationInvalidResponse)
discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"])
return err("Invalid response to ping message")
else:
Expand All @@ -597,22 +636,29 @@ proc findNode*(d: Protocol, toNode: Node, distances: seq[uint16]):
##
## Returns the received nodes or an error.
## Received ENRs are already validated and converted to `Node`.

if d.isBanned(toNode.id):
return err("toNode is banned")

let reqId = d.sendMessage(toNode, FindNodeMessage(distances: distances))
let nodes = await d.waitNodes(toNode, reqId)

if nodes.isOk:
let res = verifyNodesRecords(nodes.get(), toNode, findNodeResultLimit, distances)
d.routingTable.setJustSeen(toNode)
return ok(res)
return ok(res.filterIt(not d.isBanned(it.id)))
else:
d.replaceNode(toNode)
return err(nodes.error)

proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]):
Future[DiscResult[seq[byte]]] {.async: (raises: [CancelledError]).} =
## Send a discovery talkreq message.
##
## Returns the received talkresp message or an error.

if d.isBanned(toNode.id):
return err("toNode is banned")

let reqId = d.sendMessage(toNode,
TalkReqMessage(protocol: protocol, request: request))
let resp = await d.waitMessage(toNode, reqId)
Expand All @@ -622,7 +668,7 @@ proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]):
d.routingTable.setJustSeen(toNode)
return ok(resp.get().talkResp.response)
else:
d.replaceNode(toNode)
d.banNode(toNode, NodeBanDurationInvalidResponse)
discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"])
return err("Invalid response to talk request message")
else:
Expand Down Expand Up @@ -797,6 +843,12 @@ proc resolve*(d: Protocol, id: NodeId): Future[Opt[Node]] {.async: (raises: [Can
if id == d.localNode.id:
return Opt.some(d.localNode)

# No point in trying to resolve a banned node because it won't exist in the
# routing table and it will be filtered out of any respones in the lookup call
if d.isBanned(id):
debug "Not resolving banned node", nodeId = id
return Opt.none(Node)

let node = d.getNode(id)
if node.isSome():
let request = await d.findNode(node.get(), @[0'u16])
Expand Down Expand Up @@ -882,6 +934,9 @@ proc refreshLoop(d: Protocol) {.async: (raises: []).} =
trace "Discovered nodes in random target query", nodes = randomQuery.len
debug "Total nodes in discv5 routing table", total = d.routingTable.len()

# Remove the expired bans from routing table to limit memory usage
d.routingTable.cleanupExpiredBans()

await sleepAsync(refreshInterval)
except CancelledError:
trace "refreshLoop canceled"
Expand Down Expand Up @@ -985,6 +1040,7 @@ proc newProtocol*(
bindPort: Port,
bindIp = IPv4_any(),
enrAutoUpdate = false,
banNodes = false,
config = defaultDiscoveryConfig,
rng = newRng()):
Protocol =
Expand Down Expand Up @@ -1034,6 +1090,7 @@ proc newProtocol*(
enrAutoUpdate: enrAutoUpdate,
routingTable: RoutingTable.init(
node, config.bitsPerHop, config.tableIpLimits, rng),
banNodes: banNodes,
handshakeTimeout: config.handshakeTimeout,
responseTimeout: config.responseTimeout,
rng: rng)
Expand Down
2 changes: 1 addition & 1 deletion eth/p2p/discoveryv5/routing_table.nim
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func ipLimitDec(r: var RoutingTable, b: KBucket, n: Node) =
r.ipLimits.dec(ip)

func getNode*(r: RoutingTable, id: NodeId): Opt[Node]
proc replaceNode*(r: var RoutingTable, n: Node)
proc replaceNode*(r: var RoutingTable, n: Node) {.gcsafe.}

proc banNode*(r: var RoutingTable, nodeId: NodeId, period: chronos.Duration) =
## Ban a node from the routing table for the given period. The node is removed
Expand Down
6 changes: 4 additions & 2 deletions tests/p2p/discv5_test_helper.nim
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ proc initDiscoveryNode*(
address: Address,
bootstrapRecords: openArray[Record] = [],
localEnrFields: openArray[(string, seq[byte])] = [],
previousRecord = Opt.none(enr.Record)):
previousRecord = Opt.none(enr.Record),
banNodes = false):
discv5_protocol.Protocol =
# set bucketIpLimit to allow bucket split
let config = DiscoveryConfig.init(1000, 24, 5)
Expand All @@ -36,7 +37,8 @@ proc initDiscoveryNode*(
localEnrFields = localEnrFields,
previousRecord = previousRecord,
config = config,
rng = rng)
rng = rng,
banNodes = banNodes)

protocol.open()

Expand Down
113 changes: 113 additions & 0 deletions tests/p2p/test_discoveryv5.nim
Original file line number Diff line number Diff line change
Expand Up @@ -926,3 +926,116 @@ suite "Discovery v5 Tests":

await node1.closeWait()
await node2.closeWait()

asyncTest "Banned nodes are removed and cannot be added":
let
node = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302), banNodes = true)
targetNode = generateNode(PrivateKey.random(rng[]))

# add the node
check:
node.addNode(targetNode) == true
node.getNode(targetNode.id).isSome()

# banning the node should remove it from the routing table
node.banNode(targetNode, 1.minutes)
check node.getNode(targetNode.id).isNone()

# cannot add a banned node
check:
node.addNode(targetNode) == false
node.getNode(targetNode.id).isNone()

await node.closeWait()

asyncTest "FindNode filters out banned nodes":
let
mainNode = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20301),
banNodes = true)
testNode = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302),
@[mainNode.localNode.record], banNodes = true)

# Generate 100 random nodes and add to our main node's routing table
for i in 0 ..< 100:
discard mainNode.addSeenNode(generateNode(PrivateKey.random(rng[])))

let
neighbours = mainNode.neighbours(mainNode.localNode.id)
closest = neighbours[0]
closestDistance = logDistance(closest.id, mainNode.localNode.id)

block:
# the closest node is returned
let discovered = await testNode.findNode(mainNode.localNode, @[closestDistance])
check discovered.isOk
check closest in discovered[]

# ban the closest node
mainNode.banNode(closest, 1.minutes)

block:
# the banned node is not returned
let discovered = await testNode.findNode(mainNode.localNode, @[closestDistance])
check discovered.isOk
check closest notin discovered[]

await mainNode.closeWait()
await testNode.closeWait()

asyncTest "Cannot send messages to banned nodes":
let
node1 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302),
banNodes = true)
node2 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20301),
banNodes = true)

# ban node2 in node1's routing table
node1.banNode(node2.localNode, 1.minutes)

block:
let pong = await node1.ping(node2.localNode)
check:
pong.isErr()
pong.error() == "toNode is banned"

block:
let nodes = await node1.findNode(node2.localNode, @[0.uint16])
check:
nodes.isErr()
nodes.error() == "toNode is banned"

block:
let node = await node1.resolve(node2.localNode.id)
check node.isNone()

await node2.closeWait()
await node1.closeWait()

asyncTest "Ignore messages from banned nodes":
let
node1 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302),
banNodes = true)
node2 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20301),
banNodes = true)

# ban node1 in node2's routing table
node2.banNode(node1.localNode, 1.minutes)

block:
let pong = await node1.ping(node2.localNode)
check:
pong.isErr()
pong.error() == "Pong message not received in time"

block:
let nodes = await node1.findNode(node2.localNode, @[0.uint16])
check:
nodes.isErr()
nodes.error() == "Nodes message not received in time"

block:
let node = await node1.resolve(node2.localNode.id)
check node.isNone()

await node2.closeWait()
await node1.closeWait()

0 comments on commit c640d3c

Please sign in to comment.