Babel: Refactor TLV parsing code for easier reuse

In preparation for adding authentication checks, refactor the TLV
walking code so it can be reused for a separate pass of the packet
for authentication checks.
This commit is contained in:
Toke Høiland-Jørgensen 2021-04-15 20:15:53 +02:00 committed by Ondrej Zajicek (work)
parent 589f7d1e4f
commit 69d10132a6

View file

@ -120,8 +120,19 @@ struct babel_subtlv_source_prefix {
#define BABEL_UF_DEF_PREFIX 0x80 #define BABEL_UF_DEF_PREFIX 0x80
#define BABEL_UF_ROUTER_ID 0x40 #define BABEL_UF_ROUTER_ID 0x40
struct babel_parse_state;
struct babel_write_state;
struct babel_tlv_data {
u8 min_length;
int (*read_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_parse_state *state);
uint (*write_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_write_state *state, uint max_len);
void (*handle_tlv)(union babel_msg *m, struct babel_iface *ifa);
};
struct babel_parse_state { struct babel_parse_state {
const struct babel_tlv_data* (*get_tlv_data)(u8 type);
const struct babel_tlv_data* (*get_subtlv_data)(u8 type);
struct babel_proto *proto; struct babel_proto *proto;
struct babel_iface *ifa; struct babel_iface *ifa;
ip_addr saddr; ip_addr saddr;
@ -167,6 +178,37 @@ struct babel_write_state {
#define NET_SIZE(n) BYTES(net_pxlen(n)) #define NET_SIZE(n) BYTES(net_pxlen(n))
/* Helper macros to loop over a series of TLVs.
* @start pointer to first TLV (void * or struct babel_tlv *)
* @end byte * pointer to TLV stream end
* @tlv struct babel_tlv pointer used as iterator
* @frame_err boolean (u8) that will be set to 1 if a frame error occurred
* @saddr source addr for use in log output
* @ifname ifname for use in log output
*/
#define WALK_TLVS(start, end, tlv, frame_err, saddr, ifname) \
for (tlv = start; \
(byte *)tlv < end; \
tlv = NEXT_TLV(tlv)) \
{ \
byte *loop_pos; \
/* Ugly special case */ \
if (tlv->type == BABEL_TLV_PAD1) \
continue; \
\
/* The end of the common TLV header */ \
loop_pos = (byte *)tlv + sizeof(struct babel_tlv); \
if ((loop_pos > end) || (loop_pos + tlv->length > end)) \
{ \
LOG_PKT("Bad TLV from %I via %s type %d pos %d - framing error", \
saddr, ifname, tlv->type, (byte *)tlv - (byte *)start); \
frame_err = 1; \
break; \
}
#define WALK_TLVS_END }
static inline uint static inline uint
bytes_equal(u8 *b1, u8 *b2, uint maxlen) bytes_equal(u8 *b1, u8 *b2, uint maxlen)
{ {
@ -255,13 +297,6 @@ static uint babel_write_route_request(struct babel_tlv *hdr, union babel_msg *ms
static uint babel_write_seqno_request(struct babel_tlv *hdr, union babel_msg *msg, struct babel_write_state *state, uint max_len); static uint babel_write_seqno_request(struct babel_tlv *hdr, union babel_msg *msg, struct babel_write_state *state, uint max_len);
static int babel_write_source_prefix(struct babel_tlv *hdr, net_addr *net, uint max_len); static int babel_write_source_prefix(struct babel_tlv *hdr, net_addr *net, uint max_len);
struct babel_tlv_data {
u8 min_length;
int (*read_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_parse_state *state);
uint (*write_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_write_state *state, uint max_len);
void (*handle_tlv)(union babel_msg *m, struct babel_iface *ifa);
};
static const struct babel_tlv_data tlv_data[BABEL_TLV_MAX] = { static const struct babel_tlv_data tlv_data[BABEL_TLV_MAX] = {
[BABEL_TLV_ACK_REQ] = { [BABEL_TLV_ACK_REQ] = {
sizeof(struct babel_tlv_ack_req), sizeof(struct babel_tlv_ack_req),
@ -319,6 +354,30 @@ static const struct babel_tlv_data tlv_data[BABEL_TLV_MAX] = {
}, },
}; };
static const struct babel_tlv_data *get_packet_tlv_data(u8 type)
{
return type < sizeof(tlv_data) / sizeof(*tlv_data) ? &tlv_data[type] : NULL;
}
static const struct babel_tlv_data source_prefix_tlv_data = {
sizeof(struct babel_subtlv_source_prefix),
babel_read_source_prefix,
NULL,
NULL
};
static const struct babel_tlv_data *get_packet_subtlv_data(u8 type)
{
switch(type)
{
case BABEL_SUBTLV_SOURCE_PREFIX:
return &source_prefix_tlv_data;
default:
return NULL;
}
}
static int static int
babel_read_ack_req(struct babel_tlv *hdr, union babel_msg *m, babel_read_ack_req(struct babel_tlv *hdr, union babel_msg *m,
struct babel_parse_state *state) struct babel_parse_state *state)
@ -1083,69 +1142,65 @@ babel_write_source_prefix(struct babel_tlv *hdr, net_addr *n, uint max_len)
return len; return len;
} }
static inline int static inline int
babel_read_subtlvs(struct babel_tlv *hdr, babel_read_subtlvs(struct babel_tlv *hdr,
union babel_msg *msg, union babel_msg *msg,
struct babel_parse_state *state) struct babel_parse_state *state)
{ {
const struct babel_tlv_data *tlv_data;
struct babel_proto *p = state->proto;
struct babel_tlv *tlv; struct babel_tlv *tlv;
byte *pos, *end = (byte *) hdr + TLV_LENGTH(hdr); byte *end = (byte *) hdr + TLV_LENGTH(hdr);
u8 frame_err = 0;
int res; int res;
for (tlv = (void *) hdr + state->current_tlv_endpos; WALK_TLVS((void *)hdr + state->current_tlv_endpos, end, tlv, frame_err,
(byte *) tlv < end; state->saddr, state->ifa->ifname)
tlv = NEXT_TLV(tlv))
{ {
/* Ugly special case */ if (tlv->type == BABEL_SUBTLV_PADN)
if (tlv->type == BABEL_TLV_PAD1)
continue; continue;
/* The end of the common TLV header */ if (!state->get_subtlv_data ||
pos = (byte *)tlv + sizeof(struct babel_tlv); !(tlv_data = state->get_subtlv_data(tlv->type)) ||
if ((pos > end) || (pos + tlv->length > end)) !tlv_data->read_tlv)
return PARSE_ERROR;
/*
* The subtlv type space is non-contiguous (due to the mandatory bit), so
* use a switch for dispatch instead of the mapping array we use for TLVs
*/
switch (tlv->type)
{ {
case BABEL_SUBTLV_SOURCE_PREFIX:
res = babel_read_source_prefix(tlv, msg, state);
if (res != PARSE_SUCCESS)
return res;
break;
case BABEL_SUBTLV_PADN:
default:
/* Unknown mandatory subtlv; PARSE_IGNORE ignores the whole TLV */ /* Unknown mandatory subtlv; PARSE_IGNORE ignores the whole TLV */
if (tlv->type >= 128) if (tlv->type >= 128)
return PARSE_IGNORE; return PARSE_IGNORE;
break; continue;
} }
}
return PARSE_SUCCESS; res = tlv_data->read_tlv(tlv, msg, state);
if (res != PARSE_SUCCESS)
return res;
}
WALK_TLVS_END;
return frame_err ? PARSE_ERROR : PARSE_SUCCESS;
} }
static inline int static int
babel_read_tlv(struct babel_tlv *hdr, babel_read_tlv(struct babel_tlv *hdr,
union babel_msg *msg, union babel_msg *msg,
struct babel_parse_state *state) struct babel_parse_state *state)
{ {
const struct babel_tlv_data *tlv_data;
if ((hdr->type <= BABEL_TLV_PADN) || if ((hdr->type <= BABEL_TLV_PADN) ||
(hdr->type >= BABEL_TLV_MAX) || (hdr->type >= BABEL_TLV_MAX))
!tlv_data[hdr->type].read_tlv)
return PARSE_IGNORE; return PARSE_IGNORE;
if (TLV_LENGTH(hdr) < tlv_data[hdr->type].min_length) tlv_data = state->get_tlv_data(hdr->type);
if (!tlv_data || !tlv_data->read_tlv)
return PARSE_IGNORE;
if (TLV_LENGTH(hdr) < tlv_data->min_length)
return PARSE_ERROR; return PARSE_ERROR;
state->current_tlv_endpos = tlv_data[hdr->type].min_length; state->current_tlv_endpos = tlv_data->min_length;
int res = tlv_data[hdr->type].read_tlv(hdr, msg, state); int res = tlv_data->read_tlv(hdr, msg, state);
if (res != PARSE_SUCCESS) if (res != PARSE_SUCCESS)
return res; return res;
@ -1330,6 +1385,7 @@ static void
babel_process_packet(struct babel_pkt_header *pkt, int len, babel_process_packet(struct babel_pkt_header *pkt, int len,
ip_addr saddr, struct babel_iface *ifa) ip_addr saddr, struct babel_iface *ifa)
{ {
u8 frame_err UNUSED = 0;
struct babel_proto *p = ifa->proto; struct babel_proto *p = ifa->proto;
struct babel_tlv *tlv; struct babel_tlv *tlv;
struct babel_msg_node *msg; struct babel_msg_node *msg;
@ -1337,15 +1393,16 @@ babel_process_packet(struct babel_pkt_header *pkt, int len,
int res; int res;
int plen = sizeof(struct babel_pkt_header) + get_u16(&pkt->length); int plen = sizeof(struct babel_pkt_header) + get_u16(&pkt->length);
byte *pos;
byte *end = (byte *)pkt + plen; byte *end = (byte *)pkt + plen;
struct babel_parse_state state = { struct babel_parse_state state = {
.proto = p, .get_tlv_data = &get_packet_tlv_data,
.ifa = ifa, .get_subtlv_data = &get_packet_subtlv_data,
.saddr = saddr, .proto = p,
.next_hop_ip6 = saddr, .ifa = ifa,
.sadr_enabled = babel_sadr_enabled(p), .saddr = saddr,
.next_hop_ip6 = saddr,
.sadr_enabled = babel_sadr_enabled(p),
}; };
if ((pkt->magic != BABEL_MAGIC) || (pkt->version != BABEL_VERSION)) if ((pkt->magic != BABEL_MAGIC) || (pkt->version != BABEL_VERSION))
@ -1369,23 +1426,8 @@ babel_process_packet(struct babel_pkt_header *pkt, int len,
/* First pass through the packet TLV by TLV, parsing each into internal data /* First pass through the packet TLV by TLV, parsing each into internal data
structures. */ structures. */
for (tlv = FIRST_TLV(pkt); WALK_TLVS(FIRST_TLV(pkt), end, tlv, frame_err, saddr, ifa->iface->name)
(byte *)tlv < end;
tlv = NEXT_TLV(tlv))
{ {
/* Ugly special case */
if (tlv->type == BABEL_TLV_PAD1)
continue;
/* The end of the common TLV header */
pos = (byte *)tlv + sizeof(struct babel_tlv);
if ((pos > end) || (pos + tlv->length > end))
{
LOG_PKT("Bad TLV from %I via %s type %d pos %d - framing error",
saddr, ifa->iface->name, tlv->type, (byte *)tlv - (byte *)pkt);
break;
}
msg = sl_allocz(p->msg_slab); msg = sl_allocz(p->msg_slab);
res = babel_read_tlv(tlv, &msg->msg, &state); res = babel_read_tlv(tlv, &msg->msg, &state);
if (res == PARSE_SUCCESS) if (res == PARSE_SUCCESS)
@ -1405,8 +1447,9 @@ babel_process_packet(struct babel_pkt_header *pkt, int len,
break; break;
} }
} }
WALK_TLVS_END;
/* Parsing done, handle all parsed TLVs */ /* Parsing done, handle all parsed TLVs, regardless of any errors */
WALK_LIST_FIRST(msg, msgs) WALK_LIST_FIRST(msg, msgs)
{ {
if (tlv_data[msg->msg.type].handle_tlv) if (tlv_data[msg->msg.type].handle_tlv)