Refactor get_num_dpus function. (#17601)

Refactor current get_num_dpus function to consuming platform.json which is reliable way to retrieve the number of DPUs.
This commit is contained in:
Xincun Li 2024-01-05 13:33:04 -08:00 committed by GitHub
parent acb2e94475
commit cdd164bd12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -847,23 +847,37 @@ def is_frontend_port_present_in_host():
def get_num_dpus():
# Todo: we should use platform api to get the dpu number
# instead of rely on the platform env config.
num_dpus = 0
platform_env_conf_file_path = get_platform_env_conf_file_path()
"""
Retrieves the number of DPUs from platform.json file.
# platform_env.conf file not present for platform
if platform_env_conf_file_path is None:
return num_dpus
Args:
# Else open the file check for keyword - num_dpu -
with open(platform_env_conf_file_path) as platform_env_conf_file:
for line in platform_env_conf_file:
tokens = line.split('=')
if len(tokens) < 2:
continue
if tokens[0].lower() == 'num_dpu':
num_dpus = tokens[1].strip()
break
return int(num_dpus)
Returns:
A integer to indicate the number of DPUs.
"""
platform = get_platform()
if not platform:
return 0
# Get Platform path.
platform_path = get_path_to_platform_dir()
if os.path.isfile(os.path.join(platform_path, PLATFORM_JSON_FILE)):
json_file = os.path.join(platform_path, PLATFORM_JSON_FILE)
try:
with open(json_file, 'r') as file:
platform_data = json.load(file)
except (json.JSONDecodeError, IOError, TypeError, ValueError):
# Handle any file reading and JSON parsing errors
return 0
# Convert to lower case avoid case sensitive.
data = {k.lower(): v for k, v in platform_data.items()}
DPUs = data.get('dpus', None)
if DPUs is not None and len(DPUs) > 0:
return len(DPUs)
return 0