// Copyright 2018 Google LLC. All Rights Reserved. This file and proprietary // source code may only be used and distributed under the Widevine License // Agreement. #include "http_socket.h" #include #include #include #include #include #include #include #ifdef _WIN32 # include "winsock2.h" # include "ws2tcpip.h" # define ERROR_ASYNC_COMPLETE WSAEWOULDBLOCK #else # include # include # include # include # define ERROR_ASYNC_COMPLETE EINPROGRESS #endif #include #include #include #include "log.h" #include "platform.h" namespace wvcdm { namespace { // Number of attempts to identify an Internet host and a service should the // host's nameserver be temporarily unavailable. See getaddrinfo(3) for // more info. constexpr size_t kMaxNameserverAttempts = 2; // Helper function to tokenize a string. This makes it easier to avoid silly // parsing bugs that creep in easily when each part of the string is parsed // with its own piece of code. bool Tokenize(const std::string& source, const std::string& delim, const size_t offset, std::string* substring_output, size_t* next_offset) { size_t start_of_delim = source.find(delim, offset); if (start_of_delim == std::string::npos) { return false; } substring_output->assign(source, offset, start_of_delim - offset); *next_offset = start_of_delim + delim.size(); return true; } SSL_CTX* InitSslContext() { OpenSSL_add_all_algorithms(); SSL_load_error_strings(); const SSL_METHOD* method = TLS_client_method(); SSL_CTX* ctx = SSL_CTX_new(method); if (!ctx) LOGE("failed to create SSL context"); int ret = SSL_CTX_set_cipher_list( ctx, "ALL:!RC4-MD5:!RC4-SHA:!ECDHE-ECDSA-RC4-SHA:!ECDHE-RSA-RC4-SHA"); if (0 == ret) LOGE("error disabling vulnerable ciphers"); return ctx; } int LogBoringSslError(const char* message, size_t /* length */, void* /* user_data */) { LOGE(" BoringSSL Error: %s", message); return 1; } bool IsRetryableSslError(int ssl_error) { return ssl_error != SSL_ERROR_ZERO_RETURN && ssl_error != SSL_ERROR_SYSCALL && ssl_error != SSL_ERROR_SSL; } // Ensures that the SSL library is only initialized once. void InitSslLibrary() { static bool ssl_initialized = false; static std::mutex ssl_init_mutex; std::lock_guard guard(ssl_init_mutex); if (!ssl_initialized) { SSL_library_init(); ssl_initialized = true; } } #if 0 // unused, may be useful for debugging SSL-related issues. void ShowServerCertificate(const SSL* ssl) { // gets the server certificate X509* cert = SSL_get_peer_certificate(ssl); if (cert) { char* line = X509_NAME_oneline(X509_get_subject_name(cert), 0, 0); LOGV("server certificate:"); LOGV("subject: %s", line); free(line); line = X509_NAME_oneline(X509_get_issuer_name(cert), 0, 0); LOGV("issuer: %s", line); free(line); X509_free(cert); } else { LOGE("Failed to get server certificate"); } } #endif int GetError() { #ifdef _WIN32 return WSAGetLastError(); #else return errno; #endif } void ClearError() { #ifdef _WIN32 WSASetLastError(0); #else errno = 0; #endif } const char* GetErrorString() { #ifdef _WIN32 static char buffer[2048]; const int flags = FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS; const int code = WSAGetLastError(); if (!FormatMessage(flags, nullptr, code, 0, buffer, sizeof(buffer), nullptr)) return "Unknown error"; return buffer; #else return strerror(errno); #endif } // Formats the provided time point to ISO 8601 format with space // date-time separator (Google's recommended format). std::string FormatTimePoint(const HttpSocket::TimePoint& time_point) { const std::time_t epoch_time = std::chrono::system_clock::to_time_t(time_point); struct tm time_parts = {}; if (::gmtime_r(&epoch_time, &time_parts) == nullptr) { const int saved_errno = GetError(); if (saved_errno == EOVERFLOW) { LOGE("Overflow when converting to time parts: epoch_time = %zu", static_cast(epoch_time)); } else { LOGE( "Failed to convert time point to time parts: " "epoch_time = %zu, errno = %d", static_cast(epoch_time), saved_errno); } // Just convert to epoch seconds. return std::to_string(epoch_time); } static constexpr size_t kMaxLength = 127; static constexpr char kTimeFormat[] = "%F %T"; char time_buffer[kMaxLength + 1]; const size_t res = ::strftime(time_buffer, kMaxLength, kTimeFormat, &time_parts); if (res == 0) { LOGE("Failed to format time"); return std::to_string(epoch_time); } if (res > kMaxLength) { // Very unlikely situation, but cannot trust the contents of // |buffer| in this case. LOGE("Unexpected output from strftime: max = %zu, res = %zu", kMaxLength, res); return std::to_string(epoch_time); } return std::string(time_buffer, &time_buffer[res]); } // Formats the provided duration to Google style duration format, // with microsecond accuracy. // The template parameter D should be a std::chrono::duration // type. This is template to support C++ system_clock which // which duration accuracy may vary by platform. template std::string FormatDuration(const D& duration) { D working_duration = duration; std::string res; // If duration is negative, add a '-' and continue with absolute. if (working_duration < D::zero()) { res.push_back('-'); working_duration = -working_duration; } // Format hours (if non-zero). using Hours = std::chrono::hours; const Hours h = std::chrono::floor(working_duration); if (h != Hours::zero()) { res.append(std::to_string(h.count())); res.push_back('h'); working_duration -= h; } // Format minutes (if non-zero). using Minutes = std::chrono::minutes; const Minutes m = std::chrono::floor(working_duration); if (m != Minutes::zero()) { res.append(std::to_string(m.count())); res.push_back('m'); working_duration -= m; } // Format seconds (if non-zero). using Seconds = std::chrono::seconds; const Seconds s = std::chrono::floor(working_duration); if (s != Seconds::zero()) { res.append(std::to_string(s.count())); res.push_back('s'); working_duration -= s; } // Format microseconds (if non-zero). using Microseconds = std::chrono::microseconds; const Microseconds us = std::chrono::floor(working_duration); if (us != Microseconds::zero()) { res.append(std::to_string(us.count())); res.append("us"); } return res; } } // namespace // Parses the URL and extracts all relevant information. // static bool HttpSocket::ParseUrl(const std::string& url, std::string* scheme, bool* secure_connect, std::string* domain_name, std::string* port, std::string* path) { size_t offset = 0; if (!Tokenize(url, "://", offset, scheme, &offset)) { LOGE("Invalid URL, scheme not found: %s", url.c_str()); return false; } // If the scheme is http or https, set secure_connect and port accordingly. // Otherwise, consider the scheme unsupported and fail. if (*scheme == "http") { *secure_connect = false; port->assign("80"); } else if (*scheme == "https") { *secure_connect = true; port->assign("443"); } else { LOGE("Invalid URL, scheme not supported: %s", url.c_str()); return false; } // Strip off the domain name and port. In the url it will be terminated by // either a splash or a question mark: // like this example.com?key=value // or this example.com/path/to/resource if (!Tokenize(url, "/", offset, domain_name, &offset)) { if (Tokenize(url, "?", offset, domain_name, &offset)) { // url had no '/', but it did have '?'. Use the default path but // keep the extra parameters. i.e. turn '?extra' into '/?extra'. path->assign("/"); path->append(url, offset - 1, std::string::npos); } else { // url had no '/' or '?'. // The rest of the URL belongs to the domain name. domain_name->assign(url, offset, std::string::npos); // Use the default path. path->assign("/"); } } else { // url had a '/'. // The rest of the URL, including the preceding slash, belongs to the path. path->assign(url, offset - 1, std::string::npos); } // The domain name may optionally contain a port which overrides the default. std::string domain_name_without_port; size_t port_offset; if (Tokenize(*domain_name, ":", 0, &domain_name_without_port, &port_offset)) { port->assign(domain_name->c_str() + port_offset); int port_num = atoi(port->c_str()); if (port_num <= 0 || port_num >= 65536) { LOGE("Invalid URL, port not valid: %s", url.c_str()); return false; } domain_name->assign(domain_name_without_port); } return true; } HttpSocket::HttpSocket(const std::string& url) : url_(url) { valid_url_ = ParseUrl(url, &scheme_, &secure_connect_, &domain_name_, &port_, &resource_path_); InitSslLibrary(); } HttpSocket::~HttpSocket() { CloseSocket(); } void HttpSocket::CloseSocket() { if (socket_fd_ != kClosedFd) { #ifdef _WIN32 closesocket(socket_fd_); #else close(socket_fd_); #endif socket_fd_ = kClosedFd; } if (ssl_) { SSL_free(ssl_); ssl_ = nullptr; } if (ssl_ctx_) { SSL_CTX_free(ssl_ctx_); ssl_ctx_ = nullptr; } } bool HttpSocket::ConnectAndLogErrors(int timeout_in_ms) { const TimePoint start = GetNowTimePoint(); const bool result = Connect(timeout_in_ms); const TimePoint end = GetNowTimePoint(); if (!result) LogTime("Socket connect error", start, end); return result; } bool HttpSocket::Connect(int timeout_in_ms) { if (!valid_url_) { LOGE("URL is invalid"); return false; } if (socket_fd_ != kClosedFd) { LOGE("Socket already connected"); return false; } #ifdef _WIN32 static bool initialized = false; if (!initialized) { WSADATA ignored_data; int err = WSAStartup(MAKEWORD(2, 2), &ignored_data); if (err != 0) { LOGE("Error in WSAStartup: %d", err); return false; } initialized = true; } #endif // lookup the server IP struct addrinfo hints; memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_UNSPEC; hints.ai_socktype = SOCK_STREAM; hints.ai_flags = AI_NUMERICSERV | AI_ADDRCONFIG; struct addrinfo* addr_info = nullptr; int ret = EAI_AGAIN; for (size_t attempt = 1; attempt <= kMaxNameserverAttempts && ret == EAI_AGAIN; ++attempt) { if (attempt > 1) { LOGW( "Nameserver is temporarily unavailable, waiting to try again: " "attempt = %zu", attempt); sleep(1); } ret = getaddrinfo(domain_name_.c_str(), port_.c_str(), &hints, &addr_info); } if (ret != 0) { if (ret == EAI_SYSTEM) { // EAI_SYSTEM implies an underlying system issue. Error is // specified by |errno|. LOGE("getaddrinfo %s (port %s) failed due to system error: errno = %d", domain_name_.c_str(), port_.c_str(), GetError()); } else { // Error is specified by return value. LOGE("getaddrinfo %s (port %s) failed: ret = %d", domain_name_.c_str(), port_.c_str(), ret); } return false; } // Open a socket. socket_fd_ = socket(addr_info->ai_family, addr_info->ai_socktype, addr_info->ai_protocol); if (socket_fd_ < 0) { LOGE("Cannot open socket %s (port %s): errno = %d", domain_name_.c_str(), port_.c_str(), GetError()); freeaddrinfo(addr_info); return false; } // Set the socket in non-blocking mode. #ifdef _WIN32 u_long mode = 1; // Non-blocking mode. if (ioctlsocket(socket_fd_, FIONBIO, &mode) != 0) { LOGE("ioctlsocket error %s (port %s), wsa error = %d", domain_name_.c_str(), port_.c_str(), WSAGetLastError()); freeaddrinfo(addr_info); CloseSocket(); return false; } #else const int original_flags = fcntl(socket_fd_, F_GETFL, 0); if (original_flags == -1) { LOGE("fcntl error %s (port %s), errno = %d", domain_name_.c_str(), port_.c_str(), errno); freeaddrinfo(addr_info); CloseSocket(); return false; } if (fcntl(socket_fd_, F_SETFL, original_flags | O_NONBLOCK) == -1) { LOGE("fcntl error %s (port %s), errno = %d", domain_name_.c_str(), port_.c_str(), errno); freeaddrinfo(addr_info); CloseSocket(); return false; } #endif // connect to the server ret = connect(socket_fd_, addr_info->ai_addr, addr_info->ai_addrlen); freeaddrinfo(addr_info); addr_info = nullptr; if (ret == 0) { // Connected right away. } else { if (GetError() != ERROR_ASYNC_COMPLETE) { // failed right away. LOGE("cannot connect to %s, errno = %d", url().c_str(), GetError()); CloseSocket(); return false; } else { // in progress. block until timeout expired or connection established. if (!Wait(/* for_read */ false, timeout_in_ms)) { LOGE("cannot connect to %s", url().c_str()); CloseSocket(); return false; } } } // set up SSL if needed if (secure_connect_) { ssl_ctx_ = InitSslContext(); if (!ssl_ctx_) { CloseSocket(); return false; } ssl_ = SSL_new(ssl_ctx_); if (!ssl_) { LOGE("failed SSL_new"); CloseSocket(); return false; } // |BIO_NOCLOSE| prevents closing the socket from being closed when // the BIO is freed. BIO* a_bio = BIO_new_socket(socket_fd_, BIO_NOCLOSE); if (!a_bio) { LOGE("BIO_new_socket error"); CloseSocket(); return false; } SSL_set_bio(ssl_, a_bio, a_bio); do { ret = SSL_connect(ssl_); if (ret != 1) { int ssl_err = SSL_get_error(ssl_, ret); if (ssl_err != SSL_ERROR_WANT_READ && ssl_err != SSL_ERROR_WANT_WRITE) { char buf[256]; LOGE("SSL_connect error: %s", ERR_error_string(ERR_get_error(), buf)); CloseSocket(); return false; } const bool for_read = (ssl_err == SSL_ERROR_WANT_READ); if (!Wait(for_read, timeout_in_ms)) { LOGE("Cannot connect securely to %s", domain_name_.c_str()); CloseSocket(); return false; } } } while (ret != 1); } return true; } // Returns -1 for error, number of bytes read for success. // The timeout here only applies to the span between packets of data, for the // sake of simplicity. int HttpSocket::ReadAndLogErrors(char* data, int len, int timeout_in_ms) { const TimePoint start = GetNowTimePoint(); const int result = Read(data, len, timeout_in_ms); const TimePoint end = GetNowTimePoint(); if (result < 0) LogTime("Read error", start, end); return result; } int HttpSocket::Read(char* data, int len, int timeout_in_ms) { int total_read = 0; int to_read = len; if (socket_fd_ == kClosedFd) { LOGE("Socket to %s not open. Cannot read.", domain_name_.c_str()); return -1; } while (to_read > 0) { if (!Wait(/* for_read */ true, timeout_in_ms)) { LOGE("unable to read from %s. len=%d, to_read=%d", domain_name_.c_str(), len, to_read); return -1; } ClearError(); // Reset errors, as we will depend on its value shortly. int read; if (secure_connect_) { read = SSL_read(ssl_, data, to_read); } else { read = static_cast(recv(socket_fd_, data, to_read, 0)); } if (read > 0) { to_read -= read; data += read; total_read += read; } else if (secure_connect_) { // Secure read error int ssl_error = SSL_get_error(ssl_, read); if (ssl_error == SSL_ERROR_ZERO_RETURN || (ssl_error == SSL_ERROR_SYSCALL && GetError() == 0)) { // The connection has been closed. No more data. break; } else if (IsRetryableSslError(ssl_error)) { sleep(1); // After sleeping, fall through to iterate the loop again and retry. } else { // Unrecoverable error. Log and abort. LOGE("SSL_read returned %d, LibSSL Error = %d", read, ssl_error); if (ssl_error == SSL_ERROR_SYSCALL) { LOGE(" errno = %d = %s", GetError(), GetErrorString()); } ERR_print_errors_cb(LogBoringSslError, nullptr); return -1; } } else { // Non-secure read error if (read == 0) { // The connection has been closed. No more data. break; } else { // Log the error received LOGE("recv returned %d, errno = %d = %s", read, GetError(), GetErrorString()); return -1; } } } return total_read; } // Returns -1 for error, number of bytes written for success. // The timeout here only applies to the span between packets of data, for the // sake of simplicity. int HttpSocket::WriteAndLogErrors(const char* data, int len, int timeout_in_ms) { const TimePoint start = GetNowTimePoint(); const int result = Write(data, len, timeout_in_ms); const TimePoint end = GetNowTimePoint(); if (result < 0) LogTime("Write error", start, end); return result; } int HttpSocket::Write(const char* data, int len, int timeout_in_ms) { int total_sent = 0; int to_send = len; if (socket_fd_ == kClosedFd) { LOGE("Socket to %s not open. Cannot write.", domain_name_.c_str()); return -1; } while (to_send > 0) { int sent; if (secure_connect_) { sent = SSL_write(ssl_, data, to_send); } else { sent = static_cast(send(socket_fd_, data, to_send, 0)); } if (sent > 0) { to_send -= sent; data += sent; total_sent += sent; } else if (sent == 0) { // We filled up the pipe. Wait for room to write. if (!Wait(/* for_read */ false, timeout_in_ms)) { LOGE("unable to write to %s", domain_name_.c_str()); return -1; } } else { LOGE("send returned %d, errno = %d", sent, GetError()); return -1; } } return total_sent; } bool HttpSocket::Wait(bool for_read, int timeout_in_ms) { fd_set fds; FD_ZERO(&fds); FD_SET(socket_fd_, &fds); struct timeval tv; tv.tv_sec = timeout_in_ms / 1000; tv.tv_usec = (timeout_in_ms % 1000) * 1000; fd_set* read_fds = nullptr; fd_set* write_fds = nullptr; if (for_read) { read_fds = &fds; } else { write_fds = &fds; } const TimePoint start = GetNowTimePoint(); const int ret = select(socket_fd_ + 1, read_fds, write_fds, nullptr, &tv); const TimePoint end = GetNowTimePoint(); if (ret == 0) { LogTime("Socket select timeout", start, end); // TODO(b/186031735): Remove this when the bug is fixed. LOGE("Timeout = %0.3f. Consider adding a comment to http://b/186031735", 0.001 * timeout_in_ms); return false; } else if (ret == -1) { LOGE("select failed, errno = %d", errno); return false; } // socket ready. return true; } void HttpSocket::LogTime(const char* note, const TimePoint& start_time, const TimePoint& end_time) const { const std::string start_string = FormatTimePoint(start_time); const std::string create_start_diff_string = FormatDuration(start_time - create_time_); const std::string end_string = FormatTimePoint(end_time); const std::string start_end_diff_string = FormatDuration(end_time - start_time); LOGE("%s: start = %s = create + %s, end = %s = start + %s", note ? note : "", start_string.c_str(), create_start_diff_string.c_str(), end_string.c_str(), start_end_diff_string.c_str()); } } // namespace wvcdm