Trie: Clarify handling of less-common net types

For convenience, Trie functions generally accept as input values not only
NET_IPx types of nets, but also NET_VPNx and NET_ROAx types. But returned
values are always NET_IPx types.
This commit is contained in:
Ondrej Zajicek (work) 2021-11-29 19:00:24 +01:00
parent 14fc24f3a5
commit 78ddfd2600

View file

@ -373,9 +373,23 @@ trie_add_prefix(struct f_trie *t, const net_addr *net, uint l, uint h)
switch (net->type) switch (net->type)
{ {
case NET_IP4: px = ipt_from_ip4(net4_prefix(net)); v4 = 1; break; case NET_IP4:
case NET_IP6: px = ipa_from_ip6(net6_prefix(net)); v4 = 0; break; case NET_VPN4:
default: bug("invalid type"); case NET_ROA4:
px = ipt_from_ip4(net4_prefix(net));
v4 = 1;
break;
case NET_IP6:
case NET_VPN6:
case NET_ROA6:
case NET_IP6_SADR:
px = ipa_from_ip6(net6_prefix(net));
v4 = 0;
break;
default:
bug("invalid type");
} }
if (t->ipv4 != v4) if (t->ipv4 != v4)
@ -562,7 +576,9 @@ trie_match_net(const struct f_trie *t, const net_addr *n)
* can be used to enumerate all matching prefixes for the network @net using * can be used to enumerate all matching prefixes for the network @net using
* function trie_match_next_longest_ip4() or macro TRIE_WALK_TO_ROOT_IP4(). * function trie_match_next_longest_ip4() or macro TRIE_WALK_TO_ROOT_IP4().
* *
* This function assumes IPv4 trie, there is also an IPv6 variant. * This function assumes IPv4 trie, there is also an IPv6 variant. The @net
* argument is typed as net_addr_ip4, but would accept any IPv4-based net_addr,
* like net4_prefix(). Anyway, returned @dst is always net_addr_ip4.
* *
* Result: 1 if a matching prefix was found, 0 if not. * Result: 1 if a matching prefix was found, 0 if not.
*/ */
@ -571,6 +587,9 @@ trie_match_longest_ip4(const struct f_trie *t, const net_addr_ip4 *net, net_addr
{ {
ASSERT(t->ipv4); ASSERT(t->ipv4);
const ip4_addr prefix = net->prefix;
const int pxlen = net->pxlen;
const struct f_trie_node4 *n = &t->root.v4; const struct f_trie_node4 *n = &t->root.v4;
int len = 0; int len = 0;
@ -580,13 +599,13 @@ trie_match_longest_ip4(const struct f_trie *t, const net_addr_ip4 *net, net_addr
while (n) while (n)
{ {
/* We are out of path */ /* We are out of path */
if (!ip4_prefix_equal(net->prefix, n->addr, MIN(net->pxlen, n->plen))) if (!ip4_prefix_equal(prefix, n->addr, MIN(pxlen, n->plen)))
goto done; goto done;
/* Check accept mask */ /* Check accept mask */
for (; len < n->plen; len++) for (; len < n->plen; len++)
{ {
if (len > net->pxlen) if (len > pxlen)
goto done; goto done;
if (ip4_getbit(n->accept, len - 1)) if (ip4_getbit(n->accept, len - 1))
@ -607,9 +626,9 @@ trie_match_longest_ip4(const struct f_trie *t, const net_addr_ip4 *net, net_addr
} }
/* Check local mask */ /* Check local mask */
for (int pos = 1; pos < (1 << TRIE_STEP); pos = 2 * pos + ip4_getbit(net->prefix, len), len++) for (int pos = 1; pos < (1 << TRIE_STEP); pos = 2 * pos + ip4_getbit(prefix, len), len++)
{ {
if (len > net->pxlen) if (len > pxlen)
goto done; goto done;
if (n->local & (1u << pos)) if (n->local & (1u << pos))
@ -621,16 +640,14 @@ trie_match_longest_ip4(const struct f_trie *t, const net_addr_ip4 *net, net_addr
} }
/* Choose child */ /* Choose child */
n = n->c[ip4_getbits(net->prefix, n->plen, TRIE_STEP)]; n = n->c[ip4_getbits(prefix, n->plen, TRIE_STEP)];
} }
done: done:
if (last < 0) if (last < 0)
return 0; return 0;
net_copy_ip4(dst, net); *dst = NET_ADDR_IP4(ip4_and(prefix, ip4_mkmask(last)), last);
dst->prefix = ip4_and(dst->prefix, ip4_mkmask(last));
dst->pxlen = last;
if (found0) if (found0)
*found0 = found; *found0 = found;
@ -653,7 +670,9 @@ done:
* can be used to enumerate all matching prefixes for the network @net using * can be used to enumerate all matching prefixes for the network @net using
* function trie_match_next_longest_ip6() or macro TRIE_WALK_TO_ROOT_IP6(). * function trie_match_next_longest_ip6() or macro TRIE_WALK_TO_ROOT_IP6().
* *
* This function assumes IPv6 trie, there is also an IPv4 variant. * This function assumes IPv6 trie, there is also an IPv4 variant. The @net
* argument is typed as net_addr_ip6, but would accept any IPv6-based net_addr,
* like net6_prefix(). Anyway, returned @dst is always net_addr_ip6.
* *
* Result: 1 if a matching prefix was found, 0 if not. * Result: 1 if a matching prefix was found, 0 if not.
*/ */
@ -662,6 +681,9 @@ trie_match_longest_ip6(const struct f_trie *t, const net_addr_ip6 *net, net_addr
{ {
ASSERT(!t->ipv4); ASSERT(!t->ipv4);
const ip6_addr prefix = net->prefix;
const int pxlen = net->pxlen;
const struct f_trie_node6 *n = &t->root.v6; const struct f_trie_node6 *n = &t->root.v6;
int len = 0; int len = 0;
@ -671,13 +693,13 @@ trie_match_longest_ip6(const struct f_trie *t, const net_addr_ip6 *net, net_addr
while (n) while (n)
{ {
/* We are out of path */ /* We are out of path */
if (!ip6_prefix_equal(net->prefix, n->addr, MIN(net->pxlen, n->plen))) if (!ip6_prefix_equal(prefix, n->addr, MIN(pxlen, n->plen)))
goto done; goto done;
/* Check accept mask */ /* Check accept mask */
for (; len < n->plen; len++) for (; len < n->plen; len++)
{ {
if (len > net->pxlen) if (len > pxlen)
goto done; goto done;
if (ip6_getbit(n->accept, len - 1)) if (ip6_getbit(n->accept, len - 1))
@ -698,9 +720,9 @@ trie_match_longest_ip6(const struct f_trie *t, const net_addr_ip6 *net, net_addr
} }
/* Check local mask */ /* Check local mask */
for (int pos = 1; pos < (1 << TRIE_STEP); pos = 2 * pos + ip6_getbit(net->prefix, len), len++) for (int pos = 1; pos < (1 << TRIE_STEP); pos = 2 * pos + ip6_getbit(prefix, len), len++)
{ {
if (len > net->pxlen) if (len > pxlen)
goto done; goto done;
if (n->local & (1u << pos)) if (n->local & (1u << pos))
@ -712,16 +734,14 @@ trie_match_longest_ip6(const struct f_trie *t, const net_addr_ip6 *net, net_addr
} }
/* Choose child */ /* Choose child */
n = n->c[ip6_getbits(net->prefix, n->plen, TRIE_STEP)]; n = n->c[ip6_getbits(prefix, n->plen, TRIE_STEP)];
} }
done: done:
if (last < 0) if (last < 0)
return 0; return 0;
net_copy_ip6(dst, net); *dst = NET_ADDR_IP6(ip6_and(prefix, ip6_mkmask(last)), last);
dst->prefix = ip6_and(dst->prefix, ip6_mkmask(last));
dst->pxlen = last;
if (found0) if (found0)
*found0 = found; *found0 = found;