-
Notifications
You must be signed in to change notification settings - Fork 166
/
Copy pathpython_submissions.py
158 lines (140 loc) · 6.67 KB
/
python_submissions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from typing import Dict, Union
import time
from dbt.adapters.base import PythonJobHelper
from dbt.adapters.bigquery import BigQueryConnectionManager, BigQueryCredentials
from google.api_core import retry
from google.api_core.client_options import ClientOptions
from google.cloud import storage, dataproc_v1 # type: ignore
OPERATION_RETRY_TIME = 10
class BaseDataProcHelper(PythonJobHelper):
def __init__(self, parsed_model: Dict, credential: BigQueryCredentials) -> None:
"""_summary_
Args:
credential (_type_): _description_
"""
# validate all additional stuff for python is set
schema = parsed_model["schema"]
identifier = parsed_model["alias"]
self.parsed_model = parsed_model
python_required_configs = [
"dataproc_region",
"gcs_bucket",
]
for required_config in python_required_configs:
if not getattr(credential, required_config):
raise ValueError(
f"Need to supply {required_config} in profile to submit python job"
)
self.model_file_name = f"{schema}/{identifier}.py"
self.credential = credential
self.GoogleCredentials = BigQueryConnectionManager.get_credentials(credential)
self.storage_client = storage.Client(
project=self.credential.database, credentials=self.GoogleCredentials
)
self.gcs_location = "gs://{}/{}".format(self.credential.gcs_bucket, self.model_file_name)
# set retry policy, default to timeout after 24 hours
self.timeout = self.parsed_model["config"].get(
"timeout", self.credential.job_execution_timeout_seconds or 60 * 60 * 24
)
self.retry = retry.Retry(maximum=10.0, deadline=self.timeout)
self.client_options = ClientOptions(
api_endpoint="{}-dataproc.googleapis.com:443".format(self.credential.dataproc_region)
)
self.job_client = self._get_job_client()
def _upload_to_gcs(self, filename: str, compiled_code: str) -> None:
bucket = self.storage_client.get_bucket(self.credential.gcs_bucket)
blob = bucket.blob(filename)
blob.upload_from_string(compiled_code)
def submit(self, compiled_code: str) -> dataproc_v1.types.jobs.Job:
# upload python file to GCS
self._upload_to_gcs(self.model_file_name, compiled_code)
# submit dataproc job
return self._submit_dataproc_job()
def _get_job_client(
self,
) -> Union[dataproc_v1.JobControllerClient, dataproc_v1.BatchControllerClient]:
raise NotImplementedError("_get_job_client not implemented")
def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job:
raise NotImplementedError("_submit_dataproc_job not implemented")
def _wait_operation(self, operation):
# can't use due to https://github.com/googleapis/python-api-core/issues/458
# response = operation.result(retry=self.retry)
# Temp solution to wait for the job to finish
start = time.time()
while not operation.done(retry=None) and time.time() - start < self.timeout:
time.sleep(OPERATION_RETRY_TIME)
class ClusterDataprocHelper(BaseDataProcHelper):
def _get_job_client(self) -> dataproc_v1.JobControllerClient:
if not self._get_cluster_name():
raise ValueError(
"Need to supply dataproc_cluster_name in profile or config to submit python job with cluster submission method"
)
return dataproc_v1.JobControllerClient( # type: ignore
client_options=self.client_options, credentials=self.GoogleCredentials
)
def _get_cluster_name(self) -> str:
return self.parsed_model["config"].get(
"dataproc_cluster_name", self.credential.dataproc_cluster_name
)
def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job:
job = {
"placement": {"cluster_name": self._get_cluster_name()},
"pyspark_job": {
"main_python_file_uri": self.gcs_location,
},
}
operation = self.job_client.submit_job_as_operation( # type: ignore
request={
"project_id": self.credential.database,
"region": self.credential.dataproc_region,
"job": job,
}
)
self._wait_operation(operation)
response = operation.metadata
# check if job failed
if response.status.state == 6:
raise ValueError(response.status.details)
return response
class ServerlessDataProcHelper(BaseDataProcHelper):
def _get_job_client(self) -> dataproc_v1.BatchControllerClient:
return dataproc_v1.BatchControllerClient(
client_options=self.client_options, credentials=self.GoogleCredentials
)
def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job:
# create the Dataproc Serverless job config
batch = dataproc_v1.Batch()
batch.pyspark_batch.main_python_file_uri = self.gcs_location
# how to keep this up to date?
# we should probably also open this up to be configurable
jar_file_uri = self.parsed_model["config"].get(
"jar_file_uri",
"gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.12-0.21.1.jar",
)
batch.pyspark_batch.jar_file_uris = [jar_file_uri]
# should we make all of these spark/dataproc properties configurable?
# https://cloud.google.com/dataproc-serverless/docs/concepts/properties
# https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.batches#runtimeconfig
batch.runtime_config.properties = {
"spark.executor.instances": "2",
}
parent = f"projects/{self.credential.database}/locations/{self.credential.dataproc_region}"
request = dataproc_v1.CreateBatchRequest(
parent=parent,
batch=batch,
)
# make the request
operation = self.job_client.create_batch(request=request) # type: ignore
# this takes quite a while, waiting on GCP response to resolve
response = operation.result(retry=self.retry)
return response
# there might be useful results here that we can parse and return
# Dataproc job output is saved to the Cloud Storage bucket
# allocated to the job. Use regex to obtain the bucket and blob info.
# matches = re.match("gs://(.*?)/(.*)", response.driver_output_resource_uri)
# output = (
# self.storage_client
# .get_bucket(matches.group(1))
# .blob(f"{matches.group(2)}.000000000")
# .download_as_string()
# )