The goal of the experiment is to detect and correct the mistakes during fast typing on phone while using the swipe feature. Fast gestures in swipe currently produce some wrong results and there is no flagging/correction done after a sentence is typed. User has to go back and check correctness or reduce the swiping speed. Using language models we can detect the mistakes and improve the typing speed.

Setup

I typed few sentences using google keyboad ans swiftkey as fast as i could and found 80-90% of the words were correct. But the rest did not make sense in the context of the sentence. Using Roberta Masked Language model, those errors can be detected and rectified after a sentence has been typed. This feature can be used in phone keyboards as a second layer of check after a sentence is typed.

Algorithm

Mask each word in the raw sentence and pass it to roberta model. Collect the labels find the probability of the masked word , top 3 suggestions and their probabilities. If a sentence has N tokens, then N forward passes are done.


from transformers import RobertaTokenizer, RobertaForMaskedLM
import torch
tokenizer = RobertaTokenizer.from_pretrained('roberta-large')
model = RobertaForMaskedLM.from_pretrained('roberta-base')

def autocorrect(sentence):
    tokens = tokenizer.tokenize(sentence,add_prefix_space=False)
    #print(tokens)
    inputs = tokenizer.encode(sentence, add_special_tokens=True, add_prefix_space=False)
    labels = torch.tensor([1]).unsqueeze(0)

    #mask each token and find prediction
    final_scores = []
    from tqdm.notebook import tqdm
    for i,tok_id in tqdm(enumerate(inputs),):
        input_masked = inputs.copy()
        input_masked[i] = tokenizer.mask_token_id
        input_ids = torch.tensor(input_masked).unsqueeze(0)
        outputs = model(input_ids, masked_lm_labels=input_ids)
        loss, prediction_scores = outputs[:2]
        final_scores.append(prediction_scores.squeeze()[i])
    prediction_scores = torch.stack(final_scores)
    prediction_scores = prediction_scores.softmax(dim = -1)

    #convert a list of ids to list of words. Removes special character from BPE tokenizer.
    convert_to_words = lambda inputs : \
    [tokenizer.convert_tokens_to_string(tok).strip() for tok in tokenizer.convert_ids_to_tokens(inputs)]

    #convert_to_words = lambda inputs : list(inputs)

    probs_list = prediction_scores.sort(descending = True)
    indexes_list = probs_list.indices.squeeze()[:,:3]
    values_list = probs_list.values.squeeze()[:,:3]
    # print(probs_list)
    original_prob = ["%.4f" %prediction_scores.squeeze()[pos,index] for pos, index in enumerate(inputs)]
    options = [convert_to_words(indexes) for indexes in indexes_list]
    values_list = [["%.3f" % v for v in vals] for vals in values_list.tolist()]
    original = convert_to_words(inputs)

    #display results using pandas dataframe
    import pandas as pd
    print(f'Sentence = {tokenizer.decode(inputs)}')
    return pd.DataFrame({'original_text':original, 'original_prob':original_prob, 'suggestions':options, 'suggestions_prob':values_list,}).head(100)
sentences = ["I am trying to write a text using drawer which is sometimes wing." ,
"Two words were wrongly typed here. More I I will need to go back to reach quoted and correct then. ",
"There will be probability calculated for each word which will decide whether the word is appropriate at that place. If not, it will either be replaced or deleted. Special checks to handle succeed tons." ,
"I implemented the salary auto correct algorithm I wanted. It is identifying mistakes but the suggestions are not what I wanted. It probably needs fine-tuning on my father. "]

df = autocorrect(sentences[0])
df

Sentence = <s>I am trying to write a text using drawer which is sometimes wing.</s>

original_text original_prob suggestions suggestions_prob
0 <s> 1.0000 [<s>, ., </s>] [1.000, 0.000, 0.000]
1 I 0.9914 [I, i, I] [0.991, 0.007, 0.000]
2 am 0.2943 ['m, am, was] [0.463, 0.294, 0.214]
3 trying 0.6924 [trying, able, going] [0.692, 0.094, 0.047]
4 to 0.9995 [to, and, the] [0.999, 0.000, 0.000]
5 write 0.0959 [send, type, write] [0.323, 0.107, 0.096]
6 a 0.1075 [the, this, some] [0.361, 0.140, 0.134]
7 text 0.0042 [story, novel, poem] [0.230, 0.152, 0.094]
8 using 0.0009 [for, book, about] [0.129, 0.070, 0.067]
9 drawer 0.0000 [html, HTML, JavaScript] [0.119, 0.096, 0.074]
10 which 0.2974 [that, which, it] [0.443, 0.297, 0.054]
11 is 0.1893 [I, is, can] [0.451, 0.189, 0.089]
12 sometimes 0.0004 [a, my, very] [0.145, 0.136, 0.061]
13 wing 0.0000 [difficult, hard, tricky] [0.295, 0.150, 0.045]
14 . 0.7074 [., :, !] [0.707, 0.078, 0.028]
15 </s> 1.0000 [</s>, I, (] [1.000, 0.000, 0.000]
df = autocorrect(sentences[1])
df

Sentence = <s>Two words were wrongly typed here. More I I will need to go back to reach quoted and correct then.</s>

original_text original_prob suggestions suggestions_prob
0 <s> 1.0000 [<s>, ., </s>] [1.000, 0.000, 0.000]
1 Two 0.0267 [Some, Many, Several] [0.559, 0.113, 0.057]
2 words 0.2290 [words, sentences, lines] [0.229, 0.185, 0.155]
3 were 0.2506 [are, were, I] [0.465, 0.251, 0.212]
4 wrongly 0.0491 [not, incorrectly, already] [0.245, 0.133, 0.090]
5 typed 0.0025 [quoted, used, written] [0.597, 0.067, 0.045]
6 here 0.1820 [here, in, there] [0.182, 0.120, 0.109]
7 . 0.8207 [., ,, and] [0.821, 0.050, 0.045]
8 More 0.0001 [So, Which, I] [0.447, 0.044, 0.044]
9 I 0.0000 [., so, ,] [0.398, 0.061, 0.049]
10 I 0.0028 [think, guess, suspect] [0.455, 0.111, 0.048]
11 will 0.4178 [will, 'll, may] [0.418, 0.256, 0.054]
12 need 0.1014 [have, need, try] [0.848, 0.101, 0.030]
13 to 0.9965 [to, and, will] [0.997, 0.001, 0.000]
14 go 0.7911 [go, get, come] [0.791, 0.020, 0.019]
15 back 0.7227 [back, through, further] [0.723, 0.070, 0.019]
16 to 0.1531 [and, ,, to] [0.589, 0.161, 0.153]
17 reach 0.0000 [the, be, them] [0.229, 0.090, 0.072]
18 quoted 0.0000 [them, out, back] [0.114, 0.083, 0.035]
19 and 0.0653 [it, word, them] [0.167, 0.107, 0.104]
20 correct 0.0497 [write, correct, post] [0.057, 0.050, 0.042]
21 then 0.0001 [them, it, this] [0.332, 0.321, 0.054]
22 . 0.8164 [., :, !] [0.816, 0.095, 0.019]
23 </s> 0.9997 [</s>, ., "] [1.000, 0.000, 0.000]

Conclusion

Evident from the result, The Roberta model is very confident on the mistake positions during typing. Tokens with very low probability (0.0000) are definitely incorrect.

The suggestions are also good but not great. Finetuning on my keyboard inputs will make the suggestions better. Also masking whole words instead of just tokens might help removing any bias because pieces of wrong words might influence the results.