Skip to content

Commit

Permalink
[FLINK-22348][python] Fix DataStream.execute_and_collect which doesn'…
Browse files Browse the repository at this point in the history
…t declare managed memory for Python operators

This closes #15665.
  • Loading branch information
HuangXingBo authored and dianfu committed Apr 21, 2021
1 parent 7dee2c2 commit 9f40251
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 159 deletions.
2 changes: 2 additions & 0 deletions flink-python/pyflink/datastream/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,8 @@ def execute_and_collect(self, job_execution_name: str = None, limit: int = None)
:param job_execution_name: The name of the job execution.
:param limit: The limit for the collected elements.
"""
JPythonConfigUtil = get_gateway().jvm.org.apache.flink.python.util.PythonConfigUtil
JPythonConfigUtil.configPythonOperator(self._j_data_stream.getExecutionEnvironment())
if job_execution_name is None and limit is None:
return CloseableIterator(self._j_data_stream.executeAndCollect(), self.get_type())
elif job_execution_name is not None and limit is None:
Expand Down
10 changes: 5 additions & 5 deletions flink-python/pyflink/datastream/tests/test_data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,11 +387,11 @@ def test_execute_and_collect(self):
decimal.Decimal('2000000000000000000.061111111111111'
'11111111111111'))]
expected = test_data
ds = self.env.from_collection(test_data)
ds = self.env.from_collection(test_data).map(lambda a: a)
with ds.execute_and_collect() as results:
actual = []
for result in results:
actual.append(result)
actual = [result for result in results]
actual.sort()
expected.sort()
self.assertEqual(expected, actual)

def test_key_by_map(self):
Expand Down Expand Up @@ -942,7 +942,7 @@ def test_partition_custom(self):
expected_num_partitions = 5

def my_partitioner(key, num_partitions):
assert expected_num_partitions, num_partitions
assert expected_num_partitions == num_partitions
return key % num_partitions

partitioned_stream = ds.map(lambda x: x, output_type=Types.ROW([Types.STRING(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@
from pyflink.find_flink_home import _find_flink_source_root
from pyflink.java_gateway import get_gateway
from pyflink.pyflink_gateway_server import on_windows
from pyflink.table import DataTypes, CsvTableSource, CsvTableSink, StreamTableEnvironment
from pyflink.table import DataTypes, CsvTableSource, CsvTableSink, StreamTableEnvironment, \
EnvironmentSettings
from pyflink.testing.test_case_utils import PyFlinkTestCase, exec_insert_table, \
invoke_java_object_method
from pyflink.util.java_utils import get_j_env_configuration


class StreamExecutionEnvironmentTests(PyFlinkTestCase):
Expand Down Expand Up @@ -337,20 +339,48 @@ def test_add_python_file(self):
import uuid
python_file_dir = os.path.join(self.tempdir, "python_file_dir_" + str(uuid.uuid4()))
os.mkdir(python_file_dir)
python_file_path = os.path.join(python_file_dir, "test_stream_dependency_manage_lib.py")
python_file_path = os.path.join(python_file_dir, "test_dep1.py")
with open(python_file_path, 'w') as f:
f.write("def add_two(a):\n return a + 2")

def plus_two_map(value):
from test_stream_dependency_manage_lib import add_two
from test_dep1 import add_two
return add_two(value)

get_j_env_configuration(self.env._j_stream_execution_environment).\
setString("taskmanager.numberOfTaskSlots", "10")
self.env.add_python_file(python_file_path)
ds = self.env.from_collection([1, 2, 3, 4, 5])
ds.map(plus_two_map).add_sink(self.test_sink)
self.env.execute("test add python file")
ds = ds.map(plus_two_map, Types.LONG()) \
.slot_sharing_group("data_stream") \
.map(lambda i: i, Types.LONG()) \
.slot_sharing_group("table")

python_file_path = os.path.join(python_file_dir, "test_dep2.py")
with open(python_file_path, 'w') as f:
f.write("def add_three(a):\n return a + 3")

def plus_three(value):
from test_dep2 import add_three
return add_three(value)

t_env = StreamTableEnvironment.create(
stream_execution_environment=self.env,
environment_settings=EnvironmentSettings.new_instance().use_blink_planner().build())
self.env.add_python_file(python_file_path)

from pyflink.table.udf import udf
from pyflink.table.expressions import col
add_three = udf(plus_three, result_type=DataTypes.BIGINT())

tab = t_env.from_data_stream(ds, 'a') \
.select(add_three(col('a')))
t_env.to_append_stream(tab, Types.ROW([Types.LONG()])) \
.map(lambda i: i[0]) \
.add_sink(self.test_sink)
self.env.execute("test add_python_file")
result = self.test_sink.get_results(True)
expected = ['3', '4', '5', '6', '7']
expected = ['6', '7', '8', '9', '10']
result.sort()
expected.sort()
self.assertEqual(expected, result)
Expand Down
4 changes: 2 additions & 2 deletions flink-python/pyflink/table/table_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,7 @@ def from_pandas(self, pdf,

def _set_python_executable_for_local_executor(self):
jvm = get_gateway().jvm
j_config = get_j_env_configuration(self)
j_config = get_j_env_configuration(self._get_j_env())
if not j_config.containsKey(jvm.PythonOptions.PYTHON_EXECUTABLE.key()) \
and is_local_deployment(j_config):
j_config.setString(jvm.PythonOptions.PYTHON_EXECUTABLE.key(), sys.executable)
Expand All @@ -1559,7 +1559,7 @@ def _add_jars_to_j_env_config(self, config_key):
if jar_urls is not None:
# normalize and remove duplicates
jar_urls_set = set([jvm.java.net.URL(url).toString() for url in jar_urls.split(";")])
j_configuration = get_j_env_configuration(self)
j_configuration = get_j_env_configuration(self._get_j_env())
if j_configuration.containsKey(config_key):
for url in j_configuration.getString(config_key, "").split(";"):
jar_urls_set.add(url)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class TableEnvironmentTest(object):

def test_set_sys_executable_for_local_mode(self):
jvm = get_gateway().jvm
actual_executable = get_j_env_configuration(self.t_env) \
actual_executable = get_j_env_configuration(self.t_env._get_j_env()) \
.getString(jvm.PythonOptions.PYTHON_EXECUTABLE.key(), None)
self.assertEqual(sys.executable, actual_executable)

Expand Down
8 changes: 4 additions & 4 deletions flink-python/pyflink/util/java_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ def is_instance_of(java_object, java_class):
param, java_object)


def get_j_env_configuration(t_env):
if is_instance_of(t_env._get_j_env(), "org.apache.flink.api.java.ExecutionEnvironment"):
return t_env._get_j_env().getConfiguration()
def get_j_env_configuration(j_env):
if is_instance_of(j_env, "org.apache.flink.api.java.ExecutionEnvironment"):
return j_env.getConfiguration()
else:
return invoke_method(
t_env._get_j_env(),
j_env,
"org.apache.flink.streaming.api.environment.StreamExecutionEnvironment",
"getConfiguration"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ public class PythonConfig implements Serializable {
private final boolean isUsingManagedMemory;

/** The Configuration that contains execution configs and dependencies info. */
private final Configuration mergedConfig;
private final Configuration config;

public PythonConfig(Configuration config) {
mergedConfig = config;
this.config = config;
maxBundleSize = config.get(PythonOptions.MAX_BUNDLE_SIZE);
maxBundleTimeMills = config.get(PythonOptions.MAX_BUNDLE_TIME_MILLS);
maxArrowBatchSize = config.get(PythonOptions.MAX_ARROW_BATCH_SIZE);
Expand Down Expand Up @@ -148,7 +148,7 @@ public boolean isUsingManagedMemory() {
return isUsingManagedMemory;
}

public Configuration getMergedConfig() {
return mergedConfig;
public Configuration getConfig() {
return config;
}
}
Loading

0 comments on commit 9f40251

Please sign in to comment.