Skip to content

Commit

Permalink
fixing a ton of encoding errors related to ascii system encoding, closes
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Moffat committed Sep 7, 2013
1 parent 486f084 commit 4b76001
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 58 deletions.
144 changes: 95 additions & 49 deletions sh.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,36 @@
basestring = str


def encode_to_py3bytes_or_py2str(s):
""" takes anything and attempts to return a py2 string or py3 bytes. this
is typically used when creating command + arguments to be executed via
os.exec* """

fallback_encoding = "utf8"

if IS_PY3:
s = str(s)
try:
s = bytes(s, DEFAULT_ENCODING)
except UnicodeEncodeError:
s = bytes(s, fallback_encoding)
else:
# attempt to convert the thing to unicode from the system's encoding
try:
s = unicode(s, DEFAULT_ENCODING)
# if the thing is already unicode, or it's a number, it can't be
# coerced to unicode with an encoding argument, but if we leave out
# the encoding argument, it will convert it to a string, then to unicode
except TypeError:
s = unicode(s)

# now that we have guaranteed unicode, encode to our system encoding,
# but attempt to fall back to something
try:
s = s.encode(DEFAULT_ENCODING)
except:
s = s.encode(fallback_encoding)
return s


class ErrorReturnCode(Exception):
Expand All @@ -103,22 +133,23 @@ def __init__(self, full_cmd, stdout, stderr):
self.stderr = stderr


if self.stdout is None: tstdout = "<redirected>"
if self.stdout is None: exc_stdout = "<redirected>"
else:
tstdout = self.stdout[:self.truncate_cap]
out_delta = len(self.stdout) - len(tstdout)
exc_stdout = self.stdout[:self.truncate_cap]
out_delta = len(self.stdout) - len(exc_stdout)
if out_delta:
tstdout += ("... (%d more, please see e.stdout)" % out_delta).encode()
exc_stdout += ("... (%d more, please see e.stdout)" % out_delta).encode()

if self.stderr is None: tstderr = "<redirected>"
if self.stderr is None: exc_stderr = "<redirected>"
else:
tstderr = self.stderr[:self.truncate_cap]
err_delta = len(self.stderr) - len(tstderr)
exc_stderr = self.stderr[:self.truncate_cap]
err_delta = len(self.stderr) - len(exc_stderr)
if err_delta:
tstderr += ("... (%d more, please see e.stderr)" % err_delta).encode()
exc_stderr += ("... (%d more, please see e.stderr)" % err_delta).encode()

msg = "\n\n RAN: %r\n\n STDOUT:\n%s\n\n STDERR:\n%s" %\
(full_cmd, tstdout.decode(DEFAULT_ENCODING), tstderr.decode(DEFAULT_ENCODING))
(full_cmd, exc_stdout.decode(DEFAULT_ENCODING, "replace"),
exc_stderr.decode(DEFAULT_ENCODING, "replace"))
super(ErrorReturnCode, self).__init__(msg)


Expand Down Expand Up @@ -236,7 +267,15 @@ def __init__(self, cmd, call_args, stdin, stdout, stderr):
self.log = Logger("command", logger_str)
self.call_args = call_args
self.cmd = cmd
self.ran = " ".join(cmd)

# self.ran is used for auditing what actually ran. for example, in
# exceptions, or if you just want to know what was ran after the
# command ran
if IS_PY3:
self.ran = " ".join([arg.decode(DEFAULT_ENCODING, "ignore") for arg in cmd])
else:
self.ran = " ".join(cmd)

self.process = None

# this flag is for whether or not we've handled the exit code (like
Expand Down Expand Up @@ -304,7 +343,7 @@ def _handle_exit_code(self, code):
if code not in self.call_args["ok_code"] and \
(code > 0 or -code in SIGNALS_THAT_SHOULD_THROW_EXCEPTION):
raise get_rc_exc(code)(
" ".join(self.cmd),
self.ran,
self.process.stdout,
self.process.stderr
)
Expand Down Expand Up @@ -498,14 +537,16 @@ def _create(cls, program, **default_kwargs):
if not path: raise CommandNotFound(program)

