-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ARROW-10487 [FlightRPC][C++] Header-based auth in clients #8724
Changes from 1 commit
c2ba2c7
ab7d6a3
ecb533f
73be3e7
74ef8ea
29a3192
fac0bd0
7fd1279
0b5a08e
6861cf0
01a134b
de78c6b
c44698f
e975fd8
37889fa
e7ac27c
ba7cb9f
1426252
516d993
3000ecb
1de10fa
065af4a
d4da03b
911fcc7
f41edce
6b6fbbe
199b655
47aa581
477d865
1cc3fdb
d27465d
d21006f
6cd8a45
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,7 +105,7 @@ struct ClientRpc { | |
std::chrono::system_clock::now() + options.timeout); | ||
context.set_deadline(deadline); | ||
} | ||
for (auto metadata: options.metadata) { | ||
for (auto metadata : options.metadata) { | ||
context.AddMetadata(metadata.first, metadata.second); | ||
} | ||
} | ||
|
@@ -998,7 +998,8 @@ class FlightClient::FlightClientImpl { | |
return Status::OK(); | ||
} | ||
|
||
Status AuthenticateBasicToken(std::string username, std::string password, std::pair<std::string, std::string>* bearer_token) { | ||
Status AuthenticateBasicToken(std::string username, std::string password, | ||
std::pair<std::string, std::string>* bearer_token) { | ||
// Add bearer token factory to middleware so it can intercept the bearer token. | ||
middleware.push_back(std::make_shared<ClientBearerTokenFactory>(bearer_token)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks odd to create the shared pointer after you've passed in the raw pointer....it seems like the method itself should take a shared pointer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The raw pointer is unpopulated, so it's passed to the BearerTokenFactory's constructor, which stores it and populated it when it receives the bearer token. I could make the client pass the whole factory in with the bearer token already inside it, but it's more work and requires they understand what's going on more than they otherwise would need to. |
||
ClientRpc rpc({}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You may actually want to allow passing call options so things like timeouts can be set. |
||
|
@@ -1227,7 +1228,7 @@ Status FlightClient::Authenticate(const FlightCallOptions& options, | |
} | ||
|
||
Status FlightClient::AuthenticateBasicToken( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we make this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, I will make this change. |
||
std::string username, std::string password, | ||
std::string username, std::string password, | ||
std::pair<std::string, std::string>* bearer_token) { | ||
return impl_->AuthenticateBasicToken(username, password, bearer_token); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -194,10 +194,12 @@ class ARROW_FLIGHT_EXPORT FlightClient { | |
std::unique_ptr<ClientAuthHandler> auth_handler); | ||
|
||
/// \brief Authenticate to the server using the given handler. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no handler in play here. |
||
/// \param[in] options Per-RPC options | ||
/// \param[in] auth_handler The authentication mechanism to use | ||
/// \param[in] username Username to use | ||
/// \param[in] password Password to use | ||
/// \param[in] bearer_token Bearer token retreived if applicable | ||
/// \return Status OK if the client authenticated successfully | ||
Status AuthenticateBasicToken(std::string username, std::string password, std::pair<std::string, std::string>* bearer_token); | ||
Status AuthenticateBasicToken(std::string username, std::string password, | ||
std::pair<std::string, std::string>* bearer_token); | ||
|
||
/// \brief Perform the indicated action, returning an iterator to the stream | ||
/// of results, if any | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,12 +28,14 @@ namespace flight { | |
|
||
std::string base64_encode(const std::string& input); | ||
|
||
ClientBearerTokenMiddleware::ClientBearerTokenMiddleware(std::pair<std::string, std::string>* bearer_token_) | ||
ClientBearerTokenMiddleware::ClientBearerTokenMiddleware( | ||
std::pair<std::string, std::string>* bearer_token_) | ||
: bearer_token(bearer_token_) { } | ||
|
||
void ClientBearerTokenMiddleware::SendingHeaders(AddCallHeaders* outgoing_headers) { } | ||
|
||
void ClientBearerTokenMiddleware::ReceivedHeaders(const CallHeaders& incoming_headers) { | ||
void ClientBearerTokenMiddleware::ReceivedHeaders( | ||
const CallHeaders& incoming_headers) { | ||
// Grab the auth token if one exists. | ||
auto bearer_iter = incoming_headers.find(AUTH_HEADER); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const? |
||
if (bearer_iter == incoming_headers.end()) { | ||
|
@@ -43,9 +45,9 @@ namespace flight { | |
// Check if the value of the auth token starts with the bearer prefix, latch the token. | ||
std::string bearer_val = bearer_iter->second.to_string(); | ||
if (bearer_val.size() > BEARER_PREFIX.size()) { | ||
bool hasPrefix = std::equal(bearer_val.begin(), bearer_val.begin() + BEARER_PREFIX.size(), BEARER_PREFIX.begin(), | ||
[] (const char& char1, const char& char2) { | ||
return (std::toupper(char1) == std::toupper(char2)); | ||
bool hasPrefix = std::equal(bearer_val.begin(), bearer_val.begin() + BEARER_PREFIX.size(), BEARER_PREFIX.begin(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like they use snake_case instead of camelCase. |
||
[] (const char& char1, const char& char2) { | ||
return (std::toupper(char1) == std::toupper(char2)); | ||
} | ||
); | ||
if (hasPrefix) { | ||
|
@@ -55,7 +57,7 @@ namespace flight { | |
} | ||
|
||
void ClientBearerTokenMiddleware::CallCompleted(const Status& status) { } | ||
|
||
void ClientBearerTokenFactory::StartCall(const CallInfo& info, std::unique_ptr<ClientMiddleware>* middleware) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be better to pass a reference instead of a pointer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't this is a method of the base class, I think it's done this way to allow you to assign a new unique pointer to it without exposing the other middlewares they are already holding in their vector. |
||
*middleware = std::unique_ptr<ClientBearerTokenMiddleware>(new ClientBearerTokenMiddleware(bearer_token)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. std::make_unique? |
||
} | ||
|
@@ -68,28 +70,28 @@ namespace flight { | |
std::string string_format(const std::string& format, const Args... args) { | ||
// Check size requirement for new string and increment by 1 for null terminator. | ||
size_t size = std::snprintf(nullptr, 0, format.c_str(), args ...) + 1; | ||
if(size <= 0){ | ||
throw std::runtime_error("Error during string formatting. Format: '" + format + "'."); | ||
if(size <= 0){ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: spacing between if (, ){ |
||
throw std::runtime_error("Error during string formatting. Format: '" + format + "'."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And Arrow disallows exceptions. |
||
} | ||
|
||
// Create buffer for new string and write string in. | ||
std::unique_ptr<char[]> buf(new char[size]); | ||
std::unique_ptr<char[]> buf(new char[size]); | ||
std::snprintf(buf.get(), size, format.c_str(), args...); | ||
|
||
// Convert to std::string, subtracting size by 1 to trim null terminator. | ||
return std::string(buf.get(), buf.get() + size - 1); | ||
} | ||
|
||
void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, const std::string& password) { | ||
const std::string formatted_credentials = string_format("%s:%s", username.c_str(), password.c_str()); | ||
context->AddMetadata(AUTH_HEADER, BASIC_PREFIX + base64_encode(formatted_credentials)); | ||
} | ||
|
||
std::string base64_encode(const std::string& input) { | ||
static const std::string base64_chars = | ||
static const std::string base64_chars = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't exist in the codebase already? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I found it. Will remove this. |
||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; | ||
auto get_encoded_length = [] (const std::string& in) { | ||
return 4 * ((in.size() + 2) / 3); | ||
auto get_encoded_length = [] (const std::string& in) { | ||
return 4 * ((in.size() + 2) / 3); | ||
}; | ||
auto get_overwrite_count = [] (const std::string& in) { | ||
const std::string::size_type remainder = in.length() % 3; | ||
|
@@ -110,10 +112,12 @@ namespace flight { | |
encoded.push_back(base64_chars[(octriple >> j * 6) & 0x3F]); | ||
} | ||
} | ||
|
||
// Round up to nearest multiple of 3 and replace characters at end based on rounding. | ||
int overwrite_count = get_overwrite_count(input); | ||
encoded.replace(encoded.length() - overwrite_count, encoded.length(), overwrite_count, '='); | ||
encoded.replace(encoded.length() - overwrite_count, | ||
encoded.length(), | ||
overwrite_count, '='); | ||
return encoded; | ||
} | ||
} // namespace flight | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,9 +20,9 @@ | |
|
||
#pragma once | ||
|
||
#include "client_middleware.h" | ||
#include "client_auth.h" | ||
#include "client.h" | ||
#include "arrow/flight/client_middleware.h" | ||
#include "arrow/flight/client_auth.h" | ||
#include "arrow/flight/client.h" | ||
|
||
#ifdef GRPCPP_PP_INCLUDE | ||
#include <grpcpp/grpcpp.h> | ||
|
@@ -45,11 +45,14 @@ const std::string BASIC_PREFIX = "Basic "; | |
namespace arrow { | ||
namespace flight { | ||
|
||
void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, const std::string& password); | ||
void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is internal - it shouldn't be in a public header. (Ditto for the grpc include.) |
||
const std::string& username, | ||
const std::string& password); | ||
|
||
class ARROW_FLIGHT_EXPORT ClientBearerTokenMiddleware : public ClientMiddleware { | ||
public: | ||
explicit ClientBearerTokenMiddleware(std::pair<std::string, std::string>* bearer_token_); | ||
explicit ClientBearerTokenMiddleware( | ||
std::pair<std::string, std::string>* bearer_token_); | ||
|
||
void SendingHeaders(AddCallHeaders* outgoing_headers); | ||
void ReceivedHeaders(const CallHeaders& incoming_headers); | ||
|
@@ -61,11 +64,12 @@ class ARROW_FLIGHT_EXPORT ClientBearerTokenMiddleware : public ClientMiddleware | |
|
||
class ARROW_FLIGHT_EXPORT ClientBearerTokenFactory : public ClientMiddlewareFactory { | ||
public: | ||
explicit ClientBearerTokenFactory(std::pair<std::string, std::string>* bearer_token_) : bearer_token(bearer_token_) {} | ||
explicit ClientBearerTokenFactory(std::pair<std::string, std::string>* bearer_token_) | ||
: bearer_token(bearer_token_) {} | ||
|
||
void StartCall(const CallInfo& info, std::unique_ptr<ClientMiddleware>* middleware); | ||
void Reset(); | ||
|
||
private: | ||
std::pair<std::string, std::string>* bearer_token; | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should username/password be passed by const ref?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes