porting PR #8223, which uses one shot timer to reaload tacacs config (#9987)

Why I did it
There is a small window between load & listen to config-DB. If TACACS config got updated during that gap, the listen will not show it, hence hostcfgd would miss it, until another update.

How I did it
porting PR #8223, which uses one shot timer to reload tacacs config.
This commit is contained in:
Renuka Manavalan 2022-02-17 08:16:03 -08:00 committed by GitHub
parent 2ed7f537d4
commit 7910108fd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import threading
import sys import sys
import subprocess import subprocess
import syslog import syslog
@ -22,6 +23,20 @@ TACPLUS_SERVER_PASSKEY_DEFAULT = ""
TACPLUS_SERVER_TIMEOUT_DEFAULT = "5" TACPLUS_SERVER_TIMEOUT_DEFAULT = "5"
TACPLUS_SERVER_AUTH_TYPE_DEFAULT = "pap" TACPLUS_SERVER_AUTH_TYPE_DEFAULT = "pap"
global_lock = None
class lock_mgr:
def __init__(self):
self.lock = global_lock
def __enter__( self ):
if self.lock:
self.lock.acquire()
def __exit__( self, exc_type, exc_value, traceback ):
if self.lock:
self.lock.release()
def is_true(val): def is_true(val):
if val == 'True' or val == 'true': if val == 'True' or val == 'true':
@ -118,7 +133,7 @@ class Iptables(object):
.format(err.cmd, err.returncode, err.output)) .format(err.cmd, err.returncode, err.output))
class AaaCfg(object): class AaaCfg(object):
def __init__(self): def __init__(self, config_db):
self.auth_default = { self.auth_default = {
'login': 'local', 'login': 'local',
} }
@ -131,49 +146,42 @@ class AaaCfg(object):
self.tacplus_global = {} self.tacplus_global = {}
self.tacplus_servers = {} self.tacplus_servers = {}
self.debug = False self.debug = False
self.config_db = config_db
# Load conf from ConfigDb # Load conf from ConfigDb
def load(self, aaa_conf, tac_global_conf, tacplus_conf): def load(self):
for row in aaa_conf:
self.aaa_update(row, aaa_conf[row], modify_conf=False)
for row in tac_global_conf:
self.tacacs_global_update(row, tac_global_conf[row], modify_conf=False)
for row in tacplus_conf:
self.tacacs_server_update(row, tacplus_conf[row], modify_conf=False)
self.modify_conf_file() self.modify_conf_file()
def aaa_update(self, key, data, modify_conf=True): def aaa_update(self, key):
if key == 'authentication': if key == 'authentication':
self.auth = data
if 'failthrough' in data:
self.auth['failthrough'] = is_true(data['failthrough'])
if 'debug' in data:
self.debug = is_true(data['debug'])
if modify_conf:
self.modify_conf_file() self.modify_conf_file()
def tacacs_global_update(self, key, data, modify_conf=True): def tacacs_global_update(self, key):
if key == 'global': if key == 'global':
self.tacplus_global = data
if modify_conf:
self.modify_conf_file()
def tacacs_server_update(self, key, data, modify_conf=True):
if data == {}:
if key in self.tacplus_servers:
del self.tacplus_servers[key]
else:
self.tacplus_servers[key] = data
if modify_conf:
self.modify_conf_file() self.modify_conf_file()
def tacacs_server_update(self, key):
self.modify_conf_file()
def modify_single_file(self, filename, operations=None): def modify_single_file(self, filename, operations=None):
if operations: if operations:
cmd = "sed -e {0} {1} > {1}.new; mv -f {1} {1}.old; mv -f {1}.new {1}".format(' -e '.join(operations), filename) cmd = "sed -e {0} {1} > {1}.new; mv -f {1} {1}.old; mv -f {1}.new {1}".format(' -e '.join(operations), filename)
os.system(cmd) os.system(cmd)
def modify_conf_file(self): def modify_conf_file(self):
with lock_mgr():
self.auth = self.config_db.get_table('AAA').get("authentication", {})
if 'failthrough' in self.auth:
self.auth['failthrough'] = is_true(self.auth['failthrough'])
if 'debug' in self.auth:
self.debug = is_true(self.auth['debug'])
self.tacplus_global = self.config_db.get_table('TACPLUS').get(
"global", {})
self.tacplus_servers = self.config_db.get_table('TACPLUS_SERVER')
self._modify_conf_file()
def _modify_conf_file(self):
auth = self.auth_default.copy() auth = self.auth_default.copy()
auth.update(self.auth) auth.update(self.auth)
tacplus_global = self.tacplus_global_default.copy() tacplus_global = self.tacplus_global_default.copy()
@ -219,33 +227,50 @@ class AaaCfg(object):
with open(NSS_TACPLUS_CONF, 'w') as f: with open(NSS_TACPLUS_CONF, 'w') as f:
f.write(nss_tacplus_conf) f.write(nss_tacplus_conf)
if 'passkey' in tacplus_global:
tacplus_global['passkey'] = obfuscate(tacplus_global['passkey'])
syslog.syslog(syslog.LOG_INFO, 'pam.d files updated auth={} global={}'.
format(auth, tacplus_global))
class HostConfigDaemon: class HostConfigDaemon:
def __init__(self): def __init__(self):
self.config_db = ConfigDBConnector() self.config_db = ConfigDBConnector()
self.config_db.connect(wait_for_init=True, retry_on=True) self.config_db.connect(wait_for_init=True, retry_on=True)
syslog.syslog(syslog.LOG_INFO, 'ConfigDB connect success') syslog.syslog(syslog.LOG_INFO, 'ConfigDB connect success')
aaa = self.config_db.get_table('AAA')
tacacs_global = self.config_db.get_table('TACPLUS') self.aaacfg = AaaCfg(self.config_db)
tacacs_server = self.config_db.get_table('TACPLUS_SERVER')
self.aaacfg = AaaCfg()
self.aaacfg.load(aaa, tacacs_global, tacacs_server)
lpbk_table = self.config_db.get_table('LOOPBACK_INTERFACE')
self.iptables = Iptables() self.iptables = Iptables()
def timer_load(self):
global global_lock
syslog.syslog(syslog.LOG_INFO, 'reloading tacacs from timer thread')
self.aaacfg.load()
# Remove lock as timer is one shot
global_lock = None
def load(self):
self.aaacfg.load()
lpbk_table = self.config_db.get_table('LOOPBACK_INTERFACE')
self.iptables.load(lpbk_table) self.iptables.load(lpbk_table)
def aaa_handler(self, key, data): def aaa_handler(self, key, data):
self.aaacfg.aaa_update(key, data) self.aaacfg.aaa_update(key)
syslog.syslog(syslog.LOG_INFO, 'value of {} changed to {}'.format(key, data))
def tacacs_server_handler(self, key, data): def tacacs_server_handler(self, key, data):
self.aaacfg.tacacs_server_update(key, data) self.aaacfg.tacacs_server_update(key)
log_data = copy.deepcopy(data) log_data = copy.deepcopy(data)
if log_data.has_key('passkey'): if log_data.has_key('passkey'):
log_data['passkey'] = obfuscate(log_data['passkey']) log_data['passkey'] = obfuscate(log_data['passkey'])
syslog.syslog(syslog.LOG_INFO, 'value of {} changed to {}'.format(key, log_data)) syslog.syslog(syslog.LOG_INFO, 'value of {} changed to {}'.format(key, log_data))
def tacacs_global_handler(self, key, data): def tacacs_global_handler(self, key, data):
self.aaacfg.tacacs_global_update(key, data) self.aaacfg.tacacs_global_update(key)
log_data = copy.deepcopy(data) log_data = copy.deepcopy(data)
if log_data.has_key('passkey'): if log_data.has_key('passkey'):
log_data['passkey'] = obfuscate(log_data['passkey']) log_data['passkey'] = obfuscate(log_data['passkey'])
@ -263,10 +288,20 @@ class HostConfigDaemon:
self.iptables.iptables_handler(key, data, add) self.iptables.iptables_handler(key, data, add)
def start(self): def start(self):
global global_lock
self.config_db.subscribe('AAA', lambda table, key, data: self.aaa_handler(key, data)) self.config_db.subscribe('AAA', lambda table, key, data: self.aaa_handler(key, data))
self.config_db.subscribe('TACPLUS_SERVER', lambda table, key, data: self.tacacs_server_handler(key, data)) self.config_db.subscribe('TACPLUS_SERVER', lambda table, key, data: self.tacacs_server_handler(key, data))
self.config_db.subscribe('TACPLUS', lambda table, key, data: self.tacacs_global_handler(key, data)) self.config_db.subscribe('TACPLUS', lambda table, key, data: self.tacacs_global_handler(key, data))
self.config_db.subscribe('LOOPBACK_INTERFACE', lambda table, key, data: self.lpbk_handler(key, data)) self.config_db.subscribe('LOOPBACK_INTERFACE', lambda table, key, data: self.lpbk_handler(key, data))
# Defer load until subscribe
self.load()
global_lock = threading.Lock()
self.tmr_thread = threading.Timer(30, self.timer_load)
self.tmr_thread.start()
self.config_db.listen() self.config_db.listen()