from typing import Union, Generator
from nachos.constraints.abstract_constraint import AbstractConstraint
from nachos.constraints import register
from itertools import chain
import numpy as np
[docs]@register('kl')
class KL(AbstractConstraint):
r'''
Summary:
Defines the constraint on the categorical distribution over values
between two datasets. The cost of mismatch is computed as the
kl-divergence between two sets. In general, the smaller
set is the test set and we would like it to have specific
characteristics w/r to the large (training) set. The forward kl,
i.e.,
The forward KL, i.e.,
.. math::
kl\left(p \vert\vert q_\theta\right)
is mean seeking
cost = KL(d1_train || d2_test)
This will encourge selecting data with good coverage of the dataset,
including data points that may have been seen only occasionally in
the training data. See ReverseKL, Jeffrys for more information.
Reverse KL is
.. math::
kl\left(q_\theta \vert\vert p\right)
cost = KL(d2_test || d1_train)
This encourages mode seeking behavior.
The Jeffry's divergence symmetrizes the KL divergence as
.. math::
\frac{1}{2}\left[KL\left(p \vert\vert q_\theta\right) + KL\left(q_\theta \vert\vert p\right)\right]
'''
[docs] @classmethod
def build(cls, conf: dict):
return cls(smooth=conf['kl_smooth'], direction=conf['kl_direction'])
[docs] def __init__(self, smooth: float = 0.000001, direction: str = 'forward'):
super().__init__()
self.smooth = smooth
self.direction = direction
self.vocab = None
[docs] def __call__(self,
c1: Union[list, Generator],
c2: Union[list, Generator],
) -> float:
'''
Summary:
Computes the KL divergence between the empircal distributions
defined by values in c1 and values in c2.
Inputs
---------------------------
:param c1: the values to constrain seen in dataset 1
:type c1: Union[list, Generator]
:param c2: the values to constrain seen in dataset 2
:type c2: Union[list, Generator]
Returns
---------------------------
:return: how closely (0 is best) the sets c1, c2 satisfy the constraint
:rtype: float
'''
# Get vocab. In general, c1 and c2 should be lists of sets
if self.vocab is None:
vocab = set()
for item in chain(c1, c2):
try:
for i in item:
vocab.add(i)
except TypeError:
vocab.add(item)
self.vocab = vocab
# Get counts (i.e., distributions) for set1 and set2
c1_counts = {v: self.smooth for v in vocab}
c2_counts = {v: self.smooth for v in vocab}
c1_total = self.smooth * len(vocab)
c2_total = self.smooth * len(vocab)
for item in c1:
try:
for i in item:
# if i wasn't seen in the vocab,
# add it with count of self.smooth to both c1 and c2 counts
if i not in self.vocab:
self.ocab.add(i)
c1_total += self.smooth
c2_total += self.smooth
c1_counts[i] = self.smooth
c2_counts[i] = self.smooth
c1_counts[i] += 1
except TypeError:
if item not in self.vocab:
self.vocab.add(item)
c1_total += self.smooth
c2_total += self.smooth
c1_counts[item] = self.smooth
c2_counts[item] = self.smooth
c1_counts[item] += 1
for item in c2:
try:
for i in item:
if i not in self.vocab:
self.vocab.add(i)
c2_total += self.smooth
c1_total += self.smooth
c2_counts[i] = self.smooth
c1_counts[i] = self.smooth
c2_counts[i] += 1
except TypeError:
if item not in self.vocab:
self.vocab.add(item)
c2_total += self.smooth
c1_total += self.smooth
c2_counts[item] = self.smooth
c1_counts[item] = self.smooth
c2_counts[item] += 1
# Normalize each count by the total count to get a distribution
c1_dist = np.array(
[v for k, v in sorted(c1_counts.items(), key=lambda x: x[0])]
) / c1_total
c2_dist = np.array(
[v for k, v in sorted(c2_counts.items(), key=lambda x: x[0])]
) / c2_total
# Return the appropriate direction kl
if self.direction == "forward":
return np.dot(c1_dist, np.log(c1_dist) - np.log(c2_dist))
if self.direction == "reverse":
return np.dot(c2_dist, np.log(c2_dist) - np.log(c1_dist))
if self.direction == "symmetric":
return 0.5 * (
np.dot(c1_dist, np.log(c1_dist) - np.log(c2_dist)) +
np.dot(c2_dist, np.log(c2_dist) - np.log(c2_dist))
)