windowcapture
исходный код / Spell/wc_spell.py

wc_spell.py

57 строк · 1,965 байт · модуль Spell
 1# wc_spell.py <in_utf8> <out_utf8>
 2# Full-text Russian spelling/punctuation/case correction via the open SAGE model
 3# (ai-forever/sage-fredt5-distilled-95m, MIT). Reads the input text file (UTF-8), corrects it
 4# line by line (each line ~= one sentence to respect the model's length), writes the result.
 5# Called by the .NET app (Helpers/SageClient.cs) for the "correct whole text" action.
 6import sys
 7
 8MODEL = "ai-forever/sage-fredt5-distilled-95m"
 9
10def main():
11    if len(sys.argv) < 3:
12        sys.stderr.write("usage: wc_spell.py <in> <out>\n")
13        return 2
14    in_path, out_path = sys.argv[1], sys.argv[2]
15    try:
16        with open(in_path, encoding="utf-8") as f:
17            data = f.read()
18    except Exception as e:
19        sys.stderr.write("read error: %r\n" % (e,))
20        return 2
21
22    try:
23        import torch
24        from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
25        tok = AutoTokenizer.from_pretrained(MODEL)
26        model = AutoModelForSeq2SeqLM.from_pretrained(MODEL)
27        model.eval()
28    except Exception as e:
29        sys.stderr.write("model load error: %r\n" % (e,))
30        return 3
31
32    def correct(text):
33        ids = tok(text, return_tensors="pt", max_length=256, truncation=True)
34        with torch.no_grad():
35            out = model.generate(**ids, max_length=256, num_beams=5, early_stopping=True)
36        return tok.batch_decode(out, skip_special_tokens=True)[0]
37
38    res = []
39    for line in data.split("\n"):
40        if line.strip():
41            try:
42                res.append(correct(line))
43            except Exception:
44                res.append(line)  # on failure keep the original line (safe)
45        else:
46            res.append(line)
47
48    try:
49        with open(out_path, "w", encoding="utf-8") as f:
50            f.write("\n".join(res))
51    except Exception as e:
52        sys.stderr.write("write error: %r\n" % (e,))
53        return 1
54    return 0
55
56if __name__ == "__main__":
57    sys.exit(main())