e5808020a7
- Why I did it Added ECMP calculator tool. - How I did it New files were added. - How to verify it Manual tests performed according to tests chapter in HLD Automated tests will be added by verification.
511 lines
19 KiB
Python
Executable File
511 lines
19 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import json, jsonschema
|
|
import argparse
|
|
import ipaddress
|
|
import re
|
|
import subprocess
|
|
import pprint
|
|
import os
|
|
import sys
|
|
|
|
usr_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
lib_path = os.path.join(usr_path, "lib")
|
|
ecmp_lib_path = os.path.join(lib_path, "ecmp_calc")
|
|
sys.path.append(lib_path)
|
|
sys.path.append(ecmp_lib_path)
|
|
|
|
from ecmp_calc_sdk import sx_open_sdk_connection, sx_get_active_vrids, sx_router_get_ecmp_id, \
|
|
sx_router_ecmp_nexthops_get, sx_get_router_interface, \
|
|
sx_port_vport_base_get, sx_router_neigh_get_mac, sx_fdb_uc_mac_addr_get, \
|
|
sx_lag_port_group_get, sx_make_ip_prefix_v4, sx_make_ip_prefix_v6, \
|
|
sx_vlan_ports_get, sx_ip_addr_to_str, sx_close_sdk_connection, \
|
|
PORT, VPORT, VLAN, SX_ENTRY_NOT_FOUND
|
|
from packet_scheme import PACKET_SCHEME
|
|
from port_utils import sx_get_ports_map, is_lag
|
|
|
|
IP_VERSION_IPV4 = 1
|
|
IP_VERSION_IPV6 = 2
|
|
PORT_CHANNEL_IDX = 1
|
|
VRF_NAME_IDX = 1
|
|
IP_VERSION_MAX_MASK_LEN = {IP_VERSION_IPV4: 32, IP_VERSION_IPV6: 128}
|
|
|
|
INTF_TABLE = 'INTF_TABLE'
|
|
HASH_CALC_PATH = '/usr/bin/sx_hash_calculator'
|
|
HASH_CALC_INPUT_FILE = "/tmp/hash_calculator_input.json"
|
|
HASH_CALC_OUTPUT_FILE = "/tmp/hash_calculator_output.json"
|
|
|
|
def exec_cmd(cmd):
|
|
""" Execute shell command """
|
|
return subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=False).decode("utf-8")
|
|
|
|
def is_mac_valid(mac):
|
|
return bool(re.match("^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$", mac))
|
|
|
|
def is_ip_valid(address, ip_version):
|
|
try:
|
|
if ip_version == IP_VERSION_IPV4:
|
|
ip = ipaddress.IPv4Address(address)
|
|
invalid_list = ['0.0.0.0','255.255.255.255']
|
|
else:
|
|
ip = ipaddress.IPv6Address(address)
|
|
invalid_list = ['0::0']
|
|
|
|
if ip.is_link_local:
|
|
print ("Link local IP {} is not valid".format(ip))
|
|
return False
|
|
|
|
if ip in invalid_list:
|
|
print ("IP {} is not valid".format(ip))
|
|
return False
|
|
|
|
if ip.is_multicast:
|
|
print ("Multicast IP {} is not valid".format(ip))
|
|
return False
|
|
|
|
if ip.is_loopback:
|
|
print ("Loopback IP {} is not valid".format(ip))
|
|
return False
|
|
|
|
except ipaddress.AddressValueError:
|
|
return False
|
|
|
|
return True
|
|
|
|
def load_json(filename):
|
|
data = None
|
|
with open(filename) as f:
|
|
try:
|
|
data = json.load(f)
|
|
except json.JSONDecodeError as e:
|
|
raise ValueError("Failed to load JSON file '{}', error: '{}'".format(filename, e))
|
|
return data
|
|
|
|
def create_network_addr(ip_addr, mask_len, ip_version):
|
|
ip_addr_mask = "{}/{}".format(ip_addr, mask_len)
|
|
|
|
if ip_version == IP_VERSION_IPV4:
|
|
network_addr = ipaddress.IPv4Network(ip_addr_mask,strict = False)
|
|
else:
|
|
network_addr = ipaddress.IPv6Network(ip_addr_mask,strict = False)
|
|
|
|
network_addr_ip = network_addr.with_netmask.split('/')[0]
|
|
network_addr_mask = network_addr.with_netmask.split('/')[1]
|
|
|
|
if ip_version == IP_VERSION_IPV4:
|
|
network_addr = sx_make_ip_prefix_v4(network_addr_ip, network_addr_mask)
|
|
else:
|
|
network_addr = sx_make_ip_prefix_v6(network_addr_ip, network_addr_mask)
|
|
|
|
return network_addr
|
|
|
|
class EcmpCalcExit(Exception):
|
|
pass
|
|
|
|
class EcmpCalc:
|
|
def __init__(self):
|
|
self.packet = {}
|
|
self.ports_map = {}
|
|
self.ecmp_ids = {}
|
|
self.next_hops = {}
|
|
self.user_vrf = ''
|
|
self.ingress_port = ""
|
|
self.egress_ports = []
|
|
self.debug = False
|
|
|
|
self.open_sdk_connection()
|
|
self.init_ports_map()
|
|
self.get_active_vrids()
|
|
|
|
def __del__(self):
|
|
self.close_sdk_connection()
|
|
self.cleanup()
|
|
|
|
def cleanup(self):
|
|
for filename in [HASH_CALC_INPUT_FILE, HASH_CALC_OUTPUT_FILE]:
|
|
if os.path.exists(filename):
|
|
os.remove(filename)
|
|
|
|
def close_sdk_connection(self):
|
|
sx_close_sdk_connection(self.handle)
|
|
|
|
def open_sdk_connection(self):
|
|
self.handle = sx_open_sdk_connection()
|
|
|
|
def debug_print(self, *args, **kwargs):
|
|
if self.debug == True:
|
|
print(*args, **kwargs)
|
|
|
|
def init_ports_map(self):
|
|
self.ports_map = sx_get_ports_map(self.handle)
|
|
|
|
def validate_ingress_port(self, interface):
|
|
if interface not in self.ports_map.values():
|
|
raise ValueError("Invalid interface {}".format(interface))
|
|
self.ingress_port = interface
|
|
|
|
def validate_args(self, interface, packet, vrf, debug):
|
|
if (debug is True):
|
|
self.debug = True
|
|
|
|
self.validate_ingress_port(interface)
|
|
self.validate_packet_json(packet)
|
|
|
|
if (vrf is not None):
|
|
self.user_vrf = vrf
|
|
if not self.validate_vrf():
|
|
raise ValueError("VRF validation failed: VRF {} does not exist".format(self.user_vrf))
|
|
|
|
def validate_vrf(self):
|
|
query_output = exec_cmd(['/usr/bin/redis-cli', '-n', '0', 'keys','*VRF*']).strip()
|
|
if not query_output:
|
|
return False
|
|
|
|
vrf_entries= query_output.split('\n')
|
|
for entry in vrf_entries:
|
|
vrf = entry.split(':')[VRF_NAME_IDX]
|
|
if vrf == self.user_vrf:
|
|
return True
|
|
|
|
return False
|
|
|
|
def get_ecmp_id(self):
|
|
ip_addr = self.dst_ip
|
|
ip_version = self.ip_version
|
|
max_mask_len = IP_VERSION_MAX_MASK_LEN[self.ip_version]
|
|
route_found = False
|
|
|
|
for vrid in self.vrid_list:
|
|
for mask_len in range(max_mask_len, 0, -1):
|
|
network_addr = create_network_addr(ip_addr, mask_len, ip_version)
|
|
ecmp_id = sx_router_get_ecmp_id(self.handle, vrid, network_addr)
|
|
if ecmp_id != SX_ENTRY_NOT_FOUND:
|
|
route_found = True
|
|
self.debug_print("Found route for destination IP {} ECMP id {} VRID {}".format(self.dst_ip, ecmp_id, vrid))
|
|
self.ecmp_ids[vrid] = ecmp_id
|
|
|
|
# move to next vrid
|
|
break
|
|
|
|
if not route_found:
|
|
raise EcmpCalcExit("No route found for given packet")
|
|
|
|
def get_next_hops(self):
|
|
next_hops = []
|
|
ecmp_found = False
|
|
|
|
for vrid in self.ecmp_ids.keys():
|
|
ecmp_id = self.ecmp_ids[vrid]
|
|
next_hops = sx_router_ecmp_nexthops_get(self.handle, ecmp_id)
|
|
|
|
if len(next_hops) > 1:
|
|
if self.debug:
|
|
next_hops_ips = []
|
|
for nh in next_hops:
|
|
ip = nh.next_hop_key.next_hop_key_entry.ip_next_hop.address
|
|
next_hops_ips.append(sx_ip_addr_to_str(ip))
|
|
print("Next hops IPs {}, VRID {}".format(next_hops_ips, vrid))
|
|
print("Found ECMP for destination IP {} ECMP id {}, now checking if port is member in VRF {}".
|
|
format(self.dst_ip, ecmp_id, 'default' if self.user_vrf=='' else self.user_vrf))
|
|
|
|
self.next_hops[vrid] = next_hops
|
|
ecmp_found = True
|
|
|
|
if not ecmp_found:
|
|
raise EcmpCalcExit("No ECMP for given packet")
|
|
|
|
def calculate_egress_port(self):
|
|
for vrid in self.vrid_list:
|
|
if vrid not in self.next_hops.keys():
|
|
continue
|
|
|
|
next_hops = self.next_hops[vrid]
|
|
next_hop_idx = self.get_next_hop_index(len(next_hops))
|
|
next_hop = next_hops[next_hop_idx]
|
|
|
|
rif = next_hop.next_hop_key.next_hop_key_entry.ip_next_hop.rif
|
|
ip = next_hop.next_hop_key.next_hop_key_entry.ip_next_hop.address
|
|
rif_params = sx_get_router_interface(self.handle, vrid, rif)
|
|
|
|
self.debug_print("Next hop ip to which trafic will egress: {}".format(sx_ip_addr_to_str(ip)))
|
|
|
|
# Handle router port
|
|
if PORT in rif_params:
|
|
logical = rif_params[PORT]
|
|
port_type = PORT
|
|
vlan_id = 0
|
|
|
|
# Handle vlan subinterface
|
|
elif VPORT in rif_params:
|
|
logical, vlan_id = sx_port_vport_base_get(self.handle, rif_params[VPORT])
|
|
port_type = VPORT
|
|
|
|
# Handle vlan interface
|
|
elif VLAN in rif_params:
|
|
vlan_id = rif_params[VLAN]
|
|
neigh_mac = sx_router_neigh_get_mac(self.handle, rif, ip)
|
|
if neigh_mac is not None:
|
|
mac_entry = sx_fdb_uc_mac_addr_get(self.handle, vlan_id, neigh_mac)
|
|
if mac_entry is not None:
|
|
logical = mac_entry.log_port
|
|
port_type = VLAN
|
|
|
|
# Handle flood case
|
|
if (neigh_mac is None) or (mac_entry is None):
|
|
vlan_members = sx_vlan_ports_get(self.handle, rif_params[VLAN])
|
|
for port in vlan_members:
|
|
if is_lag(port):
|
|
port = self.get_lag_member(port, True)
|
|
self.egress_ports.append(self.ports_map[port])
|
|
return
|
|
|
|
# Check if port is binded to VRF we got from the user
|
|
if is_lag(logical):
|
|
lag_logical = logical
|
|
logical = self.get_lag_member(lag_logical)
|
|
egress_port = self.ports_map[logical]
|
|
|
|
port_channel = self.get_port_channel_name(egress_port)
|
|
if self.is_port_bind_to_user_vrf(port_type, port_channel, vlan_id):
|
|
self.egress_ports.append(egress_port)
|
|
return
|
|
else:
|
|
egress_port = self.ports_map[logical]
|
|
if self.is_port_bind_to_user_vrf(port_type, egress_port, vlan_id):
|
|
self.egress_ports.append(egress_port)
|
|
return
|
|
|
|
def print_egress_port(self):
|
|
if len(self.egress_ports) == 0:
|
|
print("Egress port not found, check input parameters")
|
|
elif len(self.egress_ports) == 1:
|
|
print("Egress port: {}".format(self.egress_ports[0]))
|
|
else:
|
|
egress_ports = ''
|
|
for port in self.egress_ports:
|
|
egress_ports += ' ' + port
|
|
print("Egress ports:{}".format(egress_ports))
|
|
|
|
def is_port_bind_to_user_vrf(self, port_type, port, vlan_id = 0):
|
|
if port_type == PORT:
|
|
# INTF_TABLE:Ethernet0
|
|
entry = '{}:{}'.format(INTF_TABLE, port)
|
|
elif port_type == VPORT:
|
|
# INTF_TABLE:Ethernet0.300
|
|
entry = '{}:{}.{}'.format(INTF_TABLE, port, vlan_id)
|
|
elif port_type == VLAN:
|
|
# INTF_TABLE:Vlan300
|
|
entry = '{}:Vlan{}'.format(INTF_TABLE, vlan_id)
|
|
|
|
port_vrf = exec_cmd(['/usr/bin/redis-cli', '-n', '0', 'hget', entry, 'vrf_name'])
|
|
if self.user_vrf == port_vrf.strip():
|
|
return True
|
|
|
|
return False
|
|
|
|
# Get port-channel name for given port-channel member port
|
|
def get_port_channel_name(self, port):
|
|
query_output = exec_cmd(['/usr/bin/redis-cli', '-n', '0', 'keys','*LAG_MEMBER_TABLE*'])
|
|
for line in query_output.split('\n'):
|
|
if str(port) in line:
|
|
port_channel = line.split(':')[PORT_CHANNEL_IDX]
|
|
return port_channel
|
|
|
|
raise KeyError("Failed to get port-channel name for interface {}".format(port))
|
|
|
|
def get_ingress_port_logical_idx(self):
|
|
for logical_index, sonic_port_name in self.ports_map.items():
|
|
if sonic_port_name == self.ingress_port:
|
|
return logical_index
|
|
|
|
raise KeyError("Failed to get logical index for interface {}".format(self.ingress_port))
|
|
|
|
# Get index in next hop array from which packet will egress
|
|
def get_next_hop_index(self, ecmp_size):
|
|
logical = self.get_ingress_port_logical_idx()
|
|
|
|
ecmp_hash = {
|
|
"ingress_port": str(hex(logical)),
|
|
"packet_info":self.packet['packet_info'],
|
|
"ecmp_size": ecmp_size,
|
|
}
|
|
|
|
self.debug_print("Calling hash calculator for ECMP")
|
|
hash_result = self.call_hash_calculator({'ecmp_hash': ecmp_hash})
|
|
ecmp_hash_result = hash_result['ecmp_hash']
|
|
index = ecmp_hash_result['ecmp_index']
|
|
|
|
self.debug_print("Next hop index to which trafic will egress: {}".format(index))
|
|
return index
|
|
|
|
# Get index in LAG memebrs array from which packet will egress
|
|
def get_lag_member_index(self, lag_size, flood_case = False):
|
|
logical = self.get_ingress_port_logical_idx()
|
|
|
|
lag_hash = {
|
|
"ingress_port": str(hex(logical)),
|
|
"packet_info": self.packet['packet_info'],
|
|
"lag_size": lag_size,
|
|
}
|
|
|
|
self.debug_print("Calling hash calculator for LAG, flood case {}".format(True if flood_case else False))
|
|
hash_result = self.call_hash_calculator({"lag_hash": lag_hash})
|
|
lag_hash_result = hash_result["lag_hash"]
|
|
|
|
if flood_case:
|
|
index = lag_hash_result['lag_mc_index']
|
|
else:
|
|
index = lag_hash_result['lag_index']
|
|
|
|
self.debug_print("Lag member index from which trafic will egress: {}".format(index))
|
|
return index
|
|
|
|
# Get LAG memebr from which packet will egress
|
|
def get_lag_member(self, logical, flood_case = False):
|
|
lag_members = sx_lag_port_group_get(self.handle, logical)
|
|
lag_members.sort()
|
|
|
|
member_index = self.get_lag_member_index(len(lag_members), flood_case)
|
|
lag_member = lag_members[member_index]
|
|
|
|
self.debug_print("Lag member from which trafic will egress: {}".format(lag_member))
|
|
return lag_member
|
|
|
|
def call_hash_calculator(self, input_dict):
|
|
with open(HASH_CALC_INPUT_FILE, "w") as outfile:
|
|
json.dump(input_dict, outfile)
|
|
|
|
out = exec_cmd([HASH_CALC_PATH, '-c', HASH_CALC_INPUT_FILE, '-o', HASH_CALC_OUTPUT_FILE, '-d'])
|
|
self.debug_print ("Hash calculator output:\n{}".format(out))
|
|
|
|
with open(HASH_CALC_OUTPUT_FILE, 'r') as openfile:
|
|
output_dict = json.loads(openfile.read())
|
|
|
|
return output_dict
|
|
|
|
def get_active_vrids(self):
|
|
self.vrid_list = sx_get_active_vrids(self.handle)
|
|
|
|
def validate_ipv4_header(self, header):
|
|
for ip in ['sip', 'dip']:
|
|
if ip in header and is_ip_valid(header[ip], IP_VERSION_IPV4) == False:
|
|
raise ValueError("Json validation failed: invalid IP {}".format(header[ip]))
|
|
|
|
def validate_ipv6_header(self, header):
|
|
for ip in ['sip', 'dip']:
|
|
if ip in header and is_ip_valid(header[ip], IP_VERSION_IPV6) == False:
|
|
raise ValueError("Json validation failed: invalid IP {}".format(header[ip]))
|
|
|
|
def validate_layer2_header(self, header):
|
|
for mac in ['smac', 'dmac']:
|
|
if mac in header and is_mac_valid(header[mac]) == False:
|
|
raise ValueError("Json validation failed: invalid mac {}".format(header[mac]))
|
|
|
|
def validate_header(self, header, is_outer_header=False):
|
|
ipv4_header = False
|
|
ipv6_header = False
|
|
|
|
# Verify IPv4 and IPv6 headers do not co-exist in header
|
|
if 'ipv4' in header:
|
|
ipv4_header = True
|
|
if 'ipv6' in header:
|
|
ipv6_header = True
|
|
|
|
if ipv4_header and ipv6_header:
|
|
raise ValueError("Json validation failed: IPv4 and IPv6 headers can not co-exist")
|
|
|
|
if ipv4_header:
|
|
# Verify valid IPs in header
|
|
self.validate_ipv4_header(header['ipv4'])
|
|
|
|
if is_outer_header:
|
|
if 'dip' not in header['ipv4']:
|
|
raise ValueError("Json validation failed: destination IP is mandatory")
|
|
|
|
self.dst_ip = header['ipv4']['dip']
|
|
self.ip_version = IP_VERSION_IPV4
|
|
|
|
if 'tcp_udp' in header and 'proto' not in header['ipv4']:
|
|
raise ValueError("Json validation failed: transport protocol (proto) is mandatory when transport layer port exists")
|
|
|
|
elif ipv6_header:
|
|
self.validate_ipv6_header(header['ipv6'])
|
|
|
|
if is_outer_header:
|
|
if 'dip' not in header['ipv6']:
|
|
raise ValueError("Json validation failed: destination IP is mandatory")
|
|
|
|
self.dst_ip = header['ipv6']['dip']
|
|
self.ip_version = IP_VERSION_IPV6
|
|
|
|
if 'tcp_udp' in header and 'next_header' not in header['ipv6']:
|
|
raise ValueError("Json validation failed: transport protocol (next_header) is mandatory when transport layer port exists")
|
|
|
|
# Verify valid macs in header
|
|
if header['layer2']:
|
|
self.validate_layer2_header(header['layer2'])
|
|
|
|
def validate_outer_header(self):
|
|
outer_header = self.packet['packet_info'].get('outer')
|
|
if not outer_header:
|
|
raise ValueError("Json validation failed: outer header is mandatory")
|
|
|
|
self.validate_header(outer_header, is_outer_header=True)
|
|
|
|
def validate_inner_header(self):
|
|
inner_header = self.packet['packet_info'].get('inner')
|
|
if not inner_header:
|
|
return
|
|
|
|
self.validate_header(inner_header)
|
|
|
|
def validate_packet_json(self, packet_json):
|
|
# Verify json has valid format
|
|
self.packet = load_json(packet_json)
|
|
|
|
# Verify json schema
|
|
try:
|
|
jsonschema.validate(self.packet, PACKET_SCHEME)
|
|
except jsonschema.exceptions.ValidationError as e:
|
|
raise ValueError("Json validation failed: {}".format(e))
|
|
|
|
# Verify outer header
|
|
self.validate_outer_header()
|
|
|
|
# Verify inner header
|
|
self.validate_inner_header()
|
|
|
|
if self.debug:
|
|
print('Packet:')
|
|
pprint.pprint(self.packet)
|
|
|
|
def main():
|
|
rc = 0
|
|
try:
|
|
parser = argparse.ArgumentParser(description="ECMP calculator")
|
|
parser.add_argument("-i", "--interface", required=True, help="Ingress interface")
|
|
parser.add_argument("-p", "--packet", required=True, help="Packet description")
|
|
parser.add_argument("-v", "--vrf", help="VRF name")
|
|
parser.add_argument("-d", "--debug", default=False, action="store_true", help="Flag for debug")
|
|
args = parser.parse_args()
|
|
|
|
ecmp_calc = EcmpCalc()
|
|
ecmp_calc.validate_args(args.interface, args.packet, args.vrf, args.debug)
|
|
ecmp_calc.get_ecmp_id()
|
|
ecmp_calc.get_next_hops()
|
|
ecmp_calc.calculate_egress_port()
|
|
ecmp_calc.print_egress_port()
|
|
|
|
except EcmpCalcExit as s:
|
|
print(s)
|
|
except ValueError as s:
|
|
print("Value error: {}".format(s))
|
|
rc = 1
|
|
except Exception as s:
|
|
print("Error: {}".format(s))
|
|
rc = 2
|
|
return rc
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|