Learn Zig Series):Exercise 1: AAAA record lookups (IPv6)
const std = @import("std");
const dns = @import("dns.zig");
pub fn main() !void {
var resolver = try dns.DnsResolver.init(.{ 8, 8, 8, 8 });
defer resolver.deinit();
const name = "google.com";
var records: [16]dns.DnsRecord = undefined;
// Resolve A records
const a_result = resolver.resolve(name, .A, &records) catch |err| {
std.debug.print("A lookup failed: {}\n", .{err});
return;
};
std.debug.print("IPv4 results for {s}:\n", .{name});
for (records[0..a_result.count]) |rec| {
if (rec.data_tag == .ipv4) {
std.debug.print(" A: {d}.{d}.{d}.{d} (TTL: {d}s)\n", .{
rec.data.ipv4[0], rec.data.ipv4[1],
rec.data.ipv4[2], rec.data.ipv4[3],
rec.ttl,
});
}
}
// Resolve AAAA records
var records6: [16]dns.DnsRecord = undefined;
const aaaa_result = resolver.resolve(name, .AAAA, &records6) catch |err| {
std.debug.print("AAAA lookup failed: {}\n", .{err});
return;
};
std.debug.print("IPv6 results for {s}:\n", .{name});
for (records6[0..aaaa_result.count]) |rec| {
if (rec.data_tag == .ipv6) {
std.debug.print(" AAAA: ", .{});
for (0..8) |i| {
if (i > 0) std.debug.print(":", .{});
const hi = rec.data.ipv6[i * 2];
const lo = rec.data.ipv6[i * 2 + 1];
std.debug.print("{x:0>2}{x:0>2}", .{ hi, lo });
}
std.debug.print(" (TTL: {d}s)\n", .{rec.ttl});
}
}
}
The key insight is that AAAA lookups use the same query format -- you just change the query type to 28. The 16-byte response data is the raw IPv6 address in network byte order, which we format as colon-separated hex groups.
Exercise 2: DNS cache with TTL
const std = @import("std");
const dns = @import("dns.zig");
const CacheKey = struct {
name_buf: [256]u8,
name_len: usize,
qtype: dns.QueryType,
fn init(name: []const u8, qtype: dns.QueryType) CacheKey {
var key: CacheKey = undefined;
key.name_len = @min(name.len, 256);
@memcpy(key.name_buf[0..key.name_len], name[0..key.name_len]);
key.qtype = qtype;
return key;
}
};
const CacheEntry = struct {
records: [16]dns.DnsRecord,
count: usize,
cached_at: i64,
min_ttl: u32,
};
const DnsCache = struct {
map: std.AutoHashMap(u64, CacheEntry),
fn init(allocator: std.mem.Allocator) DnsCache {
return .{ .map = std.AutoHashMap(u64, CacheEntry).init(allocator) };
}
fn deinit(self: *DnsCache) void {
self.map.deinit();
}
fn hashKey(key: CacheKey) u64 {
var h = std.hash.Wyhash.init(0);
h.update(key.name_buf[0..key.name_len]);
h.update(std.mem.asBytes(&key.qtype));
return h.final();
}
fn lookup(self: *DnsCache, name: []const u8, qtype: dns.QueryType) ?*const CacheEntry {
const key = CacheKey.init(name, qtype);
const entry = self.map.getPtr(hashKey(key)) orelse return null;
const now = std.time.timestamp();
if (now >= entry.cached_at + entry.min_ttl) {
_ = self.map.remove(hashKey(key));
return null;
}
return entry;
}
fn store(self: *DnsCache, name: []const u8, qtype: dns.QueryType, records: []const dns.DnsRecord, count: usize) !void {
var entry: CacheEntry = undefined;
entry.count = @min(count, 16);
@memcpy(entry.records[0..entry.count], records[0..entry.count]);
entry.cached_at = std.time.timestamp();
entry.min_ttl = std.math.maxInt(u32);
for (records[0..entry.count]) |rec| {
if (rec.ttl < entry.min_ttl) entry.min_ttl = rec.ttl;
}
if (entry.min_ttl == std.math.maxInt(u32)) entry.min_ttl = 60;
const key = CacheKey.init(name, qtype);
try self.map.put(hashKey(key), entry);
}
};
test "cache hit and TTL expiry" {
var cache = DnsCache.init(std.testing.allocator);
defer cache.deinit();
var rec: dns.DnsRecord = undefined;
rec.data.ipv4 = .{ 1, 2, 3, 4 };
rec.data_tag = .ipv4;
rec.ttl = 2; // 2 second TTL
rec.rtype = 1;
rec.class = 1;
rec.name_len = 0;
rec.cname_len = 0;
try cache.store("test.com", .A, &.{rec}, 1);
// Should hit
const hit = cache.lookup("test.com", .A);
try std.testing.expect(hit != null);
try std.testing.expectEqual(@as(usize, 1), hit.?.count);
// Different type should miss
const miss = cache.lookup("test.com", .AAAA);
try std.testing.expect(miss == null);
}
The cache uses a hash of the domain name + query type as the key. The minimum TTL across all records in an answer determines when the entry expires -- this is conservative but correct, since you'd need to re-query anyway if any record has changed.
Exercise 3: Batch DNS resolver
const std = @import("std");
const dns = @import("dns.zig");
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
const args = try std.process.argsAlloc(allocator);
defer std.process.argsFree(allocator, args);
const file_path = if (args.len > 1) args[1] else "domains.txt";
const file = std.fs.cwd().openFile(file_path, .{}) catch |err| {
std.debug.print("cannot open {s}: {}\n", .{ file_path, err });
return;
};
defer file.close();
var resolver = try dns.DnsResolver.init(.{ 8, 8, 8, 8 });
defer resolver.deinit();
const start = std.time.milliTimestamp();
var total: usize = 0;
var failed: usize = 0;
std.debug.print("{s:<40} | {s:<6} | {s:<16} | {s}\n", .{ "domain", "type", "value", "TTL" });
std.debug.print("{s}\n", .{"-" ** 80});
var buf: [1024]u8 = undefined;
const reader = file.reader();
while (reader.readUntilDelimiter(&buf, '\n') catch null) |line| {
if (line.len == 0) continue;
total += 1;
var records: [16]dns.DnsRecord = undefined;
const result = resolver.resolve(line, .A, &records) catch |err| {
std.debug.print("{s:<40} | ERROR | {}\n", .{ line, err });
failed += 1;
continue;
};
for (records[0..result.count]) |rec| {
switch (rec.data_tag) {
.ipv4 => {
var ip_buf: [16]u8 = undefined;
const ip_str = std.fmt.bufPrint(&ip_buf, "{d}.{d}.{d}.{d}", .{
rec.data.ipv4[0], rec.data.ipv4[1],
rec.data.ipv4[2], rec.data.ipv4[3],
}) catch "?";
std.debug.print("{s:<40} | A | {s:<16} | {d}s\n", .{ line, ip_str, rec.ttl });
},
.cname => {
std.debug.print("{s:<40} | CNAME | {s:<16} | {d}s\n", .{
line, rec.data.cname[0..rec.cname_len], rec.ttl,
});
},
else => {
std.debug.print("{s:<40} | type={d:<2} | (raw) | {d}s\n", .{ line, rec.rtype, rec.ttl });
},
}
}
}
const elapsed = std.time.milliTimestamp() - start;
std.debug.print("\n{d} domains resolved ({d} failed) in {d}ms\n", .{ total, failed, elapsed });
}
The batch resolver reads domains line by line and prints a formatted table. The catch |err| on each resolve call means one failed lookup (NXDOMAIN, timeout) doesn't kill the whole batch -- it prints the error and moves on.
Last episode we built a DNS resolver -- the client side of the protocol. We crafted queries, sent them to Google's 8.8.8.8, and parsed the responses we got back. Today we're flipping the whole thing around: we're building the server that answers those queries. The part that sits there listening on port 53, receives a domain name question, looks up the answer in its records, and sends back a properly formed DNS response.
This is where things get really satisfying, because we can test our server with the resolver we already wrote. Client talks to server, both written by us, both handling the same wire protocol. If that's not a good reason to keep going, I don't know what is ;-)
You might wonder: "who runs their own DNS server?" More people than you'd think. Local development environments, home labs, ad blockers (Pi-hole is essentially a DNS server), split-horizon DNS in corporate networks, custom service discovery -- there are plenty of legitimate reasons to run a DNS server that answers queries for your own domains.
Having said that, the real reason we're building one is educational. Understanding both sides of DNS makes you much better at debugging network issues. When you know exactly what a DNS response looks like on the wire, "it's probably DNS" stops being a meme and becomes something you can actually investigate.
A DNS server needs to know what records it's authoritative for. In real DNS infrastructure this comes from zone files -- text files that map domain names to IP addresses (and other record types). We'll use a simplified format for our server:
const std = @import("std");
pub const RecordType = enum(u16) {
A = 1,
AAAA = 28,
CNAME = 5,
MX = 15,
TXT = 16,
NS = 2,
};
pub const ZoneRecord = struct {
name: [256]u8,
name_len: usize,
rtype: RecordType,
ttl: u32,
data: union {
ipv4: [4]u8,
ipv6: [16]u8,
name_ref: struct {
buf: [256]u8,
len: usize,
},
txt: struct {
buf: [512]u8,
len: usize,
},
mx: struct {
preference: u16,
exchange: [256]u8,
exchange_len: usize,
},
},
data_tag: enum { ipv4, ipv6, name_ref, txt, mx },
pub fn getName(self: *const ZoneRecord) []const u8 {
return self.name[0..self.name_len];
}
};
The ZoneRecord struct uses a tagged union for the data field -- same pattern we used in the resolver (episode 82). Each record type stores its data differently: A records are 4 bytes, AAAA records are 16 bytes, CNAME and NS records contain another domain name, MX records have a preference number plus a domain name, and TXT records are arbitrary text.
Real BIND zone files have a complex format with $ORIGIN directives, relative names, and all sorts of edge cases. We'll parse a simplified version where each line is name type ttl value:
pub const Zone = struct {
records: [256]ZoneRecord,
count: usize,
pub fn init() Zone {
return .{
.records = undefined,
.count = 0,
};
}
pub fn addRecord(self: *Zone, rec: ZoneRecord) !void {
if (self.count >= self.records.len) return error.ZoneFull;
self.records[self.count] = rec;
self.count += 1;
}
/// Parse a line like: "example.com A 3600 93.184.216.34"
pub fn parseLine(line: []const u8) !ZoneRecord {
var rec: ZoneRecord = undefined;
var it = std.mem.tokenizeScalar(u8, line, ' ');
// Field 1: name
const name = it.next() orelse return error.InvalidLine;
if (name.len > 255) return error.NameTooLong;
rec.name_len = name.len;
@memcpy(rec.name[0..name.len], name);
// Field 2: record type
const rtype_str = it.next() orelse return error.InvalidLine;
rec.rtype = if (std.mem.eql(u8, rtype_str, "A"))
.A
else if (std.mem.eql(u8, rtype_str, "AAAA"))
.AAAA
else if (std.mem.eql(u8, rtype_str, "CNAME"))
.CNAME
else if (std.mem.eql(u8, rtype_str, "MX"))
.MX
else if (std.mem.eql(u8, rtype_str, "TXT"))
.TXT
else if (std.mem.eql(u8, rtype_str, "NS"))
.NS
else
return error.UnknownType;
// Field 3: TTL
const ttl_str = it.next() orelse return error.InvalidLine;
rec.ttl = std.fmt.parseInt(u32, ttl_str, 10) catch return error.InvalidTTL;
// Field 4+: data (depends on type)
switch (rec.rtype) {
.A => {
const addr_str = it.next() orelse return error.InvalidLine;
var parts: [4]u8 = undefined;
var dot_it = std.mem.splitScalar(u8, addr_str, '.');
for (&parts) |*p| {
const part = dot_it.next() orelse return error.InvalidAddress;
p.* = std.fmt.parseInt(u8, part, 10) catch return error.InvalidAddress;
}
rec.data.ipv4 = parts;
rec.data_tag = .ipv4;
},
.AAAA => {
const addr_str = it.next() orelse return error.InvalidLine;
var addr: [16]u8 = .{0} ** 16;
var colon_it = std.mem.splitScalar(u8, addr_str, ':');
var idx: usize = 0;
while (colon_it.next()) |group| {
if (idx >= 16) break;
if (group.len == 0) continue; // handle ::
const val = std.fmt.parseInt(u16, group, 16) catch 0;
addr[idx] = @truncate(val >> 8);
addr[idx + 1] = @truncate(val);
idx += 2;
}
rec.data.ipv6 = addr;
rec.data_tag = .ipv6;
},
.CNAME, .NS => {
const target = it.next() orelse return error.InvalidLine;
if (target.len > 255) return error.NameTooLong;
rec.data.name_ref.len = target.len;
@memcpy(rec.data.name_ref.buf[0..target.len], target);
rec.data_tag = .name_ref;
},
.MX => {
const pref_str = it.next() orelse return error.InvalidLine;
rec.data.mx.preference = std.fmt.parseInt(u16, pref_str, 10) catch return error.InvalidLine;
const exchange = it.next() orelse return error.InvalidLine;
if (exchange.len > 255) return error.NameTooLong;
rec.data.mx.exchange_len = exchange.len;
@memcpy(rec.data.mx.exchange[0..exchange.len], exchange);
rec.data_tag = .mx;
},
.TXT => {
// Remaining tokens are the TXT value
const txt_start = it.next() orelse return error.InvalidLine;
var txt_len: usize = txt_start.len;
@memcpy(rec.data.txt.buf[0..txt_start.len], txt_start);
while (it.next()) |extra| {
if (txt_len + 1 + extra.len > rec.data.txt.buf.len) break;
rec.data.txt.buf[txt_len] = ' ';
txt_len += 1;
@memcpy(rec.data.txt.buf[txt_len..][0..extra.len], extra);
txt_len += extra.len;
}
rec.data.txt.len = txt_len;
rec.data_tag = .txt;
},
}
return rec;
}
pub fn loadFromText(self: *Zone, text: []const u8) !void {
var line_it = std.mem.splitScalar(u8, text, '\n');
while (line_it.next()) |line| {
const trimmed = std.mem.trim(u8, line, " \t\r");
if (trimmed.len == 0) continue;
if (trimmed[0] == '#') continue; // comment
const rec = parseLine(trimmed) catch continue;
try self.addRecord(rec);
}
}
/// Find all records matching a name and type
pub fn lookup(self: *const Zone, name: []const u8, rtype: RecordType) struct {
records: []const ZoneRecord,
count: usize,
buf: [16]ZoneRecord,
} {
var result: [16]ZoneRecord = undefined;
var count: usize = 0;
for (self.records[0..self.count]) |rec| {
if (count >= 16) break;
if (rec.rtype == rtype and rec.name_len == name.len and
std.mem.eql(u8, rec.name[0..rec.name_len], name))
{
result[count] = rec;
count += 1;
}
}
return .{ .records = result[0..count], .count = count, .buf = result };
}
};
The lookup function does a linear scan through all records. For a server with thousands of records you'd want a hash map (like we built in episode 22), but for a few hundred records linear search is perfectly fine and keeps the code simple. The zone file format is one record per line, with # for comments -- nothing fancy, but enough to be useful.
Now the core of the server: taking a query we received, looking up matching records, and constructing a valid DNS response. The response has to echo back the original question and include the answer records in the correct wire format:
pub const DnsHeader = packed struct {
id: u16,
flags: u16,
qdcount: u16,
ancount: u16,
nscount: u16,
arcount: u16,
pub fn toNetworkOrder(self: DnsHeader) DnsHeader {
return .{
.id = std.mem.nativeToBig(u16, self.id),
.flags = std.mem.nativeToBig(u16, self.flags),
.qdcount = std.mem.nativeToBig(u16, self.qdcount),
.ancount = std.mem.nativeToBig(u16, self.ancount),
.nscount = std.mem.nativeToBig(u16, self.nscount),
.arcount = std.mem.nativeToBig(u16, self.arcount),
};
}
pub fn fromNetworkOrder(self: DnsHeader) DnsHeader {
return .{
.id = std.mem.bigToNative(u16, self.id),
.flags = std.mem.bigToNative(u16, self.flags),
.qdcount = std.mem.bigToNative(u16, self.qdcount),
.ancount = std.mem.bigToNative(u16, self.ancount),
.nscount = std.mem.bigToNative(u16, self.nscount),
.arcount = std.mem.bigToNative(u16, self.arcount),
};
}
};
/// Encode a domain name into DNS wire format at buf[pos..]
fn encodeName(name: []const u8, buf: []u8, pos: usize) !usize {
var offset = pos;
var it = std.mem.splitScalar(u8, name, '.');
while (it.next()) |label| {
if (label.len == 0) continue;
if (label.len > 63) return error.LabelTooLong;
if (offset + 1 + label.len >= buf.len) return error.BufferTooSmall;
buf[offset] = @intCast(label.len);
offset += 1;
@memcpy(buf[offset..][0..label.len], label);
offset += label.len;
}
if (offset >= buf.len) return error.BufferTooSmall;
buf[offset] = 0;
offset += 1;
return offset - pos;
}
pub fn buildResponse(
query: []const u8,
zone: *const Zone,
out: []u8,
) !usize {
if (query.len < 12) return error.QueryTooShort;
// Parse the incoming query header
var qheader: DnsHeader = undefined;
@memcpy(std.mem.asBytes(&qheader), query[0..12]);
qheader = qheader.fromNetworkOrder();
// Extract the question section
var qoffset: usize = 12;
// Read the query name
var qname_buf: [256]u8 = undefined;
var qname_len: usize = 0;
var raw_qname_start = qoffset;
while (qoffset < query.len) {
const len = query[qoffset];
if (len == 0) {
qoffset += 1;
break;
}
if (len & 0xC0 == 0xC0) {
qoffset += 2;
break;
}
if (qoffset + 1 + len > query.len) return error.Truncated;
if (qname_len > 0) {
qname_buf[qname_len] = '.';
qname_len += 1;
}
@memcpy(qname_buf[qname_len..][0..len], query[qoffset + 1 ..][0..len]);
qname_len += len;
qoffset += 1 + len;
}
const raw_qname_len = qoffset - raw_qname_start;
_ = raw_qname_len;
if (qoffset + 4 > query.len) return error.Truncated;
const qtype = std.mem.readInt(u16, query[qoffset..][0..2], .big);
const qclass = std.mem.readInt(u16, query[qoffset + 2 ..][0..2], .big);
_ = qclass;
qoffset += 4;
// Look up matching records
const rtype: RecordType = std.meta.intToEnum(RecordType, qtype) catch .A;
const result = zone.lookup(qname_buf[0..qname_len], rtype);
const answer_count = result.count;
// Build response flags
// QR=1 (response), AA=1 (authoritative), RD copied from query
var flags: u16 = 0x8400; // QR=1, AA=1
if (qheader.flags & 0x0100 != 0) flags |= 0x0100; // copy RD
flags |= 0x0080; // RA=1
// Set NXDOMAIN if no records found
if (answer_count == 0) {
flags |= 3; // RCODE=3 (NXDOMAIN)
}
// Write response header
const rheader = (DnsHeader{
.id = qheader.id,
.flags = flags,
.qdcount = 1,
.ancount = @intCast(answer_count),
.nscount = 0,
.arcount = 0,
}).toNetworkOrder();
var pos: usize = 0;
if (pos + 12 > out.len) return error.BufferTooSmall;
@memcpy(out[0..12], std.mem.asBytes(&rheader));
pos = 12;
// Echo back the question section
const qname_wire_len = try encodeName(qname_buf[0..qname_len], out, pos);
pos += qname_wire_len;
if (pos + 4 > out.len) return error.BufferTooSmall;
std.mem.writeInt(u16, out[pos..][0..2], qtype, .big);
std.mem.writeInt(u16, out[pos + 2 ..][0..2], 1, .big); // class IN
pos += 4;
// Write answer records
for (result.buf[0..answer_count]) |rec| {
// Name (use compression pointer to question name at offset 12)
if (pos + 2 > out.len) return error.BufferTooSmall;
out[pos] = 0xC0;
out[pos + 1] = 12; // pointer to question name
pos += 2;
// Type, class, TTL, rdlength placeholder
if (pos + 10 > out.len) return error.BufferTooSmall;
std.mem.writeInt(u16, out[pos..][0..2], @intFromEnum(rec.rtype), .big);
std.mem.writeInt(u16, out[pos + 2 ..][0..2], 1, .big); // class IN
std.mem.writeInt(u32, out[pos + 4 ..][0..4], rec.ttl, .big);
pos += 8;
const rdlen_pos = pos;
pos += 2; // skip rdlength for now
// Write record data
const data_start = pos;
switch (rec.data_tag) {
.ipv4 => {
if (pos + 4 > out.len) return error.BufferTooSmall;
@memcpy(out[pos..][0..4], &rec.data.ipv4);
pos += 4;
},
.ipv6 => {
if (pos + 16 > out.len) return error.BufferTooSmall;
@memcpy(out[pos..][0..16], &rec.data.ipv6);
pos += 16;
},
.name_ref => {
const nlen = try encodeName(rec.data.name_ref.buf[0..rec.data.name_ref.len], out, pos);
pos += nlen;
},
.txt => {
// TXT records: length-prefixed string
const tlen = rec.data.txt.len;
if (pos + 1 + tlen > out.len) return error.BufferTooSmall;
out[pos] = @intCast(tlen);
pos += 1;
@memcpy(out[pos..][0..tlen], rec.data.txt.buf[0..tlen]);
pos += tlen;
},
.mx => {
if (pos + 2 > out.len) return error.BufferTooSmall;
std.mem.writeInt(u16, out[pos..][0..2], rec.data.mx.preference, .big);
pos += 2;
const elen = try encodeName(rec.data.mx.exchange[0..rec.data.mx.exchange_len], out, pos);
pos += elen;
},
}
// Fill in rdlength
const rdlen: u16 = @intCast(pos - data_start);
std.mem.writeInt(u16, out[rdlen_pos..][0..2], rdlen, .big);
}
return pos;
}
The compression pointer trick at out[pos] = 0xC0; out[pos + 1] = 12; is worth noting. Since every answer record in this response is for the same domain name that was in the question (which starts at byte 12 of the packet), we can just point back to it instead of encoding the name again. This is exactly the same compression we had to decode in the resolver -- now we're the ones generating it. The 2 bytes of pointer versus potentially 15+ bytes of repeated name encoding... it adds up when you have multiple records.
The main server loop is actualy pretty simple. Listen on port 53 (or a custom port for development), receive datagrams, process them, send responses:
const posix = std.posix;
pub const DnsServer = struct {
sock: posix.socket_t,
zone: Zone,
queries_served: u64,
bind_port: u16,
pub fn init(port: u16) !DnsServer {
const sock = try posix.socket(posix.AF.INET, posix.SOCK.DGRAM, 0);
errdefer posix.close(sock);
try posix.setsockopt(sock, posix.SOL.SOCKET, posix.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
const addr = std.net.Address.initIp4(.{ 0, 0, 0, 0 }, port);
try posix.bind(sock, &addr.any, addr.getOsSockLen());
return .{
.sock = sock,
.zone = Zone.init(),
.queries_served = 0,
.bind_port = port,
};
}
pub fn deinit(self: *DnsServer) void {
posix.close(self.sock);
}
pub fn loadZone(self: *DnsServer, zone_text: []const u8) !void {
try self.zone.loadFromText(zone_text);
}
pub fn serve(self: *DnsServer) !void {
std.debug.print("DNS server listening on port {d}\n", .{self.bind_port});
std.debug.print("loaded {d} zone records\n", .{self.zone.count});
var recv_buf: [512]u8 = undefined;
var resp_buf: [512]u8 = undefined;
while (true) {
var client_addr: posix.sockaddr.storage = undefined;
var addr_len: posix.socklen_t = @sizeOf(posix.sockaddr.storage);
const n = try posix.recvfrom(
self.sock,
&recv_buf,
0,
@ptrCast(&client_addr),
&addr_len,
);
if (n < 12) continue; // too short for DNS header
const query = recv_buf[0..n];
const resp_len = buildResponse(query, &self.zone, &resp_buf) catch |err| {
std.debug.print("error building response: {}\n", .{err});
continue;
};
_ = posix.sendto(
self.sock,
resp_buf[0..resp_len],
0,
@ptrCast(&client_addr),
addr_len,
) catch |err| {
std.debug.print("error sending response: {}\n", .{err});
continue;
};
self.queries_served += 1;
// Log the query
const sender = std.net.Address.initPosix(@ptrCast(&client_addr));
const qid = std.mem.readInt(u16, query[0..2], .big);
std.debug.print("[{d}] query 0x{x:0>4} from {d}.{d}.{d}.{d}:{d} ({d} bytes -> {d} bytes)\n", .{
self.queries_served,
qid,
sender.in.sa.addr[0],
sender.in.sa.addr[1],
sender.in.sa.addr[2],
sender.in.sa.addr[3],
std.mem.bigToNative(u16, sender.in.sa.port),
n,
resp_len,
});
}
}
};
The serve function is an infinite loop -- recvfrom, process, sendto, repeat. Every incoming query gets logged with the client's IP, port, and the packet sizes. The query ID is logged too, which is helpful for matching up queries and responses when debugging.
Notice we're using 512-byte buffers for both receive and send. The classic DNS limit is 512 bytes for UDP messages (RFC 1035). Larger responses require TCP fallback or EDNS0 extensions -- we're keeping it simple and staying within the original spec.
Here's the main function that loads a zone file and starts the server:
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
var args = try std.process.argsAlloc(allocator);
defer std.process.argsFree(allocator, args);
const port: u16 = if (args.len > 1)
std.fmt.parseInt(u16, args[1], 10) catch 5353
else
5353;
var server = try DnsServer.init(port);
defer server.deinit();
// Example zone data
const zone_text =
\\# Example zone file
\\example.local A 3600 192.168.1.100
\\example.local AAAA 3600 fd00:0000:0000:0000:0000:0000:0000:0001
\\www.example.local CNAME 3600 example.local
\\example.local MX 3600 10 mail.example.local
\\mail.example.local A 3600 192.168.1.200
\\example.local TXT 3600 v=spf1 include:example.local ~all
\\example.local NS 86400 ns1.example.local
\\ns1.example.local A 86400 192.168.1.1
\\test.local A 60 10.0.0.1
\\test.local A 60 10.0.0.2
;
try server.loadZone(zone_text);
try server.serve();
}
We default to port 5353 instead of 53 because binding to port 53 requires root privileges on most systems. For development and testing, 5353 works just fine -- you just need to tell your resolver to use the right port.
The zone data defines a small local domain with an A record, AAAA record, CNAME alias (www pointing to the root), an MX record for mail, a TXT record for SPF, and a couple of name server entries. The test.local domain has two A records -- that's called a round-robin setup, and real DNS servers use this for basic load balancing.
Here's the best part -- we can test the server using the resolver from episode 82. Modify the resolver to connect to our local server instead of 8.8.8.8:
const std = @import("std");
const dns = @import("dns.zig");
pub fn main() !void {
// Connect to our local DNS server on port 5353
var resolver = try dns.DnsResolver.init(.{ 127, 0, 0, 1 });
defer resolver.deinit();
// Override the port (our server runs on 5353, not 53)
resolver.server = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 5353);
const domains = [_]struct { name: []const u8, qtype: dns.QueryType }{
.{ .name = "example.local", .qtype = .A },
.{ .name = "example.local", .qtype = .AAAA },
.{ .name = "www.example.local", .qtype = .CNAME },
.{ .name = "example.local", .qtype = .MX },
.{ .name = "test.local", .qtype = .A },
.{ .name = "nonexistent.local", .qtype = .A },
};
for (domains) |d| {
std.debug.print("\nQuerying {s} ({s}):\n", .{
d.name,
@tagName(d.qtype),
});
var records: [16]dns.DnsRecord = undefined;
const result = resolver.resolve(d.name, d.qtype, &records) catch |err| {
std.debug.print(" error: {}\n", .{err});
continue;
};
for (records[0..result.count]) |rec| {
switch (rec.data_tag) {
.ipv4 => std.debug.print(" A: {d}.{d}.{d}.{d}\n", .{
rec.data.ipv4[0], rec.data.ipv4[1],
rec.data.ipv4[2], rec.data.ipv4[3],
}),
.ipv6 => {
std.debug.print(" AAAA: ", .{});
for (0..8) |i| {
if (i > 0) std.debug.print(":", .{});
std.debug.print("{x:0>2}{x:0>2}", .{
rec.data.ipv6[i * 2],
rec.data.ipv6[i * 2 + 1],
});
}
std.debug.print("\n", .{});
},
.cname => std.debug.print(" CNAME: {s}\n", .{
rec.data.cname[0..rec.cname_len],
}),
.mx => std.debug.print(" MX: {d} {s}\n", .{
rec.data.mx.preference,
rec.data.mx.exchange[0..rec.data.mx.exchange_len],
}),
else => std.debug.print(" type={d}\n", .{rec.rtype}),
}
}
}
}
Run the server in one terminal, run this test client in another, and you should see the queries flowing through. The server logs each query, and the client prints the responses. The nonexistent.local query should return NameNotFound -- that's the NXDOMAIN response code we set when there are no matching records.
A DNS server exposed to the internet without rate limiting is asking for trouble. Even for a local server, it's good practice to implement basic protection against query floods. Here's a simple per-IP rate limiter:
pub const RateLimiter = struct {
/// Track query counts per IP (using the 32-bit IPv4 address as key)
counts: [256]struct {
ip: u32,
count: u32,
window_start: i64,
},
entries: usize,
max_per_second: u32,
pub fn init(max_per_second: u32) RateLimiter {
return .{
.counts = undefined,
.entries = 0,
.max_per_second = max_per_second,
};
}
pub fn allow(self: *RateLimiter, ip_bytes: [4]u8) bool {
const ip = std.mem.readInt(u32, &ip_bytes, .big);
const now = std.time.timestamp();
// Find existing entry or create new one
for (self.counts[0..self.entries]) |*entry| {
if (entry.ip == ip) {
if (now - entry.window_start >= 1) {
// New second window -- reset
entry.count = 1;
entry.window_start = now;
return true;
}
entry.count += 1;
return entry.count <= self.max_per_second;
}
}
// New IP
if (self.entries < self.counts.len) {
self.counts[self.entries] = .{
.ip = ip,
.count = 1,
.window_start = now,
};
self.entries += 1;
return true;
}
// Table full -- evict oldest entry
var oldest_idx: usize = 0;
var oldest_time: i64 = std.math.maxInt(i64);
for (self.counts[0..self.entries], 0..) |entry, i| {
if (entry.window_start < oldest_time) {
oldest_time = entry.window_start;
oldest_idx = i;
}
}
self.counts[oldest_idx] = .{
.ip = ip,
.count = 1,
.window_start = now,
};
return true;
}
};
This is a fixed-window rate limiter -- it counts queries per IP within 1-second windows. The table has 256 slots, and when it's full, the oldest entry gets evicted. For a small local DNS server this is perfectly adequate. Production DNS servers use more sophisticated approaches (sliding windows, token buckets), but the principle is the same.
You'd integrate this into the server loop right after recvfrom:
// In the serve() function, after receiving the query:
const sender = std.net.Address.initPosix(@ptrCast(&client_addr));
if (!self.limiter.allow(sender.in.sa.addr)) {
// Drop the query silently
continue;
}
Testing the response builder is critical because DNS clients will reject malformed packets. We can build a query by hand (using the functions from episode 82) and verify the response:
test "server responds to A query" {
var zone = Zone.init();
try zone.loadFromText("test.local A 300 10.0.0.42");
// Build a query for test.local A
var query_buf: [512]u8 = undefined;
// Header: id=0xBEEF, flags=0x0100 (RD), 1 question
std.mem.writeInt(u16, query_buf[0..2], 0xBEEF, .big);
std.mem.writeInt(u16, query_buf[2..4], 0x0100, .big);
std.mem.writeInt(u16, query_buf[4..6], 1, .big); // qdcount
std.mem.writeInt(u16, query_buf[6..8], 0, .big);
std.mem.writeInt(u16, query_buf[8..10], 0, .big);
std.mem.writeInt(u16, query_buf[10..12], 0, .big);
var qpos: usize = 12;
const name_len = try encodeName("test.local", &query_buf, qpos);
qpos += name_len;
std.mem.writeInt(u16, query_buf[qpos..][0..2], 1, .big); // type A
std.mem.writeInt(u16, query_buf[qpos + 2 ..][0..2], 1, .big); // class IN
qpos += 4;
// Build response
var resp_buf: [512]u8 = undefined;
const resp_len = try buildResponse(query_buf[0..qpos], &zone, &resp_buf);
// Verify response header
const resp_id = std.mem.readInt(u16, resp_buf[0..2], .big);
try std.testing.expectEqual(@as(u16, 0xBEEF), resp_id);
// Check QR bit is set (bit 15 of flags)
const resp_flags = std.mem.readInt(u16, resp_buf[2..4], .big);
try std.testing.expect(resp_flags & 0x8000 != 0); // QR=1
// Check answer count = 1
const ancount = std.mem.readInt(u16, resp_buf[6..8], .big);
try std.testing.expectEqual(@as(u16, 1), ancount);
// The response should be bigger than just the header + question
try std.testing.expect(resp_len > qpos);
}
test "server returns NXDOMAIN for unknown domain" {
var zone = Zone.init();
try zone.loadFromText("known.local A 300 10.0.0.1");
var query_buf: [512]u8 = undefined;
std.mem.writeInt(u16, query_buf[0..2], 0x1234, .big);
std.mem.writeInt(u16, query_buf[2..4], 0x0100, .big);
std.mem.writeInt(u16, query_buf[4..6], 1, .big);
std.mem.writeInt(u16, query_buf[6..8], 0, .big);
std.mem.writeInt(u16, query_buf[8..10], 0, .big);
std.mem.writeInt(u16, query_buf[10..12], 0, .big);
var qpos: usize = 12;
const name_len = try encodeName("nope.local", &query_buf, qpos);
qpos += name_len;
std.mem.writeInt(u16, query_buf[qpos..][0..2], 1, .big);
std.mem.writeInt(u16, query_buf[qpos + 2 ..][0..2], 1, .big);
qpos += 4;
var resp_buf: [512]u8 = undefined;
const resp_len = try buildResponse(query_buf[0..qpos], &zone, &resp_buf);
_ = resp_len;
const resp_flags = std.mem.readInt(u16, resp_buf[2..4], .big);
const rcode = resp_flags & 0x000F;
try std.testing.expectEqual(@as(u16, 3), rcode); // NXDOMAIN
const ancount = std.mem.readInt(u16, resp_buf[6..8], .big);
try std.testing.expectEqual(@as(u16, 0), ancount);
}
test "zone file parsing" {
var zone = Zone.init();
try zone.loadFromText(
\\# Comment line
\\example.com A 3600 1.2.3.4
\\example.com AAAA 3600 2001:0db8:0000:0000:0000:0000:0000:0001
\\www.example.com CNAME 7200 example.com
\\
);
try std.testing.expectEqual(@as(usize, 3), zone.count);
const a_result = zone.lookup("example.com", .A);
try std.testing.expectEqual(@as(usize, 1), a_result.count);
try std.testing.expectEqualSlices(u8, &.{ 1, 2, 3, 4 }, &a_result.buf[0].data.ipv4);
const cname_result = zone.lookup("www.example.com", .CNAME);
try std.testing.expectEqual(@as(usize, 1), cname_result.count);
}
These tests verify three critical paths: a successful A record lookup, an NXDOMAIN response for unknown domains, and zone file parsing. The first two test the entire query-to-response pipeline without touching the network -- we build a raw query packet and feed it directly to buildResponse. This is the same synthetic packet technique we used in episode 82, and it makes the tests fast, deterministic, and CI-friendly.
Our server is authoritative -- it only answers for domains it has records for, and returns NXDOMAIN for everything else. This is distinct from a recursive resolver (like 8.8.8.8 or the resolver we built last episode), which follows the chain of referrals from root servers to find the answer for any domain.
If you wanted to extend this into a recursive server, you'd add logic to: (1) check your own zone first, (2) if no match, query an upstream recursive resolver, (3) cache the result. That's essentially what Pi-hole does -- it's an authoritative server for its blocklist (returning 0.0.0.0 for ad domains) and a recursive forwarder for everything else. Conceptually that's not hard to build now that we have both the resolver and the server, but the edge cases around CNAME chaining, negative caching, and DNS security extensions (DNSSEC) make a production recursive resolver a significantly larger project.
For now, knowing the difference between authoritative and recursive is the important thing. The wire protocol is identical -- the difference is entirely in the server's behavior when it doesn't have a local record.
Add a statistics endpoint to the DNS server. Create a special domain like stats.server.local that, when queried for a TXT record, returns the server's statistics: total queries served, uptime in seconds, and number of zone records loaded. The response should be a single TXT record with a format like "queries=1234 uptime=5678s records=10". Test it by querying stats.server.local TXT with your resolver.
Implement zone file reloading without restarting the server. Add support for a SIGUSR1 signal handler (we covered signal handling in episode 67) that re-reads the zone file from disk when triggered. The server should log a message like "zone reloaded: {N} records" after a successfull reload. The tricky part is making sure the reload is safe if a query is being processed at the same time -- think about what guarantees you need.
Build a DNS proxy that forwards queries it can't answer locally to an upstream DNS server (like 8.8.8.8). The proxy should check its own zone first -- if it has a matching record, serve it directly. If not, forward the query to the upstream server, cache the response (using the TTL from the answer), and return it to the client. This is essentially a minimal Pi-hole: you could add a blocklist zone where known ad domains return 0.0.0.0, and everything else gets forwarded.