diff --git a/files/build_templates/mux.service.j2 b/files/build_templates/mux.service.j2 index 5aff5a24c0..eb9ea1d333 100644 --- a/files/build_templates/mux.service.j2 +++ b/files/build_templates/mux.service.j2 @@ -17,5 +17,15 @@ ExecStopPost=/usr/local/bin/write_standby.py Restart=always RestartSec=30 +[Service] +User={{ sonicadmin_user }} +ExecStartPre=/usr/local/bin/mark_dhcp_packet.py +ExecStartPre=/usr/bin/{{docker_container_name}}.sh start +ExecStart=/usr/bin/{{docker_container_name}}.sh wait +ExecStop=/usr/bin/{{docker_container_name}}.sh stop +ExecStopPost=/usr/local/bin/mark_dhcp_packet.py +Restart=always +RestartSec=30 + [Install] WantedBy=sonic.target diff --git a/files/build_templates/sonic_debian_extension.j2 b/files/build_templates/sonic_debian_extension.j2 index b00119b54f..9938b3e815 100644 --- a/files/build_templates/sonic_debian_extension.j2 +++ b/files/build_templates/sonic_debian_extension.j2 @@ -719,6 +719,9 @@ sudo LANG=C cp $SCRIPTS_DIR/sonic-netns-exec $FILESYSTEM_ROOT/usr/bin/sonic-netn # Copy write_standby script for mux state sudo LANG=C cp $SCRIPTS_DIR/write_standby.py $FILESYSTEM_ROOT/usr/local/bin/write_standby.py +# Copy mark_dhcp_packet script +sudo LANG=C cp $SCRIPTS_DIR/mark_dhcp_packet.py $FILESYSTEM_ROOT/usr/local/bin/mark_dhcp_packet.py + # Copy systemd timer configuration # It implements delayed start of services sudo cp $BUILD_TEMPLATES/snmp.timer $FILESYSTEM_ROOT_USR_LIB_SYSTEMD_SYSTEM diff --git a/files/scripts/mark_dhcp_packet.py b/files/scripts/mark_dhcp_packet.py new file mode 100644 index 0000000000..1ad2fd6509 --- /dev/null +++ b/files/scripts/mark_dhcp_packet.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 + +import os +import subprocess +import sys +import time + +from sonic_py_common import logger +from swsscommon import swsscommon + +log = logger.Logger('mark_dhcp_packet') + +class MarkDhcpPacket(object): + """ + Class used to configure dhcp packet mark in ebtables + """ + + def __init__(self): + self.config_db_connector = None + self.state_db_connector = None + + @property + def config_db(self): + """ + Returns config DB connector. + Initializes the connector during the first call + """ + if self.config_db_connector is None: + self.config_db_connector = swsscommon.ConfigDBConnector() + self.config_db_connector.connect() + + return self.config_db_connector + + @property + def state_db(self): + """ + Returns the state DB connector. + Initializes the connector during the first call + """ + if self.state_db_connector is None: + self.state_db_connector = swsscommon.SonicV2Connector(host='127.0.0.1') + self.state_db_connector.connect(self.state_db_connector.STATE_DB) + + return self.state_db_connector + + @property + def is_dualtor(self): + """ + Checks if script is running on a dual ToR system + """ + localhost_key = self.config_db.get_keys('DEVICE_METADATA')[0] + metadata = self.config_db.get_entry('DEVICE_METADATA', localhost_key) + + return 'subtype' in metadata and 'dualtor' in metadata['subtype'].lower() + + def get_mux_intfs(self): + """ + Returns a list of mux cable interfaces + """ + mux_cables = self.config_db.get_table('MUX_CABLE') + mux_intfs = [intf for intf in mux_cables] + + return mux_intfs + + def generate_mark_from_index(self, index): + ''' + type: string, format: hexadecimal + Example: 0x67001, 0x67002, ... + ''' + intf_mark = "0x67" + str(index).zfill(3) + + return intf_mark + + def run_command(self, cmd): + subprocess.call(cmd, shell=True) + log.log_info("run command: {}".format(cmd)) + + def clear_dhcp_packet_marks(self): + ''' + Flush the INPUT chain in ebtables upon restart + ''' + self.run_command("sudo ebtables -F INPUT") + + def apply_mark_in_ebtables(self, intf, mark): + self.run_command("sudo ebtables -A INPUT -i {} -j mark --mark-set {}".format(intf, mark)) + + def update_mark_in_state_db(self, intf, mark): + self.state_db.set(self.state_db.STATE_DB, 'DHCP_PACKET_MARK', intf, mark) + + def apply_marks(self): + """ + Writes dhcp packet marks in ebtables + """ + if not self.is_dualtor: + return + + self.clear_dhcp_packet_marks() + + for (index, intf) in enumerate(self.get_mux_intfs(), 1): + mark = self.generate_mark_from_index(index) + self.apply_mark_in_ebtables(intf, mark) + self.update_mark_in_state_db(intf, mark) + + log.log_info("Finish marking dhcp packets in ebtables.") + +if __name__ == '__main__': + mark_dhcp_packet = MarkDhcpPacket() + mark_dhcp_packet.apply_marks()