import inspect
import logging
from importlib import import_module
from benchmarkstt.docblock import format_docs
from collections import namedtuple
from typing import Dict
from benchmarkstt.registry import Registry
from benchmarkstt.modules import load_object
logger = logging.getLogger(__name__)
[docs]class ClassConfig(namedtuple('ClassConfigTuple', ['name', 'cls', 'docs', 'optional_args', 'required_args'])):
@property
def docs(self):
if self.cls.__doc__ is None:
docs = ''
logger.warning("No docstring for '%s'", self.name)
else:
docs = self.cls.__doc__
return format_docs(docs)
[docs]class Factory(Registry):
"""
Factory class with auto-loading of namespaces according to a base class.
"""
def __init__(self, base_class, namespaces=None, methods=None):
super().__init__()
self.base_class = base_class
self.methods = methods
self.namespaces = [base_class.__module__] if namespaces is None else namespaces
for namespace in self.namespaces:
self.register_namespace(namespace)
def __contains__(self, item):
return super().__contains__(self.normalize_class_name(item))
def __getitem__(self, item):
"""
Loads the proper class based on a name
:param item: Case-insensitive name of the class
:return: The class
:rtype: class
"""
name = self.normalize_class_name(item)
if name not in self._registry:
raise ImportError("Could not find class '%s', available: %s" % (name, ', '.join(self._registry.keys())))
return super().__getitem__(name)
def __delitem__(self, key):
if type(key) is not str:
key = self.normalize_class_name(key.__name__)
super().__delitem__(key)
[docs] def create(self, alias, *args, **kwargs):
return self[alias](*args, **kwargs)
[docs] @staticmethod
def normalize_class_name(clsname):
"""
Normalizes the class name for automatic lookup of a class, by default
this means lowercasing the class name, but may be overrided by a child
class.
:param clsname: The class name
:return: The normalized class name
:rtype: str
"""
return clsname.lower()
[docs] def is_valid(self, tocheck):
"""
Checks that tocheck is a valid class extending base_class
:param tocheck: The class to check
:rtype: bool
"""
if tocheck is self.base_class:
return False
if not inspect.isclass(tocheck):
return False
if inspect.isabstract(tocheck):
return False
if issubclass(tocheck, self.base_class):
return True
# if it contains all required methods, accept as duck
if self.methods:
return all(map(callable, (getattr(tocheck, method, None)
for method in self.methods)))
return False
[docs] def register_namespace(self, namespace):
"""
Registers all valid classes from a given namespace
:param str|module namespace:
"""
if namespace is None:
module = globals()
else:
optional = namespace[0] == '?'
if optional:
namespace = namespace[1:]
module = '.'.join(filter(len, namespace.split('.')))
try:
module = import_module(module)
except ImportError as e:
if optional:
logger.info(
"Could not load optional namespace %s for %s: %s",
namespace,
self.base_class.__name__,
e)
else:
raise e
for clsname in dir(module):
cls = getattr(module, clsname)
if not self.is_valid(cls):
continue
self.register(cls, clsname)
[docs] def register_classname(self, name, alias=None):
if alias is None:
alias = name
self.register(load_object(name), alias)
[docs] def register(self, cls, alias=None):
"""
Register an alias for a class
:param self.base_class cls:
:param str|None alias: The alias to use when trying to get the class back,
by default will use normalized class name.
:return: None
"""
if not self.is_valid(cls):
raise ValueError('Invalid class, not recognized as a %s' % (self.base_class.__name__,))
if alias is None:
alias = cls.__name__
alias = self.normalize_class_name(alias)
super().register(alias, cls)
def __iter__(self):
"""
Get available classes with a proper ClassConfig
:return: A dictionary of registered classes
:rtype: Dict[ClassConfig]
"""
for clsname, cls in self._registry.items():
argspec = inspect.getfullargspec(cls.__init__)
args = list(argspec.args)[1:]
defaults = []
if argspec.defaults:
defaults = list(argspec.defaults)
defaults_idx = len(args) - len(defaults)
required_args = args[0:defaults_idx]
optional_args = args[defaults_idx:]
yield ClassConfig(name=clsname,
cls=cls,
docs=None,
optional_args=optional_args,
required_args=required_args)
[docs]class CoreFactory:
_extra_namespaces = []
def __init__(self, base_class, allow_duck=None):
if allow_duck is None:
allow_duck = True
self._base_class = base_class
self._base_class_abstract_methods = self._abstract_methods(base_class) if allow_duck else None
self._instance = None
def _factory(self):
# defers registration until first usage
if self._instance is None:
self._instance = Factory(
self._base_class,
[self._base_class.__module__ + '.core'] + self._extra_namespaces,
self._base_class_abstract_methods
)
return self._instance
def __iter__(self):
return self._factory().__iter__()
def __getitem__(self, item):
return self._factory().__getitem__(item)
def __delitem__(self, item):
return self._factory().__delitem__(item)
def __contains__(self, item):
return self._factory().__contains__(item)
[docs] def keys(self):
return self._factory().keys()
[docs] def create(self, *args, **kwargs):
return self._factory().create(*args, **kwargs)
[docs] def is_valid(self, *args, **kwargs):
return self._factory().is_valid(*args, **kwargs)
[docs] def register(self, *args, **kwargs):
return self._factory().register(*args, **kwargs)
[docs] @classmethod
def add_supported_namespace(cls, namespace):
cls._extra_namespaces.append(namespace)
@staticmethod
def _abstract_methods(base_class):
return list(base_class.__abstractmethods__)