Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow SiteID and/or AccountID in CommandLine #79

Merged
merged 18 commits into from
Nov 9, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 143 additions & 76 deletions products/sentinel_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,62 +60,13 @@ def __init__(self, profile: str, creds_file: str, account_id: Optional[list[str]

self._last_request = 0.0

super().__init__(self.product, profile, **kwargs)

config = configparser.ConfigParser()
config.read(self.creds_file)

# instantiate site_ids and account_ids if not set
site_ids = site_id if site_id else list()
account_ids = account_id if account_id else list()
account_names = account_name if account_name else list()

# extract account/site ID from configuration if set
if 'account_id' in config[profile] and config[profile]['account_id'] not in account_ids:
account_ids.append(config[profile]['account_id'])

if 'site_id' in config[profile] and config[profile]['site_id'] not in site_ids:
site_ids.append(config[profile]['site_id'])

if 'account_name' in config[profile] and config[profile]['account_name'] not in account_names:
account_names.append(config[profile]['account_name'])

# determine site IDs to query (default is all)
self._site_ids = site_ids

if self._site_ids: # only run this if there is an actual list of site IDs to iterate over
# ensure specified site IDs are valid
site_response_data = self._get_all_paginated_data(self._build_url('/web/api/v2.1/sites'),
params={'siteIds': ','.join(site_ids)},
add_default_params=False)
existing_site_ids = set[int]()
for response in site_response_data:
for site in response['sites']:
existing_site_ids.add(site['id'])

for scope_id in self._site_ids:
if scope_id not in existing_site_ids:
raise ValueError(f'Site with ID {scope_id} does not exist')

# get site IDs for each specified account id
for scope_id in account_ids:
for response in self._get_all_paginated_data(self._build_url('/web/api/v2.1/sites'),
params={'accountId': scope_id},
add_default_params=False):
for site in response['sites']:
if site['id'] not in self._site_ids:
self._site_ids.append(site['id'])

for name in account_names:
for response in self._get_all_paginated_data(self._build_url('/web/api/v2.1/sites'),
params={'name': name},
add_default_params=False):
for site in response['sites']:
if site['id'] not in self._site_ids:
self._site_ids.append(site['id'])

self.log.debug(f'Site IDs: {self._site_ids}')
# Save these values to `self` for reference in _authenticate()
self.site_id = site_id
self.account_id = account_id
self.account_name = account_name

super().__init__(self.product, profile, **kwargs)

def _authenticate(self):
config = configparser.ConfigParser()
config.read(self.creds_file)
Expand All @@ -129,9 +80,6 @@ def _authenticate(self):
if 'url' not in section:
raise ValueError(f'S1 configuration invalid, ensure "url" is specified')

if 'site_id' not in section and 'account_id' not in section:
raise ValueError(f'S1 configuration invalid, specify a site_id or account_id')

# extract required information from configuration
if 'token' in section:
self._token = section['token']
Expand All @@ -141,9 +89,6 @@ def _authenticate(self):
f'environment variable')
self._token = os.environ['S1_TOKEN']

self._site_id = section['site_id'] if 'site_id' in section else None
self._account_id = section['account_id'] if 'account_id' in section else None

self._url = section['url'].rstrip('/')

if not self._url.startswith('https://'):
Expand All @@ -153,15 +98,136 @@ def _authenticate(self):
self._session = requests.session()
self._session.mount('https://', HTTPAdapter(pool_connections=10, pool_maxsize=10, max_retries=3))

# test API key by retrieving the sensor count, which is a fast operation
data = self._session.get(self._build_url('/web/api/v2.1/agents/count'),
headers=self._get_default_header(),
params=self._get_default_body()).json()
if 'errors' in data:
if data['errors'][0]['code'] == 4010010:
raise ValueError(f'Failed to authenticate to SentinelOne: {data}')
else:
raise ValueError(f'Error when authenticating to SentinelOne: {data}')
# generate a list of site_ids based on config file and cmdline input
# this will also test API keys as it goes
self._get_site_ids(self.site_id, self.account_id, self.account_name)

if len(self._site_ids) < 1 and len(self._account_ids) < 1:
raise ValueError(f'S1 configuration invalid, specify a site_id, account_id, or account_name')

def _get_site_ids(self, site_id, account_id, account_name):
config = configparser.ConfigParser()
config.read(self.creds_file)

# check if any cmdline stuff was input - that will take precedence over config file stuff
site_ids = (site_id) if site_id else list()
account_ids = (account_id) if account_id else list()
account_names = (account_name) if account_name else list()

if not site_ids and not account_ids and not account_names:
# extract account/site ID from configuration if set
if 'account_id' in config[self.profile]:
for id in config[self.profile]['account_id'].split(','):
if id not in account_ids:
account_ids.append(id.strip())

if 'site_id' in config[self.profile]:
for id in config[self.profile]['site_id'].split(','):
if id not in site_ids:
site_ids.append(id.strip())

if 'account_name' in config[self.profile]:
for name in config[self.profile]['account_name'].split(','):
if name not in account_names:
account_names.append(name.strip())

# determine site and account IDs to query (default is all)
self._site_ids = list()
self._account_ids = list()

if account_ids: # verify provided account IDs are valid
# create batch of 10 account IDs per call
counter = 0
temp_list = []
i = 0
while i < len(account_ids):
temp_list.append(account_ids[i])
counter += 1
if counter == 10 or i == len(account_ids) - 1:
response = self._get_all_paginated_data(self._build_url(f'/web/api/v2.1/accounts'),
params={'states': "active", 'ids': ','.join(temp_list)},
add_default_params=False)

rc-csmith marked this conversation as resolved.
Show resolved Hide resolved
if 'errors' in response:
if response['errors'][0]['code'] == 4010010:
raise ValueError(f'Failed to authenticate to SentinelOne: {response}')
else:
raise ValueError(f'Error when authenticating to SentinelOne: {response}')

for account in response:
if account['id'] not in self._account_ids:
self._account_ids.append(account['id'])

counter = 0
temp_list = []
i += 1

diff = list(set(account_ids) - set(self._account_ids))
if len(diff) > 0:
self.log.warning(f'Account IDs {",".join(diff)} not found.')

if account_names: # verify provided account names are valid
temp_account_name = list()
for name in account_names:
response = self._get_all_paginated_data(self._build_url('/web/api/v2.1/accounts'),
params={'states': "active", 'name': name},
add_default_params=False)

if 'errors' in response:
if response['errors'][0]['code'] == 4010010:
raise ValueError(f'Failed to authenticate to SentinelOne: {response}')
else:
raise ValueError(f'Error when authenticating to SentinelOne: {response}')

for account in response:
temp_account_name.append(account['name'])
if account['id'] not in self._account_ids:
self._account_ids.append(account['id'])

diff = list(set(account_names) - set(temp_account_name))
if len(diff) > 0:
self.log.warning(f'Account names {",".join(diff)} not found')

if site_ids: # ensure specified site IDs are valid and not already covered by the account_ids listed above
temp_site_ids = list()
# create batches of 10 site_ids
counter = 0
temp_list = []
i = 0
while i < len(site_ids):
temp_list.append(site_ids[i])
counter += 1
if counter == 10 or i == len(site_ids) - 1:
response = self._get_all_paginated_data(self._build_url('/web/api/v2.1/sites'),
params={'state': "active", 'siteIds': ','.join(site_ids)},
add_default_params=False)

if 'errors' in response:
if response['errors'][0]['code'] == 4010010:
raise ValueError(f'Failed to authenticate to SentinelOne: {response}')
else:
raise ValueError(f'Error when authenticating to SentinelOne: {response}')

for item in response:
for site in item['sites']:
temp_site_ids.append(site['id'])
if site['accountId'] not in self._account_ids and site['id'] not in self._site_ids:
self._site_ids.append(site['id'])
counter = 0
temp_list = []
i += 1

diff = list(set(site_ids) - set(temp_site_ids))
if len(diff) > 0:
self.log.warning(f'Site IDs {",".join(diff)} not found')

# remove unncessary variables from self
self.__dict__.pop('site_id',None)
self.__dict__.pop('account_id',None)
self.__dict__.pop('account_name',None)

self.log.debug(f'Site IDs: {self._site_ids}')
self.log.debug(f'Account IDs: {self._account_ids}')

def _build_url(self, stem: str):
"""
Expand All @@ -176,7 +242,12 @@ def _get_default_body(self) -> dict:
"""
Get the default request body for a SentinelOne API query.
"""
return {"siteIds": [self._site_id]} if self._site_id else {"accountIds": [self._account_id]}
body = {}
if self._site_ids:
body['siteIds'] = self._site_ids
if self._account_ids:
body['accountIds'] = self._account_ids
return body

def _get_default_header(self):
"""
Expand Down Expand Up @@ -316,6 +387,7 @@ def _get_dv_events(self, query_id: str) -> list[dict]:
return self._get_all_paginated_data(self._build_url('/web/api/v2.1/dv/events'),
params={'queryId': query_id},
no_progress=False,
add_default_params=False,
progress_desc='Retrieving query results')
else:
# query-status endpoint has a one request per second rate limit
Expand Down Expand Up @@ -437,11 +509,6 @@ def _process_queries(self):
if len(self._query_base):
# add base_query filter to merged query string
merged_query = f'{self._query_base} AND ({merged_query})'

if len(self._site_ids):
# restrict query to specified sites
# S1QL does not support restricting a query to a specified account ID
merged_query = f'SiteID in contains ("' + '", "'.join(self._site_ids) + f'") AND ({merged_query})'

# build request body for DV API call
params = self._get_default_body()
Expand Down