Source code for nachos.constraints.Constraints

from typing import Optional, List, Any, Dict, Union
from import Dataset, Split, collapse_factored_split
from nachos.constraints.abstract_constraint import AbstractConstraint

WORST_SCORE = 99999.

[docs]class Constraints(object):
[docs] @classmethod def build(cls, conf: dict): fns = [ for fn in conf['constraints']] weights = conf['constraint_weights'] return cls(fns, weights)
[docs] def __init__(self, fns: List[AbstractConstraint], weights: List[float]): self.fns = fns self.weights = weights
[docs] def __call__(self, u: Dataset, s: Split, n: Optional[int] = None ) -> float: ''' Summary: This function computes the discompatibility score according to predined constraints (self.fns) of a split s. Inputs ---------------------------- :param u: A dataset :type u: Dataset :param s: A proposed split :type v: Split :param n: The index of the constraint with respect to which to compute the discompatibility score. None means compute the weighted sum of all constraints :type n: Optional[int] Returns -------------------------- :return: The discompatibility score :rtype: float ''' if len(s[0]) == 0 or len(s[1]) == 0: return WORST_SCORE if n is None: constraints_zipped = zip( self.weights, self.fns, [u.get_constraints(subset=s[0], n=m) for m in range(len(u.constraint_idxs))], [u.get_constraints(subset=s[1], n=m) for m in range(len(u.constraint_idxs))], ) return sum(w*fn(f, g) for w, fn, f, g in constraints_zipped) subset_constraints = u.get_constraints(subset=s[0], n=n) # not_subset here means the complement of subset. i.e., the points # in u that are not in subset (s[0]) not_subset_constraints = u.get_constraints(subset=s[1], n=n) return self.fns[n](subset_constraints, not_subset_constraints)
[docs] def stats(self, u: Dataset, s: set, ) -> dict: ''' Compute the "stats" associated with each constraint on the split. :param u: The Dataset from which a subset is drawn :type u: Dataset :param s: The proposed subset of the dataset :type s: set :return: dictionary of the scores for the set s according to the constraints specified in this class :rtype: dict ''' constraint_stats = {} for n, fn in enumerate(self.fns): constraints = u.get_constraints(subset=s, n=n) constraint_name = u.field_names[u.constraint_idxs[n]] constraint_stats[constraint_name] = round(fn.stat(constraints), 4) # It's nice to have to the length of the sets in general constraint_stats['length'] = len(s) return constraint_stats