Program Listing for File main.cpp

Return to documentation for file (src/main.cpp)

#include <expected_helpers.h>
#include <fcntl.h>
#include <ncurses.h>
#include <net/if.h>
#include <rte_cycles.h>
#include <rte_errno.h>
#include <rte_ethdev.h>
#include <rte_hexdump.h>
#include <rte_log.h>
#include <rte_malloc.h>
#include <rte_ring_core.h>
#include <signal.h>
#include <spdlog/fmt/bundled/core.h>
#include <spdlog/sinks/stdout_color_sinks.h>
#include <sys/mman.h>
#include <sys/types.h>
#include <time.h>

#include <algorithm>
#include <bit>
#include <chrono>
#include <cstddef>
#include <cstdint>
#include <filesystem>
#include <glaze/glaze.hpp>
#include <iostream>
#include <optional>
#include <ratio>
#include <thread>
#include <tuple>

#include "ProgramOptions.hxx"
#include "arp.h"
#include "dns_format.h"
#include "dpdk_wrappers.h"
#include "eth_rxtx.h"
#include "eth_rxtx_opts.h"
#include "fixed_name.hpp"
#include "fmt_helpers.h"
#include "input_reader.h"
#include "net_info.h"
#include "network_types.h"
#include "parse_helpers.h"
#include "request.h"
#include "scanner_config.h"
#include "spdlog/fmt/bundled/core.h"
#include "spdlog/sinks/basic_file_sink.h"
#include "spdlog/sinks/stdout_color_sinks.h"
#include "spdlog/spdlog.h"
#include "thread_manager.h"
#include "version.h"
#include "worker.h"

namespace {
std::atomic_bool sigint_received = false;

// Print "error:" in red
const char* error_str = "\033[1;31merror:\033[0m";
} // namespace

void SigintHandler([[maybe_unused]] int s) {
    // printf("Caught signal %d\n", s);
    sigint_received = true;
}

template <class opts>
Arp::Error RequestGatewayMac(Arp& arp, EthRxTx<opts>& rxtx,
    RTEMempool<DefaultPacket, MbufType::Pkt>& mpool, InAddr gateway_ip) {
    const auto send_func = [&](RTEMbufElement<DefaultPacket, MbufType::Pkt>&& pkt) {
        return rxtx.SendPacket(0, std::move(pkt)) == 1;
    };

    const auto recv_func = [&] { return rxtx.template RcvPackets<8>(0); };

    return arp.RequestAddr(gateway_ip.s_addr, mpool, send_func, recv_func);
}

struct UserConfig {
    bool headless;

    uint16_t cores;
    uint32_t rate;
    uint32_t num_concurrent;
    uint32_t timeout_ms;
    uint32_t num_retries;

    std::optional<InAddr> gateway_ip;
    std::optional<InAddr> static_ip;

    std::optional<EtherAddr> gateway_mac;

    std::optional<std::string> device_name;
    std::string input_file;
    std::string xdp_path;
    std::vector<InAddr> resolvers;
    std::optional<std::vector<DnsRCode>> rcode_filters;

    std::optional<DnsName> prefix;
    std::optional<DnsName> postfix;

    std::optional<std::string> log_path;
    std::string output_path;
    bool output_raw;
    bool no_huge;
    bool debug;

    DnsQType q_type;
};

