Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): add capability to read unity catalog (uc://) uris #3113

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 99 additions & 2 deletions crates/catalog-unity/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ compile_error!(
);

use reqwest::header::{HeaderValue, InvalidHeaderValue, AUTHORIZATION};
use std::collections::HashMap;
use std::str::FromStr;

use crate::credential::{
Expand Down Expand Up @@ -201,6 +202,11 @@ pub enum UnityCatalogConfigKey {
/// - `azure_use_azure_cli`
/// - `use_azure_cli`
UseAzureCli,

/// Allow http url (e.g. http://localhost:8080/api/2.1/...)
/// Supported keys:
/// - `unity_allow_http_url`
AllowHttpUrl,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows users to work with a local (non-https) Unity Catalog REST API with delta-rs.

}

impl FromStr for UnityCatalogConfigKey {
Expand Down Expand Up @@ -246,6 +252,7 @@ impl FromStr for UnityCatalogConfigKey {
| "unity_workspace_url"
| "databricks_workspace_url"
| "databricks_host" => Ok(UnityCatalogConfigKey::WorkspaceUrl),
"allow_http_url" | "unity_allow_http_url" => Ok(UnityCatalogConfigKey::AllowHttpUrl),
_ => Err(DataCatalogError::UnknownConfigKey {
catalog: "unity",
key: s.to_string(),
Expand All @@ -259,6 +266,7 @@ impl AsRef<str> for UnityCatalogConfigKey {
fn as_ref(&self) -> &str {
match self {
UnityCatalogConfigKey::AccessToken => "unity_access_token",
UnityCatalogConfigKey::AllowHttpUrl => "unity_allow_http_url",
UnityCatalogConfigKey::AuthorityHost => "unity_authority_host",
UnityCatalogConfigKey::AuthorityId => "unity_authority_id",
UnityCatalogConfigKey::ClientId => "unity_client_id",
Expand Down Expand Up @@ -311,6 +319,9 @@ pub struct UnityCatalogBuilder {
/// When set to true, azure cli has to be used for acquiring access token
use_azure_cli: bool,

/// When set to true, http will be allowed in the catalog url
allow_http_url: bool,

/// Retry config
retry_config: RetryConfig,

Expand All @@ -333,6 +344,9 @@ impl UnityCatalogBuilder {
) -> DataCatalogResult<Self> {
match UnityCatalogConfigKey::from_str(key.as_ref())? {
UnityCatalogConfigKey::AccessToken => self.bearer_token = Some(value.into()),
UnityCatalogConfigKey::AllowHttpUrl => {
self.allow_http_url = str_is_truthy(&value.into())
}
UnityCatalogConfigKey::ClientId => self.client_id = Some(value.into()),
UnityCatalogConfigKey::ClientSecret => self.client_secret = Some(value.into()),
UnityCatalogConfigKey::AuthorityId => self.authority_id = Some(value.into()),
Expand Down Expand Up @@ -431,6 +445,50 @@ impl UnityCatalogBuilder {
self
}

/// Returns true if table uri is a valid Unity Catalog URI, false otherwise.
pub fn is_unity_catalog_uri(table_uri: &str) -> bool {
table_uri.starts_with("uc://")
}

/// Returns the storage location and temporary token to be used with the
/// Unity Catalog table.
pub async fn get_uc_location_and_token(
table_uri: &str,
) -> Result<(String, HashMap<String, String>), UnityCatalogError> {
let uri_parts: Vec<&str> = table_uri[5..].split('.').collect();
if uri_parts.len() != 3 {
panic!("Invalid Unity Catalog URI: {}", table_uri);
}

let catalog_id = uri_parts[0];
let database_name = uri_parts[1];
let table_name = uri_parts[2];

let unity_catalog = match UnityCatalogBuilder::from_env().build() {
Ok(uc) => uc,
Err(_e) => panic!("Unable to build Unity Catalog."),
};
let storage_location = match unity_catalog
.get_table_storage_location(Some(catalog_id.to_string()), database_name, table_name)
.await
{
Ok(s) => s,
Err(_e) => panic!("Unable to find the table's storage location."),
};
let temp_creds_res = unity_catalog
.get_temp_table_credentials(catalog_id, database_name, table_name)
.await?;
let credentials = match temp_creds_res {
TableTempCredentialsResponse::Success(temp_creds) => {
temp_creds.get_credentials().unwrap()
}
TableTempCredentialsResponse::Error(_error) => {
panic!("Unable to get temporary credentials from Unity Catalog.")
}
};
Ok((storage_location, credentials))
}

fn get_credential_provider(&self) -> Option<CredentialProvider> {
if let Some(token) = self.bearer_token.as_ref() {
return Some(CredentialProvider::BearerToken(token.clone()));
Expand Down Expand Up @@ -488,7 +546,12 @@ impl UnityCatalogBuilder {
.trim_end_matches('/')
.to_string();

let client = self.client_options.client()?;
let client_options = if self.allow_http_url {
self.client_options.with_allow_http(true)
} else {
self.client_options
};
let client = client_options.client()?;

Ok(UnityCatalog {
client,
Expand Down Expand Up @@ -649,7 +712,7 @@ impl UnityCatalog {
self.catalog_url(),
catalog_id.as_ref(),
database_name.as_ref(),
table_name.as_ref()
table_name.as_ref(),
))
.header(AUTHORIZATION, token)
.send()
Expand Down Expand Up @@ -692,6 +755,29 @@ impl UnityCatalog {
}
}

pub trait CatalogFactory {
fn with_table_uri(
table_uri: &str,
) -> impl std::future::Future<
Output = Result<(String, HashMap<String, String>), UnityCatalogError>,
> + Send;
}

pub struct UnityCatalogFactory {}

impl CatalogFactory for UnityCatalogFactory {
async fn with_table_uri(
table_uri: &str,
) -> Result<(String, HashMap<String, String>), UnityCatalogError> {
let (table_path, temp_creds) =
match UnityCatalogBuilder::get_uc_location_and_token(table_uri).await {
Ok(tup) => tup,
Err(err) => return Err(err),
};
Ok((table_path, temp_creds))
}
}

#[async_trait::async_trait]
impl DataCatalog for UnityCatalog {
type Error = UnityCatalogError;
Expand Down Expand Up @@ -731,6 +817,7 @@ mod tests {
use crate::models::tests::{GET_SCHEMA_RESPONSE, GET_TABLE_RESPONSE, LIST_SCHEMAS_RESPONSE};
use crate::models::*;
use crate::UnityCatalogBuilder;
use deltalake_core::DataCatalog;
use httpmock::prelude::*;

#[tokio::test]
Expand Down Expand Up @@ -788,5 +875,15 @@ mod tests {
get_table_response.unwrap(),
GetTableResponse::Success(_)
));

let storage_location = client
.get_table_storage_location(
Some("catalog_name".to_string()),
"schema_name",
"table_name",
)
.await
.unwrap();
assert!(storage_location.eq_ignore_ascii_case("string"));
}
}
3 changes: 3 additions & 0 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ doc = false
[dependencies]
delta_kernel.workspace = true

# deltalake_catalog_unity - local crate
deltalake-catalog-unity = { path = "../crates/catalog-unity" }

# arrow
arrow-schema = { workspace = true, features = ["serde"] }

Expand Down
39 changes: 28 additions & 11 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ use crate::merge::PyMergeBuilder;
use crate::query::PyQueryBuilder;
use crate::schema::{schema_to_pyobject, Field};
use crate::utils::rt;
use deltalake_catalog_unity::{CatalogFactory, UnityCatalogBuilder, UnityCatalogFactory};

#[cfg(all(target_family = "unix", not(target_os = "emscripten")))]
use jemallocator::Jemalloc;
Expand Down Expand Up @@ -170,6 +171,29 @@ impl RawDeltaTable {
original.state = state;
Ok(())
}

fn get_builder(
table_uri: &str,
storage_options: Option<HashMap<String, String>>,
) -> (DeltaTableBuilder, String, HashMap<String, String>) {
let (table_path, temp_creds) = if UnityCatalogBuilder::is_unity_catalog_uri(table_uri) {
rt().block_on(UnityCatalogFactory::with_table_uri(table_uri))
.unwrap()
} else {
(table_uri.to_string(), HashMap::new())
};
let mut options = storage_options.clone().unwrap_or_default();
if !temp_creds.is_empty() {
options.extend(temp_creds);
}
let mut builder = deltalake::DeltaTableBuilder::from_uri(&table_path)
.with_io_runtime(IORuntime::default());
if !options.is_empty() {
builder = builder.with_storage_options(options.clone());
}

(builder, table_path, options)
}
}

#[pymethods]
Expand All @@ -185,12 +209,8 @@ impl RawDeltaTable {
log_buffer_size: Option<usize>,
) -> PyResult<Self> {
py.allow_threads(|| {
let mut builder = deltalake::DeltaTableBuilder::from_uri(table_uri)
.with_io_runtime(IORuntime::default());
let options = storage_options.clone().unwrap_or_default();
if let Some(storage_options) = storage_options {
builder = builder.with_storage_options(storage_options)
}
let (mut builder, table_path, options) =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking more in the lines of this, then you don't have to modify any of the python/lib code

let factory = Arc::new(AzureFactory {});
for scheme in ["az", "adl", "azure", "abfs", "abfss"].iter() {
let url = Url::parse(&format!("{}://", scheme)).unwrap();
factories().insert(url.clone(), factory.clone());

RawDeltaTable::get_builder(table_uri, storage_options);
if let Some(version) = version {
builder = builder.with_version(version)
}
Expand All @@ -207,7 +227,7 @@ impl RawDeltaTable {
Ok(RawDeltaTable {
_table: Arc::new(Mutex::new(table)),
_config: FsConfig {
root_url: table_uri.into(),
root_url: table_path,
options,
},
})
Expand All @@ -220,10 +240,7 @@ impl RawDeltaTable {
table_uri: &str,
storage_options: Option<HashMap<String, String>>,
) -> PyResult<bool> {
let mut builder = deltalake::DeltaTableBuilder::from_uri(table_uri);
if let Some(storage_options) = storage_options {
builder = builder.with_storage_options(storage_options)
}
let (builder, _, _) = RawDeltaTable::get_builder(table_uri, storage_options);
Ok(rt()
.block_on(async {
match builder.build() {
Expand Down
Loading