Skip to content

Commit

Permalink
Style updates to scram sasl support
Browse files Browse the repository at this point in the history
  • Loading branch information
dpkp committed Dec 29, 2019
1 parent ee1c4a4 commit e3362ac
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 82 deletions.
83 changes: 5 additions & 78 deletions kafka/conn.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from __future__ import absolute_import, division

import base64
import copy
import errno
import hashlib
import hmac
import io
import logging
from random import shuffle, uniform

from uuid import uuid4

# selectors in stdlib as of py3.4
try:
import selectors # pylint: disable=import-error
Expand All @@ -34,6 +29,7 @@
from kafka.protocol.metadata import MetadataRequest
from kafka.protocol.parser import KafkaProtocol
from kafka.protocol.types import Int32, Int8
from kafka.scram import ScramClient
from kafka.version import __version__


Expand All @@ -42,12 +38,6 @@
TimeoutError = socket.error
BlockingIOError = Exception

def xor_bytes(left, right):
return bytearray(ord(lb) ^ ord(rb) for lb, rb in zip(left, right))
else:
def xor_bytes(left, right):
return bytes(lb ^ rb for lb, rb in zip(left, right))

log = logging.getLogger(__name__)

DEFAULT_KAFKA_PORT = 9092
Expand Down Expand Up @@ -107,69 +97,6 @@ class ConnectionStates(object):
AUTHENTICATING = '<authenticating>'


class ScramClient:
MECHANISMS = {
'SCRAM-SHA-256': hashlib.sha256,
'SCRAM-SHA-512': hashlib.sha512
}

def __init__(self, user, password, mechanism):
self.nonce = str(uuid4()).replace('-', '')
self.auth_message = ''
self.salted_password = None
self.user = user
self.password = password.encode()
self.hashfunc = self.MECHANISMS[mechanism]
self.hashname = ''.join(mechanism.lower().split('-')[1:3])
self.stored_key = None
self.client_key = None
self.client_signature = None
self.client_proof = None
self.server_key = None
self.server_signature = None

def first_message(self):
client_first_bare = 'n={},r={}'.format(self.user, self.nonce)
self.auth_message += client_first_bare
return 'n,,' + client_first_bare

def process_server_first_message(self, server_first_message):
self.auth_message += ',' + server_first_message
params = dict(pair.split('=', 1) for pair in server_first_message.split(','))
server_nonce = params['r']
if not server_nonce.startswith(self.nonce):
raise ValueError("Server nonce, did not start with client nonce!")
self.nonce = server_nonce
self.auth_message += ',c=biws,r=' + self.nonce

salt = base64.b64decode(params['s'].encode())
iterations = int(params['i'])
self.create_salted_password(salt, iterations)

self.client_key = self.hmac(self.salted_password, b'Client Key')
self.stored_key = self.hashfunc(self.client_key).digest()
self.client_signature = self.hmac(self.stored_key, self.auth_message.encode())
self.client_proof = xor_bytes(self.client_key, self.client_signature)
self.server_key = self.hmac(self.salted_password, b'Server Key')
self.server_signature = self.hmac(self.server_key, self.auth_message.encode())

def hmac(self, key, msg):
return hmac.new(key, msg, digestmod=self.hashfunc).digest()

def create_salted_password(self, salt, iterations):
self.salted_password = hashlib.pbkdf2_hmac(
self.hashname, self.password, salt, iterations
)

def final_message(self):
client_final_no_proof = 'c=biws,r=' + self.nonce
return 'c=biws,r={},p={}'.format(self.nonce, base64.b64encode(self.client_proof).decode())

def process_server_final_message(self, server_final_message):
params = dict(pair.split('=', 1) for pair in server_final_message.split(','))
if self.server_signature != base64.b64decode(params['v'].encode()):
raise ValueError("Server sent wrong signature!")

