windowcapture
исходный код / Spell/test_rescore.py

test_rescore.py

69 строк · 3,466 байт · модуль Spell
 1# Standalone proof that the stage-2 context rescorer (rubert-tiny2 masked-LM) actually ranks
 2# look-alike / declension / context-sensitive candidates correctly. Mirrors wc_spell_server.rescore.
 3# Run: python Spell/test_rescore.py   (downloads cointegrated/rubert-tiny2 ~118MB on first run)
 4# Writes UTF-8 results to Spell/_rescore_test_out.txt (console codepage can mangle Cyrillic).
 5import os, io, sys
 6import torch
 7from transformers import AutoModelForMaskedLM, AutoTokenizer
 8
 9NAME = os.environ.get("WC_RESCORE_MODEL", "cointegrated/rubert-tiny2")
10tok = AutoTokenizer.from_pretrained(NAME)
11m = AutoModelForMaskedLM.from_pretrained(NAME); m.eval()
12DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13m = m.to(DEVICE)
14
15def rescore(left, right, cands):
16    mask_id = tok.mask_token_id; cls = tok.cls_token_id; sep = tok.sep_token_id
17    pad = tok.pad_token_id if tok.pad_token_id is not None else 0
18    left_ids = tok(left, add_special_tokens=False)["input_ids"][-24:] if left else []
19    right_ids = tok(right, add_special_tokens=False)["input_ids"][:12] if right else []
20    rows = []
21    for ci, w in enumerate(cands):
22        wt = tok(w, add_special_tokens=False)["input_ids"]
23        if not wt: continue
24        base = ([cls] if cls is not None else []) + left_ids + wt + right_ids + ([sep] if sep is not None else [])
25        off = (1 if cls is not None else 0) + len(left_ids)
26        for j, tid in enumerate(wt):
27            row = list(base); row[off+j] = mask_id
28            rows.append((ci, row, off+j, tid))
29    if not rows: return [0.0]*len(cands)
30    maxlen = max(len(x[1]) for x in rows)
31    ids = [x[1] + [pad]*(maxlen-len(x[1])) for x in rows]
32    att = [[1]*len(x[1]) + [0]*(maxlen-len(x[1])) for x in rows]
33    ii = torch.tensor(ids, device=DEVICE); am = torch.tensor(att, device=DEVICE)
34    with torch.no_grad():
35        logp = torch.log_softmax(m(input_ids=ii, attention_mask=am).logits, dim=-1)
36    sc = [0.0]*len(cands); cnt = [0]*len(cands)
37    for k,(ci,row,pos,tid) in enumerate(rows):
38        sc[ci] += float(logp[k,pos,tid]); cnt[ci]+=1
39    return [sc[i]/cnt[i] if cnt[i] else -99.0 for i in range(len(cands))]
40
41# (left context, candidates, expected winner) — context/declension/look-alike disambiguation
42CASES = [
43    ("мама мыла", ["раму","рому"], "раму"),
44    ("кот поймал", ["мышь","мошь"], "мышь"),
45    ("мы плыли на", ["лодке","ложке"], "лодке"),
46    ("я ел суп", ["ложкой","лодкой"], "ложкой"),
47    ("я думаю о", ["тебе","тебя"], "тебе"),
48    ("я люблю", ["тебя","тебе"], "тебя"),
49    ("он играл на", ["гитаре","гитари"], "гитаре"),
50    ("я выпил стакан", ["воды","моды"], "воды"),
51]
52
53out = io.StringIO()
54out.write("model=%s device=%s\n\n" % (NAME, DEVICE))
55ok = 0
56for left, cands, exp in CASES:
57    sc = rescore(left, "", cands)
58    ranked = sorted(zip(cands, sc), key=lambda x: -x[1])
59    win = ranked[0][0]
60    good = (win == exp)
61    ok += good
62    out.write(("[%s] '%s ___'  ->  " % ("OK " if good else "MISS", left)))
63    out.write(" | ".join("%s %.3f" % (w, s) for w, s in ranked))
64    out.write("   (expected: %s)\n" % exp)
65out.write("\n%d/%d correct\n" % (ok, len(CASES)))
66
67with open(os.path.join(os.path.dirname(__file__), "_rescore_test_out.txt"), "w", encoding="utf-8") as f:
68    f.write(out.getvalue())
69print("done", ok, "/", len(CASES))