511 lines
19 KiB
Python
511 lines
19 KiB
Python
|
#!/usr/bin/env python3
|
||
|
|
||
|
import json, jsonschema
|
||
|
import argparse
|
||
|
import ipaddress
|
||
|
import re
|
||
|
import subprocess
|
||
|
import pprint
|
||
|
import os
|
||
|
import sys
|
||
|
|
||
|
usr_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
|
lib_path = os.path.join(usr_path, "lib")
|
||
|
ecmp_lib_path = os.path.join(lib_path, "ecmp_calc")
|
||
|
sys.path.append(lib_path)
|
||
|
sys.path.append(ecmp_lib_path)
|
||
|
|
||
|
from ecmp_calc_sdk import sx_open_sdk_connection, sx_get_active_vrids, sx_router_get_ecmp_id, \
|
||
|
sx_router_ecmp_nexthops_get, sx_get_router_interface, \
|
||
|
sx_port_vport_base_get, sx_router_neigh_get_mac, sx_fdb_uc_mac_addr_get, \
|
||
|
sx_lag_port_group_get, sx_make_ip_prefix_v4, sx_make_ip_prefix_v6, \
|
||
|
sx_vlan_ports_get, sx_ip_addr_to_str, sx_close_sdk_connection, \
|
||
|
PORT, VPORT, VLAN, SX_ENTRY_NOT_FOUND
|
||
|
from packet_scheme import PACKET_SCHEME
|
||
|
from port_utils import sx_get_ports_map, is_lag
|
||
|
|
||
|
IP_VERSION_IPV4 = 1
|
||
|
IP_VERSION_IPV6 = 2
|
||
|
PORT_CHANNEL_IDX = 1
|
||
|
VRF_NAME_IDX = 1
|
||
|
IP_VERSION_MAX_MASK_LEN = {IP_VERSION_IPV4: 32, IP_VERSION_IPV6: 128}
|
||
|
|
||
|
INTF_TABLE = 'INTF_TABLE'
|
||
|
HASH_CALC_PATH = '/usr/bin/sx_hash_calculator'
|
||
|
HASH_CALC_INPUT_FILE = "/tmp/hash_calculator_input.json"
|
||
|
HASH_CALC_OUTPUT_FILE = "/tmp/hash_calculator_output.json"
|
||
|
|
||
|
def exec_cmd(cmd):
|
||
|
""" Execute shell command """
|
||
|
return subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=False).decode("utf-8")
|
||
|
|
||
|
def is_mac_valid(mac):
|
||
|
return bool(re.match("^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$", mac))
|
||
|
|
||
|
def is_ip_valid(address, ip_version):
|
||
|
try:
|
||
|
if ip_version == IP_VERSION_IPV4:
|
||
|
ip = ipaddress.IPv4Address(address)
|
||
|
invalid_list = ['0.0.0.0','255.255.255.255']
|
||
|
else:
|
||
|
ip = ipaddress.IPv6Address(address)
|
||
|
invalid_list = ['0::0']
|
||
|
|
||
|
if ip.is_link_local:
|
||
|
print ("Link local IP {} is not valid".format(ip))
|
||
|
return False
|
||
|
|
||
|
if ip in invalid_list:
|
||
|
print ("IP {} is not valid".format(ip))
|
||
|
return False
|
||
|
|
||
|
if ip.is_multicast:
|
||
|
print ("Multicast IP {} is not valid".format(ip))
|
||
|
return False
|
||
|
|
||
|
if ip.is_loopback:
|
||
|
print ("Loopback IP {} is not valid".format(ip))
|
||
|
return False
|
||
|
|
||
|
except ipaddress.AddressValueError:
|
||
|
return False
|
||
|
|
||
|
return True
|
||
|
|
||
|
def load_json(filename):
|
||
|
data = None
|
||
|
with open(filename) as f:
|
||
|
try:
|
||
|
data = json.load(f)
|
||
|
except json.JSONDecodeError as e:
|
||
|
raise ValueError("Failed to load JSON file '{}', error: '{}'".format(filename, e))
|
||
|
return data
|
||
|
|
||
|
def create_network_addr(ip_addr, mask_len, ip_version):
|
||
|
ip_addr_mask = "{}/{}".format(ip_addr, mask_len)
|
||
|
|
||
|
if ip_version == IP_VERSION_IPV4:
|
||
|
network_addr = ipaddress.IPv4Network(ip_addr_mask,strict = False)
|
||
|
else:
|
||
|
network_addr = ipaddress.IPv6Network(ip_addr_mask,strict = False)
|
||
|
|
||
|
network_addr_ip = network_addr.with_netmask.split('/')[0]
|
||
|
network_addr_mask = network_addr.with_netmask.split('/')[1]
|
||
|
|
||
|
if ip_version == IP_VERSION_IPV4:
|
||
|
network_addr = sx_make_ip_prefix_v4(network_addr_ip, network_addr_mask)
|
||
|
else:
|
||
|
network_addr = sx_make_ip_prefix_v6(network_addr_ip, network_addr_mask)
|
||
|
|
||
|
return network_addr
|
||
|
|
||
|
class EcmpCalcExit(Exception):
|
||
|
pass
|
||
|
|
||
|
class EcmpCalc:
|
||
|
def __init__(self):
|
||
|
self.packet = {}
|
||
|
self.ports_map = {}
|
||
|
self.ecmp_ids = {}
|
||
|
self.next_hops = {}
|
||
|
self.user_vrf = ''
|
||
|
self.ingress_port = ""
|
||
|
self.egress_ports = []
|
||
|
self.debug = False
|
||
|
|
||
|
self.open_sdk_connection()
|
||
|
self.init_ports_map()
|
||
|
self.get_active_vrids()
|
||
|
|
||
|
def __del__(self):
|
||
|
self.close_sdk_connection()
|
||
|
self.cleanup()
|
||
|
|
||
|
def cleanup(self):
|
||
|
for filename in [HASH_CALC_INPUT_FILE, HASH_CALC_OUTPUT_FILE]:
|
||
|
if os.path.exists(filename):
|
||
|
os.remove(filename)
|
||
|
|
||
|
def close_sdk_connection(self):
|
||
|
sx_close_sdk_connection(self.handle)
|
||
|
|
||
|
def open_sdk_connection(self):
|
||
|
self.handle = sx_open_sdk_connection()
|
||
|
|
||
|
def debug_print(self, *args, **kwargs):
|
||
|
if self.debug == True:
|
||
|
print(*args, **kwargs)
|
||
|
|
||
|
def init_ports_map(self):
|
||
|
self.ports_map = sx_get_ports_map(self.handle)
|
||
|
|
||
|
def validate_ingress_port(self, interface):
|
||
|
if interface not in self.ports_map.values():
|
||
|
raise ValueError("Invalid interface {}".format(interface))
|
||
|
self.ingress_port = interface
|
||
|
|
||
|
def validate_args(self, interface, packet, vrf, debug):
|
||
|
if (debug is True):
|
||
|
self.debug = True
|
||
|
|
||
|
self.validate_ingress_port(interface)
|
||
|
self.validate_packet_json(packet)
|
||
|
|
||
|
if (vrf is not None):
|
||
|
self.user_vrf = vrf
|
||
|
if not self.validate_vrf():
|
||
|
raise ValueError("VRF validation failed: VRF {} does not exist".format(self.user_vrf))
|
||
|
|
||
|
def validate_vrf(self):
|
||
|
query_output = exec_cmd(['/usr/bin/redis-cli', '-n', '0', 'keys','*VRF*']).strip()
|
||
|
if not query_output:
|
||
|
return False
|
||
|
|
||
|
vrf_entries= query_output.split('\n')
|
||
|
for entry in vrf_entries:
|
||
|
vrf = entry.split(':')[VRF_NAME_IDX]
|
||
|
if vrf == self.user_vrf:
|
||
|
return True
|
||
|
|
||
|
return False
|
||
|
|
||
|
def get_ecmp_id(self):
|
||
|
ip_addr = self.dst_ip
|
||
|
ip_version = self.ip_version
|
||
|
max_mask_len = IP_VERSION_MAX_MASK_LEN[self.ip_version]
|
||
|
route_found = False
|
||
|
|
||
|
for vrid in self.vrid_list:
|
||
|
for mask_len in range(max_mask_len, 0, -1):
|
||
|
network_addr = create_network_addr(ip_addr, mask_len, ip_version)
|
||
|
ecmp_id = sx_router_get_ecmp_id(self.handle, vrid, network_addr)
|
||
|
if ecmp_id != SX_ENTRY_NOT_FOUND:
|
||
|
route_found = True
|
||
|
self.debug_print("Found route for destination IP {} ECMP id {} VRID {}".format(self.dst_ip, ecmp_id, vrid))
|
||
|
self.ecmp_ids[vrid] = ecmp_id
|
||
|
|
||
|
# move to next vrid
|
||
|
break
|
||
|
|
||
|
if not route_found:
|
||
|
raise EcmpCalcExit("No route found for given packet")
|
||
|
|
||
|
def get_next_hops(self):
|
||
|
next_hops = []
|
||
|
ecmp_found = False
|
||
|
|
||
|
for vrid in self.ecmp_ids.keys():
|
||
|
ecmp_id = self.ecmp_ids[vrid]
|
||
|
next_hops = sx_router_ecmp_nexthops_get(self.handle, ecmp_id)
|
||
|
|
||
|
if len(next_hops) > 1:
|
||
|
if self.debug:
|
||
|
next_hops_ips = []
|
||
|
for nh in next_hops:
|
||
|
ip = nh.next_hop_key.next_hop_key_entry.ip_next_hop.address
|
||
|
next_hops_ips.append(sx_ip_addr_to_str(ip))
|
||
|
print("Next hops IPs {}, VRID {}".format(next_hops_ips, vrid))
|
||
|
print("Found ECMP for destination IP {} ECMP id {}, now checking if port is member in VRF {}".
|
||
|
format(self.dst_ip, ecmp_id, 'default' if self.user_vrf=='' else self.user_vrf))
|
||
|
|
||
|
self.next_hops[vrid] = next_hops
|
||
|
ecmp_found = True
|
||
|
|
||
|
if not ecmp_found:
|
||
|
raise EcmpCalcExit("No ECMP for given packet")
|
||
|
|
||
|
def calculate_egress_port(self):
|
||
|
for vrid in self.vrid_list:
|
||
|
if vrid not in self.next_hops.keys():
|
||
|
continue
|
||
|
|
||
|
next_hops = self.next_hops[vrid]
|
||
|
next_hop_idx = self.get_next_hop_index(len(next_hops))
|
||
|
next_hop = next_hops[next_hop_idx]
|
||
|
|
||
|
rif = next_hop.next_hop_key.next_hop_key_entry.ip_next_hop.rif
|
||
|
ip = next_hop.next_hop_key.next_hop_key_entry.ip_next_hop.address
|
||
|
rif_params = sx_get_router_interface(self.handle, vrid, rif)
|
||
|
|
||
|
self.debug_print("Next hop ip to which trafic will egress: {}".format(sx_ip_addr_to_str(ip)))
|
||
|
|
||
|
# Handle router port
|
||
|
if PORT in rif_params:
|
||
|
logical = rif_params[PORT]
|
||
|
port_type = PORT
|
||
|
vlan_id = 0
|
||
|
|
||
|
# Handle vlan subinterface
|
||
|
elif VPORT in rif_params:
|
||
|
logical, vlan_id = sx_port_vport_base_get(self.handle, rif_params[VPORT])
|
||
|
port_type = VPORT
|
||
|
|
||
|
# Handle vlan interface
|
||
|
elif VLAN in rif_params:
|
||
|
vlan_id = rif_params[VLAN]
|
||
|
neigh_mac = sx_router_neigh_get_mac(self.handle, rif, ip)
|
||
|
if neigh_mac is not None:
|
||
|
mac_entry = sx_fdb_uc_mac_addr_get(self.handle, vlan_id, neigh_mac)
|
||
|
if mac_entry is not None:
|
||
|
logical = mac_entry.log_port
|
||
|
port_type = VLAN
|
||
|
|
||
|
# Handle flood case
|
||
|
if (neigh_mac is None) or (mac_entry is None):
|
||
|
vlan_members = sx_vlan_ports_get(self.handle, rif_params[VLAN])
|
||
|
for port in vlan_members:
|
||
|
if is_lag(port):
|
||
|
port = self.get_lag_member(port, True)
|
||
|
self.egress_ports.append(self.ports_map[port])
|
||
|
return
|
||
|
|
||
|
# Check if port is binded to VRF we got from the user
|
||
|
if is_lag(logical):
|
||
|
lag_logical = logical
|
||
|
logical = self.get_lag_member(lag_logical)
|
||
|
egress_port = self.ports_map[logical]
|
||
|
|
||
|
port_channel = self.get_port_channel_name(egress_port)
|
||
|
if self.is_port_bind_to_user_vrf(port_type, port_channel, vlan_id):
|
||
|
self.egress_ports.append(egress_port)
|
||
|
return
|
||
|
else:
|
||
|
egress_port = self.ports_map[logical]
|
||
|
if self.is_port_bind_to_user_vrf(port_type, egress_port, vlan_id):
|
||
|
self.egress_ports.append(egress_port)
|
||
|
return
|
||
|
|
||
|
def print_egress_port(self):
|
||
|
if len(self.egress_ports) == 0:
|
||
|
print("Egress port not found, check input parameters")
|
||
|
elif len(self.egress_ports) == 1:
|
||
|
print("Egress port: {}".format(self.egress_ports[0]))
|
||
|
else:
|
||
|
egress_ports = ''
|
||
|
for port in self.egress_ports:
|
||
|
egress_ports += ' ' + port
|
||
|
print("Egress ports:{}".format(egress_ports))
|
||
|
|
||
|
def is_port_bind_to_user_vrf(self, port_type, port, vlan_id = 0):
|
||
|
if port_type == PORT:
|
||
|
# INTF_TABLE:Ethernet0
|
||
|
entry = '{}:{}'.format(INTF_TABLE, port)
|
||
|
elif port_type == VPORT:
|
||
|
# INTF_TABLE:Ethernet0.300
|
||
|
entry = '{}:{}.{}'.format(INTF_TABLE, port, vlan_id)
|
||
|
elif port_type == VLAN:
|
||
|
# INTF_TABLE:Vlan300
|
||
|
entry = '{}:Vlan{}'.format(INTF_TABLE, vlan_id)
|
||
|
|
||
|
port_vrf = exec_cmd(['/usr/bin/redis-cli', '-n', '0', 'hget', entry, 'vrf_name'])
|
||
|
if self.user_vrf == port_vrf.strip():
|
||
|
return True
|
||
|
|
||
|
return False
|
||
|
|
||
|
# Get port-channel name for given port-channel member port
|
||
|
def get_port_channel_name(self, port):
|
||
|
query_output = exec_cmd(['/usr/bin/redis-cli', '-n', '0', 'keys','*LAG_MEMBER_TABLE*'])
|
||
|
for line in query_output.split('\n'):
|
||
|
if str(port) in line:
|
||
|
port_channel = line.split(':')[PORT_CHANNEL_IDX]
|
||
|
return port_channel
|
||
|
|
||
|
raise KeyError("Failed to get port-channel name for interface {}".format(port))
|
||
|
|
||
|
def get_ingress_port_logical_idx(self):
|
||
|
for logical_index, sonic_port_name in self.ports_map.items():
|
||
|
if sonic_port_name == self.ingress_port:
|
||
|
return logical_index
|
||
|
|
||
|
raise KeyError("Failed to get logical index for interface {}".format(self.ingress_port))
|
||
|
|
||
|
# Get index in next hop array from which packet will egress
|
||
|
def get_next_hop_index(self, ecmp_size):
|
||
|
logical = self.get_ingress_port_logical_idx()
|
||
|
|
||
|
ecmp_hash = {
|
||
|
"ingress_port": str(hex(logical)),
|
||
|
"packet_info":self.packet['packet_info'],
|
||
|
"ecmp_size": ecmp_size,
|
||
|
}
|
||
|
|
||
|
self.debug_print("Calling hash calculator for ECMP")
|
||
|
hash_result = self.call_hash_calculator({'ecmp_hash': ecmp_hash})
|
||
|
ecmp_hash_result = hash_result['ecmp_hash']
|
||
|
index = ecmp_hash_result['ecmp_index']
|
||
|
|
||
|
self.debug_print("Next hop index to which trafic will egress: {}".format(index))
|
||
|
return index
|
||
|
|
||
|
# Get index in LAG memebrs array from which packet will egress
|
||
|
def get_lag_member_index(self, lag_size, flood_case = False):
|
||
|
logical = self.get_ingress_port_logical_idx()
|
||
|
|
||
|
lag_hash = {
|
||
|
"ingress_port": str(hex(logical)),
|
||
|
"packet_info": self.packet['packet_info'],
|
||
|
"lag_size": lag_size,
|
||
|
}
|
||
|
|
||
|
self.debug_print("Calling hash calculator for LAG, flood case {}".format(True if flood_case else False))
|
||
|
hash_result = self.call_hash_calculator({"lag_hash": lag_hash})
|
||
|
lag_hash_result = hash_result["lag_hash"]
|
||
|
|
||
|
if flood_case:
|
||
|
index = lag_hash_result['lag_mc_index']
|
||
|
else:
|
||
|
index = lag_hash_result['lag_index']
|
||
|
|
||
|
self.debug_print("Lag member index from which trafic will egress: {}".format(index))
|
||
|
return index
|
||
|
|
||
|
# Get LAG memebr from which packet will egress
|
||
|
def get_lag_member(self, logical, flood_case = False):
|
||
|
lag_members = sx_lag_port_group_get(self.handle, logical)
|
||
|
lag_members.sort()
|
||
|
|
||
|
member_index = self.get_lag_member_index(len(lag_members), flood_case)
|
||
|
lag_member = lag_members[member_index]
|
||
|
|
||
|
self.debug_print("Lag member from which trafic will egress: {}".format(lag_member))
|
||
|
return lag_member
|
||
|
|
||
|
def call_hash_calculator(self, input_dict):
|
||
|
with open(HASH_CALC_INPUT_FILE, "w") as outfile:
|
||
|
json.dump(input_dict, outfile)
|
||
|
|
||
|
out = exec_cmd([HASH_CALC_PATH, '-c', HASH_CALC_INPUT_FILE, '-o', HASH_CALC_OUTPUT_FILE, '-d'])
|
||
|
self.debug_print ("Hash calculator output:\n{}".format(out))
|
||
|
|
||
|
with open(HASH_CALC_OUTPUT_FILE, 'r') as openfile:
|
||
|
output_dict = json.loads(openfile.read())
|
||
|
|
||
|
return output_dict
|
||
|
|
||
|
def get_active_vrids(self):
|
||
|
self.vrid_list = sx_get_active_vrids(self.handle)
|
||
|
|
||
|
def validate_ipv4_header(self, header):
|
||
|
for ip in ['sip', 'dip']:
|
||
|
if ip in header and is_ip_valid(header[ip], IP_VERSION_IPV4) == False:
|
||
|
raise ValueError("Json validation failed: invalid IP {}".format(header[ip]))
|
||
|
|
||
|
def validate_ipv6_header(self, header):
|
||
|
for ip in ['sip', 'dip']:
|
||
|
if ip in header and is_ip_valid(header[ip], IP_VERSION_IPV6) == False:
|
||
|
raise ValueError("Json validation failed: invalid IP {}".format(header[ip]))
|
||
|
|
||
|
def validate_layer2_header(self, header):
|
||
|
for mac in ['smac', 'dmac']:
|
||
|
if mac in header and is_mac_valid(header[mac]) == False:
|
||
|
raise ValueError("Json validation failed: invalid mac {}".format(header[mac]))
|
||
|
|
||
|
def validate_header(self, header, is_outer_header=False):
|
||
|
ipv4_header = False
|
||
|
ipv6_header = False
|
||
|
|
||
|
# Verify IPv4 and IPv6 headers do not co-exist in header
|
||
|
if 'ipv4' in header:
|
||
|
ipv4_header = True
|
||
|
if 'ipv6' in header:
|
||
|
ipv6_header = True
|
||
|
|
||
|
if ipv4_header and ipv6_header:
|
||
|
raise ValueError("Json validation failed: IPv4 and IPv6 headers can not co-exist")
|
||
|
|
||
|
if ipv4_header:
|
||
|
# Verify valid IPs in header
|
||
|
self.validate_ipv4_header(header['ipv4'])
|
||
|
|
||
|
if is_outer_header:
|
||
|
if 'dip' not in header['ipv4']:
|
||
|
raise ValueError("Json validation failed: destination IP is mandatory")
|
||
|
|
||
|
self.dst_ip = header['ipv4']['dip']
|
||
|
self.ip_version = IP_VERSION_IPV4
|
||
|
|
||
|
if 'tcp_udp' in header and 'proto' not in header['ipv4']:
|
||
|
raise ValueError("Json validation failed: transport protocol (proto) is mandatory when transport layer port exists")
|
||
|
|
||
|
elif ipv6_header:
|
||
|
self.validate_ipv6_header(header['ipv6'])
|
||
|
|
||
|
if is_outer_header:
|
||
|
if 'dip' not in header['ipv6']:
|
||
|
raise ValueError("Json validation failed: destination IP is mandatory")
|
||
|
|
||
|
self.dst_ip = header['ipv6']['dip']
|
||
|
self.ip_version = IP_VERSION_IPV6
|
||
|
|
||
|
if 'tcp_udp' in header and 'next_header' not in header['ipv6']:
|
||
|
raise ValueError("Json validation failed: transport protocol (next_header) is mandatory when transport layer port exists")
|
||
|
|
||
|
# Verify valid macs in header
|
||
|
if header['layer2']:
|
||
|
self.validate_layer2_header(header['layer2'])
|
||
|
|
||
|
def validate_outer_header(self):
|
||
|
outer_header = self.packet['packet_info'].get('outer')
|
||
|
if not outer_header:
|
||
|
raise ValueError("Json validation failed: outer header is mandatory")
|
||
|
|
||
|
self.validate_header(outer_header, is_outer_header=True)
|
||
|
|
||
|
def validate_inner_header(self):
|
||
|
inner_header = self.packet['packet_info'].get('inner')
|
||
|
if not inner_header:
|
||
|
return
|
||
|
|
||
|
self.validate_header(inner_header)
|
||
|
|
||
|
def validate_packet_json(self, packet_json):
|
||
|
# Verify json has valid format
|
||
|
self.packet = load_json(packet_json)
|
||
|
|
||
|
# Verify json schema
|
||
|
try:
|
||
|
jsonschema.validate(self.packet, PACKET_SCHEME)
|
||
|
except jsonschema.exceptions.ValidationError as e:
|
||
|
raise ValueError("Json validation failed: {}".format(e))
|
||
|
|
||
|
# Verify outer header
|
||
|
self.validate_outer_header()
|
||
|
|
||
|
# Verify inner header
|
||
|
self.validate_inner_header()
|
||
|
|
||
|
if self.debug:
|
||
|
print('Packet:')
|
||
|
pprint.pprint(self.packet)
|
||
|
|
||
|
def main():
|
||
|
rc = 0
|
||
|
try:
|
||
|
parser = argparse.ArgumentParser(description="ECMP calculator")
|
||
|
parser.add_argument("-i", "--interface", required=True, help="Ingress interface")
|
||
|
parser.add_argument("-p", "--packet", required=True, help="Packet description")
|
||
|
parser.add_argument("-v", "--vrf", help="VRF name")
|
||
|
parser.add_argument("-d", "--debug", default=False, action="store_true", help="Flag for debug")
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
ecmp_calc = EcmpCalc()
|
||
|
ecmp_calc.validate_args(args.interface, args.packet, args.vrf, args.debug)
|
||
|
ecmp_calc.get_ecmp_id()
|
||
|
ecmp_calc.get_next_hops()
|
||
|
ecmp_calc.calculate_egress_port()
|
||
|
ecmp_calc.print_egress_port()
|
||
|
|
||
|
except EcmpCalcExit as s:
|
||
|
print(s)
|
||
|
except ValueError as s:
|
||
|
print("Value error: {}".format(s))
|
||
|
rc = 1
|
||
|
except Exception as s:
|
||
|
print("Error: {}".format(s))
|
||
|
rc = 2
|
||
|
return rc
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
sys.exit(main())
|