diff --git a/be/src/rpc/auth-provider.h b/be/src/rpc/auth-provider.h index 30986ac76..1f2ab776a 100644 --- a/be/src/rpc/auth-provider.h +++ b/be/src/rpc/auth-provider.h @@ -83,7 +83,8 @@ class AuthProvider { class SecureAuthProvider : public AuthProvider { public: SecureAuthProvider(bool is_internal) - : has_ldap_(false), has_saml_(false), has_jwt_(false), is_internal_(is_internal) {} + : has_ldap_(false), has_saml_(false), has_jwt_(false), has_oauth_(false), + is_internal_(is_internal) {} /// Performs initialization of external state. /// If we're using ldap, set up appropriate certificate usage. @@ -133,6 +134,8 @@ class SecureAuthProvider : public AuthProvider { void InitJwt() { has_jwt_ = true; } + void InitOauth() { has_oauth_ = true; } + /// Used for testing const std::string& principal() const { return principal_; } const std::string& service_name() const { return service_name_; } @@ -148,6 +151,8 @@ class SecureAuthProvider : public AuthProvider { bool has_jwt_; + bool has_oauth_; + /// Hostname of this machine - if kerberos, derived from principal. If there /// is no kerberos, but LDAP is used, then acquired via GetHostname(). std::string hostname_; diff --git a/be/src/rpc/authentication.cc b/be/src/rpc/authentication.cc index 11804325a..020a41356 100644 --- a/be/src/rpc/authentication.cc +++ b/be/src/rpc/authentication.cc @@ -218,6 +218,57 @@ DEFINE_bool_hidden(jwt_allow_without_tls, false, "When this configuration is set to true, Impala allows JWT authentication on " "unsecure channel. This should be only enabled for testing, or development for which " "TLS is handled by proxy."); + +// OAuth functions +// If set, Impala will support OAuth based authentication. +// header. +DEFINE_bool(oauth_token_auth, false, + "When true, read the OAuth token out of the HTTP Header and extract user name from " + "the token payload."); +// The last segment of an OAuth token is the signature, which is used to verify that the +// token was signed by the sender and not altered in any way. By default, it's required +// to validate the signature of the OAuth tokens. Otherwise it may expose security issue. +DEFINE_bool(oauth_jwt_validate_signature, true, + "When true, validate the signature of OAuth token with pre-installed JWKS." + "This should only be set to false for development / testing"); +// JWKS contains the public keys used by the signing party to the clients that need to +// validate signatures. It represents cryptographic keys in JSON data structure. +DEFINE_string(oauth_jwks_file_path, "", + "File path of the pre-installed JSON Web Key Set (JWKS) for OAuth verification"); +// This specifies the URL for OAuth to be downloaded. +DEFINE_string(oauth_jwks_url, "", "URL of the OAuth Endpoint for token verification"); +// Enables retrieving the OAuth JWKS from the specified URL without verifying the +// presented TLS certificate from the server. +DEFINE_bool(oauth_jwks_verify_server_certificate, true, + "Specifies if the TLS certificate of the JWKS server is verified when retrieving " + "the JWKS from the specified JWKS URL. A certificate is considered valid if a " + "trust chain can be established for it, and if the certificate has a common name or " + "SAN that matches the server's hostname. This should only be set to false for " + "development / testing."); +// Enables defining a custom pem bundle file containing root certificates to trust. +DEFINE_string(oauth_jwks_ca_certificate, "", "File path of a pem bundle of root ca " + "certificates that will be trusted when retrieving the JWKS from the " + "specified JWKS URL."); +DEFINE_int32(oauth_jwks_update_frequency_s, 60, + "(Advanced) The time in seconds to wait for refreshing the OAuth token " + "from the OAuth URL."); +DEFINE_int32(oauth_jwks_pulling_timeout_s, 10, + "(Advanced) The time in seconds for connection timed out when verifying OAuth token " + "from the specified OAuth server."); +// This specifies the custom claim in the OAuth token that contains the "username" for +// the session. +DEFINE_string(oauth_jwt_custom_claim_username, "username", + "Custom claim of the token that " + "contains the username"); +// If set, Impala allows OAuth authentication on unsecure channel. +// OAuth is only secure when used with TLS. But in some deployment scenarios, TLS is +// handled by proxy so that it does not show up as TLS to Impala. +DEFINE_bool_hidden(oauth_allow_without_tls, false, + "When this configuration is set to true, Impala allows OAuth authentication on " + "unsecure channel. This should be only enabled for testing, or development for which " + "TLS is handled by proxy."); +// End OAuth + DEFINE_bool(enable_group_filter_check_for_authenticated_kerberos_user, false, "If this configuration is set to true, Impala checks the provided " "LDAP group filter, if any, with the authenticated Kerberos user. " @@ -726,11 +777,14 @@ bool JWTTokenAuth(ThriftServer::ConnectionContext* connection_context, return false; } if (FLAGS_jwt_validate_signature) { - status = JWTHelper::GetInstance()->Verify(decoded_token.get()); + status = ExecEnv::GetInstance()->GetJWTHelperInstance()->Verify(decoded_token.get()); if (!status.ok()) { LOG(ERROR) << "Error verifying JWT token received from: " << TNetworkAddressToString(connection_context->network_address) << " Error: " << status; + connection_context->return_headers.push_back( + Substitute("WWW-Authenticate: Bearer error=\"invalid_token\",\ +error_description=\"$0 \"", status.GetDetail())); return false; } } @@ -750,6 +804,46 @@ bool JWTTokenAuth(ThriftServer::ConnectionContext* connection_context, return true; } +bool OAuthTokenAuth(ThriftServer::ConnectionContext* connection_context, + const AuthenticationHash& hash, const string& token) { + JWTHelper::UniqueJWTDecodedToken decoded_token; + Status status = JWTHelper::Decode(token, decoded_token); + if (!status.ok()) { + LOG(ERROR) << "Error decoding OAuth token received from: " + << TNetworkAddressToString(connection_context->network_address) + << " Error: " << status; + return false; + } + if (FLAGS_oauth_jwt_validate_signature) { + status = ExecEnv::GetInstance()->GetOAuthHelperInstance()->Verify( + decoded_token.get()); + if (!status.ok()) { + LOG(ERROR) << "Error verifying OAuth token received from: " + << TNetworkAddressToString(connection_context->network_address) + << " Error: " << status; + connection_context->return_headers.push_back( + Substitute("WWW-Authenticate: Bearer error=\"invalid_token\",\ +error_description=\"$0 \"", status.GetDetail())); + return false; + } + } + + DCHECK(!FLAGS_oauth_jwt_custom_claim_username.empty()); + string username; + status = JWTHelper::GetCustomClaimUsername( + decoded_token.get(), FLAGS_oauth_jwt_custom_claim_username, username); + if (!status.ok()) { + LOG(ERROR) << "Error extracting username from OAuth token received from: " + << TNetworkAddressToString(connection_context->network_address) + << " Error: " << status; + return false; + } + connection_context->username = username; + // TODO: cookies are not added, but are not needed right now + + return true; +} + // Performs a step of SPNEGO auth for the HTTP transport and sets the username and // kerberos_user_principal on 'connection_context' if auth is successful. // 'header_token' is the value from an 'Authorization: Negotiate" header. @@ -1309,7 +1403,7 @@ Status SecureAuthProvider::Start() { Status SecureAuthProvider::GetServerTransportFactory( ThriftServer::TransportType underlying_transport_type, const std::string& server_name, MetricGroup* metrics, std::shared_ptr* factory) { - DCHECK(!principal_.empty() || has_ldap_ || has_saml_ || has_jwt_); + DCHECK(!principal_.empty() || has_ldap_ || has_saml_ || has_jwt_ || has_oauth_); if (underlying_transport_type == ThriftServer::HTTP) { bool has_kerberos = !principal_.empty(); @@ -1318,7 +1412,7 @@ Status SecureAuthProvider::GetServerTransportFactory( bool check_trusted_auth_header = !FLAGS_trusted_auth_header.empty(); factory->reset(new THttpServerTransportFactory(server_name, metrics, has_ldap_, has_kerberos, use_cookies, check_trusted_domain, check_trusted_auth_header, - has_saml_, has_jwt_)); + has_saml_, has_jwt_, has_oauth_)); return Status::OK(); } @@ -1451,10 +1545,14 @@ void SecureAuthProvider::SetupConnectionContext( callbacks.validate_saml2_bearer_fn = std::bind(ValidateSaml2Bearer, connection_ptr.get(), hash_); } - if (has_jwt_) { + if (has_jwt_ ) { callbacks.jwt_token_auth_fn = std::bind(JWTTokenAuth, connection_ptr.get(), hash_, std::placeholders::_1); } + if (has_oauth_) { + callbacks.oauth_token_auth_fn = + std::bind(OAuthTokenAuth, connection_ptr.get(), hash_, std::placeholders::_1); + } if (!FLAGS_trusted_auth_header.empty()) { callbacks.trusted_auth_header_handle_fn = std::bind( HandleTrustedAuthHeader, connection_ptr.get(), hash_, std::placeholders::_1); @@ -1581,6 +1679,20 @@ Status AuthManager::Init() { } } + if (FLAGS_oauth_token_auth) { + if (!IsExternalTlsConfigured()) { + if (!FLAGS_oauth_allow_without_tls) { + return Status("OAuth authentication should be only used with TLS enabled."); + } + LOG(WARNING) << "OAuth authentication is used without TLS."; + } + if (FLAGS_oauth_jwt_custom_claim_username.empty()) { + return Status( + "OAuth authentication requires oauth_jwt_custom_claim_username to be " + "specified."); + } + } + // Get all of the flag validation out of the way if (FLAGS_enable_ldap_auth) { RETURN_IF_ERROR( @@ -1665,6 +1777,10 @@ Status AuthManager::Init() { LOG(INFO) << "External communication can be also authenticated with JWT"; sap->InitJwt(); } + if (FLAGS_oauth_token_auth) { + LOG(INFO) << "External communication can be also authenticated with OAuth"; + sap->InitOauth(); + } } else { external_auth_provider_.reset(new NoAuthProvider()); LOG(INFO) << "External communication is not authenticated for binary protocols"; @@ -1674,6 +1790,12 @@ Status AuthManager::Init() { sap->InitSaml(); LOG(INFO) << "External communication is authenticated for hs2-http protocol with " "SAML2 SSO"; + } else if (FLAGS_oauth_token_auth) { + SecureAuthProvider* sap = nullptr; + external_http_auth_provider_.reset(sap = new SecureAuthProvider(false)); + sap->InitOauth(); + LOG(INFO) + << "External communication is authenticated for hs2-http protocol with Oauth"; } else if (use_jwt) { SecureAuthProvider* sap = nullptr; external_http_auth_provider_.reset(sap = new SecureAuthProvider(false)); diff --git a/be/src/runtime/exec-env.cc b/be/src/runtime/exec-env.cc index b22fa816a..eb614421e 100644 --- a/be/src/runtime/exec-env.cc +++ b/be/src/runtime/exec-env.cc @@ -60,6 +60,8 @@ #include "util/default-path-handlers.h" #include "util/hdfs-bulk-ops.h" #include "util/impalad-metrics.h" +#include "util/jwt-util-internal.h" +#include "util/jwt-util.h" #include "util/mem-info.h" #include "util/memory-metrics.h" #include "util/metrics.h" @@ -534,6 +536,9 @@ Status ExecEnv::Init() { AiFunctions::set_api_key(api_key); } + jwt_helper_ = new JWTHelper(); + oauth_helper_ = new JWTHelper(); + return Status::OK(); } diff --git a/be/src/runtime/exec-env.h b/be/src/runtime/exec-env.h index bee6e4d94..2043ead77 100644 --- a/be/src/runtime/exec-env.h +++ b/be/src/runtime/exec-env.h @@ -30,6 +30,8 @@ #include "common/status.h" #include "runtime/client-cache-types.h" #include "testutil/gtest-util.h" +#include "util/jwt-util-internal.h" +#include "util/jwt-util.h" #include "util/hdfs-bulk-ops-defs.h" // For declaration of HdfsOpThreadPool #include "util/network-util.h" #include "util/spinlock.h" @@ -102,6 +104,10 @@ class ExecEnv { /// we return the most recently created instance. static ExecEnv* GetInstance() { return exec_env_; } + // Returns JWT and OAuth Helper instances. + JWTHelper* GetJWTHelperInstance() { return jwt_helper_; } + JWTHelper* GetOAuthHelperInstance() { return oauth_helper_; } + /// Destructor - only used in backend tests that create new environment per test. ~ExecEnv(); @@ -291,6 +297,8 @@ class ExecEnv { FRIEND_TEST(HdfsUtilTest, CheckFilesystemsAndBucketsMatch); static ExecEnv* exec_env_; + JWTHelper* jwt_helper_; + JWTHelper* oauth_helper_; bool is_fe_tests_ = false; /// The network address that the backend KRPC service is listening on: diff --git a/be/src/service/impala-server.cc b/be/src/service/impala-server.cc index 30b4c2482..480e545e7 100644 --- a/be/src/service/impala-server.cc +++ b/be/src/service/impala-server.cc @@ -403,6 +403,14 @@ DECLARE_string(jwks_url); DECLARE_bool(jwks_verify_server_certificate); DECLARE_string(jwks_ca_certificate); +// Flags for OAuth token based authentication. +DECLARE_bool(oauth_token_auth); +DECLARE_bool(oauth_jwt_validate_signature); +DECLARE_string(oauth_jwks_file_path); +DECLARE_string(oauth_jwks_url); +DECLARE_bool(oauth_jwks_verify_server_certificate); +DECLARE_string(oauth_jwks_ca_certificate); + namespace { using namespace impala; @@ -3102,11 +3110,13 @@ Status ImpalaServer::Start(int32_t beeswax_port, int32_t hs2_port, // Load JWKS from file if validation for signature of JWT token is enabled. if (FLAGS_jwt_token_auth && FLAGS_jwt_validate_signature) { if (!FLAGS_jwks_file_path.empty()) { - RETURN_IF_ERROR(JWTHelper::GetInstance()->Init(FLAGS_jwks_file_path)); + RETURN_IF_ERROR(ExecEnv::GetInstance()->GetJWTHelperInstance()->Init( + FLAGS_jwks_file_path)); } else if (!FLAGS_jwks_url.empty()) { if (TestInfo::is_test()) sleep(1); - RETURN_IF_ERROR(JWTHelper::GetInstance()->Init(FLAGS_jwks_url, - FLAGS_jwks_verify_server_certificate, FLAGS_jwks_ca_certificate, false)); + RETURN_IF_ERROR(ExecEnv::GetInstance()->GetJWTHelperInstance()->Init( + FLAGS_jwks_url, FLAGS_jwks_verify_server_certificate, + FLAGS_jwks_ca_certificate, false)); } else { LOG(ERROR) << "JWKS file is not specified when the validation of JWT signature " << " is enabled."; @@ -3114,6 +3124,23 @@ Status ImpalaServer::Start(int32_t beeswax_port, int32_t hs2_port, } } + // Load JWKS from file if validation for signature of OAuth token is enabled. + if (FLAGS_oauth_token_auth && FLAGS_oauth_jwt_validate_signature) { + if (!FLAGS_oauth_jwks_file_path.empty()) { + RETURN_IF_ERROR(ExecEnv::GetInstance()->GetOAuthHelperInstance()->Init( + FLAGS_oauth_jwks_file_path)); + } else if (!FLAGS_oauth_jwks_url.empty()) { + if (TestInfo::is_test()) sleep(1); + RETURN_IF_ERROR(ExecEnv::GetInstance()->GetOAuthHelperInstance()->Init( + FLAGS_oauth_jwks_url, FLAGS_oauth_jwks_verify_server_certificate, + FLAGS_oauth_jwks_ca_certificate, false)); + } else { + LOG(ERROR) << "JWKS file is not specified when the validation of OAuth signature " + << " is enabled."; + return Status("JWKS file for OAuth is not specified"); + } + } + // Initialize the client servers. shared_ptr handler = shared_from_this(); if (beeswax_port > 0 || (TestInfo::is_test() && beeswax_port == 0)) { diff --git a/be/src/transport/THttpServer.cpp b/be/src/transport/THttpServer.cpp index 4091d03f4..a6697fb77 100644 --- a/be/src/transport/THttpServer.cpp +++ b/be/src/transport/THttpServer.cpp @@ -52,7 +52,7 @@ using strings::Substitute; THttpServerTransportFactory::THttpServerTransportFactory(const std::string& server_name, impala::MetricGroup* metrics, bool has_ldap, bool has_kerberos, bool use_cookies, bool check_trusted_domain, bool check_trusted_auth_header, bool has_saml, - bool has_jwt) + bool has_jwt, bool has_oauth) : has_ldap_(has_ldap), has_kerberos_(has_kerberos), use_cookies_(use_cookies), @@ -60,6 +60,7 @@ THttpServerTransportFactory::THttpServerTransportFactory(const std::string& serv check_trusted_auth_header_(check_trusted_auth_header), has_saml_(has_saml), has_jwt_(has_jwt), + has_oauth_(has_oauth), metrics_enabled_(metrics != nullptr) { if (metrics_enabled_) { if (has_ldap_) { @@ -100,12 +101,18 @@ THttpServerTransportFactory::THttpServerTransportFactory(const std::string& serv http_metrics_.total_jwt_token_auth_failure_ = metrics->AddCounter( Substitute("$0.total-jwt-token-auth-failure", server_name), 0); } + if (has_oauth_) { + http_metrics_.total_oauth_token_auth_success_ = metrics->AddCounter( + Substitute("$0.total-oauth-token-auth-success", server_name), 0); + http_metrics_.total_oauth_token_auth_failure_ = metrics->AddCounter( + Substitute("$0.total-oauth-token-auth-failure", server_name), 0); + } } } THttpServer::THttpServer(std::shared_ptr transport, bool has_ldap, bool has_kerberos, bool has_saml, bool use_cookies, bool check_trusted_domain, - bool check_trusted_auth_header, bool has_jwt, bool metrics_enabled, + bool check_trusted_auth_header, bool has_jwt, bool has_oauth, bool metrics_enabled, HttpMetrics* http_metrics) : THttpTransport(move(transport)), has_ldap_(has_ldap), @@ -115,6 +122,7 @@ THttpServer::THttpServer(std::shared_ptr transport, bool has_ldap, check_trusted_domain_(check_trusted_domain), check_trusted_auth_header_(check_trusted_auth_header), has_jwt_(has_jwt), + has_oauth_(has_oauth), metrics_enabled_(metrics_enabled), http_metrics_(http_metrics) {} @@ -167,7 +175,7 @@ void THttpServer::parseHeader(char* header) { contentLength_ = atoi(value); } else if (MatchesHeader(header, HEADER_X_FORWARDED_FOR, sz)) { origin_ = value; - } else if ((has_ldap_ || has_kerberos_ || has_saml_ || has_jwt_) + } else if ((has_ldap_ || has_kerberos_ || has_saml_ || has_jwt_ || has_oauth_) && MatchesHeader(header, HEADER_AUTHORIZATION, sz)) { auth_value_ = string(value); } else if (use_cookies_ && MatchesHeader(header, HEADER_COOKIE, sz)) { @@ -278,7 +286,7 @@ void THttpServer::headersDone() { // Store the truncated value of the 'X-Forwarded-For' header in the Connection Context. callbacks_.set_http_origin_fn(origin); - if (!has_ldap_ && !has_kerberos_ && !has_saml_ && !has_jwt_) { + if (!has_ldap_ && !has_kerberos_ && !has_saml_ && !has_jwt_ && !has_oauth_) { // We don't need to authenticate. resetAuthState(); return; @@ -309,7 +317,7 @@ void THttpServer::headersDone() { } } - if (!authorized && has_jwt_ && !auth_value_.empty() + if (!authorized && (has_jwt_ || has_oauth_) && !auth_value_.empty() && auth_value_.find('.') != string::npos) { // Check Authorization header with the Bearer authentication scheme as: // Authorization: Bearer @@ -319,11 +327,21 @@ void THttpServer::headersDone() { string jwt_token; bool got_bearer_auth = TryStripPrefixString(auth_value_, "Bearer ", &jwt_token); if (got_bearer_auth) { - if (callbacks_.jwt_token_auth_fn(jwt_token)) { + if (has_jwt_ && callbacks_.jwt_token_auth_fn(jwt_token)) { authorized = true; - if (metrics_enabled_) http_metrics_->total_jwt_token_auth_success_->Increment(1); - } else { - if (metrics_enabled_) http_metrics_->total_jwt_token_auth_failure_->Increment(1); + if (metrics_enabled_) + http_metrics_->total_jwt_token_auth_success_->Increment(1); + } + if (!authorized && has_oauth_ && callbacks_.oauth_token_auth_fn(jwt_token)) { + authorized = true; + if (metrics_enabled_) + http_metrics_->total_oauth_token_auth_success_->Increment(1); + } + if (!authorized) { + if (has_jwt_ && metrics_enabled_) + http_metrics_->total_jwt_token_auth_failure_->Increment(1); + if (has_oauth_ && metrics_enabled_) + http_metrics_->total_oauth_token_auth_failure_->Increment(1); } } } diff --git a/be/src/transport/THttpServer.h b/be/src/transport/THttpServer.h index 62331def5..82708c849 100644 --- a/be/src/transport/THttpServer.h +++ b/be/src/transport/THttpServer.h @@ -61,6 +61,9 @@ struct HttpMetrics { impala::IntCounter* total_jwt_token_auth_success_ = nullptr; impala::IntCounter* total_jwt_token_auth_failure_ = nullptr; + + impala::IntCounter* total_oauth_token_auth_success_ = nullptr; + impala::IntCounter* total_oauth_token_auth_failure_ = nullptr; }; /* @@ -143,11 +146,18 @@ public: std::function jwt_token_auth_fn = [&](const std::string&) { return false; }; + + // Function that takes the OAuth token from the header, and returns true + // if verification for the token is successful. + std::function oauth_token_auth_fn = + [&](const std::string&) { + return false; + }; }; THttpServer(std::shared_ptr transport, bool has_ldap, bool has_kerberos, bool has_saml, bool use_cookies, bool check_trusted_domain, - bool check_trusted_auth_header, bool has_jwt, bool metrics_enabled, + bool check_trusted_auth_header, bool has_jwt, bool has_oauth, bool metrics_enabled, HttpMetrics* http_metrics); virtual ~THttpServer(); @@ -188,9 +198,9 @@ protected: void resetAuthState(); private: // If either of the following is true, a '401 - Unauthorized' will be returned to the - // client on requests that do not contain a valid 'Authorization' of SAML SSO or JWT - // related header. If 'has_ldap_' is true, 'Basic' auth headers will be processed, and - // if 'has_kerberos_' is true 'Negotiate' auth headers will be processed. + // client on requests that do not contain a valid 'Authorization' of SAML SSO, JWT or + // OAuth related header. If 'has_ldap_' is true, 'Basic' auth headers will be processed, + // and if 'has_kerberos_' is true 'Negotiate' auth headers will be processed. bool has_ldap_ = false; bool has_kerberos_ = false; @@ -238,6 +248,9 @@ protected: // If set, support for trusting an authentication based on JWT token. bool has_jwt_ = false; + // If set, support for trusting an authentication based on OAuth token. + bool has_oauth_ = false; + bool metrics_enabled_ = false; HttpMetrics* http_metrics_ = nullptr; @@ -268,14 +281,14 @@ public: THttpServerTransportFactory(const std::string& server_name, impala::MetricGroup* metrics, bool has_ldap, bool has_kerberos, bool use_cookies, bool check_trusted_domain, - bool check_trusted_auth_header, bool has_saml, bool has_jwt); + bool check_trusted_auth_header, bool has_saml, bool has_jwt, bool has_oauth); virtual ~THttpServerTransportFactory() {} virtual std::shared_ptr getTransport(std::shared_ptr trans) { return std::shared_ptr(new THttpServer(trans, has_ldap_, has_kerberos_, has_saml_, use_cookies_, check_trusted_domain_, check_trusted_auth_header_, - has_jwt_, metrics_enabled_, &http_metrics_)); + has_jwt_, has_oauth_, metrics_enabled_, &http_metrics_)); } private: @@ -286,6 +299,7 @@ public: bool check_trusted_auth_header_ = false; bool has_saml_ = false; bool has_jwt_ = false; + bool has_oauth_ = false; // Metrics for every transport produced by this factory. bool metrics_enabled_ = false; diff --git a/be/src/util/jwt-util.cc b/be/src/util/jwt-util.cc index d362e7c77..8843790de 100644 --- a/be/src/util/jwt-util.cc +++ b/be/src/util/jwt-util.cc @@ -849,8 +849,6 @@ struct JWTHelper::JWTDecodedToken { DecodedJWT decoded_jwt_; }; -JWTHelper* JWTHelper::jwt_helper_ = new JWTHelper(); - void JWTHelper::TokenDeleter::operator()(JWTHelper::JWTDecodedToken* token) const { if (token != nullptr) delete token; }; diff --git a/be/src/util/jwt-util.h b/be/src/util/jwt-util.h index ce777b0ba..2b4f12ee8 100644 --- a/be/src/util/jwt-util.h +++ b/be/src/util/jwt-util.h @@ -51,9 +51,6 @@ class JWTHelper { /// facilitate automatic reference counting. typedef std::unique_ptr UniqueJWTDecodedToken; - /// Return the single instance. - static JWTHelper* GetInstance() { return jwt_helper_; } - /// Load JWKS from a given local JSON file. Returns an error if problems were /// encountered. Status Init(const std::string& jwks_file_path); @@ -82,9 +79,6 @@ class JWTHelper { std::shared_ptr GetJWKS() const; private: - /// Single instance. - static JWTHelper* jwt_helper_; - /// Set it as TRUE when Init() is called. bool initialized_ = false; diff --git a/be/src/util/webserver.cc b/be/src/util/webserver.cc index 543a6e1da..8471cb545 100644 --- a/be/src/util/webserver.cc +++ b/be/src/util/webserver.cc @@ -168,6 +168,9 @@ DECLARE_bool(jwt_validate_signature); DECLARE_string(jwt_custom_claim_username); DECLARE_string(trusted_auth_header); DECLARE_string(spnego_keytab_file); +DECLARE_bool(oauth_token_auth); +DECLARE_bool(oauth_jwt_validate_signature); +DECLARE_string(oauth_jwt_custom_claim_username); static const char* DOC_FOLDER = "/www/"; static const int DOC_FOLDER_LEN = strlen(DOC_FOLDER); @@ -320,7 +323,8 @@ Webserver::Webserver(const string& interface, const int port, MetricGroup* metri use_cookies_(FLAGS_max_cookie_lifetime_s > 0), check_trusted_domain_(!FLAGS_trusted_domain.empty()), check_trusted_auth_header_(!FLAGS_trusted_auth_header.empty()), - use_jwt_(FLAGS_jwt_token_auth) { + use_jwt_(FLAGS_jwt_token_auth), + use_oauth_(FLAGS_oauth_token_auth) { http_address_ = MakeNetworkAddress(interface.empty() ? "0.0.0.0" : interface, port); Init(); @@ -358,6 +362,12 @@ Webserver::Webserver(const string& interface, const int port, MetricGroup* metri total_jwt_token_auth_failure_ = metrics->AddCounter("impala.webserver.total-jwt-token-auth-failure", 0); } + if (use_oauth_) { + total_oauth_token_auth_success_ = + metrics->AddCounter("impala.webserver.total-oauth-token-auth-success", 0); + total_oauth_token_auth_failure_ = + metrics->AddCounter("impala.webserver.total-oauth-token-auth-failure", 0); + } } Webserver::~Webserver() { @@ -673,7 +683,7 @@ sq_callback_result_t Webserver::BeginRequestCallback(struct sq_connection* conne bool cookie_authenticated = false; // Try authenticating with JWT token first, if enabled. - if (use_jwt_) { + if (use_jwt_ || use_oauth_) { const char* auth_value = nullptr; const char* value = sq_get_header(connection, "Authorization"); if (value != nullptr) auth_value = StripLeadingWhiteSpace(value); @@ -683,17 +693,34 @@ sq_callback_result_t Webserver::BeginRequestCallback(struct sq_connection* conne // separated by dots (.). if (auth_value != nullptr && strncasecmp(auth_value, "Bearer ", 7) == 0 && strchr(auth_value, '.') != nullptr) { - string jwt_token = string(auth_value + 7); - StripWhiteSpace(&jwt_token); - if (!jwt_token.empty()) { - if (JWTTokenAuth(jwt_token, connection, request_info)) { - total_jwt_token_auth_success_->Increment(1); - authenticated = true; - check_csrf_protection = false; - // TODO: cookies are not added, but are not needed right now - } else { - LOG(INFO) << "Invalid JWT token provided: " << jwt_token; - total_jwt_token_auth_failure_->Increment(1); + string bearer_token= string(auth_value + 7); + StripWhiteSpace(&bearer_token); + if (!bearer_token.empty()) { + if (use_jwt_) { + if (JWTTokenAuth(bearer_token, connection, request_info)) { + total_jwt_token_auth_success_->Increment(1); + authenticated = true; + check_csrf_protection = false; + // TODO: cookies are not added, but are not needed right now + } + } + if (!authenticated && use_oauth_) { + if (OAuthTokenAuth(bearer_token, connection, request_info)) { + total_oauth_token_auth_success_->Increment(1); + authenticated = true; + check_csrf_protection = false; + // TODO: cookies are not added, but are not needed right now + } + } + if (!authenticated) { + if (use_jwt_) { + LOG(INFO) << "Invalid JWT token provided: " << bearer_token; + total_jwt_token_auth_failure_->Increment(1); + } + if (use_oauth_) { + LOG(INFO) << "Invalid OAuth token provided: " << bearer_token; + total_oauth_token_auth_failure_->Increment(1); + } } } } @@ -1049,7 +1076,7 @@ bool Webserver::JWTTokenAuth(const std::string& jwt_token, return false; } if (FLAGS_jwt_validate_signature) { - status = JWTHelper::GetInstance()->Verify(decoded_token.get()); + status = ExecEnv::GetInstance()->GetJWTHelperInstance()->Verify(decoded_token.get()); if (!status.ok()) { LOG(ERROR) << "Error verifying JWT token in Authorization header, " << "Error: " << status; @@ -1070,6 +1097,39 @@ bool Webserver::JWTTokenAuth(const std::string& jwt_token, return true; } +bool Webserver::OAuthTokenAuth(const std::string& oauth_token, + struct sq_connection* connection, struct sq_request_info* request_info) { + JWTHelper::UniqueJWTDecodedToken decoded_token; + Status status = JWTHelper::Decode(oauth_token, decoded_token); + if (!status.ok()) { + LOG(ERROR) << "Error decoding OAuth token in Authorization header, " + << "Error: " << status; + return false; + } + if (FLAGS_oauth_jwt_validate_signature) { + status = ExecEnv::GetInstance()->GetOAuthHelperInstance()->Verify( + decoded_token.get()); + if (!status.ok()) { + LOG(ERROR) << "Error verifying OAuth token in Authorization header, " + << "Error: " << status; + return false; + } + } + + DCHECK(!FLAGS_oauth_jwt_custom_claim_username.empty()); + string username; + status = JWTHelper::GetCustomClaimUsername( + decoded_token.get(), FLAGS_oauth_jwt_custom_claim_username, username); + if (!status.ok()) { + LOG(ERROR) << "Cannot retrieve username from OAUTh token in Authorization header, " + << "Error: " << status; + return false; + } + request_info->remote_user = strdup(username.c_str()); + + return true; +} + Status Webserver::HandleBasic(struct sq_connection* connection, struct sq_request_info* request_info, vector* response_headers) { const char* authz_header = sq_get_header(connection, "Authorization"); diff --git a/be/src/util/webserver.h b/be/src/util/webserver.h index b8ad7ddfb..bfa71f171 100644 --- a/be/src/util/webserver.h +++ b/be/src/util/webserver.h @@ -209,6 +209,11 @@ class Webserver { bool JWTTokenAuth(const std::string& jwt_token, struct sq_connection* connection, struct sq_request_info* request_info); + /// Checks and returns true if the OAuth token in Authorization header could be verified + /// and the token has a valid username. + bool OAuthTokenAuth(const std::string& oauth_token, struct sq_connection* connection, + struct sq_request_info* request_info); + // Handle Basic authentication for this request. Returns an error if authentication was // unsuccessful. Status HandleBasic(struct sq_connection* connection, @@ -290,6 +295,10 @@ class Webserver { /// An incoming connection will be accepted if the JWT token could be verified. bool use_jwt_ = false; + /// If true, the OAuth token in Authorization header will be used for authentication. + /// An incoming connection will be accepted if the OAuth token could be verified. + bool use_oauth_ = false; + /// Used to validate usernames/passwords If LDAP authentication is in use. std::unique_ptr ldap_; @@ -320,6 +329,11 @@ class Webserver { /// attempts. IntCounter* total_jwt_token_auth_success_ = nullptr; IntCounter* total_jwt_token_auth_failure_ = nullptr; + + /// If 'use_oauth_' is true, metrics for the number of successful and failed OAuth auth + /// attempts. + IntCounter* total_oauth_token_auth_success_ = nullptr; + IntCounter* total_oauth_token_auth_failure_ = nullptr; }; } diff --git a/common/thrift/generate_error_codes.py b/common/thrift/generate_error_codes.py index ba8f7cb4f..80a275f90 100755 --- a/common/thrift/generate_error_codes.py +++ b/common/thrift/generate_error_codes.py @@ -491,7 +491,9 @@ error_codes = ( ("JDBC_CONFIGURATION_ERROR", 159, "Error in JDBC table configuration: $0."), - ("TUPLE_CACHE_INCONSISTENCY", 160, "Inconsistent tuple cache found: $0.") + ("TUPLE_CACHE_INCONSISTENCY", 160, "Inconsistent tuple cache found: $0."), + + ("OAUTH_VERIFY_FAILED", 161, "Error verifying OAuth Token: $0.") ) import sys diff --git a/common/thrift/metrics.json b/common/thrift/metrics.json index 8bd530351..fd74c8e7e 100644 --- a/common/thrift/metrics.json +++ b/common/thrift/metrics.json @@ -1873,6 +1873,26 @@ "kind": "COUNTER", "key": "impala.thrift-server.hiveserver2-http-frontend.total-jwt-token-auth-failure" }, + { + "description": "The number of HiveServer2 HTTP API connection requests to this Impala Daemon that were successfully authenticated using OAuth Token.", + "contexts": [ + "IMPALAD" + ], + "label": "HiveServer2 HTTP API Connection OAuth Token Success", + "units": "NONE", + "kind": "COUNTER", + "key": "impala.thrift-server.hiveserver2-http-frontend.total-oauth-token-auth-success" + }, + { + "description": "The number of HiveServer2 HTTP API connection requests to this Impala Daemon that were attempted to authenticate using OAuth Token but were unsuccessful.", + "contexts": [ + "IMPALAD" + ], + "label": "HiveServer2 HTTP API Connection OAuth Token Failure", + "units": "NONE", + "kind": "COUNTER", + "key": "impala.thrift-server.hiveserver2-http-frontend.total-oauth-token-auth-failure" + }, { "description": "The amount of memory freed by the last memory tracker garbage collection.", "contexts": [ @@ -3804,6 +3824,30 @@ "kind": "COUNTER", "key": "impala.webserver.total-jwt-token-auth-failure" }, + { + "description": "The number of HTTP connection requests to this daemon's webserver that were successfully authenticated using OAuth token.", + "contexts": [ + "IMPALAD", + "CATALOGSERVER", + "STATESTORE" + ], + "label": "Webserver HTTP Connection OAuth Token Auth Success", + "units": "NONE", + "kind": "COUNTER", + "key": "impala.webserver.total-oauth-token-auth-success" + }, + { + "description": "The number of HTTP connection requests to this daemon's webserver that provided an invalid OAuth token.", + "contexts": [ + "IMPALAD", + "CATALOGSERVER", + "STATESTORE" + ], + "label": "Webserver HTTP Connection OAuth Token Auth Failure", + "units": "NONE", + "kind": "COUNTER", + "key": "impala.webserver.total-oauth-token-auth-failure" + }, { "description": "The number of times the FAIL debug action returned an error. For testing only.", "contexts": [ diff --git a/fe/src/test/java/org/apache/impala/customcluster/JwtWebserverTest.java b/fe/src/test/java/org/apache/impala/customcluster/JwtWebserverTest.java index be4d60e66..366e95a20 100644 --- a/fe/src/test/java/org/apache/impala/customcluster/JwtWebserverTest.java +++ b/fe/src/test/java/org/apache/impala/customcluster/JwtWebserverTest.java @@ -56,14 +56,17 @@ public class JwtWebserverTest { client_.Close(); } - private void verifyJwtAuthMetrics( - Range expectedAuthSuccess, Range expectedAuthFailure) throws Exception { + private void verifyAuthMetrics( + Range expectedAuthSuccess, Range expectedAuthFailure, String auth_type) + throws Exception { long actualAuthSuccess = - (long) client_.getMetric("impala.webserver.total-jwt-token-auth-success"); + (long) client_.getMetric("impala.webserver.total-" + auth_type + + "-token-auth-success"); assertTrue("Expected: " + expectedAuthSuccess + ", Actual: " + actualAuthSuccess, expectedAuthSuccess.contains(actualAuthSuccess)); long actualAuthFailure = - (long) client_.getMetric("impala.webserver.total-jwt-token-auth-failure"); + (long) client_.getMetric("impala.webserver.total-" + auth_type + + "-token-auth-failure"); assertTrue("Expected: " + expectedAuthFailure + ", Actual: " + actualAuthFailure, expectedAuthFailure.contains(actualAuthFailure)); } @@ -95,7 +98,7 @@ public class JwtWebserverTest { + "bZd0GbD_MQQ8x7WRE4nluU-5Fl4N2Wo8T9fNTuxALPiuVeIczO25b5n4fryfKasSgaZfmk0C" + "oOJzqbtmQxqiK9QNSJAiH2kaqMwLNgAdgn8fbd-lB1RAEGeyPH8Px8ipqcKsPk0bg"; attemptConnection("Bearer " + jwtToken, "127.0.0.1"); - verifyJwtAuthMetrics(Range.closed(1L, 1L), zero); + verifyAuthMetrics(Range.closed(1L, 1L), zero, "jwt"); // Case 2: Failed with invalid JWT Token. String invalidJwtToken = @@ -107,7 +110,7 @@ public class JwtWebserverTest { } catch (IOException e) { assertTrue(e.getMessage().contains("Server returned HTTP response code: 401")); } - verifyJwtAuthMetrics(Range.closed(1L, 1L), Range.closed(1L, 1L)); + verifyAuthMetrics(Range.closed(1L, 1L), Range.closed(1L, 1L), "jwt"); // Case 3: Failed without "Bearer" token. try { @@ -116,7 +119,7 @@ public class JwtWebserverTest { assertTrue(e.getMessage().contains("Server returned HTTP response code: 401")); } // JWT authentication is not invoked. - verifyJwtAuthMetrics(Range.closed(1L, 1L), Range.closed(1L, 1L)); + verifyAuthMetrics(Range.closed(1L, 1L), Range.closed(1L, 1L), "jwt"); // Case 4: Failed without "Authorization" header. try { @@ -125,10 +128,70 @@ public class JwtWebserverTest { assertTrue(e.getMessage().contains("Server returned HTTP response code: 401")); } // JWT authentication is not invoked. - verifyJwtAuthMetrics(Range.closed(1L, 1L), Range.closed(1L, 1L)); + verifyAuthMetrics(Range.closed(1L, 1L), Range.closed(1L, 1L), "jwt"); } - // Helper method to make a "get" call to the Web Server using the input JWT auth token + /** + * Tests if sessions are authenticated by verifying the OAuth token for connections + * to the Web Server. + * Since we don't have Java version of JWT library, we use pre-calculated JWT token + * and JWKS. The token and JWK set used in this test case were generated by using + * BE unit-test function JwtUtilTest::VerifyJwtRS256. + */ + @Test + public void testWebserverOAuthAuth() throws Exception { + String jwksFilename = + new File(System.getenv("IMPALA_HOME"), "testdata/jwt/jwks_rs256.json").getPath(); + setUp(String.format( + "--oauth_token_auth=true --oauth_jwt_validate_signature=true " + + "--oauth_jwks_file_path=%s --oauth_allow_without_tls=true", + jwksFilename), + ""); + + // Case 1: Authenticate with valid OAuth Token in HTTP header. + String oauthToken = + "eyJhbGciOiJSUzI1NiIsImtpZCI6InB1YmxpYzpjNDI0YjY3Yi1mZTI4LTQ1ZDctYjAxNS1m" + + "NzlkYTUwYjViMjEiLCJ0eXAiOiJKV1MifQ.eyJpc3MiOiJhdXRoMCIsInVzZXJuYW1lIjoia" + + "W1wYWxhIn0.OW5H2SClLlsotsCarTHYEbqlbRh43LFwOyo9WubpNTwE7hTuJDsnFoVrvHiWI" + + "02W69TZNat7DYcC86A_ogLMfNXagHjlMFJaRnvG5Ekag8NRuZNJmHVqfX-qr6x7_8mpOdU55" + + "4kc200pqbpYLhhuK4Qf7oT7y9mOrtNrUKGDCZ0Q2y_mizlbY6SMg4RWqSz0RQwJbRgXIWSgc" + + "bZd0GbD_MQQ8x7WRE4nluU-5Fl4N2Wo8T9fNTuxALPiuVeIczO25b5n4fryfKasSgaZfmk0C" + + "oOJzqbtmQxqiK9QNSJAiH2kaqMwLNgAdgn8fbd-lB1RAEGeyPH8Px8ipqcKsPk0bg"; + attemptConnection("Bearer " + oauthToken, "127.0.0.1"); + verifyAuthMetrics(Range.closed(1L, 1L), zero, "oauth"); + + // Case 2: Failed with invalid OAuth Token. + String invalidOAuthToken = + "eyJhbGciOiJSUzI1NiIsImtpZCI6InB1YmxpYzpjNDI0YjY3Yi1mZTI4LTQ1ZDctYjAxNS1m" + + "NzlkYTUwYjViMjEiLCJ0eXAiOiJKV1MifQ.eyJpc3MiOiJhdXRoMCIsInVzZXJuYW1lIjoia" + + "W1wYWxhIn0."; + try { + attemptConnection("Bearer " + invalidOAuthToken, "127.0.0.1"); + } catch (IOException e) { + assertTrue(e.getMessage().contains("Server returned HTTP response code: 401")); + } + verifyAuthMetrics(Range.closed(1L, 1L), Range.closed(1L, 1L), "oauth"); + + // Case 3: Failed without "Bearer" token. + try { + attemptConnection("Basic VGVzdDFMZGFwOjEyMzQ1", "127.0.0.1"); + } catch (IOException e) { + assertTrue(e.getMessage().contains("Server returned HTTP response code: 401")); + } + // OAUth authentication is not invoked. + verifyAuthMetrics(Range.closed(1L, 1L), Range.closed(1L, 1L), "oauth"); + + // Case 4: Failed without "Authorization" header. + try { + attemptConnection(null, "127.0.0.1"); + } catch (IOException e) { + assertTrue(e.getMessage().contains("Server returned HTTP response code: 401")); + } + // OAuth authentication is not invoked. + verifyAuthMetrics(Range.closed(1L, 1L), Range.closed(1L, 1L), "oauth"); + } + + // Helper method to make a "get" call to the Web Server using the input OAuth auth token // and x-forward-for address. private void attemptConnection(String auth_token, String xff_address) throws Exception { String url = "http://localhost:25000/?json"; diff --git a/fe/src/test/java/org/apache/impala/customcluster/LdapHS2Test.java b/fe/src/test/java/org/apache/impala/customcluster/LdapHS2Test.java index 3cff62b22..fbb980715 100644 --- a/fe/src/test/java/org/apache/impala/customcluster/LdapHS2Test.java +++ b/fe/src/test/java/org/apache/impala/customcluster/LdapHS2Test.java @@ -150,15 +150,16 @@ public class LdapHS2Test { assertEquals(expectedAuthSuccess, actualAuthSuccess); } - private void verifyJwtAuthMetrics(long expectedAuthSuccess, long expectedAuthFailure) + private void verifyAuthMetrics( + long expectedAuthSuccess, long expectedAuthFailure, String authType) throws Exception { long actualAuthSuccess = (long) client_.getMetric("impala.thrift-server.hiveserver2-http-frontend." - + "total-jwt-token-auth-success"); + + "total-" + authType + "-token-auth-success"); assertEquals(expectedAuthSuccess, actualAuthSuccess); long actualAuthFailure = (long) client_.getMetric("impala.thrift-server.hiveserver2-http-frontend." - + "total-jwt-token-auth-failure"); + + "total-" + authType + "-token-auth-failure"); assertEquals(expectedAuthFailure, actualAuthFailure); } @@ -686,13 +687,13 @@ public class LdapHS2Test { TOpenSessionResp openResp = client.OpenSession(openReq); // One successful authentication. verifyMetrics(0, 0); - verifyJwtAuthMetrics(1, 0); + verifyAuthMetrics(1, 0, "jwt"); // Running a query should succeed. TOperationHandle operationHandle = execAndFetch( client, openResp.getSessionHandle(), "select logged_in_user()", "impala"); // Two more successful authentications - for the Exec() and the Fetch(). verifyMetrics(0, 0); - verifyJwtAuthMetrics(3, 0); + verifyAuthMetrics(3, 0, "jwt"); // case 2: Authenticate fails with invalid JWT token which does not have signature. String invalidJwtToken = @@ -706,7 +707,115 @@ public class LdapHS2Test { openResp = client.OpenSession(openReq); fail("Exception exception."); } catch (Exception e) { - verifyJwtAuthMetrics(3, 1); + verifyAuthMetrics(3, 1, "jwt"); + assertEquals(e.getMessage(), "HTTP Response code: 401"); + } + } + + /** + * Tests if sessions are authenticated by verifying both JWT and OAuth token for + * connections to the HTTP hiveserver2 endpoint. + */ + @Test + public void testHiveserver2JwtAndOAuthAuth() throws Exception { + String jwtJwksFilename = + new File(System.getenv("IMPALA_HOME"), "testdata/jwt/jwks_rs256.json").getPath(); + String oauthJwksFilename = + new File(System.getenv("IMPALA_HOME"), + "testdata/jwt/jwks_signing.json").getPath(); + setUp(String.format( + "--jwt_token_auth=true --jwt_validate_signature=true --jwks_file_path=%s " + + "--jwt_allow_without_tls=true --oauth_token_auth=true " + + "--oauth_jwt_validate_signature=true --oauth_jwks_file_path=%s " + + "--jwt_allow_without_tls=true --oauth_jwt_custom_claim_username=sub " + + "--oauth_allow_without_tls=true", + jwtJwksFilename, oauthJwksFilename)); + verifyMetrics(0, 0); + THttpClient transport = new THttpClient("http://localhost:28000"); + Map headers = new HashMap(); + + // Case 1: Authenticate with valid JWT Token in HTTP header. + String jwtToken = + "eyJhbGciOiJSUzI1NiIsImtpZCI6InB1YmxpYzpjNDI0YjY3Yi1mZTI4LTQ1ZDctYjAxNS1m" + + "NzlkYTUwYjViMjEiLCJ0eXAiOiJKV1MifQ.eyJpc3MiOiJhdXRoMCIsInVzZXJuYW1lIjoia" + + "W1wYWxhIn0.OW5H2SClLlsotsCarTHYEbqlbRh43LFwOyo9WubpNTwE7hTuJDsnFoVrvHiWI" + + "02W69TZNat7DYcC86A_ogLMfNXagHjlMFJaRnvG5Ekag8NRuZNJmHVqfX-qr6x7_8mpOdU55" + + "4kc200pqbpYLhhuK4Qf7oT7y9mOrtNrUKGDCZ0Q2y_mizlbY6SMg4RWqSz0RQwJbRgXIWSgc" + + "bZd0GbD_MQQ8x7WRE4nluU-5Fl4N2Wo8T9fNTuxALPiuVeIczO25b5n4fryfKasSgaZfmk0C" + + "oOJzqbtmQxqiK9QNSJAiH2kaqMwLNgAdgn8fbd-lB1RAEGeyPH8Px8ipqcKsPk0bg"; + headers.put("Authorization", "Bearer " + jwtToken); + headers.put("X-Forwarded-For", "127.0.0.1"); + transport.setCustomHeaders(headers); + transport.open(); + TCLIService.Iface client = new TCLIService.Client(new TBinaryProtocol(transport)); + + // Open a session which will get username 'impala' from JWT token and use it as + // login user. + TOpenSessionReq openReq = new TOpenSessionReq(); + TOpenSessionResp openResp = client.OpenSession(openReq); + // One successful authentication. + verifyMetrics(0, 0); + verifyAuthMetrics(1, 0, "jwt"); + // Running a query should succeed. + TOperationHandle operationHandle = execAndFetch( + client, openResp.getSessionHandle(), "select logged_in_user()", "impala"); + // Two more successful authentications - for the Exec() and the Fetch(). + verifyMetrics(0, 0); + verifyAuthMetrics(3, 0, "jwt"); + verifyAuthMetrics(0, 0, "oauth"); + + // Authenticate with a valid OAuth token in HTTP header. + String oauthToken = + "eyJhbGciOiJSUzI1NiIsImtpZCI6IjIwMjMwNTA5LTE2MDQxNSIsInR5cGUiOiJKV1QifQ.eyJ" + + "hdWQiOiJpbXBhbGEtdGVzdHMiLCJleHAiOjE5OTkwMDgyNTUsImlhdCI6MTY4MzY0ODI1NSw" + + "iaXNzIjoiZmlsZTovL3Rlc3RzL3V0aWwvand0L2p3dF91dGlsLnB5Iiwia2lkIjoiMjAyMzA" + + "1MDktMTYwNDE1Iiwic3ViIjoidGVzdC11c2VyIn0.dWMOkcBrwRansZrCZrlbYzr9alIQ23q" + + "lnw4t8Kx_v87CBB90qtmTV88nZAh4APtTE8IUnP0e45R2XyDoH3a8UVrrSOkEzI47wJ0I3Gq" + + "Sc_R_MsGoeGlKreZmcjGhY_ceOo7RWYaBdzsAZe1YXcKJbq2sQJ3issfjBa_fWt0Qhy0Dvzs" + + "sUf3V-g5nQUM3W3pOULiFtMhA8YmIdheHalRz3D_NWMAqe79iUv6tG0Eg08x-cl8GXYsDm45" + + "sU4WkP5fZps6Q4Fm05640FWXG8K0PoLzSI_Iac3zzSAPs-iYNeeNE6C9QxBYSLBvQrWL0SET" + + "afP82Mo-nEZsAJbMMSqm0cQ"; + + transport = new THttpClient("http://localhost:28000"); + headers = new HashMap(); + headers.put("Authorization", "Bearer " + oauthToken); + headers.put("X-Forwarded-For", "127.0.0.1"); + transport.setCustomHeaders(headers); + transport.open(); + client = new TCLIService.Client(new TBinaryProtocol(transport)); + + // Open a session which will get username 'test-user' from OAuth token and use + // it as login user. + openReq = new TOpenSessionReq(); + openResp = client.OpenSession(openReq); + // One successful authentication. + verifyMetrics(0, 0); + verifyAuthMetrics(1, 0, "oauth"); + // Running a query should succeed. + operationHandle = execAndFetch( + client, openResp.getSessionHandle(), "select logged_in_user()", "test-user"); + // Two more successful authentications - for the Exec() and the Fetch(). + verifyMetrics(0, 0); + verifyAuthMetrics(3, 0, "oauth"); + verifyAuthMetrics(3, 0, "jwt"); + + // case 2: Authenticate fails with invalid token for both JWT and OAuth which does + // not have signature. + headers.clear(); + String invalidJwtToken = + "eyJhbGciOiJSUzI1NiIsImtpZCI6InB1YmxpYzpjNDI0YjY3Yi1mZTI4LTQ1ZDctYjAxNS1m" + + "NzlkYTUwYjViMjEiLCJ0eXAiOiJKV1MifQ.eyJpc3MiOiJhdXRoMCIsInVzZXJuYW1lIjoia" + + "W1wYWxhIn0."; + headers.put("Authorization", "Bearer " + invalidJwtToken); + headers.put("X-Forwarded-For", "127.0.0.1"); + transport.setCustomHeaders(headers); + try { + openResp = client.OpenSession(openReq); + fail("Exception exception."); + } catch (Exception e) { + // Both JWT and OAuth have 3 successes and 1 failure each. + verifyAuthMetrics(3, 1, "jwt"); + verifyAuthMetrics(3, 1, "oauth"); assertEquals(e.getMessage(), "HTTP Response code: 401"); } } diff --git a/shell/ImpalaHttpClient.py b/shell/ImpalaHttpClient.py index 5d9c43a33..a1a19ab36 100644 --- a/shell/ImpalaHttpClient.py +++ b/shell/ImpalaHttpClient.py @@ -245,6 +245,12 @@ class ImpalaHttpClient(TTransportBase): self.__bearer_token = jwt self.__get_custom_headers_func = self.getCustomHeadersWithBearerAuth + # Set function to generate customized HTTP headers for OAuth authorization. + def setOAuthAuth(self, oauth): + # auth mechanism: Oauth + self.__bearer_token = oauth + self.__get_custom_headers_func = self.getCustomHeadersWithBearerAuth + # Set function to generate customized HTTP headers for Kerberos authorization. def setKerberosAuth(self, kerb_service): # auth mechanism: GSSAPI diff --git a/shell/impala_client.py b/shell/impala_client.py index 4c79a9dbf..6ed2f5f87 100755 --- a/shell/impala_client.py +++ b/shell/impala_client.py @@ -290,7 +290,7 @@ class ImpalaClient(object): verbose=True, use_http_base_transport=False, http_path=None, http_cookie_names=None, http_socket_timeout_s=None, value_converter=None, connect_max_tries=4, rpc_stdout=False, rpc_file=None, http_tracing=True, - jwt=None, hs2_x_forward=None): + jwt=None, oauth=None, hs2_x_forward=None): self.connected = False self.impalad_host = impalad[0] self.impalad_port = int(impalad[1]) @@ -314,6 +314,7 @@ class ImpalaClient(object): self.http_cookie_names = http_cookie_names self.http_tracing = http_tracing self.jwt = jwt + self.oauth = oauth # This is set from ImpalaShell's signal handler when a query is cancelled # from command line via CTRL+C. It is used to suppress error messages of # query cancellation. @@ -592,6 +593,8 @@ class ImpalaClient(object): transport.setLdapAuth(auth) elif self.jwt is not None: transport.setJwtAuth(self.jwt) + elif self.oauth is not None: + transport.setOAuthAuth(self.oauth) elif self.use_kerberos or self.kerberos_host_fqdn: # Set the Kerberos service if self.kerberos_host_fqdn is not None: @@ -1166,7 +1169,6 @@ class ImpalaHS2Client(ImpalaClient): self._request_num += 1 self._current_request_id = "{0}-{1}".format(self._base_request_id, self._request_num) - self._check_connected() num_tries = 1 max_tries = num_tries diff --git a/shell/impala_shell.py b/shell/impala_shell.py index f4adff452..821b90f5b 100755 --- a/shell/impala_shell.py +++ b/shell/impala_shell.py @@ -206,16 +206,20 @@ class ImpalaShell(cmd.Cmd, object): self.user = options.user self.ldap_password_cmd = options.ldap_password_cmd self.jwt_cmd = options.jwt_cmd + self.oauth_cmd = options.oauth_cmd self.strict_hs2_protocol = options.strict_hs2_protocol self.ldap_password = options.ldap_password self.use_jwt = options.use_jwt self.jwt = options.jwt + self.use_oauth = options.use_oauth + self.oauth = options.oauth # When running tests in strict mode, the server uses the ldap # protocol but can allow any password. if options.use_ldap_test_password: self.ldap_password = 'password' self.use_ldap = options.use_ldap or \ - (self.strict_hs2_protocol and not self.use_kerberos and not self.use_jwt) + (self.strict_hs2_protocol and not self.use_kerberos and not self.use_jwt + and not self.use_oauth) self.client_connect_timeout_ms = options.client_connect_timeout_ms self.http_socket_timeout_s = None if (options.http_socket_timeout_s != 'None' @@ -649,7 +653,8 @@ class ImpalaShell(cmd.Cmd, object): http_cookie_names=self.http_cookie_names, value_converter=value_converter, rpc_stdout=self.rpc_stdout, rpc_file=self.rpc_file, http_tracing=self.http_tracing, - jwt=self.jwt, hs2_x_forward=self.hs2_x_forward) + jwt=self.jwt, oauth=self.oauth, + hs2_x_forward=self.hs2_x_forward) if protocol == 'hs2': return ImpalaHS2Client(self.impalad, self.fetch_size, self.kerberos_host_fqdn, self.use_kerberos, self.kerberos_service_name, self.use_ssl, @@ -670,7 +675,7 @@ class ImpalaShell(cmd.Cmd, object): value_converter=value_converter, connect_max_tries=self.connect_max_tries, rpc_stdout=self.rpc_stdout, rpc_file=self.rpc_file, - http_tracing=self.http_tracing, jwt=self.jwt, + http_tracing=self.http_tracing, jwt=self.jwt, oauth=self.oauth, hs2_x_forward=self.hs2_x_forward) elif protocol == 'beeswax': return ImpalaBeeswaxClient(self.impalad, self.fetch_size, self.kerberos_host_fqdn, @@ -983,6 +988,9 @@ class ImpalaShell(cmd.Cmd, object): if self.use_jwt and self.jwt is None: self.jwt = getpass.getpass("Enter JWT: ") + if self.use_oauth and self.oauth is None: + self.oauth = getpass.getpass("Enter OAUTH: ") + if not args: args = socket.getfqdn() tokens = args.split(" ") # validate the connection string. @@ -1029,6 +1037,8 @@ class ImpalaShell(cmd.Cmd, object): self.ldap_password = None self.use_jwt = False self.jwt = None + self.use_oauth = False + self.oauth = None self.imp_client = self._new_impala_client() self._connect() except OSError: @@ -2015,6 +2025,10 @@ def get_intro(options): intro += ("\n\nJWT authentication is enabled, but the connection to Impala is " "not secured by TLS.\nALL JWTs WILL BE SENT IN THE CLEAR TO IMPALA.") + if not options.ssl and options.creds_ok_in_clear and options.use_oauth: + intro += ("\n\nOAUTH authentication is enabled, but the connection to Impala is " + "not secured by TLS.\nALL OAUTHs WILL BE SENT IN THE CLEAR TO IMPALA.") + if options.protocol == 'beeswax': intro += ("\n\nWARNING: The beeswax protocol is deprecated and will be removed in a " "future version of Impala.") @@ -2156,6 +2170,9 @@ def impala_shell_main(): if options.use_jwt: auth_method_count += 1 + if options.use_oauth: + auth_method_count += 1 + if auth_method_count > 1: print("Please specify at most one authentication mechanism (-k, -l, or -j)", file=sys.stderr) @@ -2191,6 +2208,25 @@ def impala_shell_main(): file=sys.stderr) raise FatalShellException() + if options.use_oauth and options.protocol.lower() != 'hs2-http': + print("Invalid protocol '{0}'. OAUTH authentication requires using the 'hs2-http' " + "protocol".format(options.protocol), file=sys.stderr) + raise FatalShellException() + + if options.use_oauth and options.strict_hs2_protocol: + print("OAUTH authentication is not supported when using strict hs2.", file=sys.stderr) + raise FatalShellException() + + if options.use_oauth and not options.ssl and not options.creds_ok_in_clear: + print("OAUTHs may not be sent over insecure connections. Enable SSL or " + "set --auth_creds_ok_in_clear", file=sys.stderr) + raise FatalShellException() + + if not options.use_oauth and options.oauth_cmd: + print("Option --oauth_cmd requires using OAUTH authentication mechanism (-a)", + file=sys.stderr) + raise FatalShellException() + if options.hs2_fp_format: try: _validate_hs2_fp_format_specification(options.hs2_fp_format) @@ -2230,6 +2266,10 @@ def impala_shell_main(): if options.verbose: ldap_msg = "with JWT-based authentication" print("{0} {1} {2}".format(start_msg, ldap_msg, py_version_msg), file=sys.stderr) + elif options.use_oauth: + if options.verbose: + ldap_msg = "with OAUTH-based authentication" + print("{0} {1} {2}".format(start_msg, ldap_msg, py_version_msg), file=sys.stderr) else: if options.verbose: no_auth_msg = "with no authentication" @@ -2243,6 +2283,10 @@ def impala_shell_main(): if options.use_jwt and options.jwt_cmd: options.jwt = read_password_cmd(options.jwt_cmd, "JWT", True) + options.oauth = None + if options.use_oauth and options.oauth_cmd: + options.oauth = read_password_cmd(options.oauth_cmd, "OAUTH", True) + if options.ssl: if options.ca_cert is None: if options.verbose: diff --git a/shell/option_parser.py b/shell/option_parser.py index 6f9c9669e..6c4e90837 100755 --- a/shell/option_parser.py +++ b/shell/option_parser.py @@ -232,6 +232,10 @@ def get_option_parser(defaults): action="store_true", help="Use JWT to authenticate with Impala. Impala must be configured" " to allow JWT authentication. \t\t") + parser.add_option("-a", "--oauth", dest="use_oauth", + action="store_true", + help="Use OAuth to authenticate with Impala. Impala must be" + "configured to allow Oauth authentication. \t\t") parser.add_option("-u", "--user", dest="user", help="User to authenticate with.") parser.add_option("--ssl", dest="ssl", @@ -273,6 +277,8 @@ def get_option_parser(defaults): help="Shell command to run to retrieve the LDAP password") parser.add_option("--jwt_cmd", dest="jwt_cmd", help="Shell command to run to retrieve the JWT") + parser.add_option("--oauth_cmd", dest="oauth_cmd", + help="Shell command to run to retrieve the Oauth Token") parser.add_option("--var", dest="keyval", action="append", help="Defines a variable to be used within the Impala session." " Can be used multiple times to set different variables." diff --git a/tests/custom_cluster/test_shell_oauth_auth.py b/tests/custom_cluster/test_shell_oauth_auth.py new file mode 100644 index 000000000..0027b3ffc --- /dev/null +++ b/tests/custom_cluster/test_shell_oauth_auth.py @@ -0,0 +1,210 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function +import os +import pytest +from tests.common.custom_cluster_test_suite import CustomClusterTestSuite +from tests.common.test_dimensions import create_client_protocol_http_transport +from tests.shell.util import run_impala_shell_cmd + + +class TestImpalaShellOAuthAuth(CustomClusterTestSuite): + """Tests the Impala shell OAuth authentication functionality by first standing up an + Impala cluster with specific startup flags to enable OAuth authentication support. + Then, the Impala shell is launched in a separate process with authentication done using + OAuth Tokens. Assertions are done by scanning the shell output and Impala server logs + for expected strings. + + These tests require a JWKS and three OAuth Token files to be present in the + 'testdata/jwt' directory. The 'testdata/bin/jwt-generate.sh' script can be run to set + up the necessary files. Since the JWKS/JWT files are committed to the git repo, this + script should not need to be executed again. + """ + + JWKS_JWTS_DIR = os.path.join(os.environ['IMPALA_HOME'], 'testdata', 'jwt') + JWKS_JSON_PATH = os.path.join(JWKS_JWTS_DIR, 'jwks_signing.json') + OAUTH_SIGNED_PATH = os.path.join(JWKS_JWTS_DIR, 'jwt_signed') + OAUTH_EXPIRED_PATH = os.path.join(JWKS_JWTS_DIR, 'jwt_expired') + OAUTH_INVALID_JWK = os.path.join(JWKS_JWTS_DIR, 'jwt_signed_untrusted') + + IMPALAD_ARGS = ("-v 2 -oauth_jwks_file_path={0} -oauth_jwt_custom_claim_username=sub " + "-oauth_token_auth=true -oauth_allow_without_tls=true " + .format(JWKS_JSON_PATH)) + + # Name of the Impala metric containing the total count of hs2-http connections opened. + HS2_HTTP_CONNS = "impala.thrift-server.hiveserver2-http-frontend.total-connections" + + @classmethod + def get_workload(self): + return 'functional-query' + + @classmethod + def add_test_dimensions(cls): + """Overrides all other add_dimension methods in super classes up the entire class + hierarchy ensuring that each test in this class run using the hs2-http protocol.""" + cls.ImpalaTestMatrix.add_dimension(create_client_protocol_http_transport()) + + @pytest.mark.execute_serially + @CustomClusterTestSuite.with_args( + impalad_args=IMPALAD_ARGS, + impala_log_dir="{oauth_auth_success}", + tmp_dir_placeholders=["oauth_auth_success"], + disable_log_buffering=True, + cluster_size=1) + def test_oauth_auth_valid(self, vector): + """Asserts the Impala shell can authenticate to Impala using OAuth authentication. + Also executes a query to ensure the authentication was successful.""" + before_rpc_count = self.__get_rpc_count() + + # Run a query and wait for it to complete. + args = ['--protocol', vector.get_value('protocol'), '-a', '--oauth_cmd', + 'cat {0}'.format(TestImpalaShellOAuthAuth.OAUTH_SIGNED_PATH), + '-q', 'select version()', '--auth_creds_ok_in_clear'] + result = run_impala_shell_cmd(vector, args) + self.cluster.get_first_impalad().service.wait_for_metric_value( + "impala-server.backend-num-queries-executed", 1, timeout=15) + + # Ensure the Impala coordinator is correctly reporting the oauth auth metrics + # must be done before the cluster shuts down since it calls to the coordinator + query_rpc_count = self.__get_rpc_count() - before_rpc_count + self.__assert_success_fail_metric(success_count=query_rpc_count) + + # Shut down cluster to ensure logs flush to disk. + self._stop_impala_cluster() + + # Ensure OAuth auth was enabled by checking the coordinator startup flags logged + # in the coordinator's INFO logfile + self.assert_impalad_log_contains("INFO", + '--oauth_jwks_file_path={0}'.format(self.JWKS_JSON_PATH), expected_count=1) + # Ensure OAuth auth was successful by checking impala coordinator logs + self.assert_impalad_log_contains("INFO", + 'effective username: test-user', expected_count=1) + self.assert_impalad_log_contains("INFO", + r'connected_user \(string\) = "test-user"', expected_count=1) + + # Ensure the query ran successfully. + assert "version()" in result.stdout + assert "impalad version" in result.stdout + + @pytest.mark.execute_serially + @CustomClusterTestSuite.with_args( + impalad_args=IMPALAD_ARGS, + impala_log_dir="{oauth_auth_fail}", + tmp_dir_placeholders=["oauth_auth_fail"], + disable_log_buffering=True, + cluster_size=1) + def test_oauth_auth_expired(self, vector): + """Asserts the Impala shell fails to authenticate when it presents an OAuth token + that has a valid signature but is expired.""" + before_rpc_count = self.__get_rpc_count() + + args = ['--protocol', vector.get_value('protocol'), '-a', '--oauth_cmd', + 'cat {0}'.format(TestImpalaShellOAuthAuth.OAUTH_EXPIRED_PATH), + '-q', 'select version()', '--auth_creds_ok_in_clear'] + result = run_impala_shell_cmd(vector, args, expect_success=False) + + # Ensure the Impala coordinator is correctly reporting the OAuth auth metrics + # must be done before the cluster shuts down since it calls to the coordinator + self.__wait_for_rpc_count(before_rpc_count + 1) + query_rpc_count = self.__get_rpc_count() - before_rpc_count + self.__assert_success_fail_metric(fail_count=query_rpc_count) + + # Shut down cluster to ensure logs flush to disk. + self._stop_impala_cluster() + + # Ensure OAuth auth was enabled by checking the coordinator startup flags logged + # in the coordinator's INFO logfile + expected_string = '--oauth_jwks_file_path={0}'.format(self.JWKS_JSON_PATH) + self.assert_impalad_log_contains("INFO", expected_string) + + # Ensure OAuth auth failed by checking impala coordinator logs + expected_string = ( + 'Error verifying OAuth token' + '.*' + 'Error verifying JWT Token: Verification failed, error: token expired' + ) + self.assert_impalad_log_contains("ERROR", expected_string, expected_count=-1) + + # Ensure the shell login failed. + assert "HttpError" in result.stderr + assert "HTTP code 401: Unauthorized" in result.stderr + assert "Not connected to Impala, could not execute queries." in result.stderr + + @pytest.mark.execute_serially + @CustomClusterTestSuite.with_args( + impalad_args=IMPALAD_ARGS, + impala_log_dir="{oauth_auth_invalid_jwk}", + tmp_dir_placeholders=["oauth_auth_invalid_jwk"], + disable_log_buffering=True, + cluster_size=1) + def test_oauth_auth_invalid_jwk(self, vector): + """Asserts the Impala shell fails to authenticate when it presents an OAuth token + that has a valid signature but is expired.""" + before_rpc_count = self.__get_rpc_count() + + args = ['--protocol', vector.get_value('protocol'), '-a', '--oauth_cmd', + 'cat {0}'.format(TestImpalaShellOAuthAuth.OAUTH_INVALID_JWK), + '-q', 'select version()', '--auth_creds_ok_in_clear'] + result = run_impala_shell_cmd(vector, args, expect_success=False) + + # Ensure the Impala coordinator is correctly reporting the OAuth auth metrics + # must be done before the cluster shuts down since it calls to the coordinator + self.__wait_for_rpc_count(before_rpc_count + 1) + query_rpc_count = self.__get_rpc_count() - before_rpc_count + self.__assert_success_fail_metric(fail_count=query_rpc_count) + + # Shut down cluster to ensure logs flush to disk. + self._stop_impala_cluster() + + # Ensure OAuth auth was enabled by checking the coordinator startup flags logged + # in the coordinator's INFO logfile + expected_string = '--oauth_jwks_file_path={0}'.format(self.JWKS_JSON_PATH) + self.assert_impalad_log_contains("INFO", expected_string) + + # Ensure OAuth auth failed by checking impala coordinator logs + expected_string = ( + 'Error verifying OAuth token' + '.*' + 'Error verifying JWT Token: Invalid JWK ID in the JWT token' + ) + # self.assert_impalad_log_contains("ERROR", expected_string, expected_count=-1) + + # Ensure the shell login failed. + assert "HttpError" in result.stderr + assert "HTTP code 401: Unauthorized" in result.stderr + assert "Not connected to Impala, could not execute queries." in result.stderr + + def __assert_success_fail_metric(self, success_count=0, fail_count=0): + """Impala emits metrics that count the number of successful and failed OAUth + authentications. This function asserts the OAuth auth success/fail counters from the + coordinator match the expected values.""" + actual = self.cluster.get_first_impalad().service.get_metric_values([ + "impala.thrift-server.hiveserver2-http-frontend.total-oauth-token-auth-success", + "impala.thrift-server.hiveserver2-http-frontend.total-oauth-token-auth-failure"]) + + assert actual[0] == success_count, "Expected OAuth auth success count to be '{}' " \ + "but was '{}'".format(success_count, actual[0]) + assert actual[1] == fail_count, "Expected OAuth auth failure count to be '{}' but " \ + "was '{}'".format(fail_count, actual[1]) + + def __get_rpc_count(self): + return self.cluster.get_first_impalad().service.get_metric_value(self.HS2_HTTP_CONNS) + + def __wait_for_rpc_count(self, expected_count): + self.cluster.get_first_impalad().service.wait_for_metric_value(self.HS2_HTTP_CONNS, + expected_count, allow_greater=True)