ndb/dns: DoT support

This commit is contained in:
Jacob Moody 2024-01-28 00:01:56 +00:00
parent 917d0fa9b4
commit 6b0574e27e
8 changed files with 276 additions and 72 deletions

View file

@ -232,7 +232,13 @@ There may be multiple
pairs.
.TP
.B dns
a DNS server to use (for DNS and DHCP)
a DNS server to use for resolving (for DNS and DHCP)
.TP
.B dot
a DNS over TLS server to use for resolving (for DNS).
If found,
.B dns
entries are ignored.
.TP
.B ntp
an NTP server to use (for DHCP)

View file

@ -62,6 +62,9 @@ query, ipquery, mkhash, mkdb, mkhosts, cs, csquery, dns, dnsquery, dnsdebug, dns
.B -a
.I maxage
] [
.B -c
.I cert.pem
] [
.B -f
.I dbfile
] [
@ -393,11 +396,18 @@ send `recursive' queries, asking the other servers
to complete lookups.
If present,
.B /env/DNSSERVER
must be a space-separated list of such DNS servers' IP addresses,
or
.B /env/DOTSERVER
must be a space-separated list of such DNS (or DoT) servers' IP addresses,
otherwise optional
.IR ndb (6)
.B dns
attributes name DNS servers to forward queries to.
Note that when
.B DOTSERVER
is specified,
.B DNSSERVER
are ignored.
.TP
.B -R
ignore the `recursive' bit on all incoming requests.
@ -422,6 +432,12 @@ are given,
listen on any interface on network mount point
.IR netmtpt .
.TP
.B -c
When a certificate
.I cert.pem
is specified, also listen on TCP port 853 and handle
DNS requests over TLS.
.TP
.B -x
specifies the mount point of the network.
.PD

View file

@ -916,7 +916,7 @@ addlocaldnsserver(DN *dp, int class, char *addr, int i)
/* check duplicate ip */
for(n = 0; n < i; n++){
snprint(buf, sizeof buf, "local#dns#server%d", n);
snprint(buf, sizeof buf, "%s#%d", dp->name, n);
nsdp = dnlookup(buf, class, 0);
if(nsdp == nil)
continue;
@ -931,7 +931,7 @@ addlocaldnsserver(DN *dp, int class, char *addr, int i)
rrfreelist(rp);
}
snprint(buf, sizeof buf, "local#dns#server%d", i);
snprint(buf, sizeof buf, "%s#%d", dp->name, i);
nsdp = dnlookup(buf, class, 1);
/* ns record for name server, make up an impossible name */
@ -967,6 +967,33 @@ dnsservers(int class)
RR *nsrp;
DN *dp;
/* try first DoT servers */
dp = dnlookup("local#dot#servers", class, 1);
nsrp = rrlookup(dp, Tns, NOneg);
if(nsrp != nil)
return nsrp;
p = getenv("DOTSERVER"); /* list of ip addresses */
if(p != nil && (n = tokenize(p, args, nelem(args))) > 0){
for(i = 0; i < n; i++)
addlocaldnsserver(dp, class, args[i], i);
} else {
t = lookupinfo("@dot"); /* @dot=ip1 ... */
if(t == nil)
return nil;
i = 0;
for(nt = t; nt != nil; nt = nt->entry){
addlocaldnsserver(dp, class, nt->val, i);
i++;
}
ndbfree(t);
}
nsrp = rrlookup(dp, Tns, NOneg);
if(nsrp != nil)
return nsrp;
/* try regular local DNS servers */
dp = dnlookup("local#dns#servers", class, 1);
nsrp = rrlookup(dp, Tns, NOneg);
if(nsrp != nil)

View file

@ -5,6 +5,8 @@
#include <libc.h>
#include <ip.h>
#include <bio.h>
#include <mp.h>
#include <libsec.h>
#include <ndb.h>
#include "dns.h"
@ -79,7 +81,7 @@ procgetname(void)
return strdup(lp+1);
}
void
static void
rrfreelistptr(RR **rpp)
{
RR *rp;
@ -267,9 +269,13 @@ issuequery(Query *qp, char *name, int class, int recurse)
*/
if(cfg.resolver){
nsrp = randomize(getdnsservers(class));
if(nsrp != nil)
if(nsrp != nil){
int dot = strncmp(nsrp->owner->name, "local#dot#server", 16) == 0;
if(netqueryns(qp, nsrp) > Answnone)
return rrlookup(qp->dp, qp->type, OKneg);
else if(dot)
return nil; /* do not fall-back for DoT */
}
}
/*
@ -733,7 +739,7 @@ readreply(Query *qp, int medium, int fd, uvlong endms,
/*
* return non-0 if first list includes second list
*/
int
static int
contains(RR *rp1, RR *rp2)
{
RR *trp1, *trp2;
@ -753,7 +759,7 @@ contains(RR *rp1, RR *rp2)
/*
* return multicast version if any
*/
int
static int
ipisbm(uchar *ip)
{
if(isv4(ip)){
@ -1166,16 +1172,31 @@ writenet(Query *qp, int medium, int fd, uchar *pkt, int len, Dest *p)
return rv;
}
enum {
Maxfree = 4,
};
struct {
QLock lk;
struct {
uvlong when;
char *dest;
int fd;
} l[Maxfree];
} tcpfree;
/*
* send a query via tcp to a single address
* and read the answer(s) into mp->an.
*/
static int
tcpquery(Query *qp, uchar *pkt, int len, Dest *p, uvlong endms, DNSmsg *mp)
tcpquery(Query *qp, uchar *pkt, int len, Dest *p, uvlong endms, DNSmsg *mp, int tls)
{
char buf[NETPATHLEN];
int fd, rv;
int fd, rv, i, retry;
long ms;
TLSconn conn;
Thumbprint *thumb;
memset(mp, 0, sizeof *mp);
@ -1185,57 +1206,139 @@ tcpquery(Query *qp, uchar *pkt, int len, Dest *p, uvlong endms, DNSmsg *mp)
if(ms > Maxtcpdialtm)
ms = Maxtcpdialtm;
procsetname("tcp query to %I/%s for %s %s", p->a, p->s->name,
procsetname("%s query to %I/%s for %s %s", tls ? "tls" : "tcp", p->a, p->s->name,
qp->dp->name, rrname(qp->type, buf, sizeof buf));
snprint(buf, sizeof buf, "%s/tcp!%I!53", mntpt, p->a);
snprint(buf, sizeof buf, "%s/tcp!%I!%s", mntpt, p->a, tls ? "853" : "53");
fd = -1;
retry = 0;
qlock(&tcpfree.lk);
for(i = 0; i < nelem(tcpfree.l); i++){
if(tcpfree.l[i].dest == nil || tcpfree.l[i].fd == -1)
continue;
if(strcmp(tcpfree.l[i].dest, buf) != 0)
continue;
/* RFC does not specify connection reuse timeout */
if(nowms - tcpfree.l[i].when < 5000){
fd = tcpfree.l[i].fd;
tcpfree.l[i].fd = -1;
retry++;
break;
}
}
qunlock(&tcpfree.lk);
if(fd != -1)
goto Found;
Retry:
alarm(ms);
fd = dial(buf, nil, nil, nil);
alarm(0);
if (fd < 0) {
if(fd < 0){
alarm(0);
dnslog("%d: can't dial %s for %I/%s: %r",
qp->req->id, buf, p->a, p->s->name);
return -1;
}
if(tls){
memset(&conn, 0, sizeof conn);
rv = tlsClient(fd, &conn);
alarm(0);
if(rv >= 0){
fd = rv;
thumb = initThumbprints("/sys/lib/tls/dns", nil, "x509");
if(thumb == nil || !okCertificate(conn.cert, conn.certlen, thumb)){
dnslog("%d: invalid fingerprint for %s; echo 'x509 %r' >>/sys/lib/tls/dns",
qp->req->id, buf);
rv = -1;
}
free(conn.cert);
free(conn.sessionID);
freeThumbprints(thumb);
}
if(rv < 0){
close(fd);
return -1;
}
} else {
alarm(0);
}
Found:
rv = writenet(qp, Tcp, fd, pkt, len, p);
if(rv == 0){
timems(); /* account for time dialing and sending */
rv = readreply(qp, Tcp, fd, endms, mp, pkt);
}
close(fd);
if(rv < 0){
close(fd);
if(retry){
retry = 0;
goto Retry;
}
return rv;
}
qlock(&tcpfree.lk);
if(tcpfree.l[nelem(tcpfree.l)-1].dest != nil){
close(tcpfree.l[nelem(tcpfree.l)-1].fd);
free(tcpfree.l[nelem(tcpfree.l)-1].dest);
}
memmove(tcpfree.l + 1, tcpfree.l, sizeof(tcpfree.l[0])*(nelem(tcpfree.l)-1));
tcpfree.l[0].when = nowms;
tcpfree.l[0].fd = fd;
tcpfree.l[0].dest = estrdup(buf);
qunlock(&tcpfree.lk);
return rv;
}
static int
tlsqueryns(Query *qp, uchar *pkt, int len)
{
Dest dest[Maxdest], *p;
int rv, n;
uvlong endms;
DNSmsg m;
/* populates dest with v4 and v6 addresses. */
n = 0;
n = serveraddrs(qp, dest, n, Ta);
n = serveraddrs(qp, dest, n, Taaaa);
endms = nowms + 500;
for(p = dest; p < dest+n; p++){
if(tcpquery(qp, pkt, len, p, endms, &m, 1) == 0){
/* free or incorporate RRs in m */
rv = procansw(qp, p, &m);
if(rv > Answnone)
return rv;
}
}
/* if all servers returned failure, propagate it */
qp->dp->respcode = Rserver;
for(p = dest; p < dest+n; p++)
if(p->code != Rserver)
qp->dp->respcode = Rok;
return Answnone;
}
/*
* query name servers. fill in pkt with on-the-wire representation of a
* DNSmsg derived from qp. if the name server returns a pointer to another
* name server, recurse.
*/
static int
udpqueryns(Query *qp, int fd, uchar *pkt)
udpqueryns(Query *qp, int fd, uchar *pkt, int len)
{
Dest dest[Maxdest], *edest, *p, *np;
int ndest, replywaits, len, flag, rv, n;
int ndest, replywaits, rv, n;
uchar srcip[IPaddrlen];
char buf[32];
uvlong endms;
DNSmsg m;
RR *rp;
/* prepare server RR's for incremental lookup */
for(rp = qp->nsrp; rp; rp = rp->next)
rp->marker = 0;
/* request recursion only for local/override dns servers */
flag = Oquery;
if(strncmp(qp->nsrp->owner->name, "local#", 6) == 0
|| strncmp(qp->nsrp->owner->name, "override#", 9) == 0)
flag |= Frecurse;
/* pack request into a udp message */
qp->id = rand();
len = mkreq(qp->dp, qp->type, pkt, flag, qp->id);
/* no destination yet */
edest = dest;
@ -1307,7 +1410,7 @@ udpqueryns(Query *qp, int fd, uchar *pkt)
/* if response was truncated, try tcp */
if(m.flags & Ftrunc){
freeanswers(&m);
if(tcpquery(qp, pkt, len, p, endms, &m) < 0)
if(tcpquery(qp, pkt, len, p, endms, &m, 0) < 0)
break; /* failed via tcp too */
if(m.flags & Ftrunc){
freeanswers(&m);
@ -1336,27 +1439,44 @@ udpqueryns(Query *qp, int fd, uchar *pkt)
return Answnone;
}
/*
* in principle we could use a single descriptor for a udp port
* to send all queries and receive all the answers to them,
* but we'd have to sort out the answers by dns-query id.
*/
static int
udpquery(Query *qp)
doquery(Query *qp)
{
int fd, rv;
int fd, rv, len, flag;
uchar *pkt;
RR *rp;
pkt = emalloc(Maxudp+Udphdrsize);
fd = udpport(mntpt);
if (fd < 0) {
dnslog("%d: can't get udpport for %s query of name %s: %r",
qp->req->id, mntpt, qp->dp->name);
rv = -1;
goto Out;
/* prepare server RR's for incremental lookup */
for(rp = qp->nsrp; rp; rp = rp->next)
rp->marker = 0;
/* request recursion only for local/override dns servers */
flag = Oquery;
if(strncmp(qp->nsrp->owner->name, "local#", 6) == 0
|| strncmp(qp->nsrp->owner->name, "override#", 9) == 0)
flag |= Frecurse;
/* pack request into a udp message */
qp->id = rand();
len = mkreq(qp->dp, qp->type, pkt, flag, qp->id);
if(strncmp(qp->nsrp->owner->name, "local#dot#server", 16) == 0
|| strncmp(qp->nsrp->owner->name, "override#dot#server", 16) == 0){
rv = tlsqueryns(qp, pkt, len);
} else {
/*
* in principle we could use a single descriptor for a udp port
* to send all queries and receive all the answers to them,
* but we'd have to sort out the answers by dns-query id.
*/
fd = udpport(mntpt);
if (fd < 0) {
dnslog("%d: can't get udpport for %s query of name %s: %r",
qp->req->id, mntpt, qp->dp->name);
rv = -1;
goto Out;
}
rv = udpqueryns(qp, fd, pkt, len);
close(fd);
}
rv = udpqueryns(qp, fd, pkt);
close(fd);
Out:
free(pkt);
return rv;
@ -1383,5 +1503,5 @@ netquery(Query *qp)
if(!qp->req->isslave && strcmp(qp->req->from, "9p") == 0)
return Answnone;
return udpquery(qp);
return doquery(qp);
}

View file

@ -92,7 +92,7 @@ static char *respond(Job*, Mfile*, RR*, char*, int, int);
void
usage(void)
{
fprint(2, "usage: %s [-FnrLR] [-a maxage] [-f ndb-file] [-N target] "
fprint(2, "usage: %s [-FnrLR] [-a maxage] [-c cert.pem] [-f ndb-file] [-N target] "
"[-x netmtpt] [-s [addrs...]]\n", argv0);
exits("usage");
}
@ -101,10 +101,12 @@ void
main(int argc, char *argv[])
{
char ext[Maxpath], servefile[Maxpath];
char *cert;
Dir *dir;
setnetmtpt(mntpt, sizeof mntpt, nil);
ext[0] = 0;
cert = nil;
ARGBEGIN{
case 'a':
maxage = atol(EARGF(usage()));
@ -141,6 +143,9 @@ main(int argc, char *argv[])
cfg.serve = 1; /* serve network */
cfg.cachedb = 1;
break;
case 'c':
cert = EARGF(usage());
break;
case 'x':
setnetmtpt(mntpt, sizeof mntpt, EARGF(usage()));
setext(ext, sizeof ext, mntpt);
@ -181,11 +186,15 @@ main(int argc, char *argv[])
if(cfg.serve){
if(argc == 0) {
dnudpserver(mntpt, "*");
dntcpserver(mntpt, "*");
dntcpserver(mntpt, "*", nil);
if(cert != nil)
dntcpserver(mntpt, "*", cert);
} else {
while(argc-- > 0){
dnudpserver(mntpt, *argv);
dntcpserver(mntpt, *argv);
dntcpserver(mntpt, *argv, nil);
if(cert != nil)
dntcpserver(mntpt, *argv, cert);
argv++;
}
}

View file

@ -522,7 +522,7 @@ void dnserver(DNSmsg*, DNSmsg*, Request*, uchar *, int);
void dnudpserver(char*, char*);
/* dntcpserver.c */
void dntcpserver(char*, char*);
void dntcpserver(char*, char*, char*);
/* dnnotify.c */
void dnnotify(DNSmsg*, DNSmsg*, Request*);

View file

@ -172,14 +172,18 @@ getdnsservers(int class)
{
uchar ip[IPaddrlen];
DN *nsdp;
RR *rp;
RR *rp, *ns;
char name[64];
if(servername == nil)
return dnsservers(class);
if(parseip(ip, servername) == -1){
nsdp = idnlookup(servername, class, 1);
snprint(name, sizeof name, "override#%s#server", servername[0] == '!' ? "dot" : "dns");
ns = rralloc(Tns);
if(parseip(ip, servername+1) == -1){
nsdp = idnlookup(servername+1, class, 1);
} else {
nsdp = dnlookup("local#dns#server", class, 1);
nsdp = dnlookup(name, class, 1);
rp = rralloc(isv4(ip) ? Ta : Taaaa);
rp->owner = nsdp;
rp->ip = ipalookup(ip, class, 1);
@ -187,10 +191,9 @@ getdnsservers(int class)
rp->ttl = 10*Min;
rrattach(rp, Authoritative);
}
rp = rralloc(Tns);
rp->owner = dnlookup("override#dns#servers", class, 1);
rp->host = nsdp;
return rp;
ns->owner = dnlookup(name, class, 1);
ns->host = nsdp;
return ns;
}
int
@ -201,7 +204,7 @@ setserver(char *server)
servername = nil;
cfg.resolver = 0;
}
if(server == nil || *server == 0)
if(server == nil || server[0] == 0 || server[1] == 0)
return 0;
servername = estrdup(server);
cfg.resolver = 1;
@ -276,8 +279,8 @@ docmd(int n, char **f)
name = type = nil;
tmpsrv = 0;
if(*f[0] == '@') {
if(setserver(f[0]+1) < 0)
if(*f[0] == '@' || *f[0] == '!') {
if(setserver(f[0]) < 0)
return;
switch(n){
@ -306,5 +309,5 @@ docmd(int n, char **f)
doquery(name, type);
if(tmpsrv)
setserver("");
setserver("@");
}

View file

@ -3,6 +3,8 @@
#include <bio.h>
#include <ndb.h>
#include <ip.h>
#include <mp.h>
#include <libsec.h>
#include "dns.h"
enum {
@ -12,10 +14,10 @@ enum {
static int readmsg(int, uchar*, int);
static int reply(int, uchar *, DNSmsg*, Request*, uchar*);
static int dnzone(int, uchar *, DNSmsg*, DNSmsg*, Request*, uchar*);
static int tcpannounce(char *mntpt, char *addr, char caller[128]);
static int tcpannounce(char *mntpt, char *addr, char caller[128], char *cert);
void
dntcpserver(char *mntpt, char *addr)
dntcpserver(char *mntpt, char *addr, char *cert)
{
volatile int fd, len, rcode, rv;
volatile long ms;
@ -40,7 +42,7 @@ dntcpserver(char *mntpt, char *addr)
}
procsetname("%s: tcp server %s", mntpt, addr);
if((fd = tcpannounce(mntpt, addr, caller)) < 0){
if((fd = tcpannounce(mntpt, addr, caller, cert)) < 0){
warning("can't announce %s on %s: %r", addr, mntpt);
_exits(0);
}
@ -259,13 +261,20 @@ out:
}
static int
tcpannounce(char *mntpt, char *addr, char caller[128])
tcpannounce(char *mntpt, char *addr, char caller[128], char *cert)
{
char adir[NETPATHLEN], ldir[NETPATHLEN], buf[128];
int acfd, lcfd, dfd, wfd, rfd, procs;
PEMChain *chain = nil;
if(cert != nil){
chain = readcertchain(cert);
if(chain == nil)
return -1;
}
/* announce tcp dns port */
snprint(buf, sizeof(buf), "%s/tcp!%s!53", mntpt, addr);
snprint(buf, sizeof(buf), "%s/tcp!%s!%s", mntpt, addr, cert == nil ? "53" : "853");
acfd = announce(buf, adir);
if(acfd < 0)
return -1;
@ -277,7 +286,6 @@ tcpannounce(char *mntpt, char *addr, char caller[128])
close(acfd);
return -1;
}
procs = 0;
for(;;) {
if(procs >= Maxprocs || (procs % 8) == 0){
@ -314,7 +322,23 @@ tcpannounce(char *mntpt, char *addr, char caller[128])
close(lcfd);
if(dfd < 0)
_exits(0);
if(chain != nil){
TLSconn conn;
int fd;
memset(&conn, 0, sizeof conn);
conn.cert = emalloc(conn.certlen = chain->pemlen);
memmove(conn.cert, chain->pem, conn.certlen);
conn.chain = chain->next;
fd = tlsServer(dfd, &conn);
if(fd < 0){
close(dfd);
_exits(0);
}
free(conn.cert);
free(conn.sessionID);
dfd = fd;
}
/* get the callers ip!port */
memset(caller, 0, 128);
snprint(buf, sizeof(buf), "%s/remote", ldir);
@ -322,7 +346,6 @@ tcpannounce(char *mntpt, char *addr, char caller[128])
read(rfd, caller, 128-1);
close(rfd);
}
/* child returns */
return dfd;
default: