diff --git a/proto/babel/packets.c b/proto/babel/packets.c index 415ac3f9..1d2f5f5b 100644 --- a/proto/babel/packets.c +++ b/proto/babel/packets.c @@ -120,8 +120,19 @@ struct babel_subtlv_source_prefix { #define BABEL_UF_DEF_PREFIX 0x80 #define BABEL_UF_ROUTER_ID 0x40 +struct babel_parse_state; +struct babel_write_state; + +struct babel_tlv_data { + u8 min_length; + int (*read_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_parse_state *state); + uint (*write_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_write_state *state, uint max_len); + void (*handle_tlv)(union babel_msg *m, struct babel_iface *ifa); +}; struct babel_parse_state { + const struct babel_tlv_data* (*get_tlv_data)(u8 type); + const struct babel_tlv_data* (*get_subtlv_data)(u8 type); struct babel_proto *proto; struct babel_iface *ifa; ip_addr saddr; @@ -167,6 +178,37 @@ struct babel_write_state { #define NET_SIZE(n) BYTES(net_pxlen(n)) + +/* Helper macros to loop over a series of TLVs. + * @start pointer to first TLV (void * or struct babel_tlv *) + * @end byte * pointer to TLV stream end + * @tlv struct babel_tlv pointer used as iterator + * @frame_err boolean (u8) that will be set to 1 if a frame error occurred + * @saddr source addr for use in log output + * @ifname ifname for use in log output + */ +#define WALK_TLVS(start, end, tlv, frame_err, saddr, ifname) \ + for (tlv = start; \ + (byte *)tlv < end; \ + tlv = NEXT_TLV(tlv)) \ + { \ + byte *loop_pos; \ + /* Ugly special case */ \ + if (tlv->type == BABEL_TLV_PAD1) \ + continue; \ + \ + /* The end of the common TLV header */ \ + loop_pos = (byte *)tlv + sizeof(struct babel_tlv); \ + if ((loop_pos > end) || (loop_pos + tlv->length > end)) \ + { \ + LOG_PKT("Bad TLV from %I via %s type %d pos %d - framing error", \ + saddr, ifname, tlv->type, (byte *)tlv - (byte *)start); \ + frame_err = 1; \ + break; \ + } + +#define WALK_TLVS_END } + static inline uint bytes_equal(u8 *b1, u8 *b2, uint maxlen) { @@ -255,13 +297,6 @@ static uint babel_write_route_request(struct babel_tlv *hdr, union babel_msg *ms static uint babel_write_seqno_request(struct babel_tlv *hdr, union babel_msg *msg, struct babel_write_state *state, uint max_len); static int babel_write_source_prefix(struct babel_tlv *hdr, net_addr *net, uint max_len); -struct babel_tlv_data { - u8 min_length; - int (*read_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_parse_state *state); - uint (*write_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_write_state *state, uint max_len); - void (*handle_tlv)(union babel_msg *m, struct babel_iface *ifa); -}; - static const struct babel_tlv_data tlv_data[BABEL_TLV_MAX] = { [BABEL_TLV_ACK_REQ] = { sizeof(struct babel_tlv_ack_req), @@ -319,6 +354,30 @@ static const struct babel_tlv_data tlv_data[BABEL_TLV_MAX] = { }, }; +static const struct babel_tlv_data *get_packet_tlv_data(u8 type) +{ + return type < sizeof(tlv_data) / sizeof(*tlv_data) ? &tlv_data[type] : NULL; +} + +static const struct babel_tlv_data source_prefix_tlv_data = { + sizeof(struct babel_subtlv_source_prefix), + babel_read_source_prefix, + NULL, + NULL +}; + +static const struct babel_tlv_data *get_packet_subtlv_data(u8 type) +{ + switch(type) + { + case BABEL_SUBTLV_SOURCE_PREFIX: + return &source_prefix_tlv_data; + + default: + return NULL; + } +} + static int babel_read_ack_req(struct babel_tlv *hdr, union babel_msg *m, struct babel_parse_state *state) @@ -1083,69 +1142,65 @@ babel_write_source_prefix(struct babel_tlv *hdr, net_addr *n, uint max_len) return len; } - static inline int babel_read_subtlvs(struct babel_tlv *hdr, union babel_msg *msg, struct babel_parse_state *state) { + const struct babel_tlv_data *tlv_data; + struct babel_proto *p = state->proto; struct babel_tlv *tlv; - byte *pos, *end = (byte *) hdr + TLV_LENGTH(hdr); + byte *end = (byte *) hdr + TLV_LENGTH(hdr); + u8 frame_err = 0; int res; - for (tlv = (void *) hdr + state->current_tlv_endpos; - (byte *) tlv < end; - tlv = NEXT_TLV(tlv)) + WALK_TLVS((void *)hdr + state->current_tlv_endpos, end, tlv, frame_err, + state->saddr, state->ifa->ifname) { - /* Ugly special case */ - if (tlv->type == BABEL_TLV_PAD1) + if (tlv->type == BABEL_SUBTLV_PADN) continue; - /* The end of the common TLV header */ - pos = (byte *)tlv + sizeof(struct babel_tlv); - if ((pos > end) || (pos + tlv->length > end)) - return PARSE_ERROR; - - /* - * The subtlv type space is non-contiguous (due to the mandatory bit), so - * use a switch for dispatch instead of the mapping array we use for TLVs - */ - switch (tlv->type) + if (!state->get_subtlv_data || + !(tlv_data = state->get_subtlv_data(tlv->type)) || + !tlv_data->read_tlv) { - case BABEL_SUBTLV_SOURCE_PREFIX: - res = babel_read_source_prefix(tlv, msg, state); - if (res != PARSE_SUCCESS) - return res; - break; - - case BABEL_SUBTLV_PADN: - default: /* Unknown mandatory subtlv; PARSE_IGNORE ignores the whole TLV */ if (tlv->type >= 128) - return PARSE_IGNORE; - break; + return PARSE_IGNORE; + continue; } - } - return PARSE_SUCCESS; + res = tlv_data->read_tlv(tlv, msg, state); + if (res != PARSE_SUCCESS) + return res; + } + WALK_TLVS_END; + + return frame_err ? PARSE_ERROR : PARSE_SUCCESS; } -static inline int +static int babel_read_tlv(struct babel_tlv *hdr, union babel_msg *msg, struct babel_parse_state *state) { + const struct babel_tlv_data *tlv_data; + if ((hdr->type <= BABEL_TLV_PADN) || - (hdr->type >= BABEL_TLV_MAX) || - !tlv_data[hdr->type].read_tlv) + (hdr->type >= BABEL_TLV_MAX)) return PARSE_IGNORE; - if (TLV_LENGTH(hdr) < tlv_data[hdr->type].min_length) + tlv_data = state->get_tlv_data(hdr->type); + + if (!tlv_data || !tlv_data->read_tlv) + return PARSE_IGNORE; + + if (TLV_LENGTH(hdr) < tlv_data->min_length) return PARSE_ERROR; - state->current_tlv_endpos = tlv_data[hdr->type].min_length; + state->current_tlv_endpos = tlv_data->min_length; - int res = tlv_data[hdr->type].read_tlv(hdr, msg, state); + int res = tlv_data->read_tlv(hdr, msg, state); if (res != PARSE_SUCCESS) return res; @@ -1330,6 +1385,7 @@ static void babel_process_packet(struct babel_pkt_header *pkt, int len, ip_addr saddr, struct babel_iface *ifa) { + u8 frame_err UNUSED = 0; struct babel_proto *p = ifa->proto; struct babel_tlv *tlv; struct babel_msg_node *msg; @@ -1337,15 +1393,16 @@ babel_process_packet(struct babel_pkt_header *pkt, int len, int res; int plen = sizeof(struct babel_pkt_header) + get_u16(&pkt->length); - byte *pos; byte *end = (byte *)pkt + plen; struct babel_parse_state state = { - .proto = p, - .ifa = ifa, - .saddr = saddr, - .next_hop_ip6 = saddr, - .sadr_enabled = babel_sadr_enabled(p), + .get_tlv_data = &get_packet_tlv_data, + .get_subtlv_data = &get_packet_subtlv_data, + .proto = p, + .ifa = ifa, + .saddr = saddr, + .next_hop_ip6 = saddr, + .sadr_enabled = babel_sadr_enabled(p), }; if ((pkt->magic != BABEL_MAGIC) || (pkt->version != BABEL_VERSION)) @@ -1369,23 +1426,8 @@ babel_process_packet(struct babel_pkt_header *pkt, int len, /* First pass through the packet TLV by TLV, parsing each into internal data structures. */ - for (tlv = FIRST_TLV(pkt); - (byte *)tlv < end; - tlv = NEXT_TLV(tlv)) + WALK_TLVS(FIRST_TLV(pkt), end, tlv, frame_err, saddr, ifa->iface->name) { - /* Ugly special case */ - if (tlv->type == BABEL_TLV_PAD1) - continue; - - /* The end of the common TLV header */ - pos = (byte *)tlv + sizeof(struct babel_tlv); - if ((pos > end) || (pos + tlv->length > end)) - { - LOG_PKT("Bad TLV from %I via %s type %d pos %d - framing error", - saddr, ifa->iface->name, tlv->type, (byte *)tlv - (byte *)pkt); - break; - } - msg = sl_allocz(p->msg_slab); res = babel_read_tlv(tlv, &msg->msg, &state); if (res == PARSE_SUCCESS) @@ -1405,8 +1447,9 @@ babel_process_packet(struct babel_pkt_header *pkt, int len, break; } } + WALK_TLVS_END; - /* Parsing done, handle all parsed TLVs */ + /* Parsing done, handle all parsed TLVs, regardless of any errors */ WALK_LIST_FIRST(msg, msgs) { if (tlv_data[msg->msg.type].handle_tlv)