Source code for nachos.similarity_functions.set_intersection

from typing import Union, Any, Iterable, Hashable
from nachos.similarity_functions.abstract_similarity import AbstractSimilarity
from nachos.similarity_functions import register
from nachos.utils import check_iterable


[docs]@register('set_intersection') class SetIntersection(AbstractSimilarity):
[docs] @classmethod def build(cls, conf: dict): return cls()
[docs] def __call__(self, f: Any, g: Any) -> float: ''' Summary: Computes the similarity between inputs f and g. f, g are assumed to be multi-valued objects, i.e., represent sets of values. We use the size of the intersection of the elements as the similiarity. Inputs ------------------------------- :param f: a value to compare :type f: Union[Any, Iterable]. I.e., a set or something which can be converted to a set :param g: a value to compare :type g: Union[Any, Iterable]. I.e., a set or something which can be converted to a set Returns ----------------------------------- :return: returns the similarity score :rtype: float ''' # Both objects are just values --> Use Boolean comparison if not check_iterable(f) and not check_iterable(g): return float(f == g) # One object is a value the other is a set --> convert value to # iterable, then set and then compute size of set intersection elif check_iterable(f) and not check_iterable(g): g = [g] elif not check_iterable(f) and check_iterable(g): f = [f] # At this point we know that f, g are both iterable. If their elements # are hashable, convert the iterable objects to sets. elif isinstance(next(iter(f)), Hashable) and not isinstance(next(iter(g)), Hashable): f = set(f) elif not isinstance(next(iter(f)), Hashable) and isinstance(next(iter(g)), Hashable): g = set(g) elif isinstance(next(iter(f)), Hashable) and isinstance(next(iter(g)), Hashable): f, g = set(f), set(g) # At this point, we know the elements of f, and g are not hashable. We # assume they are iterable, flatten them and convert them to sets. elif isinstance(next(iter(f)), Iterable) and isinstance(next(iter(g)), Iterable): f = set().union(*f) g = set().union(*g) return float(len(f.intersection(g)))