Source code for sporco.cdict

# -*- coding: utf-8 -*-
# Copyright (C) 2015-2019 by Brendt Wohlberg <brendt@ieee.org>
# All rights reserved. BSD 3-clause License.
# This file is part of the SPORCO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.

"""Constrained dictionary class."""

from builtins import str

import pprint


__author__ = """Brendt Wohlberg <brendt@ieee.org>"""



[docs] class UnknownKeyError(KeyError): """Exception for unrecognised dict key.""" def __init__(self, arg): super(UnknownKeyError, self).__init__(arg) def __repr__(self): if isinstance(self.args[0], list): s = ".".join(self.args[0]) else: s = str(self.args[0]) return 'Unknown dictionary key: ' + s def __str__(self): return repr(self)
[docs] class InvalidValueError(ValueError): """Exception for invalid dict value.""" def __init__(self, arg): super(InvalidValueError, self).__init__(arg) def __repr__(self): if isinstance(self.args[0], list): s = ".".join(self.args[0]) else: s = str(self.args[0]) return 'Invalid dictionary value for key: ' + s def __str__(self): return repr(self)
[docs] class ConstrainedDict(dict): """A dict subclass that constrains the allowed dict keys. Base class for a dict subclass that constrains the allowed dict keys, including those of nested dicts, and also initialises the dict with default content on instantiation. The default content is specified by the `defaults` class attribute, and the allowed keys are determined from the same attribute. """ defaults = {} """Default content and allowed dict keys""" def __init__(self, d=None, pth=(), dflt=None): """Initialise a ConstrainedDict object. The object is first created with default content, which is then overwritten with the content of parameter `d`. When a subdict is initialised via this constructor, the key path from the root to this subdict (i.e. the set of keys, in sequence, that select the subdict starting from the top-level dict) should be passed as a tuple via the `pth` parameter, and the defaults dict should be passed via the `dflt` parameter. Parameters ---------- d : dict Content to overwrite the defaults pth : tuple of str Key path for objects that are subdicts of other objects dflt: dict Reference to top level defaults dict for objects that are subdicts of other objects """ # Default arguments if d is None: d = {} # Initialise with empty dictionary and set path attribute (if # path length is zero then current object is a tree root). super(ConstrainedDict, self).__init__() self.pth = pth # If dflt parameter has None value then this is the top-level # dict in the tree and the dflt attribute should be set to the # class defaults attribute. Otherwise, the dflt attribute is # initialised with the dflt parameter. if dflt is None: self.dflt = self.__class__.defaults else: self.dflt = dflt # Initialise object with defaults with the corresponding node # (as determined by pth) in the defaults tree self.update(self.__class__.getnode(self.dflt, self.pth)) # Overwrite defaults with content of parameter d self.update(d)
[docs] def update(self, d): """Update the dict with the dict tree in parameter `d`. Parameters ---------- d : dict New dict content """ # Call __setitem__ for all keys in d for key in list(d.keys()): self.__setitem__(key, d[key])
def __setitem__(self, key, value): """Set value corresponding to key. If key is a tuple, interpret it as a sequence of keys in a tree of nested dicts. Parameters ---------- key : str or tuple of str Dict key value : any Dict value corresponding to key """ # If key is a tuple, interpret it as a sequence of keys in a # tree of nested dicts and retrieve parent node in tree kc = key sd = self if isinstance(key, tuple): kc = key[-1] sd = self.__class__.getparent(self, key) # If value is not a dict, or if it is dict but also a # ConstrainedDict (meaning that it has already been # constructed, possibly as a derived class), or if it is a # dict but there is no current entry in self for the # corresponding key, then the value is inserted via parent # class __setitem__. Otherwise the value is itself a dict that # must be processed recursively via the update method. if not isinstance(value, dict) or \ isinstance(value, ConstrainedDict) or kc not in sd: vc = value # If value is a dict but not a ConstrainedDict (if it is a # ConstrainedDict instance, it has already been # constructed, possibly as a derived class), call # constructor to instantiate a ConstrainedDict object # which becomes the value actually associated with the key if isinstance(value, dict) and \ not isinstance(value, ConstrainedDict): # ConstrainedDict constructor is called instead of the # constructor of the derived class because it is # undesirable to force the derived class constructor to # have the same interface. This implies that only the root # node will have derived class type, and all others will # be of type ConstrainedDict. Since it is required that # all nodes use the derived class defaults class # attribute, it is necessary to maintain an object dflts # attribute that is initialised from the defaults class # attribute and passed down the node tree on construction. vc = ConstrainedDict(vc, sd.pth + (kc,), self.dflt) # Check that the current key and value are valid with respect # to the defaults tree. Relevant exceptions are caught and # re-raised so that the stack trace originates from this # method. try: sd.check(kc, vc) except (UnknownKeyError, InvalidValueError) as e: raise e # Call base class __setitem__ to insert key, value pair super(ConstrainedDict, sd).__setitem__(kc, vc) else: # Call update to handle subtree update sd[kc].update(value) def __getitem__(self, key): """Get value corresponding to key. If key is a tuple, interpret it as a sequence of keys in a tree of nested dicts. Parameters ---------- key : str or tuple of str Dict key """ # If key is a tuple, interpret it as a sequence of keys in a # tree of nested dicts and retrieve parent node in tree kc = key sd = self if isinstance(key, tuple): kc = key[-1] sd = self.__class__.getparent(self, key) # Return value referenced by key, or by final key in key path # if key is a tuple if kc not in sd: raise UnknownKeyError(key) return super(ConstrainedDict, sd).__getitem__(kc) def __str__(self): """Return string representation of object.""" return pprint.pformat(self)
[docs] def check(self, key, value): """Check whether `key`, `value` pair is allowed. The key is allowed if there is a corresponding key in the defaults class attribute dict. The value is not allowed if it is a dict in the defaults dict and not a dict in value. Parameters ---------- key : str or tuple of str Dict key value : any Dict value corresponding to key """ # This test necessary to avoid unpickling errors in Python 3 if hasattr(self, 'dflt'): # Get corresponding node to self, as determined by pth # attribute, of the defaults dict tree a = self.__class__.getnode(self.dflt, self.pth) # Raise UnknownKeyError exception if key not in corresponding # node of defaults tree if key not in a: raise UnknownKeyError(self.pth + (key,)) # Raise InvalidValueError if the key value in the defaults # tree is a dict and the value parameter is not a dict elif isinstance(a[key], dict) and not isinstance(value, dict): raise InvalidValueError(self.pth + (key,))
[docs] @staticmethod def getparent(d, pth): """Get the parent node of a subdict as specified by the key path in `pth`. Parameters ---------- d : dict Dict tree in which access is required pth : str or tuple of str Dict key """ c = d for key in pth[:-1]: if not isinstance(c, dict): raise InvalidValueError(c) elif key not in c: raise UnknownKeyError(pth) else: c = c.__getitem__(key) return c
[docs] @staticmethod def getnode(d, pth): """Get the node of a subdict specified by the key path in `pth`. Parameters ---------- d : dict Dict tree in which access is required pth : str or tuple of str Dict key """ c = d for key in pth: if not isinstance(c, dict): raise InvalidValueError(c) elif key not in c: raise UnknownKeyError(pth) else: c = c.__getitem__(key) return c
[docs] def keycmp(a, b, pth=()): """Compare keys in nested dicts. Recurse down the tree of nested dicts `b`, at each level checking that it does not have any keys that are not also at the same level in `a`. The key path is recorded in `pth`. If an unknown key is encountered in `b`, an `UnknownKeyError` exception is raised. If a non-dict value is encountered in `b` for which the corresponding value in `a` is a dict, an `InvalidValueError` exception is raised. Parameters ---------- a : dict Reference dict tree a : dict Compared dict tree pth : str or tuple of str Dict key """ akey = list(a.keys()) # Iterate over all keys in b for key in list(b.keys()): # If a key is encountered that is not in a, raise an # UnknownKeyError exception. if key not in akey: raise UnknownKeyError(pth + (key,)) else: # If corresponding values in a and b for the same key # are both dicts, recursively call this method for # those values. If the value in a is a dict and the # value in b is not, raise an InvalidValueError # exception. if isinstance(a[key], dict): if isinstance(b[key], dict): keycmp(a[key], b[key], pth + (key,)) else: raise InvalidValueError(pth + (key,))