sonic-buildimage/platform/mellanox/docker-syncd-mlnx/ecmp_calculator/ecmp_calc.py
2023-01-09 00:48:56 +08:00

511 lines
19 KiB
Python
Executable File

#!/usr/bin/env python3
import json, jsonschema
import argparse
import ipaddress
import re
import subprocess
import pprint
import os
import sys
usr_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
lib_path = os.path.join(usr_path, "lib")
ecmp_lib_path = os.path.join(lib_path, "ecmp_calc")
sys.path.append(lib_path)
sys.path.append(ecmp_lib_path)
from ecmp_calc_sdk import sx_open_sdk_connection, sx_get_active_vrids, sx_router_get_ecmp_id, \
sx_router_ecmp_nexthops_get, sx_get_router_interface, \
sx_port_vport_base_get, sx_router_neigh_get_mac, sx_fdb_uc_mac_addr_get, \
sx_lag_port_group_get, sx_make_ip_prefix_v4, sx_make_ip_prefix_v6, \
sx_vlan_ports_get, sx_ip_addr_to_str, sx_close_sdk_connection, \
PORT, VPORT, VLAN, SX_ENTRY_NOT_FOUND
from packet_scheme import PACKET_SCHEME
from port_utils import sx_get_ports_map, is_lag
IP_VERSION_IPV4 = 1
IP_VERSION_IPV6 = 2
PORT_CHANNEL_IDX = 1
VRF_NAME_IDX = 1
IP_VERSION_MAX_MASK_LEN = {IP_VERSION_IPV4: 32, IP_VERSION_IPV6: 128}
INTF_TABLE = 'INTF_TABLE'
HASH_CALC_PATH = '/usr/bin/sx_hash_calculator'
HASH_CALC_INPUT_FILE = "/tmp/hash_calculator_input.json"
HASH_CALC_OUTPUT_FILE = "/tmp/hash_calculator_output.json"
def exec_cmd(cmd):
""" Execute shell command """
return subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=False).decode("utf-8")
def is_mac_valid(mac):
return bool(re.match("^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$", mac))
def is_ip_valid(address, ip_version):
try:
if ip_version == IP_VERSION_IPV4:
ip = ipaddress.IPv4Address(address)
invalid_list = ['0.0.0.0','255.255.255.255']
else:
ip = ipaddress.IPv6Address(address)
invalid_list = ['0::0']
if ip.is_link_local:
print ("Link local IP {} is not valid".format(ip))
return False
if ip in invalid_list:
print ("IP {} is not valid".format(ip))
return False
if ip.is_multicast:
print ("Multicast IP {} is not valid".format(ip))
return False
if ip.is_loopback:
print ("Loopback IP {} is not valid".format(ip))
return False
except ipaddress.AddressValueError:
return False
return True
def load_json(filename):
data = None
with open(filename) as f:
try:
data = json.load(f)
except json.JSONDecodeError as e:
raise ValueError("Failed to load JSON file '{}', error: '{}'".format(filename, e))
return data
def create_network_addr(ip_addr, mask_len, ip_version):
ip_addr_mask = "{}/{}".format(ip_addr, mask_len)
if ip_version == IP_VERSION_IPV4:
network_addr = ipaddress.IPv4Network(ip_addr_mask,strict = False)
else:
network_addr = ipaddress.IPv6Network(ip_addr_mask,strict = False)
network_addr_ip = network_addr.with_netmask.split('/')[0]
network_addr_mask = network_addr.with_netmask.split('/')[1]
if ip_version == IP_VERSION_IPV4:
network_addr = sx_make_ip_prefix_v4(network_addr_ip, network_addr_mask)
else:
network_addr = sx_make_ip_prefix_v6(network_addr_ip, network_addr_mask)
return network_addr
class EcmpCalcExit(Exception):
pass
class EcmpCalc:
def __init__(self):
self.packet = {}
self.ports_map = {}
self.ecmp_ids = {}
self.next_hops = {}
self.user_vrf = ''
self.ingress_port = ""
self.egress_ports = []
self.debug = False
self.open_sdk_connection()
self.init_ports_map()
self.get_active_vrids()
def __del__(self):
self.close_sdk_connection()
self.cleanup()
def cleanup(self):
for filename in [HASH_CALC_INPUT_FILE, HASH_CALC_OUTPUT_FILE]:
if os.path.exists(filename):
os.remove(filename)
def close_sdk_connection(self):
sx_close_sdk_connection(self.handle)
def open_sdk_connection(self):
self.handle = sx_open_sdk_connection()
def debug_print(self, *args, **kwargs):
if self.debug == True:
print(*args, **kwargs)
def init_ports_map(self):
self.ports_map = sx_get_ports_map(self.handle)
def validate_ingress_port(self, interface):
if interface not in self.ports_map.values():
raise ValueError("Invalid interface {}".format(interface))
self.ingress_port = interface
def validate_args(self, interface, packet, vrf, debug):
if (debug is True):
self.debug = True
self.validate_ingress_port(interface)
self.validate_packet_json(packet)
if (vrf is not None):
self.user_vrf = vrf
if not self.validate_vrf():
raise ValueError("VRF validation failed: VRF {} does not exist".format(self.user_vrf))
def validate_vrf(self):
query_output = exec_cmd(['/usr/bin/redis-cli', '-n', '0', 'keys','*VRF*']).strip()
if not query_output:
return False
vrf_entries= query_output.split('\n')
for entry in vrf_entries:
vrf = entry.split(':')[VRF_NAME_IDX]
if vrf == self.user_vrf:
return True
return False
def get_ecmp_id(self):
ip_addr = self.dst_ip
ip_version = self.ip_version
max_mask_len = IP_VERSION_MAX_MASK_LEN[self.ip_version]
route_found = False
for vrid in self.vrid_list:
for mask_len in range(max_mask_len, 0, -1):
network_addr = create_network_addr(ip_addr, mask_len, ip_version)
ecmp_id = sx_router_get_ecmp_id(self.handle, vrid, network_addr)
if ecmp_id != SX_ENTRY_NOT_FOUND:
route_found = True
self.debug_print("Found route for destination IP {} ECMP id {} VRID {}".format(self.dst_ip, ecmp_id, vrid))
self.ecmp_ids[vrid] = ecmp_id
# move to next vrid
break
if not route_found:
raise EcmpCalcExit("No route found for given packet")
def get_next_hops(self):
next_hops = []
ecmp_found = False
for vrid in self.ecmp_ids.keys():
ecmp_id = self.ecmp_ids[vrid]
next_hops = sx_router_ecmp_nexthops_get(self.handle, ecmp_id)
if len(next_hops) > 1:
if self.debug:
next_hops_ips = []
for nh in next_hops:
ip = nh.next_hop_key.next_hop_key_entry.ip_next_hop.address
next_hops_ips.append(sx_ip_addr_to_str(ip))
print("Next hops IPs {}, VRID {}".format(next_hops_ips, vrid))
print("Found ECMP for destination IP {} ECMP id {}, now checking if port is member in VRF {}".
format(self.dst_ip, ecmp_id, 'default' if self.user_vrf=='' else self.user_vrf))
self.next_hops[vrid] = next_hops
ecmp_found = True
if not ecmp_found:
raise EcmpCalcExit("No ECMP for given packet")
def calculate_egress_port(self):
for vrid in self.vrid_list:
if vrid not in self.next_hops.keys():
continue
next_hops = self.next_hops[vrid]
next_hop_idx = self.get_next_hop_index(len(next_hops))
next_hop = next_hops[next_hop_idx]
rif = next_hop.next_hop_key.next_hop_key_entry.ip_next_hop.rif
ip = next_hop.next_hop_key.next_hop_key_entry.ip_next_hop.address
rif_params = sx_get_router_interface(self.handle, vrid, rif)
self.debug_print("Next hop ip to which trafic will egress: {}".format(sx_ip_addr_to_str(ip)))
# Handle router port
if PORT in rif_params:
logical = rif_params[PORT]
port_type = PORT
vlan_id = 0
# Handle vlan subinterface
elif VPORT in rif_params:
logical, vlan_id = sx_port_vport_base_get(self.handle, rif_params[VPORT])
port_type = VPORT
# Handle vlan interface
elif VLAN in rif_params:
vlan_id = rif_params[VLAN]
neigh_mac = sx_router_neigh_get_mac(self.handle, rif, ip)
if neigh_mac is not None:
mac_entry = sx_fdb_uc_mac_addr_get(self.handle, vlan_id, neigh_mac)
if mac_entry is not None:
logical = mac_entry.log_port
port_type = VLAN
# Handle flood case
if (neigh_mac is None) or (mac_entry is None):
vlan_members = sx_vlan_ports_get(self.handle, rif_params[VLAN])
for port in vlan_members:
if is_lag(port):
port = self.get_lag_member(port, True)
self.egress_ports.append(self.ports_map[port])
return
# Check if port is binded to VRF we got from the user
if is_lag(logical):
lag_logical = logical
logical = self.get_lag_member(lag_logical)
egress_port = self.ports_map[logical]
port_channel = self.get_port_channel_name(egress_port)
if self.is_port_bind_to_user_vrf(port_type, port_channel, vlan_id):
self.egress_ports.append(egress_port)
return
else:
egress_port = self.ports_map[logical]
if self.is_port_bind_to_user_vrf(port_type, egress_port, vlan_id):
self.egress_ports.append(egress_port)
return
def print_egress_port(self):
if len(self.egress_ports) == 0:
print("Egress port not found, check input parameters")
elif len(self.egress_ports) == 1:
print("Egress port: {}".format(self.egress_ports[0]))
else:
egress_ports = ''
for port in self.egress_ports:
egress_ports += ' ' + port
print("Egress ports:{}".format(egress_ports))
def is_port_bind_to_user_vrf(self, port_type, port, vlan_id = 0):
if port_type == PORT:
# INTF_TABLE:Ethernet0
entry = '{}:{}'.format(INTF_TABLE, port)
elif port_type == VPORT:
# INTF_TABLE:Ethernet0.300
entry = '{}:{}.{}'.format(INTF_TABLE, port, vlan_id)
elif port_type == VLAN:
# INTF_TABLE:Vlan300
entry = '{}:Vlan{}'.format(INTF_TABLE, vlan_id)
port_vrf = exec_cmd(['/usr/bin/redis-cli', '-n', '0', 'hget', entry, 'vrf_name'])
if self.user_vrf == port_vrf.strip():
return True
return False
# Get port-channel name for given port-channel member port
def get_port_channel_name(self, port):
query_output = exec_cmd(['/usr/bin/redis-cli', '-n', '0', 'keys','*LAG_MEMBER_TABLE*'])
for line in query_output.split('\n'):
if str(port) in line:
port_channel = line.split(':')[PORT_CHANNEL_IDX]
return port_channel
raise KeyError("Failed to get port-channel name for interface {}".format(port))
def get_ingress_port_logical_idx(self):
for logical_index, sonic_port_name in self.ports_map.items():
if sonic_port_name == self.ingress_port:
return logical_index
raise KeyError("Failed to get logical index for interface {}".format(self.ingress_port))
# Get index in next hop array from which packet will egress
def get_next_hop_index(self, ecmp_size):
logical = self.get_ingress_port_logical_idx()
ecmp_hash = {
"ingress_port": str(hex(logical)),
"packet_info":self.packet['packet_info'],
"ecmp_size": ecmp_size,
}
self.debug_print("Calling hash calculator for ECMP")
hash_result = self.call_hash_calculator({'ecmp_hash': ecmp_hash})
ecmp_hash_result = hash_result['ecmp_hash']
index = ecmp_hash_result['ecmp_index']
self.debug_print("Next hop index to which trafic will egress: {}".format(index))
return index
# Get index in LAG memebrs array from which packet will egress
def get_lag_member_index(self, lag_size, flood_case = False):
logical = self.get_ingress_port_logical_idx()
lag_hash = {
"ingress_port": str(hex(logical)),
"packet_info": self.packet['packet_info'],
"lag_size": lag_size,
}
self.debug_print("Calling hash calculator for LAG, flood case {}".format(True if flood_case else False))
hash_result = self.call_hash_calculator({"lag_hash": lag_hash})
lag_hash_result = hash_result["lag_hash"]
if flood_case:
index = lag_hash_result['lag_mc_index']
else:
index = lag_hash_result['lag_index']
self.debug_print("Lag member index from which trafic will egress: {}".format(index))
return index
# Get LAG memebr from which packet will egress
def get_lag_member(self, logical, flood_case = False):
lag_members = sx_lag_port_group_get(self.handle, logical)
lag_members.sort()
member_index = self.get_lag_member_index(len(lag_members), flood_case)
lag_member = lag_members[member_index]
self.debug_print("Lag member from which trafic will egress: {}".format(lag_member))
return lag_member
def call_hash_calculator(self, input_dict):
with open(HASH_CALC_INPUT_FILE, "w") as outfile:
json.dump(input_dict, outfile)
out = exec_cmd([HASH_CALC_PATH, '-c', HASH_CALC_INPUT_FILE, '-o', HASH_CALC_OUTPUT_FILE, '-d'])
self.debug_print ("Hash calculator output:\n{}".format(out))
with open(HASH_CALC_OUTPUT_FILE, 'r') as openfile:
output_dict = json.loads(openfile.read())
return output_dict
def get_active_vrids(self):
self.vrid_list = sx_get_active_vrids(self.handle)
def validate_ipv4_header(self, header):
for ip in ['sip', 'dip']:
if ip in header and is_ip_valid(header[ip], IP_VERSION_IPV4) == False:
raise ValueError("Json validation failed: invalid IP {}".format(header[ip]))
def validate_ipv6_header(self, header):
for ip in ['sip', 'dip']:
if ip in header and is_ip_valid(header[ip], IP_VERSION_IPV6) == False:
raise ValueError("Json validation failed: invalid IP {}".format(header[ip]))
def validate_layer2_header(self, header):
for mac in ['smac', 'dmac']:
if mac in header and is_mac_valid(header[mac]) == False:
raise ValueError("Json validation failed: invalid mac {}".format(header[mac]))
def validate_header(self, header, is_outer_header=False):
ipv4_header = False
ipv6_header = False
# Verify IPv4 and IPv6 headers do not co-exist in header
if 'ipv4' in header:
ipv4_header = True
if 'ipv6' in header:
ipv6_header = True
if ipv4_header and ipv6_header:
raise ValueError("Json validation failed: IPv4 and IPv6 headers can not co-exist")
if ipv4_header:
# Verify valid IPs in header
self.validate_ipv4_header(header['ipv4'])
if is_outer_header:
if 'dip' not in header['ipv4']:
raise ValueError("Json validation failed: destination IP is mandatory")
self.dst_ip = header['ipv4']['dip']
self.ip_version = IP_VERSION_IPV4
if 'tcp_udp' in header and 'proto' not in header['ipv4']:
raise ValueError("Json validation failed: transport protocol (proto) is mandatory when transport layer port exists")
elif ipv6_header:
self.validate_ipv6_header(header['ipv6'])
if is_outer_header:
if 'dip' not in header['ipv6']:
raise ValueError("Json validation failed: destination IP is mandatory")
self.dst_ip = header['ipv6']['dip']
self.ip_version = IP_VERSION_IPV6
if 'tcp_udp' in header and 'next_header' not in header['ipv6']:
raise ValueError("Json validation failed: transport protocol (next_header) is mandatory when transport layer port exists")
# Verify valid macs in header
if header['layer2']:
self.validate_layer2_header(header['layer2'])
def validate_outer_header(self):
outer_header = self.packet['packet_info'].get('outer')
if not outer_header:
raise ValueError("Json validation failed: outer header is mandatory")
self.validate_header(outer_header, is_outer_header=True)
def validate_inner_header(self):
inner_header = self.packet['packet_info'].get('inner')
if not inner_header:
return
self.validate_header(inner_header)
def validate_packet_json(self, packet_json):
# Verify json has valid format
self.packet = load_json(packet_json)
# Verify json schema
try:
jsonschema.validate(self.packet, PACKET_SCHEME)
except jsonschema.exceptions.ValidationError as e:
raise ValueError("Json validation failed: {}".format(e))
# Verify outer header
self.validate_outer_header()
# Verify inner header
self.validate_inner_header()
if self.debug:
print('Packet:')
pprint.pprint(self.packet)
def main():
rc = 0
try:
parser = argparse.ArgumentParser(description="ECMP calculator")
parser.add_argument("-i", "--interface", required=True, help="Ingress interface")
parser.add_argument("-p", "--packet", required=True, help="Packet description")
parser.add_argument("-v", "--vrf", help="VRF name")
parser.add_argument("-d", "--debug", default=False, action="store_true", help="Flag for debug")
args = parser.parse_args()
ecmp_calc = EcmpCalc()
ecmp_calc.validate_args(args.interface, args.packet, args.vrf, args.debug)
ecmp_calc.get_ecmp_id()
ecmp_calc.get_next_hops()
ecmp_calc.calculate_egress_port()
ecmp_calc.print_egress_port()
except EcmpCalcExit as s:
print(s)
except ValueError as s:
print("Value error: {}".format(s))
rc = 1
except Exception as s:
print("Error: {}".format(s))
rc = 2
return rc
if __name__ == "__main__":
sys.exit(main())