diff --git a/tensorflow_probability/python/experimental/substrates/numpy/rewrite.py b/tensorflow_probability/python/experimental/substrates/numpy/rewrite.py index a138b844c3..f7a1c742fc 100644 --- a/tensorflow_probability/python/experimental/substrates/numpy/rewrite.py +++ b/tensorflow_probability/python/experimental/substrates/numpy/rewrite.py @@ -18,11 +18,17 @@ from __future__ import division from __future__ import print_function +import sys +if not sys.path[0].endswith('.runfiles'): + sys.path.pop(0) + +# pylint: disable=g-import-not-at-top,g-bad-import-order import collections # Dependency imports from absl import app from absl import flags +# pylint: enable=g-import-not-at-top,g-bad-import-order flags.DEFINE_boolean('numpy_to_jax', False, 'Whether or not to rewrite numpy imports to jax.numpy') diff --git a/tensorflow_probability/python/internal/backend/numpy/gen_linear_operators.py b/tensorflow_probability/python/internal/backend/numpy/gen_linear_operators.py index 6c9b11202d..8eb9943b84 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen_linear_operators.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen_linear_operators.py @@ -18,14 +18,18 @@ from __future__ import division from __future__ import print_function +import sys +if not sys.path[0].endswith('.runfiles'): + sys.path.pop(0) + +# pylint: disable=g-import-not-at-top,g-bad-import-order import importlib import inspect import re -# Dependency imports - from absl import app from absl import flags +# pylint: enable=g-import-not-at-top,g-bad-import-order FLAGS = flags.FLAGS