Source code for compliance_checker.suite

"""
Compliance Checker suite runner
"""

import codecs
import inspect
import itertools
import os
import re
import subprocess
import sys
import textwrap
import warnings
from collections import defaultdict
from datetime import datetime, timezone
from operator import itemgetter
from pathlib import Path
from urllib.parse import urlparse

import importlib_metadata
import requests
from lxml import etree as ET
from netCDF4 import Dataset
from owslib.sos import SensorObservationService
from owslib.swe.sensor.sml import SensorML
from packaging.version import parse

from compliance_checker import __version__, tempnc
from compliance_checker.base import BaseCheck, GenericFile, Result, fix_return_value
from compliance_checker.protocols import cdl, netcdf, opendap

# Ensure output is encoded as Unicode when checker output is redirected or piped
if sys.stdout.encoding is None:
    sys.stdout = codecs.getwriter("utf8")(sys.stdout)
if sys.stderr.encoding is None:
    sys.stderr = codecs.getwriter("utf8")(sys.stderr)


[docs] def extract_docstring_summary(docstring): """ Returns a dedented docstring without parameter information :param docstring: A docstring :type docstring: str :returns: str """ # return a dedented, then indented two spaces docstring with leading and # trailing whitespace removed. return re.sub( r"^(?=.)", " ", textwrap.dedent( re.split(r"\n\s*:\w", docstring, flags=re.MULTILINE)[0], ).strip(), flags=re.MULTILINE, )
[docs] class CheckSuite: checkers = ( {} ) # Base dict of checker names to BaseCheck derived types, override this in your CheckSuite implementation templates_root = "compliance_checker" # modify to load alternative Jinja2 templates
[docs] def __init__(self, options=None): self.col_width = 40 self.options = options or {}
@classmethod def _get_generator_plugins(cls): """ Return a list of classes from external plugins that are used to generate checker classes """ if not hasattr(cls, "suite_generators"): gens = importlib_metadata.entry_points( group="compliance_checker.generators", ) cls.suite_generators = [x.load() for x in gens] return cls.suite_generators def _print_suites(self, verbose=0): """ Prints out available check suites. If the verbose argument is True, includes the internal module version number of the check and also displays "latest" meta-versions. :param check_suite: Check suite object :param verbose: Integer indicating whether to print verbose output :type verbose: int """ for checker in sorted(self.checkers.keys()): version = getattr(self.checkers[checker], "_cc_checker_version", "???") if verbose > 0: print(f" - {checker} (v{version})") elif ":" in checker and not checker.endswith( ":latest", ): # Skip the "latest" output print(f" - {checker}") def _print_checker(self, checker_obj): """ Prints each available check and a description with an abridged docstring for a given checker object :param checker_obj: Checker object on which to operate :type checker_obj: subclass of compliance_checker.base.BaseChecker """ check_functions = self._get_checks(checker_obj, {}, defaultdict(lambda: None)) for c, _ in check_functions: print(f"- {c.__name__}") if c.__doc__ is not None: u_doc = c.__doc__ print(f"\n{extract_docstring_summary(u_doc)}\n")
[docs] @classmethod def add_plugin_args(cls, parser): """ Add command line arguments for external plugins that generate checker classes """ for gen in cls._get_generator_plugins(): gen.add_arguments(parser)
[docs] @classmethod def load_generated_checkers(cls, args): """ Load checker classes from generator plugins """ for gen in cls._get_generator_plugins(): checkers = gen.get_checkers(args) cls.checkers.update(checkers)
[docs] @classmethod def load_all_available_checkers(cls): """ Helper method to retrieve all sub checker classes derived from various base classes. """ cls._load_checkers( importlib_metadata.entry_points(group="compliance_checker.suites"), )
@classmethod def _load_checkers(cls, checkers): """ Loads up checkers in an iterable into the class checkers dict :param checkers: An iterable containing the checker objects """ for c in checkers: try: check_obj = c.load() if hasattr(check_obj, "_cc_spec") and hasattr( check_obj, "_cc_spec_version", ): check_version_str = ":".join( (check_obj._cc_spec, check_obj._cc_spec_version), ) cls.checkers[check_version_str] = check_obj # TODO: remove this once all checkers move over to the new # _cc_spec, _cc_spec_version else: # if _cc_spec and _cc_spec_version attributes aren't # present, fall back to using name attribute checker_name = getattr(check_obj, "name", None) or getattr( check_obj, "_cc_spec", None, ) warnings.warn( "Checker for {} should implement both " '"_cc_spec" and "_cc_spec_version" ' 'attributes. "name" attribute is deprecated. ' "Assuming checker is latest version.", DeprecationWarning, stacklevel=2, ) # append "unknown" to version string since no versioning # info was provided cls.checkers[f"{checker_name}:unknown"] = check_obj except Exception as e: print("Could not load", c, ":", e, file=sys.stderr) # find the latest version of versioned checkers and set that as the # default checker for compliance checker if no version is specified ver_checkers = sorted([c.split(":", 1) for c in cls.checkers if ":" in c]) for spec, versions in itertools.groupby(ver_checkers, itemgetter(0)): version_nums = [v[-1] for v in versions] try: latest_version = str(max(parse(v) for v in version_nums)) # if the version can't be parsed, do it according to character collation except ValueError: latest_version = max(version_nums) cls.checkers[spec] = cls.checkers[spec + ":latest"] = cls.checkers[ ":".join((spec, latest_version)) ] def _get_checks(self, checkclass, include_checks, skip_checks): """ Helper method to retrieve check methods from a Checker class. Excludes any checks in `skip_checks`. The name of the methods in the Checker class should start with "check_" for this method to find them. :param checkclass BaseCheck: The checker class being considered :param skip_checks list: A list of strings with the names of the check methods to skip or include, depending on the value of `skip_flag`. :param skip_flag bool: A boolean parameter to determine whether to skip over checks specified (True) or only include the checks specified (False). """ meths = inspect.getmembers(checkclass, inspect.isroutine) # return all check methods not among the skipped checks returned_checks = [] if include_checks: for fn_name, fn_obj in meths: if fn_name in include_checks: returned_checks.append((fn_obj, skip_checks[fn_name])) else: for fn_name, fn_obj in meths: if ( fn_name.startswith("check_") and skip_checks[fn_name] != BaseCheck.HIGH ): returned_checks.append((fn_obj, skip_checks[fn_name])) return returned_checks def _run_check(self, check_method, ds, max_level): """ Runs a check and appends a result to the values list. @param bound method check_method: a given check method @param netCDF4 dataset ds @param int max_level: check level @return list: list of Result objects """ val = check_method(ds) if hasattr(val, "__iter__"): # Handle OrderedDict when we need to modify results in a superclass # i.e. some checks in CF 1.7 which extend CF 1.6 behaviors if isinstance(val, dict): val_iter = val.values() else: val_iter = val check_val = [] for v in val_iter: res = fix_return_value( v, check_method.__func__.__name__, check_method, check_method.__self__, ) if max_level is None or res.weight > max_level: check_val.append(res) return check_val else: check_val = fix_return_value( val, check_method.__func__.__name__, check_method, check_method.__self__, ) if max_level is None or check_val.weight > max_level: return [check_val] else: return [] def _get_check_versioned_name(self, check_name): """ The compliance checker allows the user to specify a check without a version number but we want the report to specify the version number. Returns the check name with the version number it checked """ if ":" not in check_name or ":latest" in check_name: check_name = ":".join( (check_name.split(":")[0], self.checkers[check_name]._cc_spec_version), ) return check_name def _get_check_url(self, check_name): """ Return the check's reference URL if it exists. If not, return empty str. @param check_name str: name of the check being run returned by _get_check_versioned_name() """ return getattr(self.checkers[check_name], "_cc_url", "") def _get_valid_checkers(self, ds, checker_names): """ Returns a filtered list of 2-tuples: (name, valid checker) based on the ds object's type and the user selected names. """ assert len(self.checkers) > 0, "No checkers could be found." if len(checker_names) == 0: checker_names = list(self.checkers.keys()) args = [ (name, self.checkers[name]) for name in checker_names if name in self.checkers ] valid = [] all_checked = {a[1] for a in args} # only class types checker_queue = set(args) while len(checker_queue): name, a = checker_queue.pop() # is the current dataset type in the supported filetypes # for the checker class? if type(ds) in a().supported_ds: valid.append((name, a)) # add subclasses of SOS checks if "ioos_sos" in name: for subc in a.__subclasses__(): if subc not in all_checked: all_checked.add(subc) checker_queue.add((name, subc)) return valid @classmethod def _process_skip_checks(cls, skip_checks): """ Processes an iterable of skip_checks with strings and returns a dict with <check_name>: <max_skip_level> pairs """ check_dict = defaultdict(lambda: None) # A is for "all", "M" is for medium, "L" is for low check_lookup = {"A": BaseCheck.HIGH, "M": BaseCheck.MEDIUM, "L": BaseCheck.LOW} for skip_check_spec in skip_checks: split_check_spec = skip_check_spec.split(":") check_name = split_check_spec[0] if len(split_check_spec) < 2: check_max_level = BaseCheck.HIGH else: try: check_max_level = check_lookup[split_check_spec[1]] except KeyError: warnings.warn( f"Skip specifier '{split_check_spec[1]}' on check '{check_name}' not found," " defaulting to skip entire check", stacklevel=2, ) check_max_level = BaseCheck.HIGH check_dict[check_name] = check_max_level return check_dict
[docs] def run(self, ds, skip_checks, *checker_names): warnings.warn( "suite.run is deprecated, use suite.run_all in calls instead", stacklevel=2, ) return self.run_all(ds, checker_names, skip_checks=skip_checks)
[docs] def run_all(self, ds, checker_names, include_checks=None, skip_checks=None): """ Runs this CheckSuite on the dataset with all the passed Checker instances. Returns a dictionary mapping checker names to a 2-tuple of their grouped scores and errors/exceptions while running checks. """ ret_val = {} checkers = self._get_valid_checkers(ds, checker_names) if skip_checks is not None: skip_check_dict = CheckSuite._process_skip_checks(skip_checks) else: skip_check_dict = defaultdict(lambda: None) if include_checks: include_dict = {check_name: 0 for check_name in include_checks} else: include_dict = {} if len(checkers) == 0: print( "No valid checkers found for tests '{}'".format( ",".join(checker_names), ), ) for checker_name, checker_class in checkers: # TODO: maybe this a little more reliable than depending on # a string to determine the type of the checker -- perhaps # use some kind of checker object with checker type and # version baked in checker_type_name = checker_name.split(":")[0] checker_opts = self.options.get(checker_type_name, set()) # instantiate a Checker object try: checker = checker_class(options=checker_opts) # hacky fix for no options in constructor except TypeError: checker = checker_class() # TODO? : Why is setup(ds) called at all instead of just moving the # checker setup into the constructor? # setup method to prep checker.setup(ds) checks = self._get_checks(checker, include_dict, skip_check_dict) vals = [] errs = {} # check method name -> (exc, traceback) for c, max_level in checks: try: vals.extend(self._run_check(c, ds, max_level)) except Exception as e: errs[c.__func__.__name__] = (e, sys.exc_info()[2]) # score the results we got back groups = self.scores(vals) # invoke finalizer explicitly del checker ret_val[checker_name] = groups, errs return ret_val
[docs] @classmethod def passtree(cls, groups, limit): for r in groups: if r.children: x = cls.passtree(r.children, limit) if r.weight >= limit and x is False: return False if r.weight >= limit and r.value[0] != r.value[1]: return False return True
[docs] def build_structure(self, check_name, groups, source_name, limit=1): """ Compiles the checks, results and scores into an aggregate structure which looks like: { "scored_points": 396, "low_count": 0, "possible_points": 400, "testname": "gliderdac", "medium_count": 2, "source_name": ".//rutgers/ru01-20140120T1444/ru01-20140120T1649.nc", "high_count": 0, "all_priorities" : [...], "high_priorities": [...], "medium_priorities" : [...], "low_priorities" : [...] } @param check_name The test which was run @param groups List of results from compliance checker @param source_name Source of the dataset, used for title """ aggregates = {} aggregates["scored_points"] = 0 aggregates["possible_points"] = 0 high_priorities = [] medium_priorities = [] low_priorities = [] all_priorities = [] aggregates["high_count"] = 0 aggregates["medium_count"] = 0 aggregates["low_count"] = 0 def named_function(result): for child in result.children: all_priorities.append(child) named_function(child) # For each result, bin them into the appropriate category, put them all # into the all_priorities category and add up the point values for res in groups: if res.weight < limit: continue # If the result has 0 possible points, then it was not valid for # this dataset and contains no meaningful information if res.value[1] == 0: continue aggregates["scored_points"] += res.value[0] aggregates["possible_points"] += res.value[1] if res.weight == 3: high_priorities.append(res) if res.value[0] < res.value[1]: aggregates["high_count"] += 1 elif res.weight == 2: medium_priorities.append(res) if res.value[0] < res.value[1]: aggregates["medium_count"] += 1 else: low_priorities.append(res) if res.value[0] < res.value[1]: aggregates["low_count"] += 1 all_priorities.append(res) # Some results have children # We don't render children inline with the top three tables, but we # do total the points and display the messages named_function(res) aggregates["high_priorities"] = high_priorities aggregates["medium_priorities"] = medium_priorities aggregates["low_priorities"] = low_priorities aggregates["all_priorities"] = all_priorities aggregates["testname"] = self._get_check_versioned_name(check_name) aggregates["source_name"] = source_name aggregates["scoreheader"] = self.checkers[check_name]._cc_display_headers aggregates["cc_spec_version"] = self.checkers[check_name]._cc_spec_version aggregates["cc_url"] = self._get_check_url(aggregates["testname"]) aggregates["report_timestamp"] = datetime.now(timezone.utc).strftime( "%Y-%m-%dT%H:%M:%SZ", ) aggregates["cc_version"] = __version__ return aggregates
[docs] def dict_output(self, check_name, groups, source_name, limit): """ Builds the results into a JSON structure and writes it to the file buffer. @param check_name The test which was run @param groups List of results from compliance checker @param output_filename Path to file to save output @param source_name Source of the dataset, used for title @param limit Integer value for limiting output """ aggregates = self.build_structure(check_name, groups, source_name, limit) return self.serialize(aggregates)
[docs] def serialize(self, o): """ Returns a safe serializable object that can be serialized into JSON. @param o Python object to serialize """ if isinstance(o, (list, tuple)): return [self.serialize(i) for i in o] if isinstance(o, dict): return {k: self.serialize(v) for k, v in o.items()} if isinstance(o, datetime): return o.isoformat() if isinstance(o, Result): return self.serialize(o.serialize()) return o
[docs] def checker_html_output(self, check_name, groups, source_name, limit): """ Renders the HTML output for a single test using Jinja2 and returns it as a string. @param check_name The test which was run @param groups List of results from compliance checker @param source_name Source of the dataset, used for title @param limit Integer value for limiting output """ from jinja2 import Environment, PackageLoader self.j2 = Environment( loader=PackageLoader(self.templates_root, "data/templates"), ) template = self.j2.get_template("ccheck.html.j2") template_vars = self.build_structure(check_name, groups, source_name, limit) return template.render(**template_vars)
[docs] def html_output(self, checkers_html): """ Renders the HTML output for multiple tests and returns it as a string. @param checkers_html List of HTML for single tests as returned by checker_html_output """ # Note: This relies on checker_html_output having been called so that # self.j2 is initialised template = self.j2.get_template("ccheck_wrapper.html.j2") return template.render(checkers=checkers_html)
[docs] def get_points(self, groups, limit): score_list = [] score_only_list = [] for g in groups: if g.weight >= limit: score_only_list.append(g.value) # checks where all pertinent sections passed all_passed = sum(x[0] == x[1] for x in score_only_list) out_of = len(score_only_list) # sorts lists into high/medium/low order score_list.sort(key=lambda x: x.weight, reverse=True) return score_list, all_passed, out_of
[docs] def standard_output(self, ds, limit, check_name, groups): """ Generates the Terminal Output for Standard cases Returns the dataset needed for the verbose output, as well as the failure flags. """ score_list, points, out_of = self.get_points(groups, limit) issue_count = out_of - points # Let's add the version number to the check name if it's missing check_name = self._get_check_versioned_name(check_name) check_url = self._get_check_url(check_name) width = 2 * self.col_width # NOTE: printing and use of .center() # Nested .format() calls should be avoided when possible. # As a future enhancement, a string.Template string might work best here # but for the time being individual lines are printed and centered with # .center() print("\n") print("-" * width) print("IOOS Compliance Checker Report".center(width)) print(f"Version {__version__}".center(width)) print( "Report generated {}".format( datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), ).center(width), ) print(f"{check_name}".center(width)) print(f"{check_url}".center(width)) print("-" * width) if issue_count > 0: print("Corrective Actions".center(width)) plural = "" if issue_count == 1 else "s" print( f"{os.path.basename(ds)} has {issue_count} potential issue{plural}", ) return [groups, points, out_of]
[docs] def standard_output_generation(self, groups, limit, points, out_of, check): """ Generates the Terminal Output """ if points < out_of: self.reasoning_routine(groups, check, priority_flag=limit) else: print("All tests passed!")
[docs] def reasoning_routine(self, groups, check, priority_flag=3, _top_level=True): """ print routine performed @param list groups: the Result groups @param str check: checker name @param int priority_flag: indicates the weight of the groups @param bool _top_level: indicates the level of the group so as to print out the appropriate header string """ def weight_sort(result): return result.weight groups_sorted = sorted(groups, key=weight_sort, reverse=True) # create dict of the groups -> {level: [reasons]} result = { key: [v for v in valuesiter if v.value[0] != v.value[1]] for key, valuesiter in itertools.groupby(groups_sorted, key=weight_sort) } priorities = self.checkers[check]._cc_display_headers def process_table(res, check): """Recursively calls reasoning_routine to parse out child reasons from the parent reasons. @param Result res: Result object @param str check: checker name""" issue = res.name if not res.children: reasons = res.msgs else: child_reasons = self.reasoning_routine( res.children, check, _top_level=False, ) # there shouldn't be messages if there are children # is this a valid assumption? reasons = child_reasons return issue, reasons # iterate in reverse to the min priority requested; # the higher the limit, the more lenient the output proc_strs = "" for level in range(3, priority_flag - 1, -1): level_name = priorities.get(level, level) # print headers proc_strs = [] # skip any levels that aren't in the result if level not in result: continue # skip any empty result levels if len(result[level]) > 0: # only print priority headers at top level, i.e. non-child # datasets if _top_level: width = 2 * self.col_width print("\n") print("{:^{width}}".format(level_name, width=width)) print("-" * width) data_issues = [process_table(res, check) for res in result[level]] has_printed = False for issue, reasons in data_issues: # if this isn't the first printed issue, add a newline # separating this and the previous level if has_printed: print("") # join alphabetized reasons together reason_str = "\n".join( f"* {r}" for r in sorted(reasons, key=lambda x: x[0]) ) proc_str = f"{issue}\n{reason_str}" print(proc_str) proc_strs.append(proc_str) has_printed = True return "\n".join(proc_strs)
[docs] def process_doc(self, doc): """ Attempt to parse an xml string conforming to either an SOS or SensorML dataset and return the results """ xml_doc = ET.fromstring(doc) if xml_doc.tag == "{http://www.opengis.net/sos/1.0}Capabilities": ds = SensorObservationService(None, xml=doc) # SensorObservationService does not store the etree doc root, # so maybe use monkey patching here for now? ds._root = xml_doc elif xml_doc.tag == "{http://www.opengis.net/sensorML/1.0.1}SensorML": ds = SensorML(xml_doc) else: raise ValueError(f"Unrecognized XML root element: {xml_doc.tag}") return ds
[docs] def generate_dataset(self, cdl_path): """ Use ncgen to generate a netCDF file from a .cdl file Returns the path to the generated netcdf file. If ncgen fails, uses sys.exit(1) to terminate program so a long stack trace is not reported to the user. :param str cdl_path: Absolute path to cdl file that is used to generate netCDF file """ if isinstance(cdl_path, str): cdl_path = Path(cdl_path) ds_str = cdl_path.with_suffix(".nc") # generate netCDF-4 file iostat = subprocess.run( ["ncgen", "-k", "nc4", "-o", ds_str, cdl_path], stderr=subprocess.PIPE, ) if iostat.returncode != 0: # if not successful, create netCDF classic file print( "netCDF-4 file could not be generated from cdl file with " + "message:", ) print(iostat.stderr.decode()) print("Trying to create netCDF Classic file instead.") iostat = subprocess.run( ["ncgen", "-k", "nc3", "-o", ds_str, cdl_path], stderr=subprocess.PIPE, ) if iostat.returncode != 0: # Exit program if neither a netCDF Classic nor a netCDF-4 file # could be created. print( "netCDF Classic file could not be generated from cdl file" + "with message:", ) print(iostat.stderr.decode()) sys.exit(1) return ds_str
[docs] def load_dataset(self, ds_str): """ Returns an instantiated instance of either a netCDF file or an SOS mapped DS object. :param str ds_str: URL of the resource to load """ if isinstance(ds_str, Path): ds_str = str(ds_str) # If it's a remote URL load it as a remote resource, otherwise treat it # as a local resource. pr = urlparse(ds_str) if pr.netloc: return self.load_remote_dataset(ds_str) else: return self.load_local_dataset(ds_str)
[docs] def check_remote_netcdf(self, ds_str): if netcdf.is_remote_netcdf(ds_str): response = requests.get(ds_str, allow_redirects=True, timeout=60) try: return Dataset( urlparse(response.url).path, memory=response.content, ) except OSError: # handle case when netCDF C libs weren't compiled with # in-memory support by using tempfile with tempnc(response.content) as _nc: return Dataset(_nc)
[docs] def load_remote_dataset(self, ds_str): """ Returns a dataset instance for the remote resource, either OPeNDAP or SOS :param str ds_str: URL to the remote resource """ url_parsed = urlparse(ds_str) # ERDDAP TableDAP request nc_remote_result = self.check_remote_netcdf(ds_str) if nc_remote_result: return nc_remote_result # if application/x-netcdf wasn't detected in the Content-Type headers # and this is some kind of erddap tabledap form, then try to get the # .ncCF file from ERDDAP elif "tabledap" in ds_str and not url_parsed.query: # modify ds_str to contain the full variable request variables_str = opendap.create_DAP_variable_str(ds_str) # join to create a URL to an .ncCF resource ds_str = f"{ds_str}.ncCF?{variables_str}" nc_remote_result = self.check_remote_netcdf(ds_str) if nc_remote_result: return nc_remote_result # if it's just an OPeNDAP endpoint, use that elif opendap.is_opendap(ds_str): return Dataset(ds_str) # Check if the HTTP response is XML, if it is, it's likely SOS so # we'll attempt to parse the response as SOS. # Some SOS servers don't seem to support HEAD requests. # Issue GET instead if we reach here and can't get the response response = requests.get(ds_str, allow_redirects=True, timeout=60) content_type = response.headers.get("content-type") if content_type.split(";")[0] == "text/xml": return self.process_doc(response.content) elif content_type.split(";")[0] == "application/x-netcdf": return Dataset( urlparse(response.url).path, memory=response.content, ) else: raise ValueError( f"Unknown service with content-type: {content_type}", )
[docs] def load_local_dataset(self, ds_str): """ Returns a dataset instance for the local resource :param ds_str: Path to the resource """ if cdl.is_cdl(ds_str): ds_str = self.generate_dataset(ds_str) if netcdf.is_netcdf(ds_str): return Dataset(ds_str) # Assume this is just a Generic File if it exists if os.path.isfile(ds_str): return GenericFile(ds_str) raise ValueError("File is an unknown format")
[docs] def scores(self, raw_scores): """ Transforms raw scores from a single checker into a fully tallied and grouped scoreline. """ grouped = self._group_raw(raw_scores) return grouped
def _group_raw(self, raw_scores, cur=None, level=1): """ Internal recursive method to group raw scores into a cascading score summary. Only top level items are tallied for scores. @param list raw_scores: list of raw scores (Result objects) """ def trim_groups(r): if isinstance(r.name, tuple) or isinstance(r.name, list): new_name = r.name[1:] else: new_name = [] return Result(r.weight, r.value, new_name, r.msgs) # CHECK FOR TERMINAL CONDITION: all raw_scores.name are single length # @TODO could have a problem here with scalar name, but probably still works terminal = [len(x.name) for x in raw_scores] if terminal == [0] * len(raw_scores): return [] def group_func(r): """ Takes a Result object and slices off the first element of its name if its's a tuple. Otherwise, does nothing to the name. Returns the Result's name and weight in a tuple to be used for sorting in that order in a groupby function. @param Result r @return tuple (str, int) """ if isinstance(r.name, tuple) or isinstance(r.name, list): if len(r.name) == 0: retval = "" else: retval = r.name[0:1][0] else: retval = r.name return retval, r.weight # END INTERNAL FUNCS ########################################## # NOTE until this point, *ALL* Results in raw_scores are # individual Result objects. # sort then group by name, then by priority weighting grouped = itertools.groupby(sorted(raw_scores, key=group_func), key=group_func) # NOTE: post-grouping, grouped looks something like # [(('Global Attributes', 1), <itertools._grouper at 0x7f10982b5390>), # (('Global Attributes', 3), <itertools._grouper at 0x7f10982b5438>), # (('Not a Global Attr', 1), <itertools._grouper at 0x7f10982b5470>)] # (('Some Variable', 2), <itertools._grouper at 0x7f10982b5400>), ret_val = [] for k, v in grouped: # iterate through the grouped tuples k = k[0] # slice ("name", weight_val) --> "name" v = list(v) # from itertools._grouper to list cv = self._group_raw(list(map(trim_groups, v)), k, level + 1) if len(cv): # if this node has children, max weight of children + sum of all the scores max_weight = max([x.weight for x in cv]) sum_scores = tuple(map(sum, list(zip(*([x.value for x in cv]))))) msgs = [] else: max_weight = max([x.weight for x in v]) sum_scores = tuple( map(sum, list(zip(*([self._translate_value(x.value) for x in v])))), ) msgs = sum([x.msgs for x in v], []) ret_val.append( Result( name=k, weight=max_weight, value=sum_scores, children=cv, msgs=msgs, ), ) return ret_val def _translate_value(self, val): """ Turns shorthand True/False/None checks into full scores (1, 1)/(0, 1)/(0, 0). Leaves full scores alone. """ if val is True: return (1, 1) elif val is False: return (0, 1) elif val is None: return (0, 0) return val