class BrokerConnection(object):
"""Initialize a Kafka broker connection
Expand Down Expand Up @@ -747,20 +674,20 @@ def _try_authenticate_scram(self, future):
close = False
else:
try:
client_first = scram_client.first_message().encode()
client_first = scram_client.first_message().encode('utf-8')
size = Int32.encode(len(client_first))
self._send_bytes_blocking(size + client_first)

(data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4))
server_first = self._recv_bytes_blocking(data_len).decode()
server_first = self._recv_bytes_blocking(data_len).decode('utf-8')
scram_client.process_server_first_message(server_first)

client_final = scram_client.final_message().encode()
client_final = scram_client.final_message().encode('utf-8')
size = Int32.encode(len(client_final))
self._send_bytes_blocking(size + client_final)

(data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4))
server_final = self._recv_bytes_blocking(data_len).decode()
server_final = self._recv_bytes_blocking(data_len).decode('utf-8')
scram_client.process_server_final_message(server_final)

except (ConnectionError, TimeoutError) as e:
Expand Down
82 changes: 82 additions & 0 deletions kafka/scram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import absolute_import

import base64
import hashlib
import hmac
import uuid

from kafka.vendor import six


if six.PY2:
def xor_bytes(left, right):
return bytearray(ord(lb) ^ ord(rb) for lb, rb in zip(left, right))
else:
def xor_bytes(left, right):
return bytes(lb ^ rb for lb, rb in zip(left, right))


class ScramClient:
MECHANISMS = {
'SCRAM-SHA-256': hashlib.sha256,
'SCRAM-SHA-512': hashlib.sha512
}

def __init__(self, user, password, mechanism):
self.nonce = str(uuid.uuid4()).replace('-', '')
self.auth_message = ''
self.salted_password = None
self.user = user
self.password = password.encode('utf-8')
self.hashfunc = self.MECHANISMS[mechanism]
self.hashname = ''.join(mechanism.lower().split('-')[1:3])
self.stored_key = None
self.client_key = None
self.client_signature = None
self.client_proof = None
self.server_key = None
self.server_signature = None

def first_message(self):
client_first_bare = 'n={},r={}'.format(self.user, self.nonce)
self.auth_message += client_first_bare
return 'n,,' + client_first_bare

def process_server_first_message(self, server_first_message):
self.auth_message += ',' + server_first_message
params = dict(pair.split('=', 1) for pair in server_first_message.split(','))
server_nonce = params['r']
if not server_nonce.startswith(self.nonce):
raise ValueError("Server nonce, did not start with client nonce!")
self.nonce = server_nonce
self.auth_message += ',c=biws,r=' + self.nonce

salt = base64.b64decode(params['s'].encode('utf-8'))
iterations = int(params['i'])
self.create_salted_password(salt, iterations)

self.client_key = self.hmac(self.salted_password, b'Client Key')
self.stored_key = self.hashfunc(self.client_key).digest()
self.client_signature = self.hmac(self.stored_key, self.auth_message.encode('utf-8'))
self.client_proof = xor_bytes(self.client_key, self.client_signature)
self.server_key = self.hmac(self.salted_password, b'Server Key')
self.server_signature = self.hmac(self.server_key, self.auth_message.encode('utf-8'))

def hmac(self, key, msg):
return hmac.new(key, msg, digestmod=self.hashfunc).digest()

def create_salted_password(self, salt, iterations):
self.salted_password = hashlib.pbkdf2_hmac(
self.hashname, self.password, salt, iterations
)

def final_message(self):
client_final_no_proof = 'c=biws,r=' + self.nonce
return 'c=biws,r={},p={}'.format(self.nonce, base64.b64encode(self.client_proof).decode('utf-8'))

def process_server_final_message(self, server_final_message):
params = dict(pair.split('=', 1) for pair in server_final_message.split(','))
if self.server_signature != base64.b64decode(params['v'].encode('utf-8')):
raise ValueError("Server sent wrong signature!")


10 changes: 6 additions & 4 deletions test/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,10 @@ def _sasl_config(self):
if not self.sasl_enabled:
return ''

sasl_config = "sasl.enabled.mechanisms={mechanism}\n"
sasl_config += "sasl.mechanism.inter.broker.protocol={mechanism}\n"
sasl_config = (
'sasl.enabled.mechanisms={mechanism}\n'
'sasl.mechanism.inter.broker.protocol={mechanism}\n'
)
return sasl_config.format(mechanism=self.sasl_mechanism)

def _jaas_config(self):
Expand All @@ -328,12 +330,12 @@ def _jaas_config(self):

elif self.sasl_mechanism == 'PLAIN':
jaas_config = (
"org.apache.kafka.common.security.plain.PlainLoginModule required\n"
'org.apache.kafka.common.security.plain.PlainLoginModule required\n'
' username="{user}" password="{password}" user_{user}="{password}";\n'
)
elif self.sasl_mechanism in ("SCRAM-SHA-256", "SCRAM-SHA-512"):
jaas_config = (
"org.apache.kafka.common.security.scram.ScramLoginModule required\n"
'org.apache.kafka.common.security.scram.ScramLoginModule required\n'
' username="{user}" password="{password}";\n'
)
else:
Expand Down

0 comments on commit e3362ac

Please sign in to comment.