diff --git a/filter/f-inst.c b/filter/f-inst.c index 4041a804..cee5b8e7 100644 --- a/filter/f-inst.c +++ b/filter/f-inst.c @@ -11,47 +11,35 @@ /* Binary operators */ case FI_ADD: - ARG(1,T_INT); - ARG(2,T_INT); - res.type = T_INT; - res.val.i = v1.val.i + v2.val.i; + ARG_T(1,0,T_INT); + ARG_T(2,1,T_INT); + res.val.i += v1.val.i; break; case FI_SUBTRACT: - ARG(1,T_INT); - ARG(2,T_INT); - res.type = T_INT; - res.val.i = v1.val.i - v2.val.i; + ARG_T(1,0,T_INT); + ARG_T(2,1,T_INT); + res.val.i -= v1.val.i; break; case FI_MULTIPLY: - ARG(1,T_INT); - ARG(2,T_INT); - res.type = T_INT; - res.val.i = v1.val.i * v2.val.i; + ARG_T(1,0,T_INT); + ARG_T(2,1,T_INT); + res.val.i *= v1.val.i; break; case FI_DIVIDE: - ARG(1,T_INT); - ARG(2,T_INT); - res.type = T_INT; - if (v2.val.i == 0) runtime( "Mother told me not to divide by 0" ); - res.val.i = v1.val.i / v2.val.i; + ARG_T(1,0,T_INT); + ARG_T(2,1,T_INT); + if (v1.val.i == 0) runtime( "Mother told me not to divide by 0" ); + res.val.i /= v1.val.i; break; case FI_AND: - ARG(1,T_BOOL); - if (!v1.val.i) { - res = v1; - } else { - ARG(2,T_BOOL); - res = v2; - } + ARG_T(1,0,T_BOOL); + if (res.val.i) + ARG_T(2,0,T_BOOL); break; case FI_OR: - ARG(1,T_BOOL); - if (v1.val.i) { - res = v1; - } else { - ARG(2,T_BOOL); - res = v2; - } + ARG_T(1,0,T_BOOL); + if (!res.val.i) + ARG_T(2,0,T_BOOL); break; case FI_PAIR_CONSTRUCT: ARG(1,T_INT); @@ -184,8 +172,7 @@ break; case FI_NOT: - ARG(1,T_BOOL); - res = v1; + ARG_T(1,0,T_BOOL); res.val.i = !res.val.i; break; @@ -642,7 +629,6 @@ break; case FI_ROUTE_DISTINGUISHER: ARG(1, T_NET); - res.type = T_IP; if (!net_is_vpn(v1.val.net)) runtime( "VPN address expected" ); res.type = T_RD; @@ -671,11 +657,10 @@ res.val.i = as_path_get_last_nonaggregated(v1.val.ad); break; case FI_RETURN: - ARG_ANY(1); - res = v1; + ARG_ANY_T(1,0); return F_RETURN; case FI_CALL: - ARG_ANY(1); + ARG_ANY_T(1,0); fret = interpret(fs, what->a2.p); if (fret > F_RETURN) return fret; diff --git a/filter/filter.c b/filter/filter.c index b15ede8a..a1bb7415 100644 --- a/filter/filter.c +++ b/filter/filter.c @@ -634,6 +634,7 @@ interpret(struct filter_state *fs, struct f_inst *what) u32 as; #define res fs->stack[fs->stack_ptr].val +#define v0 res #define v1 fs->stack[fs->stack_ptr + 1].val #define v2 fs->stack[fs->stack_ptr + 2].val #define v3 fs->stack[fs->stack_ptr + 3].val @@ -650,15 +651,18 @@ interpret(struct filter_state *fs, struct f_inst *what) return F_ERROR; \ } while(0) -#define ARG_ANY(n) INTERPRET(what->a##n.p, n) +#define ARG_ANY_T(n, tt) INTERPRET(what->a##n.p, tt) +#define ARG_ANY(n) ARG_ANY_T(n, n) -#define ARG(n,t) do { \ - ARG_ANY(n); \ - if (v##n.type != t) \ +#define ARG_T(n,tt,t) do { \ + ARG_ANY_T(n,tt); \ + if (v##tt.type != t) \ runtime("Argument %d of instruction %s must be of type %02x, got %02x", \ - n, f_instruction_name(what->fi_code), t, v##n.type); \ + n, f_instruction_name(what->fi_code), t, v##tt.type); \ } while (0) +#define ARG(n,t) ARG_T(n,n,t) + #define INTERPRET(what_, n) do { \ fs->stack_ptr += n; \ fret = interpret(fs, what_); \