[Mellanox] Fix race condition while creating SFP (#17441)

- Why I did it
Fix issue xcvrd crashes due to cannot import name 'initialize_sfp_thermal':

Nov 27 09:47:16.388639 sonic ERR pmon#xcvrd: Exception occured at CmisManagerTask thread due to ImportError("cannot import name 'initialize_sfp_thermal' from partially initialized module 'sonic_platform.thermal' (most likely due to a circular import) (/usr/local/lib/python3.9/dist-packages/sonic_platform/thermal.py)")

- How I did it
Add lock for creating SFP object

- How to verify it
Unit test
Manual Test
This commit is contained in:
Junchao-Mellanox 2023-12-14 18:01:11 +08:00 committed by GitHub
parent da3e7cbbba
commit f373a16e95
No account linked to committer's email address
2 changed files with 66 additions and 28 deletions

View File

@ -124,6 +124,7 @@ class Chassis(ChassisBase):
self.reboot_cause_initialized = False self.reboot_cause_initialized = False
self.sfp_module = None self.sfp_module = None
self.sfp_lock = threading.Lock()
# Build the RJ45 port list from platform.json and hwsku.json # Build the RJ45 port list from platform.json and hwsku.json
self._RJ45_port_inited = False self._RJ45_port_inited = False
@ -277,38 +278,49 @@ class Chassis(ChassisBase):
def initialize_single_sfp(self, index): def initialize_single_sfp(self, index):
sfp_count = self.get_num_sfps() sfp_count = self.get_num_sfps()
# Use double checked locking mechanism for:
# 1. protect shared resource self._sfp_list
# 2. performance (avoid locking every time)
if index < sfp_count: if index < sfp_count:
if not self._sfp_list: if not self._sfp_list or not self._sfp_list[index]:
self._sfp_list = [None] * sfp_count with self.sfp_lock:
if not self._sfp_list:
self._sfp_list = [None] * sfp_count
if not self._sfp_list[index]: if not self._sfp_list[index]:
sfp_module = self._import_sfp_module() sfp_module = self._import_sfp_module()
if self.RJ45_port_list and index in self.RJ45_port_list: if self.RJ45_port_list and index in self.RJ45_port_list:
self._sfp_list[index] = sfp_module.RJ45Port(index) self._sfp_list[index] = sfp_module.RJ45Port(index)
else: else:
self._sfp_list[index] = sfp_module.SFP(index) self._sfp_list[index] = sfp_module.SFP(index)
self.sfp_initialized_count += 1 self.sfp_initialized_count += 1
def initialize_sfp(self): def initialize_sfp(self):
if not self._sfp_list: sfp_count = self.get_num_sfps()
sfp_module = self._import_sfp_module() # Use double checked locking mechanism for:
sfp_count = self.get_num_sfps() # 1. protect shared resource self._sfp_list
for index in range(sfp_count): # 2. performance (avoid locking every time)
if self.RJ45_port_list and index in self.RJ45_port_list: if sfp_count != self.sfp_initialized_count:
sfp_object = sfp_module.RJ45Port(index) with self.sfp_lock:
else: if sfp_count != self.sfp_initialized_count:
sfp_object = sfp_module.SFP(index) if not self._sfp_list:
self._sfp_list.append(sfp_object) sfp_module = self._import_sfp_module()
self.sfp_initialized_count = sfp_count for index in range(sfp_count):
elif self.sfp_initialized_count != len(self._sfp_list): if self.RJ45_port_list and index in self.RJ45_port_list:
sfp_module = self._import_sfp_module() sfp_object = sfp_module.RJ45Port(index)
for index in range(len(self._sfp_list)): else:
if self._sfp_list[index] is None: sfp_object = sfp_module.SFP(index)
if self.RJ45_port_list and index in self.RJ45_port_list: self._sfp_list.append(sfp_object)
self._sfp_list[index] = sfp_module.RJ45Port(index) self.sfp_initialized_count = sfp_count
else: elif self.sfp_initialized_count != len(self._sfp_list):
self._sfp_list[index] = sfp_module.SFP(index) sfp_module = self._import_sfp_module()
self.sfp_initialized_count = len(self._sfp_list) for index in range(len(self._sfp_list)):
if self._sfp_list[index] is None:
if self.RJ45_port_list and index in self.RJ45_port_list:
self._sfp_list[index] = sfp_module.RJ45Port(index)
else:
self._sfp_list[index] = sfp_module.SFP(index)
self.sfp_initialized_count = len(self._sfp_list)
def get_num_sfps(self): def get_num_sfps(self):
""" """

View File

@ -16,8 +16,10 @@
# #
import os import os
import random
import sys import sys
import subprocess import subprocess
import threading
from mock import MagicMock from mock import MagicMock
if sys.version_info.major == 3: if sys.version_info.major == 3:
@ -167,6 +169,30 @@ class TestChassis:
assert len(sfp_list) == 3 assert len(sfp_list) == 3
assert chassis.sfp_initialized_count == 3 assert chassis.sfp_initialized_count == 3
def test_create_sfp_in_multi_thread(self):
DeviceDataManager.get_sfp_count = mock.MagicMock(return_value=3)
iteration_num = 100
while iteration_num > 0:
chassis = Chassis()
assert chassis.sfp_initialized_count == 0
t1 = threading.Thread(target=lambda: chassis.get_sfp(1))
t2 = threading.Thread(target=lambda: chassis.get_sfp(1))
t3 = threading.Thread(target=lambda: chassis.get_all_sfps())
t4 = threading.Thread(target=lambda: chassis.get_all_sfps())
threads = [t1, t2, t3, t4]
random.shuffle(threads)
for t in threads:
t.start()
for t in threads:
t.join()
assert len(chassis.get_all_sfps()) == 3
assert chassis.sfp_initialized_count == 3
for index, s in enumerate(chassis.get_all_sfps()):
assert s.sdk_index == index
iteration_num -= 1
@mock.patch('sonic_platform.device_data.DeviceDataManager.get_sfp_count', MagicMock(return_value=3)) @mock.patch('sonic_platform.device_data.DeviceDataManager.get_sfp_count', MagicMock(return_value=3))
def test_change_event(self): def test_change_event(self):
chassis = Chassis() chassis = Chassis()