1# Standalone proof that the stage-2 context rescorer (rubert-tiny2 masked-LM) actually ranks 2# look-alike / declension / context-sensitive candidates correctly. Mirrors wc_spell_server.rescore. 3# Run: python Spell/test_rescore.py (downloads cointegrated/rubert-tiny2 ~118MB on first run) 4# Writes UTF-8 results to Spell/_rescore_test_out.txt (console codepage can mangle Cyrillic). 5import os, io, sys 6import torch 7from transformers import AutoModelForMaskedLM, AutoTokenizer 8 9NAME = os.environ.get("WC_RESCORE_MODEL", "cointegrated/rubert-tiny2") 10tok = AutoTokenizer.from_pretrained(NAME) 11m = AutoModelForMaskedLM.from_pretrained(NAME); m.eval() 12DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 13m = m.to(DEVICE) 14 15def rescore(left, right, cands): 16 mask_id = tok.mask_token_id; cls = tok.cls_token_id; sep = tok.sep_token_id 17 pad = tok.pad_token_id if tok.pad_token_id is not None else 0 18 left_ids = tok(left, add_special_tokens=False)["input_ids"][-24:] if left else [] 19 right_ids = tok(right, add_special_tokens=False)["input_ids"][:12] if right else [] 20 rows = [] 21 for ci, w in enumerate(cands): 22 wt = tok(w, add_special_tokens=False)["input_ids"] 23 if not wt: continue 24 base = ([cls] if cls is not None else []) + left_ids + wt + right_ids + ([sep] if sep is not None else []) 25 off = (1 if cls is not None else 0) + len(left_ids) 26 for j, tid in enumerate(wt): 27 row = list(base); row[off+j] = mask_id 28 rows.append((ci, row, off+j, tid)) 29 if not rows: return [0.0]*len(cands) 30 maxlen = max(len(x[1]) for x in rows) 31 ids = [x[1] + [pad]*(maxlen-len(x[1])) for x in rows] 32 att = [[1]*len(x[1]) + [0]*(maxlen-len(x[1])) for x in rows] 33 ii = torch.tensor(ids, device=DEVICE); am = torch.tensor(att, device=DEVICE) 34 with torch.no_grad(): 35 logp = torch.log_softmax(m(input_ids=ii, attention_mask=am).logits, dim=-1) 36 sc = [0.0]*len(cands); cnt = [0]*len(cands) 37 for k,(ci,row,pos,tid) in enumerate(rows): 38 sc[ci] += float(logp[k,pos,tid]); cnt[ci]+=1 39 return [sc[i]/cnt[i] if cnt[i] else -99.0 for i in range(len(cands))] 40 41# (left context, candidates, expected winner) — context/declension/look-alike disambiguation 42CASES = [ 43 ("мама мыла", ["раму","рому"], "раму"), 44 ("кот поймал", ["мышь","мошь"], "мышь"), 45 ("мы плыли на", ["лодке","ложке"], "лодке"), 46 ("я ел суп", ["ложкой","лодкой"], "ложкой"), 47 ("я думаю о", ["тебе","тебя"], "тебе"), 48 ("я люблю", ["тебя","тебе"], "тебя"), 49 ("он играл на", ["гитаре","гитари"], "гитаре"), 50 ("я выпил стакан", ["воды","моды"], "воды"), 51] 52 53out = io.StringIO() 54out.write("model=%s device=%s\n\n" % (NAME, DEVICE)) 55ok = 0 56for left, cands, exp in CASES: 57 sc = rescore(left, "", cands) 58 ranked = sorted(zip(cands, sc), key=lambda x: -x[1]) 59 win = ranked[0][0] 60 good = (win == exp) 61 ok += good 62 out.write(("[%s] '%s ___' -> " % ("OK " if good else "MISS", left))) 63 out.write(" | ".join("%s %.3f" % (w, s) for w, s in ranked)) 64 out.write(" (expected: %s)\n" % exp) 65out.write("\n%d/%d correct\n" % (ok, len(CASES))) 66 67with open(os.path.join(os.path.dirname(__file__), "_rescore_test_out.txt"), "w", encoding="utf-8") as f: 68 f.write(out.getvalue()) 69print("done", ok, "/", len(CASES))