diff --git a/src/sswsdk/dbconnector.py b/src/sswsdk/dbconnector.py index da4780e176..a1bac2bb37 100644 --- a/src/sswsdk/dbconnector.py +++ b/src/sswsdk/dbconnector.py @@ -9,6 +9,9 @@ class SonicV1Connector(DBInterface): + def __init__(self, **kwargs): + super(SonicV1Connector, self).__init__(**kwargs) + pass @@ -19,6 +22,9 @@ class SonicV1Connector(DBInterface): class SonicV2Connector(DBInterface): + def __init__(self, **kwargs): + super(SonicV2Connector, self).__init__(**kwargs) + pass diff --git a/src/sswsdk/interface.py b/src/sswsdk/interface.py index bad9a65bc4..ab66df6a3d 100644 --- a/src/sswsdk/interface.py +++ b/src/sswsdk/interface.py @@ -134,10 +134,15 @@ class DBInterface(object): db_map = dict() - def __init__(self): + def __init__(self, **kwargs): super(DBInterface, self).__init__() + # Store the arguments for redis client + self.redis_kwargs = kwargs + if len(self.redis_kwargs) == 0: + self.redis_kwargs['unix_socket_path'] = self.REDIS_UNIX_SOCKET_PATH + # For thread safety as recommended by python-redis # Create a separate client for each database self.redis_clients = DBRegistry() @@ -183,12 +188,7 @@ def _onetime_connect(self, db_name): if db_id is None: raise ValueError("No database ID configured for '{}'".format(db_name)) - kwargs = dict( - unix_socket_path=self.REDIS_UNIX_SOCKET_PATH - ) - kwargs.update(self.db_map[db_name]) - - client = redis.StrictRedis(**kwargs) + client = redis.StrictRedis(db=self.db_map[db_name]['db'], **self.redis_kwargs) # Enable the notification mechanism for keyspace events in Redis client.config_set('notify-keyspace-events', self.KEYSPACE_EVENTS) diff --git a/src/sswsdk/util.py b/src/sswsdk/util.py index 9f2f75d66e..a63c98a38f 100644 --- a/src/sswsdk/util.py +++ b/src/sswsdk/util.py @@ -2,6 +2,7 @@ Syslog and daemon script utility library. """ +from __future__ import print_function import json import logging import logging.config @@ -12,7 +13,7 @@ # TODO: move to dbsync project. def usage(script_name): print('Usage: python ', script_name, - '-d [logging_level] -f [update_frequency] -h [help]') + '-t [host] -p [port] -s [unix_socket_path] -d [logging_level] -f [update_frequency] -h [help]') # TODO: move to dbsync project. @@ -20,16 +21,26 @@ def process_options(script_name): """ Process command line options """ - options, remainders = getopt(sys.argv[1:], "d:f:h", ["debug=", "frequency=", "help"]) + options, remainders = getopt(sys.argv[1:], "t:p:s:d:f:h", ["host=", "port=", "unix_socket_path=", "debug=", "frequency=", "help"]) args = {} for (opt, arg) in options: - if opt in ('-d', '--debug'): - args['log_level'] = int(arg) - elif opt in ('-f', '--frequency'): - args['update_frequency'] = int(arg) - elif opt in ('-h', '--help'): - usage(script_name) + try: + if opt in ('-d', '--debug'): + args['log_level'] = int(arg) + elif opt in ('-t', '--host'): + args['host'] = arg + elif opt in ('-p', '--port'): + args['port'] = int(arg) + elif opt in ('-s', 'unix_socket_path'): + args['unix_socket_path'] = arg + elif opt in ('-f', '--frequency'): + args['update_frequency'] = int(arg) + elif opt in ('-h', '--help'): + usage(script_name) + except ValueError as e: + print('Invalid option for {}: {}'.format(opt, e)) + sys.exit(1) return args