cmd = cls(path)
if default_kwargs: cmd = cmd.bake(**default_kwargs)
if default_kwargs:
cmd = cmd.bake(**default_kwargs)

return cmd


def __init__(self, path):
path = which(path)
if not path: raise CommandNotFound(path)
if not path:
raise CommandNotFound(path)
self._path = path

self._partial = False
Expand Down Expand Up @@ -551,22 +592,6 @@ def _extract_call_args(kwargs, to_override={}):

return call_args, kwargs


def _format_arg(self, arg):
""" for normalizing an argument into a string in the system's default
encoding. we can feed it a number or a string or whatever """

if IS_PY3:
arg = str(arg)
else:
# if the argument is already unicode, or a number or whatever,
# this first call will fail.
try:
arg = unicode(arg, DEFAULT_ENCODING).encode(DEFAULT_ENCODING)
except TypeError:
arg = unicode(arg).encode(DEFAULT_ENCODING)
return arg


def _aggregate_keywords(self, keywords, sep, raw=False):
processed = []
Expand All @@ -575,20 +600,22 @@ def _aggregate_keywords(self, keywords, sep, raw=False):
# cut(d="\t")
if len(k) == 1:
if v is not False:
processed.append("-" + k)
processed.append(encode_to_py3bytes_or_py2str("-" + k))
if v is not True:
processed.append(self._format_arg(v))
processed.append(encode_to_py3bytes_or_py2str(v))

# we're doing a long arg
else:
if not raw: k = k.replace("_", "-")
if not raw:
k = k.replace("_", "-")

if v is True:
processed.append("--" + k)
processed.append(encode_to_py3bytes_or_py2str("--" + k))
elif v is False:
pass
else:
processed.append("--%s%s%s" % (k, sep, self._format_arg(v)))
arg = encode_to_py3bytes_or_py2str("--%s%s%s" % (k, sep, v))
processed.append(arg)
return processed


Expand All @@ -600,12 +627,13 @@ def _compile_args(self, args, kwargs, sep):
if isinstance(arg, (list, tuple)):
if not arg:
warnings.warn("Empty list passed as an argument to %r. \
If you're using glob.glob(), please use sh.glob() instead." % self.path, stacklevel=3)
for sub_arg in arg: processed_args.append(self._format_arg(sub_arg))
If you're using glob.glob(), please use sh.glob() instead." % self._path, stacklevel=3)
for sub_arg in arg:
processed_args.append(encode_to_py3bytes_or_py2str(sub_arg))
elif isinstance(arg, dict):
processed_args += self._aggregate_keywords(arg, sep, raw=True)
else:
processed_args.append(self._format_arg(arg))
processed_args.append(encode_to_py3bytes_or_py2str(arg))

# aggregate the keyword arguments
processed_args += self._aggregate_keywords(kwargs, sep)
Expand Down Expand Up @@ -635,19 +663,25 @@ def bake(self, *args, **kwargs):
return fn

def __str__(self):
if IS_PY3: return self.__unicode__()
else: return unicode(self).encode(DEFAULT_ENCODING)
if IS_PY3:
return self.__unicode__()
else:
return unicode(self).encode(DEFAULT_ENCODING)


def __eq__(self, other):
try: return str(self) == str(other)
except: return False


def __repr__(self):
return "<Command %r>" % str(self)


def __unicode__(self):
baked_args = " ".join(self._partial_baked_args)
if baked_args: baked_args = " " + baked_args
if baked_args:
baked_args = " " + baked_args
return self._path + baked_args

def __enter__(self):
Expand All @@ -674,7 +708,10 @@ def __call__(self, *args, **kwargs):
call_args.update(pcall_args)
cmd.extend(prepend.cmd)

cmd.append(self._path)
if IS_PY3:
cmd.append(bytes(self._path, call_args["encoding"]))
else:
cmd.append(self._path)

