Files
android/libwvdrmengine/cdm/core/test/http_socket.cpp
Alex Dale ee995d5fae Replacing NULL with nullptr in core/
[ Merge of http://go/wvgerrit/84647 ]
[ Merge of http://go/wvgerrit/84648 ]

Replacing most instances of C's NULL with C++'s nullptr.  Also changed
how a NULL check is performed on smart pointers.  They provided an
implicit boolean operator for null checks, meaning the underlying
pointer does not need to be compared directly (as it was in some places
before).

Note that clang-format has performed additional changes to some of the
test files that have not yet been formatted.

Bug: 120602075
Test: Linux and Android unittests
Change-Id: I06ddebe34b0ea6dfecedb5527e7e808e32f5269a
2019-08-19 14:18:25 -07:00

474 lines
12 KiB
C++

// Copyright 2018 Google LLC. All Rights Reserved. This file and proprietary
// source code may only be used and distributed under the Widevine Master
// License Agreement.
#include "http_socket.h"
#include <cstring>
#include <errno.h>
#include <fcntl.h>
#include <stdlib.h>
#include <string.h>
#ifdef _WIN32
# include "winsock2.h"
# include "ws2tcpip.h"
# define ERROR_ASYNC_COMPLETE WSAEWOULDBLOCK
#else
# include <netdb.h>
# include <netinet/in.h>
# include <sys/socket.h>
# include <unistd.h>
# define ERROR_ASYNC_COMPLETE EINPROGRESS
#endif
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/x509.h>
#include "log.h"
#include "platform.h"
namespace wvcdm {
namespace {
// Helper function to tokenize a string. This makes it easier to avoid silly
// parsing bugs that creep in easily when each part of the string is parsed
// with its own piece of code.
bool Tokenize(const std::string& source, const std::string& delim,
const size_t offset, std::string* substring_output,
size_t* next_offset) {
size_t start_of_delim = source.find(delim, offset);
if (start_of_delim == std::string::npos) {
return false;
}
substring_output->assign(source, offset, start_of_delim - offset);
*next_offset = start_of_delim + delim.size();
return true;
}
SSL_CTX* InitSslContext() {
OpenSSL_add_all_algorithms();
SSL_load_error_strings();
const SSL_METHOD* method = TLS_client_method();
SSL_CTX* ctx = SSL_CTX_new(method);
if (!ctx) LOGE("failed to create SSL context");
int ret = SSL_CTX_set_cipher_list(
ctx, "ALL:!RC4-MD5:!RC4-SHA:!ECDHE-ECDSA-RC4-SHA:!ECDHE-RSA-RC4-SHA");
if (0 == ret) LOGE("error disabling vulnerable ciphers");
return ctx;
}
static int LogBoringSslError(
const char* message, size_t length, void* /* user_data */) {
LOGE(" BoringSSL Error: %s", message);
return length;
}
bool IsRetryableSslError(int ssl_error) {
return ssl_error != SSL_ERROR_ZERO_RETURN &&
ssl_error != SSL_ERROR_SYSCALL &&
ssl_error != SSL_ERROR_SSL;
}
#if 0
// unused, may be useful for debugging SSL-related issues.
void ShowServerCertificate(const SSL* ssl) {
// gets the server certificate
X509* cert = SSL_get_peer_certificate(ssl);
if (cert) {
char* line = X509_NAME_oneline(X509_get_subject_name(cert), 0, 0);
LOGV("server certificate:");
LOGV("subject: %s", line);
free(line);
line = X509_NAME_oneline(X509_get_issuer_name(cert), 0, 0);
LOGV("issuer: %s", line);
free(line);
X509_free(cert);
} else {
LOGE("Failed to get server certificate");
}
}
#endif
// Wait for a socket to be ready for reading or writing.
// Establishing a connection counts as "ready for write".
// Returns false on select error or timeout.
// Returns true when the socket is ready.
bool SocketWait(int fd, bool for_read, int timeout_in_ms) {
fd_set fds;
FD_ZERO(&fds);
FD_SET(fd, &fds);
struct timeval tv;
tv.tv_sec = timeout_in_ms / 1000;
tv.tv_usec = (timeout_in_ms % 1000) * 1000;
fd_set* read_fds = nullptr;
fd_set* write_fds = nullptr;
if (for_read) {
read_fds = &fds;
} else {
write_fds = &fds;
}
int ret = select(fd + 1, read_fds, write_fds, nullptr, &tv);
if (ret == 0) {
LOGE("socket timed out");
return false;
} else if (ret == -1) {
LOGE("select failed, errno = %d", errno);
return false;
}
// socket ready.
return true;
}
int GetError() {
#ifdef _WIN32
return WSAGetLastError();
#else
return errno;
#endif
}
void ClearError() {
#ifdef _WIN32
WSASetLastError(0);
#else
errno = 0;
#endif
}
const char* GetErrorString() {
#ifdef _WIN32
static char buffer[2048];
const int flags = FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS;
const int code = WSAGetLastError();
if (!FormatMessage(flags, nullptr, code, 0, buffer, sizeof(buffer), nullptr))
return "Unknown error";
return buffer;
#else
return strerror(errno);
#endif
}
} // namespace
// Parses the URL and extracts all relevant information.
// static
bool HttpSocket::ParseUrl(const std::string& url, std::string* scheme,
bool* secure_connect, std::string* domain_name,
std::string* port, std::string* path) {
size_t offset = 0;
if (!Tokenize(url, "://", offset, scheme, &offset)) {
LOGE("Invalid URL, scheme not found: %s", url.c_str());
return false;
}
// If the scheme is http or https, set secure_connect and port accordingly.
// Otherwise, consider the scheme unsupported and fail.
if (*scheme == "http") {
*secure_connect = false;
port->assign("80");
} else if (*scheme == "https") {
*secure_connect = true;
port->assign("443");
} else {
LOGE("Invalid URL, scheme not supported: %s", url.c_str());
return false;
}
if (!Tokenize(url, "/", offset, domain_name, &offset)) {
// The rest of the URL belongs to the domain name.
domain_name->assign(url, offset, std::string::npos);
// No explicit path after the domain name.
path->assign("/");
} else {
// The rest of the URL, including the preceding slash, belongs to the path.
path->assign(url, offset - 1, std::string::npos);
}
// The domain name may optionally contain a port which overrides the default.
std::string domain_name_without_port;
size_t port_offset;
if (Tokenize(*domain_name, ":", 0, &domain_name_without_port, &port_offset)) {
port->assign(domain_name->c_str() + port_offset);
int port_num = atoi(port->c_str());
if (port_num <= 0 || port_num >= 65536) {
LOGE("Invalid URL, port not valid: %s", url.c_str());
return false;
}
domain_name->assign(domain_name_without_port);
}
return true;
}
HttpSocket::HttpSocket(const std::string& url)
: socket_fd_(-1), ssl_(nullptr), ssl_ctx_(nullptr) {
valid_url_ = ParseUrl(url, &scheme_, &secure_connect_, &domain_name_, &port_,
&resource_path_);
SSL_library_init();
}
HttpSocket::~HttpSocket() { CloseSocket(); }
void HttpSocket::CloseSocket() {
if (socket_fd_ != -1) {
#ifdef _WIN32
closesocket(socket_fd_);
#else
close(socket_fd_);
#endif
socket_fd_ = -1;
}
if (ssl_) {
SSL_free(ssl_);
ssl_ = nullptr;
}
if (ssl_ctx_) {
SSL_CTX_free(ssl_ctx_);
ssl_ctx_ = nullptr;
}
}
bool HttpSocket::Connect(int timeout_in_ms) {
if (!valid_url_) {
return false;
}
#ifdef _WIN32
static bool initialized = false;
if (!initialized) {
WSADATA ignored_data;
int err = WSAStartup(MAKEWORD(2, 2), &ignored_data);
if (err != 0) {
LOGE("Error in WSAStartup: %d", err);
return false;
}
initialized = true;
}
#endif
// lookup the server IP
struct addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_NUMERICSERV | AI_ADDRCONFIG;
struct addrinfo* addr_info = nullptr;
int ret = getaddrinfo(domain_name_.c_str(), port_.c_str(), &hints,
&addr_info);
if (ret != 0) {
LOGE("getaddrinfo failed, errno = %d", ret);
return false;
}
// get a socket
socket_fd_ = socket(addr_info->ai_family, addr_info->ai_socktype,
addr_info->ai_protocol);
if (socket_fd_ < 0) {
LOGE("cannot open socket, errno = %d", GetError());
return false;
}
// set the socket in non-blocking mode
#ifdef _WIN32
u_long mode = 1; // Non-blocking mode.
if (ioctlsocket(socket_fd_, FIONBIO, &mode) != 0) {
LOGE("ioctlsocket error, wsa error = %d", WSAGetLastError());
CloseSocket();
return false;
}
#else
int original_flags = fcntl(socket_fd_, F_GETFL, 0);
if (original_flags == -1) {
LOGE("fcntl error, errno = %d", errno);
CloseSocket();
return false;
}
if (fcntl(socket_fd_, F_SETFL, original_flags | O_NONBLOCK) == -1) {
LOGE("fcntl error, errno = %d", errno);
CloseSocket();
return false;
}
#endif
// connect to the server
ret = connect(socket_fd_, addr_info->ai_addr, addr_info->ai_addrlen);
freeaddrinfo(addr_info);
if (ret == 0) {
// connected right away.
} else {
if (GetError() != ERROR_ASYNC_COMPLETE) {
// failed right away.
LOGE("cannot connect to %s, errno = %d", domain_name_.c_str(), GetError());
CloseSocket();
return false;
} else {
// in progress. block until timeout expired or connection established.
if (!SocketWait(socket_fd_, /* for_read */ false, timeout_in_ms)) {
LOGE("cannot connect to %s", domain_name_.c_str());
CloseSocket();
return false;
}
}
}
// set up SSL if needed
if (secure_connect_) {
ssl_ctx_ = InitSslContext();
if (!ssl_ctx_) {
CloseSocket();
return false;
}
ssl_ = SSL_new(ssl_ctx_);
if (!ssl_) {
LOGE("failed SSL_new");
CloseSocket();
return false;
}
BIO* a_bio = BIO_new_socket(socket_fd_, BIO_NOCLOSE);
if (!a_bio) {
LOGE("BIO_new_socket error");
CloseSocket();
return false;
}
SSL_set_bio(ssl_, a_bio, a_bio);
do {
ret = SSL_connect(ssl_);
if (ret != 1) {
int ssl_err = SSL_get_error(ssl_, ret);
if (ssl_err != SSL_ERROR_WANT_READ && ssl_err != SSL_ERROR_WANT_WRITE) {
char buf[256];
LOGE("SSL_connect error: %s", ERR_error_string(ERR_get_error(), buf));
CloseSocket();
return false;
}
bool for_read = ssl_err == SSL_ERROR_WANT_READ;
if (!SocketWait(socket_fd_, for_read, timeout_in_ms)) {
LOGE("cannot connect to %s", domain_name_.c_str());
CloseSocket();
return false;
}
}
} while (ret != 1);
}
return true;
}
// Returns -1 for error, number of bytes read for success.
// The timeout here only applies to the span between packets of data, for the
// sake of simplicity.
int HttpSocket::Read(char* data, int len, int timeout_in_ms) {
int total_read = 0;
int to_read = len;
if (socket_fd_ == -1) {
LOGE("Socket to %s not open. Cannot read.", domain_name_.c_str());
return -1;
}
while (to_read > 0) {
if (!SocketWait(socket_fd_, /* for_read */ true, timeout_in_ms)) {
LOGE("unable to read from %s", domain_name_.c_str());
return -1;
}
ClearError(); // Reset errors, as we will depend on its value shortly.
int read;
if (secure_connect_) {
read = SSL_read(ssl_, data, to_read);
} else {
read = recv(socket_fd_, data, to_read, 0);
}
if (read > 0) {
to_read -= read;
data += read;
total_read += read;
} else if (secure_connect_) {
// Secure read error
int ssl_error = SSL_get_error(ssl_, read);
if (ssl_error == SSL_ERROR_ZERO_RETURN ||
(ssl_error == SSL_ERROR_SYSCALL && GetError() == 0)) {
// The connection has been closed. No more data.
break;
} else if (IsRetryableSslError(ssl_error)) {
sleep(1);
// After sleeping, fall through to iterate the loop again and retry.
} else {
// Unrecoverable error. Log and abort.
LOGE("SSL_read returned %d, LibSSL Error = %d", read, ssl_error);
if (ssl_error == SSL_ERROR_SYSCALL) {
LOGE(" errno = %d = %s", GetError(), GetErrorString());
}
ERR_print_errors_cb(LogBoringSslError, nullptr);
return -1;
}
} else {
// Non-secure read error
if (read == 0) {
// The connection has been closed. No more data.
break;
} else {
// Log the error received
LOGE("recv returned %d, errno = %d = %s", read, GetError(),
GetErrorString());
return -1;
}
}
}
return total_read;
}
// Returns -1 for error, number of bytes written for success.
// The timeout here only applies to the span between packets of data, for the
// sake of simplicity.
int HttpSocket::Write(const char* data, int len, int timeout_in_ms) {
int total_sent = 0;
int to_send = len;
if (socket_fd_ == -1) {
LOGE("Socket to %s not open. Cannot write.", domain_name_.c_str());
return -1;
}
while (to_send > 0) {
int sent;
if (secure_connect_)
sent = SSL_write(ssl_, data, to_send);
else
sent = send(socket_fd_, data, to_send, 0);
if (sent > 0) {
to_send -= sent;
data += sent;
total_sent += sent;
} else if (sent == 0) {
// We filled up the pipe. Wait for room to write.
if (!SocketWait(socket_fd_, /* for_read */ false, timeout_in_ms)) {
LOGE("unable to write to %s", domain_name_.c_str());
return -1;
}
} else {
LOGE("send returned %d, errno = %d", sent, GetError());
return -1;
}
}
return total_sent;
}
} // namespace wvcdm