diff --git a/pyroute2/ndb/schema.py b/pyroute2/ndb/schema.py index 42a4514cc..4a6692af2 100644 --- a/pyroute2/ndb/schema.py +++ b/pyroute2/ndb/schema.py @@ -129,6 +129,7 @@ from pyroute2 import config from pyroute2.common import basestring, uuid32 +from pyroute2.netlink import NLM_F_REPLACE # from .objects import address, interface, neighbour, netns, probe, route, rule @@ -906,6 +907,36 @@ def load_netlink(self, table, target, event, ctable=None, propagate=False): compiled = self.compiled[table] # a map of sub-NLAs nodes = {} + # replace + r_conditions = [] + r_values = [] + + # Check route replace + if ( + event['header'].get('flags', 0) == NLM_F_REPLACE + and event['event'] == 'RTM_NEWROUTE' + ): + # Replace existing route + r_conditions = [table + '.f_target = %s' % self.plch] + r_values = [target] + for key in self.indices[table]: + if key not in [ + 'RTA_DST', + 'dst_len', + 'table', + 'RTA_PRIORITY', + ]: + continue + + r_conditions.append( + table + '.f_%s = %s' % (key, self.plch) + ) + value = event.get(key) or event.get_attr(key) + if value is None: + value = self.key_defaults[table][key] + if isinstance(value, (dict, list, tuple, set)): + value = json.dumps(value) + r_values.append(value) # fetch values (exc. the first two columns) for fname, ftype in self.spec[table].items(): @@ -942,6 +973,12 @@ def load_netlink(self, table, target, event, ctable=None, propagate=False): values.append(value) try: + w_fidx = compiled['fidx'] + w_ivalues = ivalues + if r_conditions: + w_fidx = ' AND '.join(r_conditions) + w_ivalues = r_values + if self.provider == DBProvider.psycopg2: # # run UPSERT -- the DB provider must support it @@ -957,9 +994,9 @@ def load_netlink(self, table, target, event, ctable=None, propagate=False): compiled['plchs'], compiled['knames'], compiled['fset'], - compiled['fidx'], + w_fidx, ), - (values + values + ivalues), + (values + values + w_ivalues), ) ) # @@ -976,8 +1013,8 @@ def load_netlink(self, table, target, event, ctable=None, propagate=False): ''' SELECT count(*) FROM %s WHERE %s ''' - % (table, compiled['fidx']), - ivalues, + % (table, w_fidx), + w_ivalues, ).fetchone() )[0] if count == 0: @@ -993,8 +1030,8 @@ def load_netlink(self, table, target, event, ctable=None, propagate=False): ''' UPDATE %s SET %s WHERE %s ''' - % (table, compiled['fset'], compiled['fidx']), - (values + ivalues), + % (table, compiled['fset'], w_fidx), + (values + w_ivalues), ) else: raise NotImplementedError()