ss/ss.py

287 lines
14 KiB
Python
Raw Normal View History

2023-07-10 15:35:57 +08:00
#!/usr/bin/env python3
import os
import argparse
from pathlib import Path
from functools import reduce
import json
import subprocess
import ipaddress
import socket
script_path = Path(__file__).parent
local_config_path = script_path / "configs"
include_config = local_config_path / "config.common.inc"
ss_config_paths = [Path("/etc/shadowsocks"), Path("/etc/shadowsocks-rust"), Path("/etc/shadowsocks-libev")]
ss_service_name = "shadowsocks-libev-redir@" # "shadowsocks-rust@"
ss_prefix = "autogen-"
ss_config_path =[p for p in ss_config_paths if p.exists()][0]
ss_config_get = lambda name: ss_config_path / f"{ss_prefix}{name}.json"
ss_service_get = lambda name: f"{ss_service_name}{ss_prefix}{name}.service"
nft_rule_redir = script_path / "transparent-proxy.nft"
nft_rule_v6_redir = script_path / "transparent-proxy-v6.nft"
nft_rule_tproxy = script_path / "transparent-proxy-tproxy.nft"
nft_rule_v6_tproxy = script_path / "transparent-proxy-v6-tproxy.nft"
chnroute = "/etc/dnsmasq.d/chinadns_chnroute.txt"
chnroute6 = "/etc/dnsmasq.d/chinadns_chnroute6.txt"
proxy_interfaces = []
proxy_interfaces_v6 = [] or proxy_interfaces
extra_bypass = []
def gen_configs(config_name: str) -> dict:
config_inc = json.loads(include_config.read_text())
RUST_ATTRS = {'mv': ('local_address', 'local_port', 'mode', 'protocol'), 'mvdel': ('tcp_redir', 'udp_redir')}
assert config_inc['common']['tcp_redir'] in ('redirect', 'tproxy')
assert config_inc['common']['udp_redir'] == 'tproxy'
rust_only = config_inc['common']['tcp_redir'] != 'redirect'
config = json.loads((local_config_path / f"{config_name}.json").read_text())
MODES = ("ipv4_tcp_udp", "ipv6_tcp_udp", "ipv4_tcp_only", "ipv4_udp_only", "ipv6_tcp_only", "ipv6_udp_only", "ipv4_ipv6_tcp_udp")
# handle legacy config
config_common = config if "modes" not in config else config["common"]
config_inc["common"] = dict(sorted({**config_inc["common"], **config_common}.items()))
if "modes" in config:
config_inc["modes"] = config["modes"]
assert all(m in MODES for m in config_inc["modes"])
assert config_inc["modes"]
def assert_overlap():
_f = lambda ipx, proto: len([m for m in config["modes"] if ipx in m and proto in m]) > 1
IPX = ("ipv4", "ipv6")
L4PROTO = ("tcp", "udp")
assert not any(_f(ipx, proto) for ipx in IPX for proto in L4PROTO) # has overlap
ipvx_enabled = lambda ipx: any(ipx in m for m in config_inc['modes'])
# has both tcp and udp
assert all(any(True for m in config["modes"] if ipx in m and proto in m) for ipx in IPX if ipvx_enabled(ipx) for proto in L4PROTO)
assert_overlap()
for m in config_inc["modes"]:
config_inc[m] = dict(sorted({**config_inc["common"], **config_inc.get(m, dict()), **config.get(m, dict())}.items()))
for idx, m in enumerate(config_inc["modes"]):
config_inc[m] = dict(sorted({**config_inc["common"], **config_inc.get(m, dict())}.items()))
if rust_only:
config_inc[m]['locals'] = [{k: config_inc[m][k] for k in reduce(lambda x,y:x+y, RUST_ATTRS.values())}]
for _a in reduce(lambda x,y:x+y, RUST_ATTRS.values()):
config_inc[m][f"#{_a}"] = config_inc[m].pop(_a, None)
else:
for _a in reduce(lambda x,y:x+y, RUST_ATTRS.values()):
config_inc[m][f"#{_a}"] = config_inc[m].get(_a, None)
for _a in RUST_ATTRS['mvdel']:
config_inc[m].pop(_a, None)
if idx == 0:
config_inc[m]['_meta_name'] = config_name
return config_inc
def print_config_names(do_print=True) -> str:
def get_current_up() -> str:
primary_conf = ss_config_get(0)
try:
if primary_conf.exists():
current_up = json.loads(primary_conf.read_text())['_meta_name']
return current_up
except Exception:
return ""
current_up = get_current_up()
if do_print:
for conf in local_config_path.iterdir():
if conf.name.endswith('.json'):
name = conf.name[:-len('.json')]
_c = gen_configs(name)
c = _c[_c["modes"][0]]
server_info = " %s \t(%s:%d)" % (name, c["server"], c["server_port"])
if name == current_up:
server_info = ">" + server_info[1:]
print(server_info)
return current_up
def stop_and_remove(config_name):
service = ss_service_get(config_name)
if not subprocess.run(["systemctl", "is-active", service], check=False, capture_output=True).returncode:
if subprocess.run(["systemctl", "stop", service], check=False).returncode:
print(f"[!] systemctl stop {service} failed")
ss_config_get(config_name).unlink()
def stop_all_configs():
for conf in ss_config_path.iterdir():
if conf.name.endswith(".json") and conf.name.startswith(ss_prefix):
name = conf.name[len(ss_prefix):-len(".json")]
service = ss_service_get(name)
if not subprocess.run(["systemctl", "is-active", service], check=False, capture_output=True).returncode:
if subprocess.run(["systemctl", "stop", service], check=False).returncode:
print(f"[!] systemctl stop {service} failed")
print(f"stopped {service}")
def write_and_enable_configs(config_dict, dry_run=False) -> bool:
changed = [False, False, False]
def mark_changed(x):
changed[x] = True
idx_to_name = {k: v for k, v in enumerate(config_dict['modes'])}
for conf in ss_config_path.iterdir():
if conf.name.endswith(".json") and conf.name.startswith(ss_prefix):
name = conf.name[len(ss_prefix):-len(".json")]
try:
idx = int(name)
assert idx in idx_to_name
except Exception:
if dry_run:
print(f"check failed: should stop and remove {conf.name=}")
else:
stop_and_remove(name)
mark_changed(0)
for idx, name in enumerate(config_dict['modes']):
cfgname = str(idx)
cfg = ss_config_get(cfgname)
old = cfg.read_text() if cfg.exists() else ""
new = json.dumps({k:v for k, v in config_dict[name].items() if not k.startswith("#")})
config_same = new == old
if not config_same:
if dry_run:
print(f"check failed: should write {cfgname} {name}")
else:
cfg.write_text(new)
mark_changed(1)
systemd_ret = subprocess.run(["systemctl", "is-active", ss_service_get(cfgname)], check=False, capture_output=True).returncode
def restart_service(name):
service = ss_service_get(name)
if dry_run:
print(f"check failed: should start {service}")
else:
if subprocess.run(["systemctl", "restart", service], check=False).returncode:
print(f"[!] systemctl start {service} failed")
mark_changed(2)
if systemd_ret:
restart_service(cfgname)
else:
if not config_same:
restart_service(cfgname)
if changed[0]:
print("deleted old config")
if changed[1]:
print("wrote new config")
if changed[2]:
print("restart systemd")
def invoke_self_with_sudo():
assert os.getuid() != 0
import sys
return subprocess.run(["sudo", sys.executable, *sys.argv], check=False).returncode
def prepare_cgroup_path():
CGv2_ROOT = Path('/sys/fs/cgroup')
needed_slices = ('ss_bp.slice', 'ss_bp_tcp.slice', 'ss_bp_udp.slice', 'ss_fw.slice', 'ss_fw_tcp.slice', 'ss_fw_udp.slice')
for slice in needed_slices:
(CGv2_ROOT / slice).mkdir(exist_ok=True)
def process_nft_rule(configs: dict) -> list:
nft_rule, nft_rule_v6 = (nft_rule_redir, nft_rule_v6_redir) \
if configs['common']['tcp_redir'] == 'redirect' \
else (nft_rule_tproxy, nft_rule_v6_tproxy)
def get_family_proto_config(family: int, l4proto: str) -> str:
filter_family = [m for m in configs['modes'] if f"ipv{family}" in m]
mode = [m for m in filter_family if l4proto in m][0]
return mode
def process_nft_rule(family: int) -> str:
nft_lines = list(filter(None, (nft_rule_v6 if family == 6 else nft_rule).read_text().split('\n')))
nft_lines = nft_lines[nft_lines.index('## DO NOT CHANGE THIS LINE'):]
_tcp = configs[get_family_proto_config(family, 'tcp')]
_udp = configs[get_family_proto_config(family, 'udp')]
def get_server(hostname_or_ip: str):
try:
server = ipaddress.ip_address(hostname_or_ip)
except ValueError:
server = ipaddress.ip_address(socket.getaddrinfo(hostname_or_ip, None, type=socket.SOCK_RAW)[0][4][0])
return server
_tcp_server = get_server(_tcp['server'])
_udp_server = get_server(_udp['server'])
proxy_ifs_real = proxy_interfaces_v6 if family == 6 else proxy_interfaces
nft_define = {
'tcp_host': f"@empty_ipv{family}" if _tcp_server.version != family else str(_tcp_server),
'udp_host': f"@empty_ipv{family}" if _udp_server.version != family else str(_udp_server),
'tcp_proxy_ifnames': "{ %s }" % ', '.join([f'"{x}"' for x in proxy_ifs_real]) if proxy_ifs_real else '@empty_str',
'udp_proxy_ifnames': "{ %s }" % ', '.join([f'"{x}"' for x in proxy_ifs_real]) if proxy_ifs_real else '@empty_str',
'tcp_server_port': _tcp['server_port'],
'udp_server_port': _udp['server_port'],
'tcp_local_port': _tcp['#local_port'],
'udp_local_port': _udp['#local_port']
}
nft_lines = [f"define {k} = {v}" for k, v in nft_define.items()] + nft_lines
return '\n'.join(nft_lines)
ipvx_enabled = lambda x: any(f"ipv{x}" in m for m in configs['modes'])
return {x: process_nft_rule(x) for x in (4, 6) if ipvx_enabled(x)}
def flush_nft() -> bool:
nft = '\n'.join((
'add table ip transparent_proxy',
'delete table ip transparent_proxy',
'add table ip6 transparent_proxy_v6',
'delete table ip6 transparent_proxy_v6',
'add table ip6 output_deny',
'delete table ip6 output_deny',
)).encode('utf-8')
if subprocess.run(["nft", "-f", "-"], input=nft, check=False).returncode:
print("[!] nft flush failed")
return False
return True
def flush_iproute2() -> None:
ip_batch = '\n'.join(('route flush table 100', 'rule del fwmark 0xdeaf table 100')).encode('utf-8')
subprocess.run(["ip", "-force", "-batch", "-"], input=ip_batch, check=False, stderr=subprocess.DEVNULL)
subprocess.run(["ip", "-6", "-force", "-batch", "-"], input=ip_batch, check=False, stderr=subprocess.DEVNULL) # always run v6 cleanup
def main():
parser = argparse.ArgumentParser(description='ss.py')
parser.add_argument('action', type=str, default='info', nargs='?', choices=['info', 'up', 'down'], help='what to do')
parser.add_argument('config', type=str, default=None, nargs='?', help='config name')
parser.add_argument('-s', '--stop-all', action='store_true', help='stop systemd units')
args = parser.parse_args()
if args.action == 'info':
name = print_config_names()
if name:
if (local_config_path / f"{name}.json").exists():
write_and_enable_configs(gen_configs(name), dry_run=True)
else:
print(f"[!] current config {name}.json is missing")
elif args.action == 'up':
if os.getuid() != 0:
return invoke_self_with_sudo()
prepare_cgroup_path()
if not args.config:
name = print_config_names(do_print=False)
args.config = name
print("autoselected config %s" % name)
assert args.config
configs = gen_configs(args.config)
write_and_enable_configs(configs)
ipvx_enabled = lambda x: any(f"ipv{x}" in m for m in configs['modes'])
nfts = {k: v.encode('utf-8') for k, v in process_nft_rule(configs).items()}
flush_iproute2()
ip_batch = '\n'.join(('route add local default dev lo table 100', 'rule add fwmark 0xdeaf table 100')).encode('utf-8')
for x in (4, 6):
if ipvx_enabled(x):
if subprocess.run(["ip", f"-{x}", "-force", "-batch", "-"], input=ip_batch, check=False).returncode:
print(f"[!] iproute2 ipv{x} failed")
flush_nft()
for x, nft in nfts.items():
if subprocess.run(["nft", "-f", "-"], input=nft, check=False).returncode:
print(f"[!] nft ipv{x} failed, flushing")
flush_nft()
break
else:
bp = [ipaddress.ip_network(net) for net in extra_bypass]
for x in (4, 6):
if ipvx_enabled(x):
nft_chnroute = list(filter(None, Path(chnroute6 if x==6 else chnroute).read_text().split('\n')))
nft_chnroute.extend([str(net) for net in bp if net.version == x])
nft_chnroute_rule = '\n'.join([(f"add element {'ip6' if x==6 else 'ip'} "
f"transparent_proxy{'_v6' if x==6 else ''} chnroute {{ {ipx} }}") for ipx in nft_chnroute]).encode('utf-8')
if subprocess.run(["nft", "-f", "-"], input=nft_chnroute_rule, check=False).returncode:
print("[!] nft chnroute failed")
elif args.action == 'down':
if os.getuid() != 0:
return invoke_self_with_sudo()
flush_iproute2()
flush_nft()
if args.stop_all:
stop_all_configs()
if __name__ == "__main__":
exit(main() or 0)