Source code for renn.data.wordpiece_tokenizer_learner_lib

# coding=utf-8
# Copyright 2020 TF.Text Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Algorithm for learning wordpiece vocabulary."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections

Params = collections.namedtuple(
    'Params', 'upper_thresh lower_thresh '
    'num_iterations max_input_tokens '
    'max_token_length max_unique_chars vocab_size '
    'slack_ratio include_joiner_token joiner '
    'reserved_tokens')


[docs]def extract_char_tokens(word_counts): """Extracts all single-character tokens from word_counts. Args: word_counts: list of (string, int) tuples Returns: set of single-character strings contained within word_counts """ seen_chars = set() for word, _ in word_counts: for char in word: seen_chars.add(char) return seen_chars
[docs]def ensure_all_tokens_exist(input_tokens, output_tokens, include_joiner_token, joiner): """Adds all tokens in input_tokens to output_tokens if not already present. Args: input_tokens: set of strings (tokens) we want to include output_tokens: string to int dictionary mapping token to count include_joiner_token: bool whether to include joiner token joiner: string used to indicate suffixes Returns: string to int dictionary with all tokens in input_tokens included """ for token in input_tokens: if token not in output_tokens: output_tokens[token] = 1 if include_joiner_token: joined_token = joiner + token if joined_token not in output_tokens: output_tokens[joined_token] = 1 return output_tokens
[docs]def get_split_indices(word, curr_tokens, include_joiner_token, joiner): """Gets indices for valid substrings of word, for iterations > 0. For iterations > 0, rather than considering every possible substring, we only want to consider starting points corresponding to the start of wordpieces in the current vocabulary. Args: word: string we want to split into substrings curr_tokens: string to int dict of tokens in vocab (from previous iteration) include_joiner_token: bool whether to include joiner token joiner: string used to indicate suffixes Returns: list of ints containing valid starting indices for word """ indices = [] start = 0 while start < len(word): end = len(word) while end > start: subtoken = word[start:end] # Subtoken includes the joiner token. if include_joiner_token and start > 0: subtoken = joiner + subtoken # If subtoken is part of vocab, 'end' is a valid start index. if subtoken in curr_tokens: indices.append(end) break end -= 1 if end == start: return None start = end return indices
[docs]def get_search_threshs(word_counts, upper_thresh, lower_thresh): """Clips the thresholds for binary search based on current word counts. The upper threshold parameter typically has a large default value that can result in many iterations of unnecessary search. Thus we clip the upper and lower bounds of search to the maximum and the minimum wordcount values. Args: word_counts: list of (string, int) tuples upper_thresh: int, upper threshold for binary search lower_thresh: int, lower threshold for binary search Returns: upper_search: int, clipped upper threshold for binary search lower_search: int, clipped lower threshold for binary search """ counts = [count for _, count in word_counts] max_count = max(counts) min_count = min(counts) upper_search = max_count if max_count < upper_thresh else upper_thresh lower_search = min_count if min_count > lower_thresh else lower_thresh return upper_search, lower_search
[docs]def get_input_words(word_counts, reserved_tokens, max_token_length): """Filters out words that are longer than max_token_length or are reserved. Args: word_counts: list of (string, int) tuples reserved_tokens: list of strings max_token_length: int, maximum length of a token Returns: list of (string, int) tuples of filtered wordcounts """ all_counts = [] for word, count in word_counts: if len(word) > max_token_length or word in reserved_tokens: continue all_counts.append((word, count)) return all_counts
[docs]def get_allowed_chars(all_counts, max_unique_chars): """Get the top max_unique_chars characters within our wordcounts. We want each character to be in the vocabulary so that we can keep splitting down to the character level if necessary. However, in order not to inflate our vocabulary with rare characters, we only keep the top max_unique_chars characters. Args: all_counts: list of (string, int) tuples max_unique_chars: int, maximum number of unique single-character tokens Returns: set of strings containing top max_unique_chars characters in all_counts """ char_counts = collections.defaultdict(int) for word, count in all_counts: for char in word: char_counts[char] += count # Sort by count, then alphabetically. sorted_counts = sorted(sorted(char_counts.items(), key=lambda x: x[0]), key=lambda x: x[1], reverse=True) allowed_chars = set() for i in range(min(len(sorted_counts), max_unique_chars)): allowed_chars.add(sorted_counts[i][0]) return allowed_chars
[docs]def filter_input_words(all_counts, allowed_chars, max_input_tokens): """Filters out words with unallowed chars and limits words to max_input_tokens. Args: all_counts: list of (string, int) tuples allowed_chars: list of single-character strings max_input_tokens: int, maximum number of tokens accepted as input Returns: list of (string, int) tuples of filtered wordcounts """ filtered_counts = [] for word, count in all_counts: if (max_input_tokens != -1 and len(filtered_counts) >= max_input_tokens): break has_unallowed_chars = False for char in word: if char not in allowed_chars: has_unallowed_chars = True break if has_unallowed_chars: continue filtered_counts.append((word, count)) return filtered_counts
[docs]def generate_final_vocabulary(reserved_tokens, char_tokens, curr_tokens): """Generates final vocab given reserved, single-character, and current tokens. Args: reserved_tokens: list of strings (tokens) that must be included in vocab char_tokens: set of single-character strings curr_tokens: string to int dict mapping token to count Returns: list of strings representing final vocabulary """ sorted_char_tokens = sorted(list(char_tokens)) vocab_char_arrays = [] vocab_char_arrays.extend(reserved_tokens) vocab_char_arrays.extend(sorted_char_tokens) # Sort by count, then alphabetically. sorted_tokens = sorted(sorted(curr_tokens.items(), key=lambda x: x[0]), key=lambda x: x[1], reverse=True) for token, _ in sorted_tokens: vocab_char_arrays.append(token) seen_tokens = set() # Adding unique tokens to list to maintain sorted order. vocab_words = [] for word in vocab_char_arrays: if word in seen_tokens: continue seen_tokens.add(word) vocab_words.append(word) return vocab_words
[docs]def learn_with_thresh(word_counts, thresh, params): """Wordpiece learning algorithm to produce a vocab given frequency threshold. Args: word_counts: list of (string, int) tuples thresh: int, frequency threshold for a token to be included in the vocab params: Params namedtuple, parameters for learning Returns: list of strings, vocabulary generated for the given thresh """ # Set of single-character tokens. char_tokens = extract_char_tokens(word_counts) curr_tokens = ensure_all_tokens_exist(char_tokens, {}, params.include_joiner_token, params.joiner) for iteration in range(params.num_iterations): subtokens = [dict() for _ in range(params.max_token_length + 1)] # Populate array with counts of each subtoken. for word, count in word_counts: if iteration == 0: split_indices = range(1, len(word) + 1) else: split_indices = get_split_indices(word, curr_tokens, params.include_joiner_token, params.joiner) if not split_indices: continue start = 0 for index in split_indices: for end in range(start + 1, len(word) + 1): subtoken = word[start:end] length = len(subtoken) if params.include_joiner_token and start > 0: subtoken = params.joiner + subtoken if subtoken in subtokens[length]: # Subtoken exists, increment count. subtokens[length][subtoken] += count else: # New subtoken, add to dict. subtokens[length][subtoken] = count start = index next_tokens = {} # Get all tokens that have a count above the threshold. for length in range(params.max_token_length, 0, -1): for token, count in subtokens[length].items(): if count >= thresh: next_tokens[token] = count # Decrement the count of all prefixes. if len(token) > length: # This token includes the joiner. joiner_len = len(params.joiner) for i in range(1 + joiner_len, length + joiner_len): prefix = token[0:i] if prefix in subtokens[i - joiner_len]: subtokens[i - joiner_len][prefix] -= count else: for i in range(1, length): prefix = token[0:i] if prefix in subtokens[i]: subtokens[i][prefix] -= count # Add back single-character tokens. curr_tokens = ensure_all_tokens_exist(char_tokens, next_tokens, params.include_joiner_token, params.joiner) vocab_words = generate_final_vocabulary(params.reserved_tokens, char_tokens, curr_tokens) return vocab_words
[docs]def learn(word_counts, params): """Takes in wordcounts and returns wordpiece vocabulary. Args: word_counts: list of (string, int) tuples params: Params namedtuple, parameters for learning Returns: string, final vocabulary with each word separated by newline """ upper_search, lower_search = get_search_threshs(word_counts, params.upper_thresh, params.lower_thresh) all_counts = get_input_words(word_counts, params.reserved_tokens, params.max_token_length) allowed_chars = get_allowed_chars(all_counts, params.max_unique_chars) filtered_counts = filter_input_words(all_counts, allowed_chars, params.max_input_tokens) vocab = learn_binary_search(filtered_counts, lower_search, upper_search, params) return vocab