Files
ce_cdm/cdm/test/perf_test.cpp
2024-03-28 19:15:22 -07:00

350 lines
10 KiB
C++

// Copyright 2021 Google LLC. All Rights Reserved. This file and proprietary
// source code may only be used and distributed under the Widevine License
// Agreement.
#include "perf_test.h"
#include <gtest/gtest.h>
#include <stdio.h>
#include <stdlib.h>
#include <chrono>
#include <cmath>
#include <cstring>
#include <ctime>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "config_test_env.h"
#include "license_request.h"
#include "logger_global.h"
#include "test_host.h"
#include "url_request.h"
#define ASSERT_SUCCESS(code) ASSERT_EQ(code, Cdm::kSuccess)
#define EXPECT_SUCCESS(code) EXPECT_EQ(code, Cdm::kSuccess)
#define WALL_NOW std::chrono::high_resolution_clock::now()
TestHost* g_host = nullptr;
widevine::StderrLogger g_stderr_logger;
namespace widevine {
namespace {
constexpr const size_t kTestCount = 50;
const wvcdm::ConfigTestEnv kTestData(wvcdm::kContentProtectionUatServer);
CreateFuncType create_func = nullptr;
using TimeType = std::chrono::duration<double, std::milli>;
struct PerfInfo {
double mean;
double min;
double max;
double std_dev;
template <size_t Size>
PerfInfo(const double (&values)[Size]) {
static_assert(Size > 0, "Must pass at least one value");
// First pass to calculate min/max/mean.
bool first = true;
double sum = 0;
for (auto v : values) {
sum += v;
if (first) {
min = max = v;
first = false;
} else {
if (v < min) min = v;
if (v > max) max = v;
}
}
mean = sum / Size;
// Second pass to calculate standard deviation.
sum = 0;
for (auto v : values) {
sum += (v - mean) * (v - mean);
}
std_dev = std::sqrt(sum / Size);
}
};
std::ostream& operator<<(std::ostream& os, const PerfInfo& info) {
// mean=12.33442, std-dev=1.44421, min=1.22431, max=244.1133144
return os << "mean=" << info.mean << ", std-dev=" << info.std_dev
<< ", min=" << info.min << ", max=" << info.max;
}
class PerfTracker {
public:
class Test {
public:
Test(PerfTracker* tracker)
: wall_start_(WALL_NOW), cpu_start_(std::clock()), tracker_(tracker) {}
~Test() {
tracker_->wall_times_[tracker_->index_] =
TimeType(WALL_NOW - wall_start_).count();
tracker_->cpu_times_[tracker_->index_] =
(std::clock() - cpu_start_) * 1000.0 / CLOCKS_PER_SEC;
tracker_->index_++;
}
private:
std::chrono::high_resolution_clock::time_point wall_start_;
std::clock_t cpu_start_;
PerfTracker* tracker_;
};
void Print(const std::string& name, size_t block_size_bytes = 0) {
PerfInfo wall_perf(wall_times_);
PerfInfo cpu_perf(cpu_times_);
std::cout << name << " (wall, ms): " << wall_perf << "\n";
std::cout << name << " (cpu, ms): " << cpu_perf << "\n";
if (block_size_bytes) {
// |mean| is in milliseconds.
std::cout << name << " (wall, MBit/sec): "
<< (block_size_bytes * 8 * 1000 / wall_perf.mean / 1024 / 1024)
<< "\n";
std::cout << name << " (cpu, MBit/sec): "
<< (block_size_bytes * 8 * 1000 / cpu_perf.mean / 1024 / 1024)
<< "\n";
}
}
private:
double wall_times_[kTestCount];
double cpu_times_[kTestCount];
size_t index_ = 0;
};
#define MEASURE_PERF(tracker, code) \
{ \
PerfTracker::Test test(&(tracker)); \
code; \
}
class EventListener : public Cdm::IEventListener {
public:
struct MessageInfo {
std::string session_id;
std::string message;
Cdm::MessageType message_type;
std::string url;
};
void onMessage(const std::string& session_id, Cdm::MessageType message_type,
const std::string& message, const std::string& url) override {
messages.push_back({session_id, message, message_type, url});
}
void onKeyStatusesChange(const std::string& session_id,
bool has_new_usable_key) override {}
void onExpirationChange(const std::string& session_id,
int64_t new_expiration) override {}
void onRemoveComplete(const std::string& session_id) override {}
std::vector<MessageInfo> messages;
};
bool SendPost(const std::string& message, std::string* response) {
wvcdm::UrlRequest req(kTestData.license_server());
std::string raw_response;
if (!req.is_connected() || !req.PostRequest(message) ||
!req.GetResponse(&raw_response)) {
return false;
}
wvcdm::LicenseRequest helper;
helper.GetDrmMessage(raw_response, *response);
return true;
}
std::unique_ptr<Cdm> CreateCdm(EventListener* event_listener) {
std::unique_ptr<Cdm> ret(
create_func(event_listener, &g_host->per_origin_storage(), true));
if (ret) {
EXPECT_SUCCESS(ret->setServiceCertificate(
Cdm::kProvisioningService,
kTestData.provisioning_service_certificate()));
EXPECT_SUCCESS(ret->setServiceCertificate(
Cdm::kLicensingService, kTestData.license_service_certificate()));
}
return ret;
}
class GlobalEnv : public testing::Environment {
public:
GlobalEnv(InitFuncType init_func, const std::string& cert)
: init_func_(init_func), cert_(cert) {}
void SetUp() override {
// Manually set the logger because `TestHost` makes logging calls before
// the global logger is set in |init_func_|.
g_logger = &g_stderr_logger;
g_host = new TestHost;
if (!cert_.empty()) g_host->per_origin_storage().write("cert.bin", cert_);
Cdm::LogLevel log_level = Cdm::kErrors;
if (const char* verbose = getenv("VERBOSE_OUTPUT")) {
if (std::strcmp(verbose, "1") == 0) log_level = Cdm::kVerbose;
}
ASSERT_SUCCESS(init_func_(Cdm::kNoSecureOutput, &g_host->global_storage(),
g_host, g_host, &g_stderr_logger, log_level));
}
private:
const InitFuncType init_func_;
const std::string cert_;
};
} // namespace
class PerfTest : public testing::Test {};
TEST_F(PerfTest, LicenseExchange) {
EventListener event_listener;
auto cdm = CreateCdm(&event_listener);
ASSERT_TRUE(cdm);
ASSERT_EQ(cdm->getProvisioningStatus(), Cdm::kProvisioned);
PerfTracker create;
PerfTracker generate;
PerfTracker update;
PerfTracker close;
for (size_t i = 0; i < kTestCount; i++) {
std::string session_id;
MEASURE_PERF(create, ASSERT_SUCCESS(
cdm->createSession(Cdm::kTemporary, &session_id)));
MEASURE_PERF(
generate,
ASSERT_SUCCESS(cdm->generateRequest(
session_id, Cdm::kCenc,
wvcdm::ConfigTestEnv::GetInitData(wvcdm::kContentIdStreaming))));
std::string response;
ASSERT_TRUE(SendPost(event_listener.messages[0].message, &response));
MEASURE_PERF(update, ASSERT_SUCCESS(cdm->update(session_id, response)));
MEASURE_PERF(close, ASSERT_SUCCESS(cdm->close(session_id)));
event_listener.messages.pop_back();
}
create.Print("Create ");
generate.Print("Generate");
update.Print("Update ");
close.Print("Close ");
}
class DecryptPerfTest : public PerfTest,
public testing::WithParamInterface<bool> {};
TEST_P(DecryptPerfTest, Decrypt) {
EventListener event_listener;
auto cdm = CreateCdm(&event_listener);
ASSERT_TRUE(cdm);
ASSERT_EQ(cdm->getProvisioningStatus(), Cdm::kProvisioned);
std::string session_id;
ASSERT_SUCCESS(cdm->createSession(Cdm::kTemporary, &session_id));
ASSERT_SUCCESS(cdm->generateRequest(
session_id, Cdm::kCenc,
wvcdm::ConfigTestEnv::GetInitData(wvcdm::kContentIdStreaming)));
std::string response;
ASSERT_TRUE(SendPost(event_listener.messages[0].message, &response));
ASSERT_SUCCESS(cdm->update(session_id, response));
Cdm::KeyStatusMap statuses;
ASSERT_SUCCESS(cdm->getKeyStatuses(session_id, &statuses));
ASSERT_GT(statuses.size(), 0u);
const std::string key_id = statuses.begin()->first;
// Use in-place decrypt to avoid allocations. We don't care about the data,
// so we can just decrypt the same buffer again.
constexpr const size_t k16M = 16 * 1024 * 1024;
std::vector<uint8_t> buffer(k16M);
uint8_t iv[16];
for (auto& b : buffer) b = rand();
Cdm::DecryptionBatch batch;
batch.key_id = reinterpret_cast<const uint8_t*>(key_id.data());
batch.key_id_length = static_cast<uint32_t>(key_id.size());
if (GetParam()) {
batch.pattern.encrypted_blocks = batch.pattern.clear_blocks = 0;
} else {
batch.pattern.encrypted_blocks = 1;
batch.pattern.clear_blocks = 9;
}
batch.is_secure = false;
batch.encryption_scheme = GetParam() ? Cdm::kAesCtr : Cdm::kAesCbc;
batch.is_video = true;
Cdm::Subsample subsample;
subsample.clear_bytes = 0;
// subsample.encrypted_bytes set in the test.
Cdm::Sample sample;
sample.input.iv = iv;
sample.input.iv_length = 16;
sample.input.data = buffer.data();
// sample.data_length set in the test.
sample.input.subsamples = &subsample;
sample.input.subsamples_length = 1;
sample.output.data = buffer.data();
sample.output.data_offset = 0;
sample.output.data_length = static_cast<uint32_t>(buffer.size());
batch.samples = &sample;
batch.samples_length = 1;
constexpr const size_t block_sizes[] = {8 * 1024, 256 * 1024, k16M};
constexpr const size_t sizes_count =
sizeof(block_sizes) / sizeof(block_sizes[0]);
const std::string block_names[] = {" 8k", "256k", " 16M"};
PerfTracker perf[sizes_count];
for (size_t i = 0; i < sizes_count; i++) {
subsample.protected_bytes = sample.input.data_length =
sample.output.data_length = static_cast<uint32_t>(block_sizes[i]);
for (size_t j = 0; j < kTestCount; j++) {
MEASURE_PERF(perf[i], ASSERT_SUCCESS(cdm->decrypt(batch)));
}
}
for (size_t i = 0; i < sizes_count; i++) {
perf[i].Print("Decrypt " + block_names[i], block_sizes[i]);
}
}
std::string PrintDecryptParam(const testing::TestParamInfo<bool>& info) {
return info.param ? "CTR" : "CBC";
}
INSTANTIATE_TEST_SUITE_P(Decrypt, DecryptPerfTest, testing::Bool(),
PrintDecryptParam);
int PerfTestMain(InitFuncType init_func, CreateFuncType create,
const std::string& cert) {
#ifdef _DEBUG
// Don't use #error since we build all targets and we don't want to fail the
// debug build (and we can't have configuration-specific targets).
fprintf(stderr, "Don't run performance tests in Debug mode\n");
return 1;
#else
create_func = create;
testing::AddGlobalTestEnvironment(new GlobalEnv(init_func, cert));
return RUN_ALL_TESTS();
#endif
}
} // namespace widevine