# here we extract the special kwargs and override any
# special kwargs from the possibly baked command
Expand Down Expand Up @@ -842,8 +879,10 @@ def __init__(self, cmd, stdin, stdout, stderr, call_args,
self.setwinsize(1)

# actually execute the process
if self.call_args["env"] is None: os.execv(cmd[0], cmd)
else: os.execve(cmd[0], cmd, self.call_args["env"])
if self.call_args["env"] is None:
os.execv(cmd[0], cmd)
else:
os.execve(cmd[0], cmd, self.call_args["env"])

os._exit(255)

Expand Down Expand Up @@ -1592,8 +1631,10 @@ def __getitem__(self, k):
raise AttributeError

# how about an environment variable?
try: return os.environ[k]
except KeyError: pass
try:
return os.environ[k]
except KeyError:
pass

# is it a custom builtin?
builtin = getattr(self, "b_"+k, None)
Expand Down Expand Up @@ -1682,21 +1723,26 @@ def __call__(self, **kwargs):
if arg == "test":
import subprocess

def run_test(version):
def run_test(version, locale):
py_version = "python%s" % version
py_bin = which(py_version)

if py_bin:
print("Testing %s" % py_version.capitalize())
print("Testing %s, locale %r" % (py_version.capitalize(), locale))

env = os.environ.copy()
env["LC_ALL"] = locale
p = subprocess.Popen([py_bin, os.path.join(THIS_DIR, "test.py")]
+ sys.argv[1:])
+ sys.argv[1:], env=env)
p.wait()
else:
print("Couldn't find %s, skipping" % py_version.capitalize())

versions = ("2.6", "2.7", "3.1", "3.2", "3.3")
for version in versions: run_test(version)
locales = ("en_US.UTF-8", "C")
for locale in locales:
for version in versions:
run_test(version, locale)

else:
env = Environment(globals())
Expand Down
35 changes: 26 additions & 9 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ def wrapper(thing): return thing
return wrapper

requires_posix = skipUnless(os.name == "posix", "Requires POSIX")

requires_utf8 = skipUnless(sh.DEFAULT_ENCODING == "UTF-8", "System encoding must be UTF-8")


def create_tmp_test(code):
""" creates a temporary test file that lives on disk, on which we can run
python with sh """

py = tempfile.NamedTemporaryFile()
if IS_PY3: code = bytes(code, "UTF-8")
py.write(code)
Expand All @@ -41,22 +44,26 @@ def create_tmp_test(code):


@requires_posix
class Basic(unittest.TestCase):
class FunctionalTests(unittest.TestCase):

def test_print_command(self):
from sh import ls, which
actual_location = which("ls")
out = str(ls)
self.assertEqual(out, actual_location)


def test_unicode_arg(self):
from sh import echo

test = "漢字"
if not IS_PY3: test = test.decode("utf8")
if not IS_PY3:
test = test.decode("utf8")

p = echo(test, _encoding="utf8")
output = p.strip()
self.assertEqual(test, output)

p = echo(test).strip()
self.assertEqual(test, p)

def test_number_arg(self):
py = create_tmp_test("""
Expand Down Expand Up @@ -1263,9 +1270,13 @@ def test_failure_with_large_output(self):
# an UnicodeDecodeError
def test_non_ascii_error(self):
from sh import ls, ErrorReturnCode

test = "/á"
if not IS_PY3:

# coerce to unicode
if IS_PY3:
pass
else:
test = test.decode("utf8")

self.assertRaises(ErrorReturnCode, ls, test)
Expand Down Expand Up @@ -1337,7 +1348,13 @@ def test_decode_error_handling(self):
py = create_tmp_test("""
# -*- coding: utf8 -*-
import sys
sys.stdout.write("te漢字st")
import os
sys.stdout = os.fdopen(sys.stdout.fileno(), 'wb')
IS_PY3 = sys.version_info[0] == 3
if IS_PY3:
sys.stdout.write(bytes("te漢字st", "utf8"))
else:
sys.stdout.write("te漢字st")
""")
fn = partial(python, py.name, _encoding="ascii")
def s(fn): str(fn())
Expand Down Expand Up @@ -1441,5 +1458,5 @@ def test_file_output_isnt_buffered(self):
if len(sys.argv) > 1:
unittest.main()
else:
suite = unittest.TestLoader().loadTestsFromTestCase(Basic)
suite = unittest.TestLoader().loadTestsFromTestCase(FunctionalTests)
unittest.TextTestRunner(verbosity=2).run(suite)

0 comments on commit 4b76001

Please sign in to comment.