From e51f869190d4db0f6f3f67ea32a5b8b76045906d Mon Sep 17 00:00:00 2001 From: Alex Dale Date: Fri, 12 Mar 2021 19:27:51 -0800 Subject: [PATCH] Base64 encoding for string input. [ Merge of http://go/wvgerrit/119805 ] This change adds 3 new functions for encoding binary data from a C++ string to a base64 encoded ASCII string. The CDM and protobuf generated code use C++ strings to store binary data. These binary strings are commonly converted into a base64 encoded ASCII string for logging and for returning to the app. This change also cleans up some of the internal components of the string_conversions library to use several standard library C++11 method. Bug: 181732604 Test: CE CDM unittests Change-Id: I547568c6402e011344260f2df2a06e972122ab8a --- libwvdrmengine/cdm/core/src/cdm_session.cpp | 2 +- .../cdm/util/include/string_conversions.h | 35 +- .../cdm/util/src/string_conversions.cpp | 390 +++++++++--------- libwvdrmengine/cdm/util/test/base64_test.cpp | 98 ++++- 4 files changed, 312 insertions(+), 213 deletions(-) diff --git a/libwvdrmengine/cdm/core/src/cdm_session.cpp b/libwvdrmengine/cdm/core/src/cdm_session.cpp index aeb11a63..8933d717 100644 --- a/libwvdrmengine/cdm/core/src/cdm_session.cpp +++ b/libwvdrmengine/cdm/core/src/cdm_session.cpp @@ -954,7 +954,7 @@ int64_t CdmSession::GetDurationRemaining() { CdmSessionId CdmSession::GenerateSessionId() { static int session_num = 1; - return SESSION_ID_PREFIX + IntToString(++session_num); + return SESSION_ID_PREFIX + std::to_string(++session_num); } bool CdmSession::GenerateKeySetId(bool atsc_mode_enabled, diff --git a/libwvdrmengine/cdm/util/include/string_conversions.h b/libwvdrmengine/cdm/util/include/string_conversions.h index feff9dca..c1606527 100644 --- a/libwvdrmengine/cdm/util/include/string_conversions.h +++ b/libwvdrmengine/cdm/util/include/string_conversions.h @@ -1,7 +1,6 @@ // Copyright 2018 Google LLC. All Rights Reserved. This file and proprietary // source code may only be used and distributed under the Widevine License // Agreement. - #ifndef WVCDM_UTIL_STRING_CONVERSIONS_H_ #define WVCDM_UTIL_STRING_CONVERSIONS_H_ @@ -15,29 +14,49 @@ namespace wvcdm { +// ASCII hex to Binary conversion. CORE_UTIL_EXPORT std::vector a2b_hex(const std::string& b); CORE_UTIL_EXPORT std::vector a2b_hex(const std::string& label, const std::string& b); CORE_UTIL_EXPORT std::string a2bs_hex(const std::string& b); + +// Binary to ASCII hex conversion. CORE_UTIL_EXPORT std::string b2a_hex(const std::vector& b); CORE_UTIL_EXPORT std::string b2a_hex(const std::string& b); +CORE_UTIL_EXPORT std::string HexEncode(const uint8_t* bytes, size_t size); + +// Base64 encoding/decoding. +// Converts binary data into the ASCII Base64 character set and vice +// versa using the encoding rules defined in RFC4648 section 4. CORE_UTIL_EXPORT std::string Base64Encode( const std::vector& bin_input); +CORE_UTIL_EXPORT std::string Base64Encode(const std::string& bin_input); CORE_UTIL_EXPORT std::vector Base64Decode( const std::string& bin_input); + +// URL-Safe Base64 encoding/decoding. +// Converts binary data into the URL/Filename safe ASCII Base64 +// character set and vice versa using the encoding rules defined in +// RFC4648 section 5. CORE_UTIL_EXPORT std::string Base64SafeEncode( const std::vector& bin_input); -CORE_UTIL_EXPORT std::string Base64SafeEncodeNoPad( - const std::vector& bin_input); +CORE_UTIL_EXPORT std::string Base64SafeEncode(const std::string& bin_input); CORE_UTIL_EXPORT std::vector Base64SafeDecode( const std::string& bin_input); -CORE_UTIL_EXPORT std::string HexEncode(const uint8_t* bytes, unsigned size); -CORE_UTIL_EXPORT std::string IntToString(int value); +// URL-Safe Base64 encoding without padding. +// Similar to Base64SafeEncode(), without any padding character '=' +// at the end. +CORE_UTIL_EXPORT std::string Base64SafeEncodeNoPad( + const std::vector& bin_input); +CORE_UTIL_EXPORT std::string Base64SafeEncodeNoPad( + const std::string& bin_input); + +// Host to Network/Network to Host conversion. CORE_UTIL_EXPORT int64_t htonll64(int64_t x); CORE_UTIL_EXPORT inline int64_t ntohll64(int64_t x) { return htonll64(x); } -CORE_UTIL_EXPORT std::string BytesToString(const uint8_t* bytes, unsigned size); -// Encode unsigned integer into a big endian formatted string -CORE_UTIL_EXPORT std::string EncodeUint32(unsigned int u); + +// Encode unsigned integer into a big endian formatted string. +CORE_UTIL_EXPORT std::string EncodeUint32(uint32_t u); } // namespace wvcdm diff --git a/libwvdrmengine/cdm/util/src/string_conversions.cpp b/libwvdrmengine/cdm/util/src/string_conversions.cpp index 2b0ba5f1..3faa2fe5 100644 --- a/libwvdrmengine/cdm/util/src/string_conversions.cpp +++ b/libwvdrmengine/cdm/util/src/string_conversions.cpp @@ -10,15 +10,18 @@ #include #include -#include #include "log.h" #include "platform.h" namespace wvcdm { - -static const char kBase64Codes[] = +namespace { +// Base64 character set, indexed for their 6-bit mapping, plus '='. +const char kBase64Codes[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="; +// URL safe Base64 character set. +const char kBase64SafeCodes[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_="; // Gets the low |n| bits of |in|. #define GET_LOW_BITS(in, n) ((in) & ((1 << (n)) - 1)) @@ -27,26 +30,131 @@ static const char kBase64Codes[] = // Calculates a/b using round-up division (only works for positive numbers). #define CEIL_DIVIDE(a, b) ((((a)-1) / (b)) + 1) -int DecodeBase64Char(char c) { - const char* it = strchr(kBase64Codes, c); +// Decodes a single Base64 encoded character into its 6-bit value. +// The provided |codes| must be a Base64 character map. +int DecodeBase64Char(char c, const char* codes) { + const char* it = strchr(codes, c); if (it == nullptr) return -1; - return it - kBase64Codes; + return it - codes; } -bool DecodeHexChar(char ch, unsigned char* digit) { +bool DecodeHexChar(char ch, uint8_t* digit) { if (ch >= '0' && ch <= '9') { *digit = ch - '0'; - } else { - ch = tolower(ch); - if ((ch >= 'a') && (ch <= 'f')) { - *digit = ch - 'a' + 10; - } else { - return false; + return true; + } + ch = tolower(ch); + if ((ch >= 'a') && (ch <= 'f')) { + *digit = ch - 'a' + 10; + return true; + } + return false; +} + +// Encode for standard base64 encoding (RFC4648). +// https://en.wikipedia.org/wiki/Base64 +// Text | M | a | n | +// ASCI | 77 (0x4d) | 97 (0x61) | 110 (0x6e) | +// Bits | 0 1 0 0 1 1 0 1 0 1 1 0 0 0 0 1 0 1 1 0 1 1 1 0 | +// Index | 19 | 22 | 5 | 46 | +// Base64 | T | W | F | u | +// | <----------------- 24-bits -----------------> | + +// The provided |codes| must be a Base64 character map. +std::string Base64EncodeInternal(const uint8_t* data, size_t length, + const char* codes) { + // |temp| stores a 24-bit block that is treated as an array where insertions + // occur from high to low. + uint32_t temp = 0; + size_t out_index = 0; + const size_t out_size = CEIL_DIVIDE(length, 3) * 4; + std::string result(out_size, '\0'); + for (size_t i = 0; i < length; i++) { + // "insert" 8-bits of data + temp |= (data[i] << ((2 - (i % 3)) * 8)); + if (i % 3 == 2) { + result[out_index++] = codes[GET_BITS(temp, 18, 24)]; + result[out_index++] = codes[GET_BITS(temp, 12, 18)]; + result[out_index++] = codes[GET_BITS(temp, 6, 12)]; + result[out_index++] = codes[GET_BITS(temp, 0, 6)]; + temp = 0; } } - return true; + if (length % 3 == 1) { + result[out_index++] = codes[GET_BITS(temp, 18, 24)]; + result[out_index++] = codes[GET_BITS(temp, 12, 18)]; + result[out_index++] = '='; + result[out_index++] = '='; + } else if (length % 3 == 2) { + result[out_index++] = codes[GET_BITS(temp, 18, 24)]; + result[out_index++] = codes[GET_BITS(temp, 12, 18)]; + result[out_index++] = codes[GET_BITS(temp, 6, 12)]; + result[out_index++] = '='; + } + return result; } +std::vector Base64DecodeInternal(const char* encoded, size_t length, + const char* codes) { + const size_t out_size_max = CEIL_DIVIDE(length * 3, 4); + std::vector result(out_size_max, '\0'); + // |temp| stores 24-bits of data that is treated as an array where insertions + // occur from high to low. + uint32_t temp = 0; + size_t out_index = 0; + size_t i; + for (i = 0; i < length; i++) { + if (encoded[i] == '=') { + // Verify an '=' only appears at the end. We want i to remain at the + // first '=', so we need an inner loop. + for (size_t j = i; j < length; j++) { + if (encoded[j] != '=') { + LOGE("base64Decode failed"); + return std::vector(); + } + } + if (length % 4 != 0) { + // If padded, then the length must be a multiple of 4. + // Unpadded messages are OK. + LOGE("base64Decode failed"); + return std::vector(); + } + break; + } + + const int decoded = DecodeBase64Char(encoded[i], codes); + if (decoded < 0) { + LOGE("base64Decode failed"); + return std::vector(); + } + // "insert" 6-bits of data + temp |= (decoded << ((3 - (i % 4)) * 6)); + + if (i % 4 == 3) { + result[out_index++] = GET_BITS(temp, 16, 24); + result[out_index++] = GET_BITS(temp, 8, 16); + result[out_index++] = GET_BITS(temp, 0, 8); + temp = 0; + } + } + + switch (i % 4) { + case 1: + LOGE("base64Decode failed"); + return std::vector(); + case 2: + result[out_index++] = GET_BITS(temp, 16, 24); + break; + case 3: + result[out_index++] = GET_BITS(temp, 16, 24); + result[out_index++] = GET_BITS(temp, 8, 16); + break; + } + result.resize(out_index); + return result; +} +} // namespace + // converts an ascii hex string(2 bytes per digit) into a decimal byte string std::vector a2b_hex(const std::string& byte) { std::vector array; @@ -97,161 +205,7 @@ std::string b2a_hex(const std::string& byte) { byte.length()); } -// Encode for standard base64 encoding (RFC4648). -// https://en.wikipedia.org/wiki/Base64 -// Text | M | a | n | -// ASCI | 77 (0x4d) | 97 (0x61) | 110 (0x6e) | -// Bits | 0 1 0 0 1 1 0 1 0 1 1 0 0 0 0 1 0 1 1 0 1 1 1 0 | -// Index | 19 | 22 | 5 | 46 | -// Base64 | T | W | F | u | -// | <----------------- 24-bits -----------------> | -std::string Base64Encode(const std::vector& bin_input) { - if (bin_input.empty()) { - return std::string(); - } - - // |temp| stores a 24-bit block that is treated as an array where insertions - // occur from high to low. - uint32_t temp = 0; - size_t out_index = 0; - const size_t out_size = CEIL_DIVIDE(bin_input.size(), 3) * 4; - std::string result(out_size, '\0'); - for (size_t i = 0; i < bin_input.size(); i++) { - // "insert" 8-bits of data - temp |= (bin_input[i] << ((2 - (i % 3)) * 8)); - - if (i % 3 == 2) { - result[out_index++] = kBase64Codes[GET_BITS(temp, 18, 24)]; - result[out_index++] = kBase64Codes[GET_BITS(temp, 12, 18)]; - result[out_index++] = kBase64Codes[GET_BITS(temp, 6, 12)]; - result[out_index++] = kBase64Codes[GET_BITS(temp, 0, 6)]; - temp = 0; - } - } - - if (bin_input.size() % 3 == 1) { - result[out_index++] = kBase64Codes[GET_BITS(temp, 18, 24)]; - result[out_index++] = kBase64Codes[GET_BITS(temp, 12, 18)]; - result[out_index++] = '='; - result[out_index++] = '='; - } else if (bin_input.size() % 3 == 2) { - result[out_index++] = kBase64Codes[GET_BITS(temp, 18, 24)]; - result[out_index++] = kBase64Codes[GET_BITS(temp, 12, 18)]; - result[out_index++] = kBase64Codes[GET_BITS(temp, 6, 12)]; - result[out_index++] = '='; - } - - return result; -} - -// Filename-friendly base64 encoding (RFC4648), commonly referred to -// as Base64WebSafeEncode. -// -// This is the encoding required to interface with the provisioning server, as -// well as for certain license server transactions. It is also used for logging -// certain strings. The difference between web safe encoding vs regular encoding -// is that the web safe version replaces '+' with '-' and '/' with '_'. -std::string Base64SafeEncode(const std::vector& bin_input) { - if (bin_input.empty()) { - return std::string(); - } - - std::string ret = Base64Encode(bin_input); - for (size_t i = 0; i < ret.size(); i++) { - if (ret[i] == '+') - ret[i] = '-'; - else if (ret[i] == '/') - ret[i] = '_'; - } - return ret; -} - -std::string Base64SafeEncodeNoPad(const std::vector& bin_input) { - std::string b64_output = Base64SafeEncode(bin_input); - // Output size: ceiling [ bin_input.size() * 4 / 3 ]. - b64_output.resize((bin_input.size() * 4 + 2) / 3); - return b64_output; -} - -// Decode for standard base64 encoding (RFC4648). -std::vector Base64Decode(const std::string& b64_input) { - if (b64_input.empty()) { - return std::vector(); - } - - const size_t out_size_max = CEIL_DIVIDE(b64_input.size() * 3, 4); - std::vector result(out_size_max, '\0'); - - // |temp| stores 24-bits of data that is treated as an array where insertions - // occur from high to low. - uint32_t temp = 0; - size_t out_index = 0; - size_t i; - for (i = 0; i < b64_input.size(); i++) { - if (b64_input[i] == '=') { - // Verify an '=' only appears at the end. We want i to remain at the - // first '=', so we need an inner loop. - for (size_t j = i; j < b64_input.size(); j++) { - if (b64_input[j] != '=') { - LOGE("base64Decode failed"); - return std::vector(); - } - } - break; - } - - const int decoded = DecodeBase64Char(b64_input[i]); - if (decoded < 0) { - LOGE("base64Decode failed"); - return std::vector(); - } - // "insert" 6-bits of data - temp |= (decoded << ((3 - (i % 4)) * 6)); - - if (i % 4 == 3) { - result[out_index++] = GET_BITS(temp, 16, 24); - result[out_index++] = GET_BITS(temp, 8, 16); - result[out_index++] = GET_BITS(temp, 0, 8); - temp = 0; - } - } - - switch (i % 4) { - case 1: - LOGE("base64Decode failed"); - return std::vector(); - case 2: - result[out_index++] = GET_BITS(temp, 16, 24); - break; - case 3: - result[out_index++] = GET_BITS(temp, 16, 24); - result[out_index++] = GET_BITS(temp, 8, 16); - break; - } - result.resize(out_index); - return result; -} - -// Decode for Filename-friendly base64 encoding (RFC4648), commonly referred -// as Base64WebSafeDecode. Add padding if needed. -std::vector Base64SafeDecode(const std::string& b64_input) { - if (b64_input.empty()) { - return std::vector(); - } - - // Make a copy so we can modify it to replace the web-safe special characters - // with the normal ones. - std::string input_copy = b64_input; - for (size_t i = 0; i < input_copy.size(); i++) { - if (input_copy[i] == '-') - input_copy[i] = '+'; - else if (input_copy[i] == '_') - input_copy[i] = '/'; - } - return Base64Decode(input_copy); -} - -std::string HexEncode(const uint8_t* in_buffer, unsigned int size) { +std::string HexEncode(const uint8_t* in_buffer, size_t size) { static const char kHexChars[] = "0123456789ABCDEF"; if (size == 0) return ""; constexpr unsigned int kMaxSafeSize = 3072; @@ -267,19 +221,83 @@ std::string HexEncode(const uint8_t* in_buffer, unsigned int size) { return out_buffer; } -std::string IntToString(int value) { - // log10(2) ~= 0.3 bytes needed per bit or per byte log10(2**8) ~= 2.4. - // So round up to allocate 3 output characters per byte, plus 1 for '-'. - const int kOutputBufSize = 3 * sizeof(int) + 1; - char buffer[kOutputBufSize]; - memset(buffer, 0, kOutputBufSize); - snprintf(buffer, kOutputBufSize, "%d", value); +// Standard Base64 encoding and decoding. - std::string out_string(buffer); - return out_string; +std::string Base64Encode(const std::vector& bin_input) { + if (bin_input.empty()) { + return std::string(); + } + return Base64EncodeInternal(bin_input.data(), bin_input.size(), kBase64Codes); } -int64_t htonll64(int64_t x) { // Convert to big endian (network-byte-order) +std::string Base64Encode(const std::string& bin_input) { + if (bin_input.empty()) { + return std::string(); + } + return Base64EncodeInternal( + reinterpret_cast(bin_input.data()), bin_input.size(), + kBase64Codes); +} + +// Decode for standard base64 encoding (RFC4648). +std::vector Base64Decode(const std::string& b64_input) { + if (b64_input.empty()) { + return std::vector(); + } + return Base64DecodeInternal(b64_input.data(), b64_input.size(), kBase64Codes); +} + +// URL/Filename Safe Base64 encoding and decoding. + +// This is the encoding required to interface with the provisioning server, as +// well as for certain license server transactions. It is also used for logging +// certain strings. The difference between web safe encoding vs regular encoding +// is that the web safe version replaces '+' with '-' and '/' with '_'. +std::string Base64SafeEncode(const std::vector& bin_input) { + if (bin_input.empty()) { + return std::string(); + } + return Base64EncodeInternal(bin_input.data(), bin_input.size(), + kBase64SafeCodes); +} + +std::string Base64SafeEncode(const std::string& bin_input) { + if (bin_input.empty()) { + return std::string(); + } + return Base64EncodeInternal( + reinterpret_cast(bin_input.data()), bin_input.size(), + kBase64SafeCodes); +} + +std::vector Base64SafeDecode(const std::string& b64_input) { + if (b64_input.empty()) { + return std::vector(); + } + return Base64DecodeInternal(b64_input.data(), b64_input.size(), + kBase64SafeCodes); +} + +// URL/Filename Safe Base64 encoding without padding. + +std::string Base64SafeEncodeNoPad(const std::vector& bin_input) { + std::string b64_output = Base64SafeEncode(bin_input); + // Output size: ceiling [ bin_input.size() * 4 / 3 ]. + b64_output.resize((bin_input.size() * 4 + 2) / 3); + return b64_output; +} + +std::string Base64SafeEncodeNoPad(const std::string& bin_input) { + std::string b64_output = Base64SafeEncode(bin_input); + // Output size: ceiling [ bin_input.size() * 4 / 3 ]. + b64_output.resize((bin_input.size() * 4 + 2) / 3); + return b64_output; +} + +// Host to Network/Network to Host conversion. + +// Convert to big endian (network-byte-order) +int64_t htonll64(int64_t x) { union { uint32_t array[2]; int64_t number; @@ -296,19 +314,13 @@ int64_t htonll64(int64_t x) { // Convert to big endian (network-byte-order) } } -std::string BytesToString(const uint8_t* bytes, unsigned size) { - if (!bytes || !size) return ""; - const char* char_bytes = reinterpret_cast(bytes); - return std::string(char_bytes, char_bytes + size); -} - // Encode unsigned integer into a big endian formatted string std::string EncodeUint32(unsigned int u) { std::string s; - s.append(1, (u >> 24) & 0xFF); - s.append(1, (u >> 16) & 0xFF); - s.append(1, (u >> 8) & 0xFF); - s.append(1, (u >> 0) & 0xFF); + s.push_back((u >> 24) & 0xFF); + s.push_back((u >> 16) & 0xFF); + s.push_back((u >> 8) & 0xFF); + s.push_back(u & 0xFF); return s; } diff --git a/libwvdrmengine/cdm/util/test/base64_test.cpp b/libwvdrmengine/cdm/util/test/base64_test.cpp index 81d32a98..684fa447 100644 --- a/libwvdrmengine/cdm/util/test/base64_test.cpp +++ b/libwvdrmengine/cdm/util/test/base64_test.cpp @@ -55,8 +55,14 @@ const std::pair kBase64TestVectors[] = { make_pair(&kTwoBytesOverData, &kTwoBytesOverB64Data), make_pair(&kTestData, &kB64TestData)}; -const std::string kBase64ErrorVectors[] = {"Foo$sa", "Foo\x99\x23\xfa\02", - "Foo==Foo", "FooBa"}; +// Arbitrary invalid base64 test vectors +const std::string kBase64ErrorVectors[] = {"Foo$sa", + "Foo\x99\x23\xfa\02", + "Foo==Foo", + "FooBa", + "SGVsbG8sIFdvcmxkI===", + "SGVsbG8sIFdvcmxkI======", + "SGVsbG8sIFdvcmxkIQp=="}; std::string ConvertToBase64WebSafe(const std::string& std_base64_string) { std::string str(std_base64_string); @@ -77,28 +83,90 @@ class Base64EncodeDecodeTest TEST_P(Base64EncodeDecodeTest, EncodeDecodeTest) { std::pair values = GetParam(); - std::vector decoded_vector = Base64Decode(values.second->data()); - std::string decoded_string(decoded_vector.begin(), decoded_vector.end()); - EXPECT_STREQ(values.first->data(), decoded_string.data()); - std::string b64_string = Base64Encode(decoded_vector); - EXPECT_STREQ(values.second->data(), b64_string.data()); + const std::string& plain_text_string = *(values.first); + const std::string& expected_encoded = *(values.second); + + // Encode from string. + const std::string b64_string_encoded = Base64Encode(plain_text_string); + EXPECT_EQ(b64_string_encoded, expected_encoded); + + // Encode from vector. + const std::vector plain_text_vector(plain_text_string.begin(), + plain_text_string.end()); + const std::string b64_vector_encoded = Base64Encode(plain_text_vector); + EXPECT_EQ(b64_vector_encoded, expected_encoded); + + // Decode from string. + const std::vector decoded_vector = Base64Decode(expected_encoded); + EXPECT_EQ(decoded_vector, plain_text_vector); } TEST_P(Base64EncodeDecodeTest, WebSafeEncodeDecodeTest) { std::pair values = GetParam(); - std::string encoded_string = ConvertToBase64WebSafe(*(values.second)); - std::vector decoded_vector = Base64SafeDecode(encoded_string); - std::string decoded_string(decoded_vector.begin(), decoded_vector.end()); - EXPECT_STREQ(values.first->data(), decoded_string.data()); - std::string b64_string = Base64SafeEncode(decoded_vector); - EXPECT_STREQ(encoded_string.data(), b64_string.data()); + const std::string& plain_text_string = *(values.first); + const std::string& expected_encoded = + ConvertToBase64WebSafe(*(values.second)); + + // Encode from string. + const std::string b64_string_encoded = Base64SafeEncode(plain_text_string); + EXPECT_EQ(b64_string_encoded, expected_encoded); + + // Encode from vector. + const std::vector plain_text_vector(plain_text_string.begin(), + plain_text_string.end()); + const std::string b64_vector_encoded = Base64SafeEncode(plain_text_vector); + EXPECT_EQ(b64_vector_encoded, expected_encoded); + + // Decode from string. + const std::vector decoded_vector = + Base64SafeDecode(expected_encoded); + EXPECT_EQ(decoded_vector, plain_text_vector); +} + +TEST_P(Base64EncodeDecodeTest, WebSafeEncodeNoPad) { + std::pair values = GetParam(); + const std::string& plain_text_string = *(values.first); + const std::string& padded_encoded = ConvertToBase64WebSafe(*(values.second)); + + // Encode from string. + const std::string b64_string_encoded = + Base64SafeEncodeNoPad(plain_text_string); + + // If input is empty, output will be empty. + if (plain_text_string.empty()) { + EXPECT_TRUE(b64_string_encoded.empty()); + return; + } + + if (padded_encoded.back() == '=') { + // If padding is present in the regular encoding, then it should be + // striped from the result. + EXPECT_NE(b64_string_encoded.back(), '='); + const std::string expected_encoded = + padded_encoded.substr(0, b64_string_encoded.size()); + EXPECT_EQ(b64_string_encoded, expected_encoded); + } else { + // If no padding is present, then results should be equal. + EXPECT_EQ(b64_string_encoded, padded_encoded); + } + + // Encode from vector. + const std::vector plain_text_vector(plain_text_string.begin(), + plain_text_string.end()); + const std::string b64_vector_encoded = + Base64SafeEncodeNoPad(plain_text_vector); + // Assuming the above has passed, the results should be the same as + // a result encoded from a string. + EXPECT_EQ(b64_vector_encoded, b64_string_encoded); } class Base64ErrorDecodeTest : public ::testing::TestWithParam {}; TEST_P(Base64ErrorDecodeTest, EncoderErrors) { - std::vector result = Base64Decode(GetParam()); - EXPECT_EQ(0u, result.size()); + const std::vector standard_result = Base64Decode(GetParam()); + EXPECT_TRUE(standard_result.empty()); + const std::vector safe_result = Base64SafeDecode(GetParam()); + EXPECT_TRUE(safe_result.empty()); } INSTANTIATE_TEST_CASE_P(ExecutesBase64Test, Base64EncodeDecodeTest,