std::optional<UserConfig> InitConfigFromArgs(int argc, char** argv) {
    po::parser parser;

    auto& help = parser["help"].abbreviation('h').description("print this help screen");

    auto& version = parser["version"].description("print the version and exit");

    auto& headless = parser["headless"].description("run in headless mode (no terminal UI)");

    auto& cores = parser["cores"]
                      .abbreviation('w')
                      .description("number of cores to use (default: 2)")
                      .type(po::u32)
                      .fallback(2);

    auto& rate = parser["rate"]
                     .abbreviation('r')
                     .description("scan rate in [packets per second] (default: 1000)")
                     .type(po::u32)
                     .fallback(1000);

    auto& num_concurrent =
        parser["num-concurrent"]
        .abbreviation('c')
        .description("max number of concurrent DNS requests\n  (default: rate/5)")
        .type(po::u32)
        .fallback(200);

    auto& timeout_ms = parser["timeout"]
                           .abbreviation('t')
                           .description("timeout [ms] (default: 15'000)")
                           .type(po::u32)
                           .fallback(15'000);

    auto& num_retries = parser["num-retries"]
                            .description("number of retries (default: 10)")
                            .type(po::u32)
                            .fallback(10);

    auto& gateway_ip = parser["gateway-ip"]
                           .abbreviation('g')
                           .description("IP address of gateway")
                           .type(po::string);

    auto& static_ip = parser["static-ip"]
                          .abbreviation('s')
                          .description("own (static) IP address")
                          .type(po::string);

    auto& gateway_mac = parser["gateway-mac"]
                            .abbreviation('m')
                            .description("gateway mac, ARP will be used if no MAC is specified")
                            .type(po::string);

    auto& device_name = parser["device-name"]
                            .abbreviation('d')
                            .description("Device name (example: 0000:2e:00:0)")
                            .type(po::string);

    auto& input_file = parser["input-file"]
                           .abbreviation('i')
                           .description("Path of input file with domains")
                           .type(po::string);

#ifdef NIC_AF_XDP
    auto& xdp_path = parser["xdp-path"]
                         .abbreviation('x')
                         .description("Path to XDP program")
                         .type(po::string)
                         .fallback("/usr/local/bin/sanicdns_xdp.c.o");
#endif

    auto& resolvers =
        parser["resolvers"]
        .description("Resolvers, either:\n   1. Comma-seperated "
                 "list of IP's\n   2. File with a resolver specified on each line")
        .type(po::string);

    auto& rcodes_filters = parser["rcodes"]
                               .description("Only output results with these DNS return codes\n"
                        "    Example: --rcodes NOERROR,SERVFAIL")
                               .type(po::string);

    auto& prefix = parser["prefix"]
                       .description("Prefix to add to each line of the input")
                       .type(po::string);

    auto& postfix = parser["postfix"]
                        .description("Postfix to add to each line of the input")
                        .type(po::string);

    auto& log_path =
        parser["log-path"]
        .abbreviation('l')
        .description("Log file path, logging will be enabled when a log path is set")
        .type(po::string);

    auto& output_path = parser["output-path"]
                            .abbreviation('o')
                            .description("output path (default: output.txt)")
                            .type(po::string)
                            .fallback("output.txt");

    auto& output_raw = parser["output-raw"].description(
        "output raw DNS packets in hex (from DNS header to end of packet)");

    auto& no_huge = parser["no-huge"].description("Don't use huge pages");

    auto& debug = parser["debug"].description("Print debug information");

    auto& q_type = parser["q-type"]
                       .abbreviation('q')
                       .description("Question type\n (A, NS, CNAME, DNAME, SOA, "
                    "PTR, MX, TXT, AAAA, CAA, OPT)")
                       .type(po::string)
                       .fallback("A");

    if (!parser(argc, argv))
        return std::nullopt;

    if (help.was_set()) {
        std::cout << parser << "\n";
        return std::nullopt;
    }

    if (version.was_set()) {
        std::cout << fmt::format("sanicdns {}, built for NIC type: {}", SANICDNS_VERSION, NIC_NAME) << std::endl;
        return std::nullopt;
    }

    UserConfig config{};

    config.headless = headless.was_set();

    const uint32_t hw_concurrency = std::thread::hardware_concurrency();
    config.cores = cores.get().u32;
    if (config.cores < 2 || config.cores > hw_concurrency) {
        fmt::print("{} minimum number of cores is 2, max is {}\n", error_str,
            hw_concurrency);
        return std::nullopt;
    }

    config.rate = rate.get().u32;
    if (config.rate < 100) {
        fmt::print("{} minimum rate limit is 100[pps]", error_str);
        return std::nullopt;
    }

    config.num_concurrent = num_concurrent.get().u32;
    if (!num_concurrent.was_set())
        config.num_concurrent = config.rate / 2;

    config.timeout_ms = timeout_ms.get().u32;
    config.num_retries = num_retries.get().u32;

    // Parse gateway IP
    if (gateway_ip.was_set()) {
        std::optional res = InAddr::init(gateway_ip.get().string);
        if (!res) {
            fmt::print("{} cannot parse IP {}\n", error_str, gateway_ip.get().string);
            return std::nullopt;
        }
        config.gateway_ip = res;
    }

    if (static_ip.was_set()) {
        std::optional res = InAddr::init(static_ip.get().string);
        if (!res) {
            fmt::print("{} cannot parse IP {}\n", error_str, static_ip.get().string);
            return std::nullopt;
        }
        config.static_ip = res;
    }

    if (gateway_mac.was_set()) {
        auto res = EtherAddr::init(gateway_mac.get().string);
        if (!res) {
            fmt::print("{} cannot parse MAC {}\n", error_str, gateway_mac.get().string);
            return std::nullopt;
        }
        config.gateway_mac = res;
    }

    if (device_name.was_set()) {
        config.device_name = device_name.get().string;
    }

    if (!input_file.was_set()) {
        fmt::print("{} provide an input file\n", error_str);
        return std::nullopt;
    }
    config.input_file = input_file.get().string;

#ifdef NIC_AF_XDP
    config.xdp_path = xdp_path.get().string;
#endif

    if (!resolvers.was_set()) {
        fmt::print("{} provide resolvers\n", error_str);
        return std::nullopt;
    }
    config.resolvers = ({
        tl::expected res = ParseResolvers(resolvers.get().string);
        if (!res) {
            fmt::print("{} parsing resolvers: {}\n", error_str, res.error());
            return std::nullopt;
        }
        std::move(*res);
    });

    if (rcodes_filters.was_set()) {
        config.rcode_filters = ({
            tl::expected res = ParseDNSReturnCodes(rcodes_filters.get().string);
            if (!res) {
                fmt::print("{} parsing DNS return codes: {}\n", error_str,
                    res.error());
                return std::nullopt;
            }
            std::move(*res);
        });
    }

    config.prefix = std::nullopt;
    if (prefix.was_set()) {
        auto tmp = DnsName::init(prefix.get().string);
        if (!tmp) {
            fmt::print("{} prefix too long\n", error_str);
            return std::nullopt;
        }
        config.prefix = tmp.value();
    }

    config.postfix = std::nullopt;
    if (postfix.was_set()) {
        auto tmp = DnsName::init(postfix.get().string);
        if (!tmp) {
            fmt::print("{} postfix too long\n", error_str);
            return std::nullopt;
        }
        config.postfix = tmp.value();
    }

    config.log_path = std::nullopt;
    if (log_path.was_set())
        config.log_path = log_path.get().string;

    config.output_path = output_path.get().string;
    config.output_raw = output_raw.was_set();
    config.no_huge = no_huge.was_set();
    config.debug = debug.was_set();

    config.q_type = ({
        std::optional res = GetQTypeFromString(q_type.get().string);
        if (!res) {
            fmt::print("{} invalid question type {}, choose from A, NS, CNAME, "
                   "DNAME, SOA, PTR, MX, TXT, AAAA, CAA, OPT\n",
                error_str, q_type.get().string);
            return std::nullopt;
        }
        std::move(*res);
    });

    return config;
}

struct EthernetConfig {
    std::string device_name;

    // If nullopt, destination MAC has to be retrieved through ARP
    std::optional<EtherAddr> dst_mac;

    InAddr src_ip;
    InAddr dst_ip;
};

std::optional<EthernetConfig> GetEthernetConfig(const UserConfig& user_config) {
    std::optional<net_info::RouteInfo> route_info;

    EthernetConfig to_ret{};

    // Fetch the kernel network routes when AF_XDP is enabled and initialize the device name
    if (NIC_OPTS::make_af_xdp_socket) {
        std::optional<FixedName<IFNAMSIZ>> dev_name_linux{};
        if (user_config.device_name.has_value()) {
            dev_name_linux = FixedName<IFNAMSIZ>::init(user_config.device_name.value());
        }

        route_info = ({
            tl::expected res = net_info::get_route_info(dev_name_linux);
            if (!res) {
                fmt::print("{} get_route_info: {}\n", error_str, res.error());
                return std::nullopt;
            }
            std::move(*res);
        });

        spdlog::info("Got route info: {}", glz::write_json(route_info.value()));
        to_ret.device_name = std::string_view(route_info->if_name);
    } else if (user_config.device_name.has_value()) {
        to_ret.device_name = user_config.device_name.value();
    } else {
        fmt::print("{} Enter device name\n", error_str);
        return std::nullopt;
    }

    if (user_config.static_ip) {
        to_ret.src_ip = user_config.static_ip.value();
    } else if (route_info.has_value() && route_info->source_addr.has_value()) {
        to_ret.src_ip = route_info->source_addr.value();
    } else {
        fmt::print("{}: Enter static IP manually\n", error_str);
        return std::nullopt;
    }

    if (user_config.gateway_ip) {
        to_ret.dst_ip = user_config.gateway_ip.value();
    } else if (route_info.has_value() && route_info->gateway_addr.has_value()) {
        to_ret.dst_ip = route_info->gateway_addr.value();
    } else {
        fmt::print("{}: Enter gateway IP manually\n", error_str);
        return std::nullopt;
    }

    std::optional<EtherAddr> mac_addr_os_arp_table{};
    if constexpr (NIC_OPTS::make_af_xdp_socket) {
        auto dev_name = UNWRAP_OR_RETURN_VAL_AND_LOG(
            FixedName<IFNAMSIZ>::init(to_ret.device_name), std::nullopt);
        auto res = net_info::get_mac_address(dev_name, to_ret.dst_ip);
        if (!res)
            spdlog::warn("Cannot get MAC of {}: {}\n", to_ret.dst_ip.str(),
                res.error());

        if (res)
            mac_addr_os_arp_table = res.value();
    }

    if (user_config.gateway_mac) {
        to_ret.dst_mac = user_config.gateway_mac.value();
    } else if (NIC_OPTS::request_arp) {
        to_ret.dst_mac = std::nullopt; // Request ARP when network is configured
    } else if (mac_addr_os_arp_table) {
        to_ret.dst_mac = mac_addr_os_arp_table;
    } else {
        fmt::print("{}: Enter gateway MAC manually\n", error_str);
        return std::nullopt;
    }

    return to_ret;
}

tl::expected<void, std::string> VerifyQueues(FixedName<IFNAMSIZ> dev_name, const uint16_t num_cores,
    const uint16_t total_queues) {
    auto channel_count = net_info::get_channel_count(dev_name);
    if (!channel_count.has_value()) {
        return tl::unexpected(
            fmt::format("{}, channel_count error: {}\n", error_str, channel_count.error()));
    }

    uint32_t combined_channels = channel_count.value().combined_count;

    // Seems to be the DPDK logic
    if (combined_channels == 0)
        combined_channels = 1;

    if (total_queues != combined_channels) {
        return tl::unexpected(fmt::format(
            "{} Running sanicdns with '-w {}' requires {} queues (current {}). "
            "Configure using 'sudo ethtool -L {} combined {}'\n",
            error_str, num_cores, total_queues, combined_channels, dev_name, total_queues));
    }

    return {};
}

rte_malloc_socket_stats GetMemoryUsage() {
    rte_malloc_socket_stats sock_stats{};

    for (int socket_id = 0; socket_id < RTE_MAX_NUMA_NODES; socket_id++) {
        rte_malloc_socket_stats sock_stats_node;
        int ret = rte_malloc_get_socket_stats(socket_id, &sock_stats_node);
        if (ret == 0) {
            sock_stats.heap_allocsz_bytes += sock_stats_node.heap_allocsz_bytes;
            sock_stats.heap_freesz_bytes += sock_stats_node.heap_freesz_bytes;
        }
    }

    return sock_stats;
}

std::string FormatCounters(const PerCoreCounters& total_count,
    const PerCoreCounters& total_count_per_s, const bool debug) {
    using namespace fmt_helpers;

    std::string out;
    out += fmt::format("Resolved:\t{}domains\t- {}domains/second\n",
        fmt_count(total_count.num_resolved), fmt_count(total_count_per_s.num_resolved));
    out += fmt::format("Retry:\t{}pkts\t- {}pps\n", fmt_count(total_count.retry),
        fmt_count(total_count_per_s.retry));
    out += fmt::format("Max retry:\t{}pkts\t- {}pps\n", fmt_count(total_count.max_retry),
        fmt_count(total_count_per_s.max_retry));
    out += "-----------------------------------\n";
    out += fmt::format("Input lines rejected:\t{}lines\t- {}lines/second\n",
        fmt_count(total_count.lines_rejected), fmt_count(total_count_per_s.max_retry));
    out += "-----------------------------------\n";
    out += fmt::format("TX:\t{}pkts\t- {}pps\n", fmt_count(total_count.sent_pkts),
        fmt_count(total_count_per_s.sent_pkts));
    out += fmt::format("RX:\t{}pkts\t- {}pps\n", fmt_count(total_count.rcvd_pkts),
        fmt_count(total_count_per_s.rcvd_pkts));
    out += "-----------------------------------\n";
    out += fmt::format("Parse fail:\t{}pkts\t- {}pps\n", fmt_count(total_count.parse_fail),
        fmt_count(total_count_per_s.parse_fail));
    out += fmt::format("Parse success:\t{}pkts\t- {}pps\n",
        fmt_count(total_count.parse_success), fmt_count(total_count_per_s.parse_success));
    out += "-----------------------------------\n";
    out += "DNS error codes:\n";
    out += fmt::format("NOERROR:\t{}pkts\t- {}pps\n", fmt_count(total_count.noerror),
        fmt_count(total_count_per_s.noerror));
    out += fmt::format("NXDOMAIN:\t{}pkts\t- {}pps\n", fmt_count(total_count.nxdomain),
        fmt_count(total_count_per_s.nxdomain));
    out += fmt::format("SERVFAIL:\t{}pkts\t- {}pps\n", fmt_count(total_count.servfail),
        fmt_count(total_count_per_s.servfail));
    out += fmt::format("Other:\t{}pkts\t- {}pps\n", fmt_count(total_count.rcode_other),
        fmt_count(total_count_per_s.rcode_other));

    if (!debug)
        return out;
    out += "-----------------------------------\n";

    auto usage = GetMemoryUsage();
    out += fmt::format("Memory usage: {}B\n", fmt_count(usage.heap_allocsz_bytes));
    out += fmt::format("Memory free: {}B\n", fmt_count(usage.heap_freesz_bytes));

    return out;
}

static std::vector<std::string> init_eal_args(const UserConfig& user_config,
    const EthernetConfig& ethernet_config) {
    std::vector<std::string> args;

    if (user_config.no_huge)
        args.emplace_back("--no-huge");

    // Cores is always larger than 2, subtract 1 for main thread
    const int num_workers = user_config.cores - 1;
    const int total_queues = num_workers + (NIC_OPTS::queue_for_main_thread ? 1 : 0);

    spdlog::info("total_queues: {}", total_queues);
    if constexpr (NIC_OPTS::make_af_xdp_socket)
        args.emplace_back(
            fmt::format("--vdev=net_af_xdp,iface={},xdp_prog={},queue_count={}",
            ethernet_config.device_name, user_config.xdp_path, total_queues));

    return args;
}

int main(int argc, char** argv) {
    auto user_config = UNWRAP_OR_RETURN_VAL(InitConfigFromArgs(argc, argv), -1);

    // Init loggers
    std::vector<spdlog::sink_ptr> sinks{};
    if (user_config.headless) {
        sinks.push_back(std::make_shared<spdlog::sinks::stdout_color_sink_mt>());
    }
    if (user_config.log_path) {
        sinks.push_back(std::make_shared<spdlog::sinks::basic_file_sink_mt>(
            user_config.log_path.value(), true));
    }

    auto logger = std::make_shared<spdlog::logger>("", begin(sinks), end(sinks));
    logger->set_level(spdlog::level::debug);
    spdlog::set_default_logger(logger);

    auto output_logger = spdlog::basic_logger_mt("output_log", user_config.output_path, true);
    output_logger->set_pattern("%v");

    auto ethernet_config = UNWRAP_OR_RETURN_VAL(GetEthernetConfig(user_config), -1);

    spdlog::info("User config: {}", glz::write_json(user_config));
    spdlog::info("Ethernet config: {}", glz::write_json(ethernet_config));

    // Cores is always larger than 2, subtract 1 for main thread
    const uint16_t num_workers = user_config.cores - 1;
    const uint16_t total_queues = num_workers + (NIC_OPTS::queue_for_main_thread ? 1 : 0);

    if constexpr (NIC_OPTS::make_af_xdp_socket) {
        auto dev_name = UNWRAP_OR_RETURN_VAL(
            FixedName<IFNAMSIZ>::init(ethernet_config.device_name), -1);
        auto res = VerifyQueues(dev_name, user_config.cores, total_queues);
        if (!res) {
            fmt::print("{} {}", error_str, res.error());
            return -1;
        }
    }

    auto dpdk_args = init_eal_args(user_config, ethernet_config);

    if (NIC_OPTS::make_af_xdp_socket && !std::filesystem::exists(user_config.xdp_path)) {
        fmt::print("{} Cannot find XDP path {}\n", error_str, user_config.xdp_path);
        return -1;
    }

    std::vector<char*> dpdk_args_argv;
    dpdk_args_argv.push_back(argv[0]);
    for (auto& arg : dpdk_args)
        dpdk_args_argv.push_back(arg.data());

    auto eal_guard_ = ({
        tl::expected res = EALGuard::init(dpdk_args_argv.size(), dpdk_args_argv.data());
        if (!res) {
            fmt::print("{} Failed to initialize EAL\n", error_str);
            return -1;
        }
        std::move(*res);
    });

    ThreadManager thread_manager;

    signal(SIGINT, SigintHandler);

    std::unique_ptr<FILE, int (*)(FILE*)> fptr(fopen(user_config.input_file.data(), "rt"),
        [](FILE* fp) -> int {
            if (fp)
                return ::fclose(fp);
            return EOF;
        });

    if (fptr.get() == NULL) {
        fmt::print("{} File {} cannot be opened\n", error_str, user_config.input_file);
        return -1;
    }

    InputReader input_reader(fptr.get());

    /****************************************
    ***** Init Mempools, Ring and Rxtx *****
    ****************************************/
    auto rxtx_pool = ({
        tl::expected res = RTEMempool<DefaultPacket, MbufType::Pkt>::init("RXTX_POOL",
            RXTX_POOL_SIZE, 512, 0);
        if (!res) {
            fmt::print("{} Failed to initialize rxtx_pool: {}\n", error_str,
                rte_strerror(res.error()));
            return -1;
        }
        std::move(*res);
    });

    EthDevConf eth_config{};

    eth_config.nb_rx_descrs = 512;
    eth_config.nb_tx_descrs = 512;

    eth_config.nb_tx_queues = total_queues;
    eth_config.nb_rx_queues = total_queues;

    const auto nic_id =
        NIC_OPTS::make_af_xdp_socket ? "net_af_xdp" : ethernet_config.device_name;

    auto rxtx_if = ({
        tl::expected res = NICType::init(eth_config, nic_id, rxtx_pool.get());
        if (!res) {
            fmt::print("{} Failed to initialize NIC: {}\n", error_str, res.error());
            return -1;
        }
        std::move(*res);
    });

    if constexpr (NIC_OPTS::make_af_xdp_socket)
        rte_delay_ms(3000);

    const uint32_t request_mempool_size = std::max(user_config.num_concurrent, 1023u);

    auto request_mempool = ({
        tl::expected res =
            RTEMempool<Request>::init("request_mempool", request_mempool_size, 512, 0, 0);
        if (!res) {
            fmt::print("{} Failed to initialize request mempool: {}\n", error_str,
                rte_strerror(res.error()));
            return -1;
        }
        std::move(*res);
    });

    auto dispatch_ring = ({
        tl::expected res = RTERing<Request>::init("request_dispatch_ring", request_mempool,
            std::bit_ceil(10 * num_workers * TX_PKT_BURST),
            RING_F_SP_ENQ | RING_F_MC_HTS_DEQ);
        if (!res) {
            fmt::print("{} Failed to initialize dispatch ring: {}\n", error_str,
                rte_strerror(res.error()));
            return -1;
        }
        std::move(*res);
    });

    auto dns_mempool = ({
        tl::expected res =
            RTEMempool<DefaultPacket>::init("dns_mempool", DNS_MEMPOOL_SIZE, 512, 0, 0);
        if (!res) {
            fmt::print("{} Failed to initialize dns mempool: {}\n", error_str,
                rte_strerror(res.error()));
            return -1;
        }
        std::move(*res);
    });

    auto parsed_mempool = ({
        tl::expected res = RTEMempool<DNSPacketDistr>::init("parsed_dns_mempool",
            DNS_MEMPOOL_SIZE, 512, 0, 0);
        if (!res) {
            fmt::print("{} Failed to initialize parsed mempool: {}\n", error_str,
                rte_strerror(res.error()));
            return -1;
        }
        std::move(*res);
    });

    std::vector<RTERing<DNSPacketDistr>> distribution_rings;
    distribution_rings.reserve(num_workers);
    for (uint16_t i = 0; i < num_workers; i++) {
        auto dns_ring = UNWRAP_OR_RETURN_RAW((RTERing<DNSPacketDistr>::init(
            fmt::format("distribution_ring_{}", i), parsed_mempool, DNS_RING_SIZE, 0)));
        distribution_rings.push_back(std::move(dns_ring));
    }

    if constexpr (NIC_OPTS::hw_flowsteering) {
        for (uint16_t i = 0; i < num_workers; i++) {
            auto res =
                rxtx_if.GenerateDNSFlow(i + 1, 0, 0, 0, 0, (i + 1) << 10, 0xFC00);
            if (!res) {
                fmt::print("{} Failed to generate DNS flow for {}: {}\n", error_str,
                    i, res.error());
                return -1;
            }
        }
    }

    Arp arp(ethernet_config.src_ip.s_addr, rxtx_if.GetMacAddr());

    if (!ethernet_config.dst_mac) {
        spdlog::info("Requesting gateway MAC through ARP");
        if (auto e = RequestGatewayMac(arp, rxtx_if, rxtx_pool, ethernet_config.dst_ip);
            e != Arp::Error::ARP_OK) {
            fmt::print("{} Failed to request gateway MAC: {}\n", error_str,
                Arp::ErrorToString(e));
            return -1;
        }
    } else {
        arp.InsertAddr(ethernet_config.dst_ip.s_addr, ethernet_config.dst_mac.value());
    }

    auto _dst_mac = arp.GetEtherAddr(ethernet_config.dst_ip.s_addr);
    if (!_dst_mac) {
        fmt::print("{} could not find gateway MAC\n", error_str);
        return -1;
    }
    auto dst_mac = std::move(*_dst_mac);

    spdlog::info("Gateway mac: {:x}",
        fmt::join(dst_mac.addr_bytes, dst_mac.addr_bytes + RTE_ETHER_ADDR_LEN, ", "));

    auto start = std::chrono::steady_clock::now();

    /*************************
    ***** Start workers *****
    *************************/
    std::latch workers_finished(num_workers);
    std::latch domains_finished(1);

    // Reserve a PerCoreCounter for the main thread as well
    std::vector<PerCoreCounters, RteAllocator<PerCoreCounters>> counters(num_workers + 1);
    PerCoreCounters& main_core_counters = counters[num_workers];

    WorkerParams shared_worker_params = {
        .num_workers = num_workers,
        .num_containers = (100 + (user_config.num_concurrent / num_workers)),
        .rate_lim_pps = user_config.rate / num_workers,
        .timeout_ms = user_config.timeout_ms,
        .max_retries = user_config.num_retries,
        .counters = counters,
        .resolvers = user_config.resolvers,
        .rcode_filters = user_config.rcode_filters,
        .workers_finished = workers_finished,
        .domains_finished = domains_finished,
        .ring = dispatch_ring,
        .rxtx_if = rxtx_if,
        .raw_mempool = dns_mempool,
        .pkt_mempool = rxtx_pool,
        .request_mempool = request_mempool,
        .output_raw = user_config.output_raw,
        .dns_mempool = parsed_mempool,
        .distribution_rings = distribution_rings,
    };

    for (uint16_t i = 0; i < num_workers; i++) {
        auto res = thread_manager.LaunchThread(Worker, i, shared_worker_params);
        if (res != ThreadManager::LaunchThreadResult::Success) {
            fmt::print("{} Launching thread {} failed\n", error_str, i);
            return -1;
        }
    }

    // EthDevConf eth_config2{};

    // eth_config2.nb_rx_descrs = 4096;
    // eth_config2.nb_tx_descrs = 4096;

    // eth_config2.nb_tx_queues = 1;
    // eth_config2.nb_rx_queues = 1;

    // auto rxtx_if2 = ({
    //  tl::expected res =
    //      EthRxTx<I40E_OPTS>::init(eth_config2, "0000:2e:00.1", rxtx_pool.get());
    //  if (!res) {
    //      spdlog::error("Failed to initialize NIC: {}", res.error());
    //      return -1;
    //  }
    //  std::move(*res);
    // });

    // std::ignore = thread_manager.LaunchThread(
    //     [](std::stop_token stp, EthRxTx<I40E_OPTS>& rxtx_if2,
    //  RTEMempool<DefaultPacket, MbufType::Pkt>& rxtx_pool) {
    //      uint64_t total_packets{};
    //      while (!stp.stop_requested()) {
    //          auto recvd = rxtx_if2.RcvPackets<RX_PKT_BURST>(0);
    //          for (auto& pkt : recvd) {
    //              auto eth_hdr = &pkt.data<struct rte_ether_hdr>();
    //              rte_ipv4_hdr* ipv4_hdr = (rte_ipv4_hdr*) (eth_hdr + 1);
    //              rte_udp_hdr* udp_hdr = (rte_udp_hdr*) (ipv4_hdr + 1);
    //              DnsHeader* dns_hdr = (DnsHeader*) (udp_hdr + 1);

    //              std::swap(eth_hdr->src_addr, eth_hdr->dst_addr);
    //              auto src_ip_old = ipv4_hdr->src_addr;
    //              ipv4_hdr->src_addr = ipv4_hdr->dst_addr;
    //              ipv4_hdr->dst_addr = src_ip_old;

    //              auto src_port_old = udp_hdr->src_port;
    //              udp_hdr->src_port = udp_hdr->dst_port;
    //              udp_hdr->dst_port = src_port_old;

    //              dns_hdr->qr = 1;
    //              dns_hdr->rcode = (unsigned char) DnsRCode::R_NXDOMAIN;

    //              pkt.l2_len = sizeof(rte_ether_hdr);
    //              pkt.l3_len = sizeof(rte_ipv4_hdr);
    //              pkt.l4_len = sizeof(rte_udp_hdr);

    //              pkt.nb_segs = 1;

    //              rxtx_if2
    //              .PreparePktCksums<DefaultPacket, L3Type::Ipv4, L4Type::UDP>(
    //                  pkt);
    //              total_packets++;
    //          }

    //          rxtx_if2.PreparePackets(0, recvd);
    //          rxtx_if2.SendPackets(0, std::move(recvd));
    //      }

    //      spdlog::info("Pkts: {}", total_packets);
    //      return 0;
    //     },
    //     std::ref(rxtx_if2), std::ref(rxtx_pool));

    /***************************************************
    ***** Main loop dispatches domains to workers *****
    ***************************************************/
    RTEMbufArray<Request, TX_PKT_BURST> filled_request_buffer =
        RTEMbufArray<Request, TX_PKT_BURST>::init(request_mempool, 0).value();

    std::chrono::time_point<std::chrono::steady_clock> next_log_time =
        std::chrono::steady_clock::now();

    std::chrono::time_point<std::chrono::steady_clock> next_tui_time =
        std::chrono::steady_clock::now();
    PerCoreCounters last_total_count{};

    if (!user_config.headless)
        initscr();

    while (!workers_finished.try_wait()) {
        if (sigint_received.load(std::memory_order_relaxed)) {
            // Request all threads to stop
            thread_manager.request_stop();
            break;
        }

        [&] {
            // Get some new requests to fill
            auto new_mbufs = ({
                tl::expected res = RTEMbufArray<Request, TX_PKT_BURST>::init(
                    request_mempool, TX_PKT_BURST);
                if (!res) {
                    return;
                }
                std::move(*res);
            });

            size_t max_iters =
                std::min(filled_request_buffer.free_cnt(), new_mbufs.size());
            size_t i = 0;
            for (i = 0; i < max_iters; i++) {
                Request& request = new_mbufs.get_data(i);
                ReadDomainResult res;
                do {
                    DomainInputInfo domain_info;
                    domain_info.buf = request.name.buf.data();

                    res = ReadDomainResult::Success;
                    res = input_reader.GetDomain(domain_info);

                    if (res == ReadDomainResult::FileEnd) {
                        if (!domains_finished.try_wait())
                            domains_finished.count_down();
                        break;
                    }

                    request.name.len = domain_info.len;

                    if (user_config.prefix) {
                        request.name = ({
                            std::optional tmp =
                                user_config.prefix.value() +
                                static_cast<std::string_view>(
                                request.name);
                            if (!tmp) {
                                res = ReadDomainResult::NotValid;
                                continue;
                            }
                            std::move(*tmp);
                        });
                    }

                    if (user_config.postfix) {
                        request.name = ({
                            std::optional tmp =
                                request.name +
                                static_cast<std::string_view>(
                                user_config.postfix.value());
                            if (!tmp) {
                                res = ReadDomainResult::NotValid;
                                continue;
                            }
                            std::move(*tmp);
                        });
                    }

                    // Keep track of the number invalid lines
                    if (res == ReadDomainResult::NotValid)
                        main_core_counters.lines_rejected++;

                    if (request.name.buf[request.name.len - 1] != '.') {
                        request.name.buf[request.name.len] = '.';
                        request.name.buf[++request.name.len] = '\0';
                    }

                    request.src_ip = ethernet_config.src_ip;
                    request.num_ips = REQUEST_MAX_IPS;
                    request.dst_mac = dst_mac;

                    request.q_type = user_config.q_type;

                } while (res == ReadDomainResult::NotValid);

                if (res != ReadDomainResult::Success)
                    break;
            }

            auto [filled, _] = new_mbufs.split(i);
            std::ignore = filled_request_buffer.insert(std::move(filled));
        }();

        auto not_enqueued_requests =
            dispatch_ring.enqueue_burst(std::move(filled_request_buffer));
        filled_request_buffer = std::move(not_enqueued_requests);

        std::chrono::time_point<std::chrono::steady_clock> current_time =
            std::chrono::steady_clock::now();

        if (current_time > next_log_time) {
            next_log_time =
                current_time + std::chrono::milliseconds(COUNTER_LOG_TIME_MS);

            PerCoreCounters total_count;
            for (auto& core_counter : counters)
                total_count += core_counter;

            std::string str = glz::write_json(total_count);
            spdlog::info(str);
        }

        if (current_time > next_tui_time && !user_config.headless) {
            clear();

            PerCoreCounters total_count;
            for (auto& core_counter : counters)
                total_count += core_counter;

            constexpr auto tui_delay = std::chrono::milliseconds(TERMINAL_TUI_TIME_MS);

            const double time_since_last_tui_update =
                std::chrono::duration<double>(current_time - next_tui_time + tui_delay)
                .count();

            PerCoreCounters total_count_per_s =
                (total_count - last_total_count) / time_since_last_tui_update;

            std::string counters_formatted =
                FormatCounters(total_count, total_count_per_s, user_config.debug);

            printw("%s", counters_formatted.c_str());
            refresh();

            next_tui_time = current_time + tui_delay;
            last_total_count = total_count;
        }
    }

    if (!user_config.headless) {
        // End window and print last counter values to stdout
        endwin();

        PerCoreCounters total_count;
        for (auto& core_counter : counters)
            total_count += core_counter;

        std::string counters_formatted =
            FormatCounters(total_count, PerCoreCounters{}, user_config.debug);
        std::cout << counters_formatted;
    }

    thread_manager.request_stop();

    // Join all threads
    thread_manager.join();

    auto stop = std::chrono::steady_clock::now();
    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start);

    // std::cout << "\nRxTx if stats:\n\n";

    rxtx_if.PrintStats();

    std::cout << "\n\nRunning scanner took " << duration.count() << " ms\n";
}