Program Listing for File dns_packet.cpp
↰ Return to documentation for file (src/dns_packet.cpp
)
#include "dns_packet.h"
#include <expected.h>
#include <expected_helpers.h>
#include <netinet/in.h>
#include <rte_branch_prediction.h>
#include <rte_byteorder.h>
#include <rte_ether.h>
#include <rte_ip.h>
#include <rte_mbuf.h>
#include <rte_udp.h>
#include <spdlog/spdlog.h>
#include <algorithm>
#include <cstring>
#include <exception>
#include "dns_format.h"
#include "dns_struct_defs.h"
#include "dpdk_wrappers.h"
#include "network_types.h"
#include "spdlog/fmt/bundled/core.h"
#include "spdlog/fmt/fmt.h"
auto fmt::formatter<DNSParseError>::format(DNSParseError e, format_context &ctx) const
-> decltype(ctx.out()) {
string_view error = "unknown error :(";
switch (e) {
case DNSParseError::OutOfBounds:
error = "out of bounds";
break;
case DNSParseError::PktError:
error = "packet error";
break;
case DNSParseError::SrcPortErr:
error = "source port error";
break;
case DNSParseError::TxtTooLong:
error = "text too long";
break;
case DNSParseError::NameTooLong:
error = "name too long";
break;
case DNSParseError::InvalidQCount:
error = "invalid question count";
break;
case DNSParseError::IpHdrProtoErr:
error = "wrong ip header proto";
break;
case DNSParseError::AllocationError:
error = "allocation error";
break;
case DNSParseError::MalformedPacket:
error = "malformed packet";
break;
case DNSParseError::MaxJumpsReached:
error = "max jumps reached";
break;
case DNSParseError::EtherHdrProtoErr:
error = "ethernet header error";
case DNSParseError::InvalidChar:
error = "invalid character detected in packet";
break;
}
return formatter<string_view>::format(error, ctx);
}
template <typename T>
inline tl::expected<const T *, DNSParseError> AdvanceReader(std::span<const std::byte> bytes,
std::span<const std::byte>::iterator &reader) {
if (static_cast<std::span<const std::byte>::iterator>(reader + sizeof(T)) > bytes.end())
[[unlikely]]
return tl::unexpected(DNSParseError::OutOfBounds);
auto ptr = reinterpret_cast<const T *>(reader.base());
reader += sizeof(T);
return ptr;
}
bool contains_unprintable_chars_or_space(std::string_view sv) {
bool result = false;
for (const auto &c : sv) {
result |= c < '!' || c > '~';
}
return result;
}
tl::expected<DnsName, DNSParseError> ReadFromDNSNameFormat(std::span<const std::byte> bytes,
std::span<const std::byte>::iterator &reader) {
DnsName name_parsed;
// Keep track of jump count to prevent infinite loops
int jmp_cnt = 0;
name_parsed.buf[0] = '\0';
// Number of bytes read to name
name_parsed.len = 0;
// Number of bytes stepped forward in packet, start at one to step past NULL terminator
int count = 1;
// Save where we started so we can set the iterator to the correct
// position later.
auto begin = reader;
// Check if current character can be safely accessed
if (reader >= bytes.end())
return tl::unexpected(DNSParseError::OutOfBounds);
// After every read/write the packet bounds are checked for the subsequent character,
// in this way it is always possible to check for / insert NULL terminator
// Read until \0 terminator is found
while (static_cast<char>(*reader) != '\0') {
// Check for jump before every segement, indicated by 0b11xxxxxx byte
if (static_cast<unsigned char>(*reader) >= 0xC0) [[unlikely]] {
// Check bounds before incrementing pointer
if (reader + 1 >= bytes.end()) [[unlikely]]
return tl::unexpected(DNSParseError::OutOfBounds);
std::byte msb = *reader;
std::byte lsb = *(reader + 1);
uint16_t offset = static_cast<uint16_t>(msb & std::byte{0x3F}) << 8 |
static_cast<uint16_t>(lsb);
reader = bytes.begin() + offset;
// Check new location of reader for bounds before reading
if (reader >= bytes.end() || reader < bytes.begin()) [[unlikely]]
return tl::unexpected(DNSParseError::OutOfBounds);
jmp_cnt++;
if (jmp_cnt > 10) [[unlikely]]
return tl::unexpected(DNSParseError::MaxJumpsReached);
continue;
}
// Read length of next segment
const uint8_t len = static_cast<uint8_t>(*(reader));
if (++reader >= bytes.end()) [[unlikely]]
return tl::unexpected(DNSParseError::OutOfBounds);
// Increment count if we haven't jumped yet
count += !(jmp_cnt);
// Read segment into buffer
for (uint8_t p = 0; p < len; p++) {
name_parsed.buf[name_parsed.len] = static_cast<uint16_t>(*(reader));
if (++reader >= bytes.end()) [[unlikely]]
return tl::unexpected(DNSParseError::OutOfBounds);
if (++name_parsed.len >= DOMAIN_NAME_MAX_SIZE) [[unlikely]]
return tl::unexpected(DNSParseError::NameTooLong);
// Increment count if we haven't jumped yet
count += !(jmp_cnt);
}
name_parsed.buf[name_parsed.len] = '.';
if (++name_parsed.len >= DOMAIN_NAME_MAX_SIZE) [[unlikely]]
return tl::unexpected(DNSParseError::NameTooLong);
}
// Add NULL terminator
name_parsed.buf[name_parsed.len] = '\0';
if (contains_unprintable_chars_or_space(std::string_view(name_parsed))) [[unlikely]]
return tl::unexpected(DNSParseError::InvalidChar);
// Add one to count if jumped to account for 2 byte offset field instead of 1 byte NULL
// terminator
count += (bool) jmp_cnt;
reader = begin + count;
return name_parsed;
}
template <size_t N>
tl::expected<FixedName<N>, DNSParseError> ParseFixedName(std::span<const std::byte> bytes,
std::span<const std::byte>::iterator &reader, uint16_t length) {
FixedName<N> result;
// Tag field name length does not take \0 terminator into account
if (length >= N) [[unlikely]]
return tl::unexpected(DNSParseError::NameTooLong);
if (reader + length > bytes.end()) [[unlikely]]
return tl::unexpected(DNSParseError::OutOfBounds);
// Copy tag into r_data
std::copy(reader, reader + length, reinterpret_cast<std::byte *>(result.buf.begin()));
result.buf[length] = '\0';
result.len = length;
reader += length;
return result;
}
tl::expected<ResourceRecord, DNSParseError> ParseResourceRecord(std::span<const std::byte> bytes,
std::span<const std::byte>::iterator &reader) {
ResourceRecord parsed_record;
parsed_record.name = UNWRAP_OR_RETURN(ReadFromDNSNameFormat(bytes, reader));
const RData *response = UNWRAP_OR_RETURN(AdvanceReader<RData>(bytes, reader));
parsed_record.q_type = static_cast<DnsQType>(rte_be_to_cpu_16(response->type));
parsed_record.ttl = rte_be_to_cpu_32(response->ttl);
auto rdata_bytes = std::span(reader, rte_be_to_cpu_16(response->data_len));
// Make sure that the rdata bytes area doesn't go outside the packet byte area
if (rdata_bytes.end() > bytes.end()) [[unlikely]]
return tl::unexpected(DNSParseError::OutOfBounds);
auto begin = reader;
switch (parsed_record.q_type) {
case DnsQType::A: {
ARdata r_data;
r_data.ipv4_addr =
*UNWRAP_OR_RETURN(AdvanceReader<InAddr>(rdata_bytes, reader));
parsed_record.r_data = r_data;
break;
}
case DnsQType::AAAA: {
AAAARdata r_data;
r_data.ipv6_addr =
*UNWRAP_OR_RETURN(AdvanceReader<In6Addr>(rdata_bytes, reader));
parsed_record.r_data = r_data;
break;
}
case DnsQType::NS: {
NSRdata r_data;
// Always pass in the full DNS packet as valid area for the
// ReadFromDNSNameFormat since the reader might jump
r_data.nameserver = UNWRAP_OR_RETURN(ReadFromDNSNameFormat(bytes, reader));
parsed_record.r_data = r_data;
break;
}
case DnsQType::MX: {
MXRdata r_data;
r_data.preference = rte_be_to_cpu_16(
*UNWRAP_OR_RETURN(AdvanceReader<uint16_t>(rdata_bytes, reader)));
r_data.mailserver = UNWRAP_OR_RETURN(ReadFromDNSNameFormat(bytes, reader));
parsed_record.r_data = r_data;
break;
}
case DnsQType::CNAME: {
CNAMERdata r_data;
r_data.cname = UNWRAP_OR_RETURN(ReadFromDNSNameFormat(bytes, reader));
parsed_record.r_data = r_data;
break;
}
case DnsQType::DNAME: {
DNAMERdata r_data;
r_data.dname = UNWRAP_OR_RETURN(ReadFromDNSNameFormat(bytes, reader));
parsed_record.r_data = r_data;
break;
}
case DnsQType::PTR: {
PTRRdata r_data;
r_data.ptr = UNWRAP_OR_RETURN(ReadFromDNSNameFormat(bytes, reader));
parsed_record.r_data = r_data;
break;
}
case DnsQType::TXT: {
TXTRdata r_data;
r_data.txt.len = 0;
auto txt_writer = r_data.txt.buf.begin();
while (reader < rdata_bytes.end()) {
uint8_t next_string_size =
*UNWRAP_OR_RETURN(AdvanceReader<uint8_t>(rdata_bytes, reader));
if (reader + next_string_size > rdata_bytes.end()) [[unlikely]]
return tl::unexpected(DNSParseError::OutOfBounds);
// Also take \0 terminator into account
if (txt_writer + next_string_size + 1 > r_data.txt.buf.end())
[[unlikely]]
return tl::unexpected(DNSParseError::TxtTooLong);
std::copy(reader, reader + next_string_size,
reinterpret_cast<std::byte *>(txt_writer));
reader += next_string_size;
txt_writer += next_string_size;
r_data.txt.len += next_string_size;
}
// Bound validity is already checked
*txt_writer = '\0';
parsed_record.r_data = r_data;
break;
}
case DnsQType::SOA: {
SOARdata r_data;
r_data.m_name = UNWRAP_OR_RETURN(ReadFromDNSNameFormat(bytes, reader));
r_data.r_name = UNWRAP_OR_RETURN(ReadFromDNSNameFormat(bytes, reader));
r_data.interval_settings = *UNWRAP_OR_RETURN(
AdvanceReader<SOARdata::IntervalSettings>(bytes, reader));
r_data.interval_settings.serial =
rte_be_to_cpu_32(r_data.interval_settings.serial);
r_data.interval_settings.refresh =
rte_be_to_cpu_32(r_data.interval_settings.refresh);
r_data.interval_settings.retry =
rte_be_to_cpu_32(r_data.interval_settings.retry);
r_data.interval_settings.expire =
rte_be_to_cpu_32(r_data.interval_settings.expire);
r_data.interval_settings.minimum =
rte_be_to_cpu_32(r_data.interval_settings.minimum);
parsed_record.r_data = r_data;
break;
}
case DnsQType::CAA: {
CAARdata r_data;
r_data.flags =
*UNWRAP_OR_RETURN(AdvanceReader<uint8_t>(rdata_bytes, reader));
uint8_t tag_length =
*UNWRAP_OR_RETURN(AdvanceReader<uint8_t>(rdata_bytes, reader));
r_data.tag = UNWRAP_OR_RETURN(
ParseFixedName<CAA_TAG_MAX_SIZE>(rdata_bytes, reader, tag_length));
if (contains_unprintable_chars_or_space(std::string_view(r_data.tag)))
[[unlikely]]
return tl::unexpected(DNSParseError::InvalidChar);
uint16_t value_len = rdata_bytes.end() - reader;
// Max length of the value is not specified in the RFC,
// character string should be sufficient
r_data.value = UNWRAP_OR_RETURN(ParseFixedName<CHARACTER_STRING_MAX_SIZE>(
rdata_bytes, reader, value_len));
parsed_record.r_data = r_data;
break;
}
case DnsQType::OPT: {
reader += rte_be_to_cpu_16(response->data_len);
parsed_record.r_data = OPTRdata{};
break;
}
default:
reader += rte_be_to_cpu_16(response->data_len);
parsed_record.r_data = std::monostate();
break;
}
// Check that the actual read bytes are equal to the number of bytes indicated in
// the RData length section
if (reader != begin + rte_be_to_cpu_16(response->data_len)) [[unlikely]]
return tl::unexpected(DNSParseError::MalformedPacket);
return parsed_record;
}
tl::expected<DNSPacket, DNSParseError> DNSPacket::init(
RTEMempool<DefaultPacket, MbufType::Raw> &mempool,
RTEMbufElement<DefaultPacket, MbufType::Pkt> &raw_pkt) {
auto pkt = raw_pkt.get();
std::span<const std::byte> packet_bytes(pkt.data().padding.data(), pkt.data_len);
auto reader = packet_bytes.begin();
IpData ip_data{};
const rte_ether_hdr *ether_hdr =
UNWRAP_OR_RETURN(AdvanceReader<rte_ether_hdr>(packet_bytes, reader));
uint16_t ether_type = rte_be_to_cpu_16(ether_hdr->ether_type);
if (ether_type == RTE_ETHER_TYPE_IPV4) {
const rte_ipv4_hdr *ip_hdr =
UNWRAP_OR_RETURN(AdvanceReader<rte_ipv4_hdr>(packet_bytes, reader));
ip_data.dst_ip = InAddr{ip_hdr->dst_addr};
ip_data.src_ip = InAddr{ip_hdr->src_addr};
if (ip_hdr->next_proto_id != IPPROTO_UDP)
return tl::unexpected(DNSParseError::IpHdrProtoErr);
} else if (ether_type == RTE_ETHER_TYPE_IPV6) {
const rte_ipv6_hdr *ip_hdr =
UNWRAP_OR_RETURN(AdvanceReader<rte_ipv6_hdr>(packet_bytes, reader));
ip_data.dst_ip = *reinterpret_cast<const In6Addr *>(ip_hdr->dst_addr);
ip_data.src_ip = *reinterpret_cast<const In6Addr *>(ip_hdr->src_addr);
if (ip_hdr->proto != IPPROTO_UDP)
return tl::unexpected(DNSParseError::IpHdrProtoErr);
} else {
return tl::unexpected(DNSParseError::EtherHdrProtoErr);
}
const rte_udp_hdr *udp_hdr =
UNWRAP_OR_RETURN(AdvanceReader<rte_udp_hdr>(packet_bytes, reader));
ip_data.dst_port = rte_be_to_cpu_16(udp_hdr->dst_port);
ip_data.src_port = rte_be_to_cpu_16(udp_hdr->src_port);
if (ip_data.src_port != 53)
return tl::unexpected(DNSParseError::SrcPortErr);
std::span<const std::byte> dns_bytes = std::span(reader, packet_bytes.end());
const DnsHeader *hdr = UNWRAP_OR_RETURN(AdvanceReader<DnsHeader>(dns_bytes, reader));
if (rte_be_to_cpu_16(hdr->q_count) != 1) [[unlikely]]
return tl::unexpected(DNSParseError::InvalidQCount);
DnsName question = UNWRAP_OR_RETURN(ReadFromDNSNameFormat(dns_bytes, reader));
const QuestionInfo *question_info =
UNWRAP_OR_RETURN(AdvanceReader<QuestionInfo>(dns_bytes, reader));
uint16_t num_ans = std::min(MAX_RECORDS, rte_be_to_cpu_16(hdr->ans_count));
uint16_t num_auth = std::min(MAX_RECORDS, rte_be_to_cpu_16(hdr->auth_count));
uint16_t num_add = std::min(MAX_RECORDS, rte_be_to_cpu_16(hdr->add_count));
bool records_capped = rte_be_to_cpu_16(hdr->add_count) > MAX_RECORDS ||
rte_be_to_cpu_16(hdr->auth_count) > MAX_RECORDS ||
rte_be_to_cpu_16(hdr->add_count) > MAX_RECORDS;
auto ans_mbufs_ =
RTEMbufArray<ResourceRecord, MAX_RECORDS, MbufType::Raw>::init(mempool, num_ans);
auto ans_mbufs =
UNWRAP_OR_RETURN_ERR(std::move(ans_mbufs_), DNSParseError::AllocationError);
auto auth_mbufs_ =
RTEMbufArray<ResourceRecord, MAX_RECORDS, MbufType::Raw>::init(mempool, num_auth);
auto auth_mbufs =
UNWRAP_OR_RETURN_ERR(std::move(auth_mbufs_), DNSParseError::AllocationError);
auto add_mbufs_ =
RTEMbufArray<ResourceRecord, MAX_RECORDS, MbufType::Raw>::init(mempool, num_add);
auto add_mbufs =
UNWRAP_OR_RETURN_ERR(std::move(add_mbufs_), DNSParseError::AllocationError);
for (auto &rec : ans_mbufs) {
rec = UNWRAP_OR_RETURN(ParseResourceRecord(dns_bytes, reader));
}
for (auto &rec : auth_mbufs) {
rec = UNWRAP_OR_RETURN(ParseResourceRecord(dns_bytes, reader));
}
for (auto &rec : add_mbufs) {
rec = UNWRAP_OR_RETURN(ParseResourceRecord(dns_bytes, reader));
}
// Construct packet
auto res = DNSPacket(ip_data, rte_be_to_cpu_16(hdr->id), question,
static_cast<DnsQType>(rte_be_to_cpu_16(question_info->qtype)),
static_cast<DnsRCode>(hdr->rcode), records_capped, std::move(ans_mbufs),
std::move(auth_mbufs), std::move(add_mbufs));
return res;
}