# Copyright (C) 2017-2023 Cleanlab Inc.
# This file is part of cleanlab.
#
# cleanlab is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cleanlab is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with cleanlab. If not, see <https://www.gnu.org/licenses/>.
"""
Helper methods used internally in cleanlab.token_classification
"""
from __future__ import annotations
import re
import string
import numpy as np
from termcolor import colored
from typing import List, Optional, Callable, Tuple, TypeVar, TYPE_CHECKING
if TYPE_CHECKING: # pragma: no cover
import numpy.typing as npt
T = TypeVar("T", bound=npt.NBitBase)
[docs]def get_sentence(words: List[str]) -> str:
"""
Get sentence formed by a list of words with minor processing for readability
Parameters
----------
words:
list of word-level tokens
Returns
----------
sentence:
sentence formed by list of word-level tokens
Examples
--------
>>> from cleanlab.internal.token_classification_utils import get_sentence
>>> words = ["This", "is", "a", "sentence", "."]
>>> get_sentence(words)
'This is a sentence.'
"""
sentence = ""
for word in words:
if word not in string.punctuation or word in ["-", "("]:
word = " " + word
sentence += word
sentence = sentence.replace(" '", "'").replace("( ", "(").strip()
return sentence
[docs]def filter_sentence(
sentences: List[str],
condition: Optional[Callable[[str], bool]] = None,
) -> Tuple[List[str], List[bool]]:
"""
Filter sentence based on some condition, and returns filter mask
Parameters
----------
sentences:
list of sentences
condition:
sentence filtering condition
Returns
---------
sentences:
list of sentences filtered
mask:
boolean mask such that `mask[i] == True` if the i'th sentence is included in the
filtered sentence, otherwise `mask[i] == False`
Examples
--------
>>> from cleanlab.internal.token_classification_utils import filter_sentence
>>> sentences = ["Short sentence.", "This is a longer sentence."]
>>> condition = lambda x: len(x.split()) > 2
>>> long_sentences, _ = filter_sentence(sentences, condition)
>>> long_sentences
['This is a longer sentence.']
>>> document = ["# Headline", "Sentence 1.", "&", "Sentence 2."]
>>> sentences, mask = filter_sentence(document)
>>> sentences, mask
(['Sentence 1.', 'Sentence 2.'], [False, True, False, True])
"""
if not condition:
condition = lambda sentence: len(sentence) > 1 and "#" not in sentence
mask = list(map(condition, sentences))
sentences = [sentence for m, sentence in zip(mask, sentences) if m]
return sentences, mask
[docs]def process_token(token: str, replace: List[Tuple[str, str]] = [("#", "")]) -> str:
"""
Replaces special characters in the tokens
Parameters
----------
token:
token which potentially contains special characters
replace:
list of tuples `(s1, s2)`, where all occurances of s1 are replaced by s2
Returns
---------
processed_token:
processed token whose special character has been replaced
Note
----
Only applies to characters in the original input token.
Examples
--------
>>> from cleanlab.internal.token_classification_utils import process_token
>>> token = "#Comment"
>>> process_token("#Comment")
'Comment'
Specify custom replacement rules
>>> replace = [("C", "a"), ("a", "C")]
>>> process_token("Cleanlab", replace)
'aleCnlCb'
"""
replace_dict = {re.escape(k): v for (k, v) in replace}
pattern = "|".join(replace_dict.keys())
compiled_pattern = re.compile(pattern)
replacement = lambda match: replace_dict[re.escape(match.group(0))]
processed_token = compiled_pattern.sub(replacement, token)
return processed_token
[docs]def mapping(entities: List[int], maps: List[int]) -> List[int]:
"""
Map a list of entities to its corresponding entities
Parameters
----------
entities:
a list of given entities
maps:
a list of mapped entities, such that the i'th indexed token should be mapped to `maps[i]`
Returns
---------
mapped_entities:
a list of mapped entities
Examples
--------
>>> unique_identities = [0, 1, 2, 3, 4] # ["O", "B-PER", "I-PER", "B-LOC", "I-LOC"]
>>> maps = [0, 1, 1, 2, 2] # ["O", "PER", "PER", "LOC", "LOC"]
>>> mapping(unique_identities, maps)
[0, 1, 1, 2, 2] # ["O", "PER", "PER", "LOC", "LOC"]
>>> mapping([0, 0, 4, 4, 3, 4, 0, 2], maps)
[0, 0, 2, 2, 2, 2, 0, 1] # ["O", "O", "LOC", "LOC", "LOC", "LOC", "O", "PER"]
"""
f = lambda x: maps[x]
return list(map(f, entities))
[docs]def merge_probs(
probs: npt.NDArray["np.floating[T]"], maps: List[int]
) -> npt.NDArray["np.floating[T]"]:
"""
Merges model-predictive probabilities with desired mapping
Parameters
----------
probs:
A 2D np.array of shape `(N, K)`, where N is the number of tokens, and K is the number of classes for the model
maps:
a list of mapped index, such that the probability of the token being in the i'th class is mapped to the
`maps[i]` index. If `maps[i] == -1`, the i'th column of `probs` is ignored. If `np.any(maps == -1)`, the
returned probability is re-normalized.
Returns
---------
probs_merged:
A 2D np.array of shape ``(N, K')``, where `K'` is the number of new classes. Probabilities are merged and
re-normalized if necessary.
Examples
--------
>>> import numpy as np
>>> from cleanlab.internal.token_classification_utils import merge_probs
>>> probs = np.array([
... [0.55, 0.0125, 0.0375, 0.1, 0.3],
... [0.1, 0.8, 0, 0.075, 0.025],
... ])
>>> maps = [0, 1, 1, 2, 2]
>>> merge_probs(probs, maps)
array([[0.55, 0.05, 0.4 ],
[0.1 , 0.8 , 0.1 ]])
"""
old_classes = probs.shape[1]
map_size = np.max(maps) + 1
probs_merged = np.zeros([len(probs), map_size], dtype=probs.dtype.type)
for i in range(old_classes):
if maps[i] >= 0:
probs_merged[:, maps[i]] += probs[:, i]
if -1 in maps:
row_sums = probs_merged.sum(axis=1)
probs_merged /= row_sums[:, np.newaxis]
return probs_merged
[docs]def color_sentence(sentence: str, word: str) -> str:
"""
Searches for a given token in the sentence and returns the sentence where the given token is colored red
Parameters
----------
sentence:
a sentence where the word is searched
word:
keyword to find in `sentence`. Assumes the word exists in the sentence.
Returns
---------
colored_sentence:
`sentence` where the every occurrence of the word is colored red, using ``termcolor.colored``
Examples
--------
>>> from cleanlab.internal.token_classification_utils import color_sentence
>>> sentence = "This is a sentence."
>>> word = "sentence"
>>> color_sentence(sentence, word)
'This is a \x1b[31msentence\x1b[0m.'
Also works for multiple occurrences of the word
>>> document = "This is a sentence. This is another sentence."
>>> word = "sentence"
>>> color_sentence(document, word)
'This is a \x1b[31msentence\x1b[0m. This is another \x1b[31msentence\x1b[0m.'
"""
colored_word = colored(word, "red")
return _replace_sentence(sentence=sentence, word=word, new_word=colored_word)
def _replace_sentence(sentence: str, word: str, new_word: str) -> str:
"""
Searches for a given token in the sentence and returns the sentence where the given token has been replaced by
`new_word`.
Parameters
----------
sentence:
a sentence where the word is searched
word:
keyword to find in `sentence`. Assumes the word exists in the sentence.
new_word:
the word to replace the keyword with
Returns
---------
new_sentence:
`sentence` where the every occurrence of the word is replaced by `colored_word`
"""
new_sentence, number_of_substitions = re.subn(
r"\b{}\b".format(re.escape(word)), new_word, sentence
)
if number_of_substitions == 0:
# Use basic string manipulation if regex fails
new_sentence = sentence.replace(word, new_word)
return new_sentence