Program Listing for File worker.cpp
↰ Return to documentation for file (src/worker.cpp
)
#include "worker.h"
#include <rte_common.h>
#include <rte_hexdump.h>
#include <rte_ip.h>
#include <rte_malloc.h>
#include <rte_mempool.h>
#include <spdlog/spdlog.h>
#include <glaze/glaze.hpp>
#include <optional>
#include <stack>
#include <tuple>
#include <variant>
#include "dns_format.h"
#include "dns_packet.h"
#include "dns_packet_constructor.h"
#include "dns_struct_defs.h"
#include "dpdk_wrappers.h"
#include "expected_helpers.h"
#include "intrusive_list.h"
#include "network_types.h"
#include "scanner_config.h"
struct RequestContainer {
RequestContainer()
: ready_for_send_node(this), timeout_node(this), request(std::nullopt) { }
Node<RequestContainer> ready_for_send_node;
Node<RequestContainer> timeout_node;
std::optional<RTEMbufElement<Request>> request;
std::chrono::time_point<std::chrono::steady_clock> time_sent;
InAddr resolver;
size_t num_tries;
size_t loc;
};
struct WorkerContext {
static tl::expected<WorkerContext, int> init(uint16_t worker_id,
const WorkerParams ¶m) {
WorkerContext ctx{worker_id, param};
for (size_t i = 0; i < param.num_containers; i++) {
ctx.request_containers[i].loc = i;
ctx.available_container_stack.push(ctx.request_containers.data() + i);
}
ctx.pkt_distributors.reserve(param.num_workers);
for (size_t i = 0; i < param.num_workers; i++) {
auto dns_array = UNWRAP_OR_RETURN(
(RTEMbufArray<DNSPacketDistr, RX_PKT_BURST>::init(param.dns_mempool)));
ctx.pkt_distributors.emplace_back(std::move(dns_array));
}
return ctx;
}
uint16_t worker_id;
uint16_t queue_id;
size_t resolver_count;
IntrusiveList<RequestContainer, &RequestContainer::ready_for_send_node> ready_for_send_list;
IntrusiveList<RequestContainer, &RequestContainer::timeout_node> timeout_list;
std::vector<RequestContainer, RteAllocator<RequestContainer>> request_containers;
std::stack<RequestContainer *,
std::vector<RequestContainer *, RteAllocator<RequestContainer *>>>
available_container_stack;
std::vector<RTEMbufArray<DNSPacketDistr, RX_PKT_BURST>> pkt_distributors;
private:
WorkerContext(uint16_t worker_id, const WorkerParams ¶m)
: worker_id(worker_id),
queue_id(worker_id + (NIC_OPTS::queue_for_main_thread ? 1 : 0)),
request_containers(param.num_containers) { }
};
std::pair<uint16_t, uint16_t> PackPacketParams(uint16_t worker_id, uint32_t buffer_loc) {
// Use upper 6 bits of dns id for worker id
uint16_t dns_id = (worker_id + 1) << 10;
// Store upper 10 bits of buffer_loc in dns id
dns_id |= (buffer_loc & static_cast<uint32_t>(0x000FFC000)) >> 14;
// Store lower 14 bits of buffer_loc in udp_src, 1024 should be the minimum port number
uint16_t udp_port = (buffer_loc & static_cast<uint32_t>(0x00003FFF)) + 1024;
return std::make_pair(udp_port, dns_id);
}
RTEMbufArray<DefaultPacket, TX_PKT_BURST, MbufType::Pkt> ConsumeReadyForSendAndPrepare(
NICType &rxtx_if, RTEMempool<DefaultPacket, MbufType::Pkt> &mempool, WorkerContext &ctx,
WorkerParams ¶m, uint16_t max_to_prepare) {
auto current_time = std::chrono::steady_clock::now();
// Get packets or return empty MbufArray
auto packets = ({
tl::expected res = RTEMbufArray<DefaultPacket, TX_PKT_BURST, MbufType::Pkt>::init(
mempool, max_to_prepare);
if (!res) {
return RTEMbufArray<DefaultPacket, TX_PKT_BURST, MbufType::Pkt>::init(
mempool, 0)
.value();
}
std::move(*res);
});
const rte_ether_addr src_mac = rxtx_if.GetMacAddr();
size_t tx_write_cnt = 0;
for (auto it = ctx.ready_for_send_list.begin(); it != ctx.ready_for_send_list.end();) {
if (tx_write_cnt >= packets.size())
break;
if (!it->request) {
ctx.available_container_stack.push(&it.front());
it = ctx.ready_for_send_list.erase(it);
spdlog::error("Request container without request in ready for send queue");
continue;
}
Request &req = it->request->get_data();
auto [udp_port, dns_id] = PackPacketParams(ctx.worker_id, it->loc);
param.counters[ctx.worker_id].retry += (it->num_tries != 0);
ctx.resolver_count++;
ctx.resolver_count %= param.resolvers.size();
it->resolver = param.resolvers[ctx.resolver_count];
DNSPacketConstructor::ConstructIpv4DNSPacket(packets[tx_write_cnt], src_mac,
req.dst_mac, std::get<InAddr>(req.src_ip).s_addr, it->resolver.s_addr, udp_port,
dns_id, req.name.buf.data(), req.name.len, req.q_type);
rxtx_if.PreparePktCksums<DefaultPacket, L3Type::Ipv4, L4Type::UDP>(
packets[tx_write_cnt]);
tx_write_cnt++;
it->time_sent = current_time;
ctx.timeout_list.push_back(*it);
it = ctx.ready_for_send_list.erase(it);
}
auto [constructed_packets, _] = packets.split(tx_write_cnt);
rxtx_if.PreparePackets(ctx.queue_id, constructed_packets);
return std::move(constructed_packets);
}
void ProcessTimeouts(WorkerContext &ctx, WorkerParams ¶m) {
auto current_time = std::chrono::steady_clock::now();
for (auto it = ctx.timeout_list.begin(); it != ctx.timeout_list.end();) {
auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
current_time - it->time_sent);
if (static_cast<uint64_t>(elapsed_ms.count()) < param.timeout_ms)
break;
if (++it->num_tries > param.max_retries) {
spdlog::warn("{} has reached max retries",
it->request->get().name.buf.data());
it->request = std::nullopt;
it->num_tries = 0;
param.counters[ctx.worker_id].max_retry++;
ctx.available_container_stack.push(&it.front());
} else {
ctx.ready_for_send_list.push_back(*it);
}
it = ctx.timeout_list.erase(it);
}
}
void AddRequestsToReadyForSend(WorkerContext &ctx,
RTEMbufArray<Request, TX_PKT_BURST> &request_buf) {
while (!ctx.available_container_stack.empty()) {
RequestContainer *free_container = ctx.available_container_stack.top();
auto request = ({
std::optional res = request_buf.pop();
if (!res)
break;
std::move(*res);
});
if (request.get_data().num_ips == 0 || request.get_data().num_ips > REQUEST_MAX_IPS)
[[unlikely]] {
spdlog::warn("Dropping request, invalid number of IPs");
continue;
}
free_container->num_tries = 0;
free_container->request = std::move(request);
ctx.ready_for_send_list.push_back(*free_container);
ctx.available_container_stack.pop();
}
}
void HandleParsedPacket(WorkerContext &ctx, WorkerParams ¶m, const DNSPacket &pkt,
RTEMbuf<DefaultPacket> &raw_pkt, std::shared_ptr<spdlog::logger> out_logger) {
switch (pkt.r_code) {
case DnsRCode::NOERROR:
param.counters[ctx.worker_id].noerror++;
break;
case DnsRCode::SERVFAIL:
param.counters[ctx.worker_id].servfail++;
break;
case DnsRCode::NXDOMAIN:
param.counters[ctx.worker_id].nxdomain++;
break;
default:
param.counters[ctx.worker_id].rcode_other++;
break;
}
if (pkt.GetBufferLoc() >= ctx.request_containers.size()) [[unlikely]] {
spdlog::warn("packet has out of bounds request buffer location: {}",
pkt.GetBufferLoc());
return;
}
auto &request_container = ctx.request_containers[pkt.GetBufferLoc()];
if (!request_container.request) [[unlikely]] {
spdlog::warn("packet with name {} in location {:x} already processed!",
pkt.question, pkt.GetBufferLoc());
return;
}
if (request_container.request->get().q_type != pkt.q_type) [[unlikely]] {
spdlog::warn("packet with name {} has q type mismatch!", pkt.question);
return;
}
if (request_container.resolver.s_addr != std::get<InAddr>(pkt.ip_data.src_ip).s_addr)
[[unlikely]] {
spdlog::warn("packet with name {} has IP mismatch!", pkt.question);
return;
}
if (request_container.request->get().name != pkt.question) [[unlikely]] {
spdlog::warn("packet with name {} has name mismatch with request {}", pkt.question,
request_container.request->get().name);
return;
}
param.counters[ctx.worker_id].num_resolved++;
request_container.request = std::nullopt;
request_container.num_tries = 0;
ctx.timeout_list.delete_elem(request_container);
ctx.ready_for_send_list.delete_elem(request_container);
ctx.available_container_stack.push(&request_container);
if (param.rcode_filters) {
const auto &v = param.rcode_filters.value();
if (!std::count(v.begin(), v.end(), pkt.r_code))
return;
}
// fmt::print("{:02x}\n", fmt::join(reinterpret_cast<unsigned char
// *>(raw_pkt.data().padding.data()),
// reinterpret_cast<unsigned char *>(raw_pkt.data().padding.data() +
// raw_pkt.data_len), ""));
if (!param.output_raw) {
std::string out = glz::write_json(pkt);
out_logger->info(out);
} else {
auto dns_offset = sizeof(rte_ether_hdr) + sizeof(rte_udp_hdr);
dns_offset += std::holds_alternative<InAddr>(pkt.ip_data.dst_ip)
? sizeof(rte_ipv4_hdr)
: sizeof(rte_ipv6_hdr);
auto pkt_begin = raw_pkt.data().padding.data();
auto dns_begin = pkt_begin + dns_offset;
auto pkt_end = pkt_begin + raw_pkt.data_len;
// Reset transaction ID to 0
auto dns_header = reinterpret_cast<DnsHeader *>(dns_begin);
dns_header->id = 0;
out_logger->info("{:02x}", fmt::join(reinterpret_cast<unsigned char *>(dns_begin),
reinterpret_cast<unsigned char *>(pkt_end), ""));
}
}
void RX(WorkerContext &ctx, NICType &rxtx_if, uint16_t worker_id, WorkerParams ¶m,
std::shared_ptr<spdlog::logger> out_logger) {
auto pkts = rxtx_if.RcvPackets<RX_PKT_BURST>(ctx.queue_id);
param.counters[ctx.worker_id].rcvd_pkts += pkts.size();
std::optional<RTEMbufElement<DefaultPacket, MbufType::Pkt>> p;
while ((p = pkts.pop())) {
if ((p->get().ol_flags & RTE_MBUF_F_RX_IP_CKSUM_MASK) ==
RTE_MBUF_F_RX_IP_CKSUM_BAD ||
(p->get().ol_flags & RTE_MBUF_F_RX_L4_CKSUM_MASK) == RTE_MBUF_F_RX_IP_CKSUM_BAD)
continue;
auto parsed_packet = ({
auto res = DNSPacket::init(param.raw_mempool, *p);
if (!res) {
auto pkt_begin = p->get().data().padding.data();
auto pkt_end = pkt_begin + p->get().data_len;
auto pkt_formatted =
fmt::join(reinterpret_cast<unsigned char *>(pkt_begin),
reinterpret_cast<unsigned char *>(pkt_end), "");
spdlog::warn("error parsing packet: {}, raw pkt: {:02x}",
res.error(), std::move(pkt_formatted));
param.counters[worker_id].parse_fail++;
continue;
}
std::move(*res);
});
if (parsed_packet.GetWorkerId() >= param.num_workers) [[unlikely]] {
spdlog::warn("packet worker id {} is larger than number of workers",
parsed_packet.GetWorkerId());
param.counters[worker_id].parse_fail++;
continue;
}
param.counters[worker_id].parse_success++;
if constexpr (!NIC_OPTS::hw_flowsteering) {
if (worker_id != parsed_packet.GetWorkerId()) {
auto &dns_dist = ctx.pkt_distributors[parsed_packet.GetWorkerId()];
auto dns_packet_distr =
DNSPacketDistr{.dns_packet = std::move(parsed_packet),
.raw_packet = std::move(*p)};
// Distributing the packet on the ring requires it to be wrapped in
// an mbuf. We can avoid the allocation in the fast path, but here
// it is required.
auto wrapped_packet = ({
auto res = RTEMbufElement<DNSPacketDistr>::init(
param.dns_mempool, std::move(dns_packet_distr));
if (!res) {
spdlog::warn("worker {}: dropping packet due to "
"allocation failure {}",
worker_id, res.error());
continue;
}
std::move(*res);
});
auto dns_push = dns_dist.push(std::move(wrapped_packet));
if (dns_push)
spdlog::warn("worker {}: distribution buffer is full, "
"dropping packet");
continue;
}
}
if (worker_id != parsed_packet.GetWorkerId()) [[unlikely]] {
spdlog::warn("packet worker id {} differs from ours: {}",
parsed_packet.GetWorkerId(), worker_id);
param.counters[worker_id].other_worker_id++;
continue;
}
// Fast path - the packet does not need to be distributed on a ring.
HandleParsedPacket(ctx, param, parsed_packet, p->get(), out_logger);
}
// With flowsteering there is nothing left to do.
// Without flowsteering we continue on the slow path.
if constexpr (NIC_OPTS::hw_flowsteering)
return;
// Distribute the packets in the distribution buffer on to each other worker's ring.
for (size_t remote = 0; remote < param.num_workers; remote++) {
// Packets meant for our worker use fast path - even when flow steering is disabled.
if (remote == ctx.worker_id)
continue;
auto &dns_ring = param.distribution_rings[remote];
auto &dns_buf = ctx.pkt_distributors[remote];
if (dns_buf.size() == 0)
continue;
auto dns_remainder = dns_ring.enqueue_burst(std::move(dns_buf));
ctx.pkt_distributors[remote] = std::move(dns_remainder);
}
// Dequeue all outstanding packets and continue packet receive process.
auto &dns_ring = param.distribution_rings[worker_id];
auto rx_dns = dns_ring.dequeue_burst<RX_PKT_BURST>();
for (auto &p : rx_dns)
HandleParsedPacket(ctx, param, p.dns_packet, p.raw_packet.get(), out_logger);
}
int Worker(std::stop_token stop_token, uint16_t worker_id, WorkerParams param) {
const size_t num_containers = param.num_containers;
auto ctx = ({
auto res = WorkerContext::init(worker_id, param);
if (!res) {
spdlog::error("worker {}: failed to init WorkerContext: {}", worker_id,
res.error());
param.workers_finished.count_down();
return -1;
}
std::move(*res);
});
std::chrono::time_point<std::chrono::steady_clock> next_send_time =
std::chrono::steady_clock::now();
auto request_buf = RTEMbufArray<Request, TX_PKT_BURST>::init(param.request_mempool).value();
auto prepared_packet_buf =
RTEMbufArray<DefaultPacket, TX_PKT_BURST, MbufType::Pkt>::init(param.pkt_mempool)
.value();
auto out_logger = spdlog::get("output_log");
bool finished_set = false;
while (!param.workers_finished.try_wait()) {
bool worker_finished = ctx.available_container_stack.size() == num_containers &&
param.domains_finished.try_wait() &&
request_buf.size() == 0 && param.ring.empty();
if ((stop_token.stop_requested() || worker_finished) && !finished_set) {
finished_set = true;
param.workers_finished.count_down();
}
auto current_time = std::chrono::steady_clock::now();
std::ignore = request_buf.insert(
param.ring.dequeue_burst<TX_PKT_BURST>(request_buf.free_cnt()));
ProcessTimeouts(ctx, param);
AddRequestsToReadyForSend(ctx, request_buf);
if (current_time >= next_send_time) {
auto prepared_packets = ConsumeReadyForSendAndPrepare(param.rxtx_if,
param.pkt_mempool, ctx, param, prepared_packet_buf.free_cnt());
size_t ps_per_req = 1000'000'000'000 / param.rate_lim_pps;
next_send_time =
current_time +
std::chrono::nanoseconds((ps_per_req * prepared_packets.size()) / 1000);
std::ignore = prepared_packet_buf.insert(std::move(prepared_packets));
}
// for (auto& packet: prepared_packet_buf) {
// char* data = rte_pktmbuf_mtod(&packet, char*);
// rte_memdump(stdout, "PACKET", data, packet.data_len);
// }
auto [num_sent, unsent_packets] =
param.rxtx_if.SendPackets(ctx.queue_id, std::move(prepared_packet_buf));
prepared_packet_buf = std::move(unsent_packets);
param.counters[worker_id].sent_pkts += num_sent;
RX(ctx, param.rxtx_if, worker_id, param, out_logger);
}
return 0;
}