Sentence masking is a common procedure in the pre-training of deep learning models. For instance, in the pre-training of the BART model, there is a training objective named text infilling. The aim of text infilling is to let the model acquire the ability to fill in the blanks, which requires the development of a masking algorithm to generate masked sentences for the training. However, existing papers (TODO: who?) provide only vague descriptions of the masking methods in their experiments, which are merely theoretical and cannot really be implemented as algorithms; while most of the implementations only design roughly usable algorithms without in-depth analysis. In response, we propose an appropriate and robust masking algorithm (the wakong algorithm), and publishes a Python library that can be used directly in production.
A sentence masking algorithm is an algorithm for masking a sentence
of a fixed length. The input of the algorithm is the length of the
sentence, denoted by seq_len
, and the output is a list of
tuples. Each tuple represents a mask in the original sentence. The first
element represents the starting position of the mask and the second
element represents the length.
For example, when the length of the input sentence 40, the output of the algorithm might be:
[(5, 4), (23, 2)]
This output indicates that two positions of the sentence are masked. The first one starts at position 5 and has a length of 4, while the second starts at position 23 and has a length of 2. It can also be represented graphically as follows:
.....(xxxx)..............(xx)...............
Based on the BART paper, we suggest that a reasonable sentence masking algorithm should satisfy the following four requirements:
1. Constants
\mathsf{proposedMaskRate} = 0.188
\mathsf{poissonRate} = 4.2
\mathsf{maxSpanLen} = 10
2. probsList
\mathsf{probsList} = \left[ \mathrm{normalise} \left( \mathsf{probs} \left[ {:}\,i \right] \right) \mathrm{for} \; i \; \mathrm{in} \left[2, \; .. , \; \mathsf{maxSpanLen} + 1 \right] \right]
\mathsf{probs} = \left[ \Pr(X=i) \; \mathrm{for} \; i \; \mathrm{in} \left[ 0, \; .., \; \mathsf{maxSpanLen} + 1 \right] \right]
X \sim \mathrm{Pois}(\mathsf{poissonRate})
3. determineShouldMaskLen
\mathsf{determineShouldMaskLen} \left( \mathsf{seqLen} \right) = \begin{cases} \lceil x \rceil, & \text{if} \; \omega < p \\ \lfloor x \rfloor, & \text{otherwise} \\ \end{cases}
\omega \sim \mathrm{U} \left( 0, 1 \right)
x = \mathsf{seqLen} * \mathsf{proposedMaskRate}
p = x - \lfloor x \rfloor
4. generateSpans
\mathsf{generateSpans} \left( m \right) = \mathrm{shuffle} \left( \mathrm{anamorphism} \left( f \right) \left( m \right) \right)
f \left( \mathsf{remainedLen} \right) = \begin{cases} \mathrm{Nothing}, & \text{if} \; \mathsf{remainedLen} \leq 0 \\ \left( \mathsf{span}, \; \mathrm{Just} \left( \mathsf{remainedLen} - \mathsf{span} - 1 \right) \right), & \text{otherwise} \end{cases}
\mathsf{span} \sim \mathrm{Categorical} \left( [0, \; .., \; n + 1], \; \mathsf{probsList} \left[ n - 1 \right] \right)
n = \min \left( \mathsf{maxSpanLen}, \; \mathsf{remainedLen} \right)
5. distributeInsertPoses
\mathsf{distributeInsertPoses} \left( \mathsf{xs} \right) = f \left( \mathsf{xs}, \; 0 \right)
f \left( n, \; \mathsf{xs} \right) = \begin{cases} \mathsf{\left[ \, \right]}, & \text{if} \; \mathrm{empty} \left( \mathsf{xs} \right) \\ \left[ \left( p + n, \; s \right) \right] + f \left(n + s + 1, \; \mathsf{ys} \right), & \text{otherwise} \\ \end{cases}
\left[ \left( p, s \right) \right] + \mathsf{ys} \leftarrow \mathsf{xs}
6. randomAddOne
\mathsf{randomAddOne} \left( \mathsf{xs} \right) = \begin{cases} \mathsf{xs}, & \text{if} \; \omega < 0.5 \\ \left[ (p + 1, s) \; \mathrm{for} \; (p, s) \; \mathrm{in} \; \mathsf{xs} \right], & \text{otherwise} \\ \end{cases}
\omega \sim \mathrm{U} \left( 0, 1 \right)
7. wakong
\mathsf{wakong} \left( \mathsf{seqLen} \right) = \mathsf{randomAddOne} \left( \mathsf{distributeInsertPoses} \left( \mathrm{zip} \left( \mathsf{absInsertPoses}, \; \mathsf{spans} \right) \right) \right)
\mathsf{absInsertPoses} = \mathrm{sort} \left( X \right)
X = X_{1, \; .., \; \mathsf{nSpans}} \sim \mathrm{DiscreteUniform} \left[ 0, \; \mathsf{nPossibleInsertPoses} - 1 \right]
\left( \forall \; i, j \in \left\{ 1, \; .., \; \mathsf{nSpans} \right\}, X_i \ne X_j \right)
\mathsf{nPossibleInsertPoses} = \mathsf{seqLen} - \mathrm{sum} \left( \mathsf{spans} \right) - \mathsf{nSpans} + 1
\mathsf{nSpans} = \mathrm{len} \left( \mathsf{spans} \right)
\mathsf{spans} = \mathsf{generateSpans} \left( \mathsf{shouldMaskLen} \right)
\mathsf{shouldMaskLen} = \mathsf{determineShouldMaskLen} \left( \mathsf{seqLen} \right)
The step with the highest time complexity in the algorithm is sorting the randomly generated kn blanks, so the overall time complexity is O \left( n \log n \right).
The algorithm requires that an average of 15% of the words in a sentence should be masked, but this calculation sometimes results in fractions. If this occur, we set the number of fractions to be rounded down or up randomly according to the fractional places. For example, if the number of words to be masked is calculated to be 3.3, a random number is randomly generated once with a uniform distribution between 0 and 1, rounded up to 4 if the number is less than 0.3, otherwise rounded down to 3.
Following the BART paper, we sample from a Poisson distribution to randomly generate the length of the mask. Instead of setting the parameter of the Poisson distribution to 4 as in the BART paper, we set it to 3.5 so that masks of length 3 would occur most frequently (however, as will be mentioned in a subsequent step, we revised this parameter to 4.2). For values greater than 10, we set the probability to 0 and normalise the probability of values between 0 and 10 to sum to 1. This produces a distribution with a cumulative distribution function of [0.0151 0.0783 0.2111 0.3970 0.5922 0.7562 0.8710 0.9399 0.9760 0.9929 1.0000].
The list of mask lengths is generated by repeatedly sampling from the above distribution. The sampling stops when the sum of the lengths of the masks reaches the number of words to be masked.
If the sum of the lengths of the masks does not reach the target number of words, but the sum of the lengths of the sampled results plus the masks is greater than the target number of words (e.g. if the target number of words is 10 and the current sum of lengths is 9, but the sampled result is 5 and 9 plus 5 is greater than 10), the algorithm will discard the result of that sample and resample it until the sum of the lengths of the sampled results plus the masks is within the range of the target number of words. In practice, in order to ensure the efficiency of the algorithm, it should not re-sample when the sampling fails, but should first calculate the range of expected sampling results based on the sum of the lengths of the target words and the mask, then calculate a new distribution based on the above distribution excluding the values outside the expected range of sampling results, normalise the probability to sum to 1 and sample from the new distribution.
The algorithm requires that any two masks cannot be directly adjacent to each other, and a mask of length k actually occupies a position to its right, i.e. the actual length is k+1. Therefore, when calculating the sum of the lengths of the masks, the length of each mask needs to be added by an extra 1, i.e. the sum of the lengths of the masks plus the number of masks. Although this is a good way to avoid the problem of two masks being directly adjacent to each other, it will result in a smaller average number of words masked than the expected 15% (for this reason, the average number of words masked will be adjusted to 18.8% in a subsequent step to make the final result closer to 15%).
An asymmetry arises because samples of length 0 may occur at the start of sampling, while samples of length 0 are unlikely to occur at the end of sampling conditional on the target length being reached. For this reason, the list of mask lengths should be randomly scrambled at the end of sampling so that the lengths of the masks are randomly distributed.
Let the length of the sentence be m, the sum of the lengths of the masks be K, and the number of masks be n. There are m-K-n+1 possible starting positions, and n of these starting positions are chosen at random as the starting positions of the masks. The reason for subtracting n is that, as mentioned above, a mask of length k actually occupies a position to its right, so n masks will occupy an additional n positions.
However, this would result in the last word of the sentence never being masked. For this reason, after the above step, a random number between 0 and 1 is randomly generated, and if this number is less than 0.5, all masks are shifted one place to the right, i.e. the empty space is assumed to be on the left, thus ensuring the symmetry of the algorithm.
After implementing the algorithm, we found that the average number of words masked was less than 15%. This is due to the fact that, as mentioned above, the sum of the lengths of the masks is calculated by adding an extra 1 to the length of each blank, resulting in the actual number of words masked being less than the target number of words. For this reason, it was found that by adjusting the average number of words masked in the algorithm parameters to 18.8%, the final result was close to 15.17%, which is close to the expected value of 15%.
In addition, the algorithm generated shorter length masks more frequently than expected because the expected sampling results could only occur at smaller values as the sampling neared its end. This is allowed by the algorithm as it is only necessary to ensure that a mask of length 3 occurs most frequently. However, in order to make the algorithm generate masks of longer lengths more frequently in order to make training more difficult, we modified the parameter of the Poisson distribution from 3.5 to 4.2.
import jax.numpy as np
import numpyro.distributions as dist
from random import Random
= 0.188 # resulting mask rate would be approximately 0.15
proposed_mask_rate = 4.2 # span length = 3 would be the most frequent in the resulting distribution
poisson_rate = 10
max_span_len
def normalise_probs(a: np.ndarray) -> np.ndarray:
return a / a.sum()
def generate_probs_list() -> list[list[float]]:
= []
probs_list
= dist.Poisson(rate=poisson_rate)
poisson = np.exp(poisson.log_prob(np.arange(max_span_len + 1)))
probs
= normalise_probs(probs)
probs_
probs_list.append(probs_.cumsum().tolist())
for i in range(max_span_len - 1):
= normalise_probs(probs[:-i-1])
probs_
probs_list.append(probs_.cumsum().tolist())
return probs_list[::-1]
= generate_probs_list()
probs_list
= list[tuple[int, int]]
MaskScheme
def determine_should_mask_len(rng: Random, seq_len: int) -> int:
= seq_len * proposed_mask_rate
x = int(x)
integer_part = x - float(integer_part)
fractional_part = rng.random() < fractional_part
should_add = integer_part + should_add
should_mask_len return should_mask_len
def generate_spans(rng: Random, should_mask_len: int) -> list[int]:
= []
spans while should_mask_len > 0:
= min(max_span_len, should_mask_len)
current_max_span_len = probs_list[current_max_span_len - 1]
probs = rng.choices(range(current_max_span_len + 1), cum_weights=probs)[0]
span_len
spans.append(span_len)-= span_len + 1
should_mask_len
rng.shuffle(spans)return spans
def distribute_insert_poses(abs_insert_poses: list[int], spans: list[int]) -> MaskScheme:
= 0
offset = []
mask_scheme for abs_insert_pos, span in zip(abs_insert_poses, spans):
= abs_insert_pos + offset
insert_pos
mask_scheme.append((insert_pos, span))+= span + 1
offset return mask_scheme
def random_add_one(rng: Random, mask_scheme: MaskScheme) -> MaskScheme:
= rng.random() < 0.5
should_add_one if should_add_one:
= [(insert_pos + 1, span) for insert_pos, span in mask_scheme]
mask_scheme return mask_scheme
def wakong(rng: Random, seq_len: int) -> MaskScheme:
= determine_should_mask_len(rng, seq_len)
should_mask_len = generate_spans(rng, should_mask_len)
spans
= len(spans)
n_spans = seq_len - sum(spans) - n_spans + 1
n_possible_insert_poses = sorted(rng.sample(range(n_possible_insert_poses), n_spans))
abs_insert_poses
= distribute_insert_poses(abs_insert_poses, spans)
mask_scheme = random_add_one(rng, mask_scheme)
mask_scheme return mask_scheme
def test():
= 42
seed = Random(seed)
rng = wakong(rng, 100)
mask_scheme print(mask_scheme)
if __name__ == '__main__':
test()