1# wc_spell.py <in_utf8> <out_utf8> 2# Full-text Russian spelling/punctuation/case correction via the open SAGE model 3# (ai-forever/sage-fredt5-distilled-95m, MIT). Reads the input text file (UTF-8), corrects it 4# line by line (each line ~= one sentence to respect the model's length), writes the result. 5# Called by the .NET app (Helpers/SageClient.cs) for the "correct whole text" action. 6import sys 7 8MODEL = "ai-forever/sage-fredt5-distilled-95m" 9 10def main(): 11 if len(sys.argv) < 3: 12 sys.stderr.write("usage: wc_spell.py <in> <out>\n") 13 return 2 14 in_path, out_path = sys.argv[1], sys.argv[2] 15 try: 16 with open(in_path, encoding="utf-8") as f: 17 data = f.read() 18 except Exception as e: 19 sys.stderr.write("read error: %r\n" % (e,)) 20 return 2 21 22 try: 23 import torch 24 from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 25 tok = AutoTokenizer.from_pretrained(MODEL) 26 model = AutoModelForSeq2SeqLM.from_pretrained(MODEL) 27 model.eval() 28 except Exception as e: 29 sys.stderr.write("model load error: %r\n" % (e,)) 30 return 3 31 32 def correct(text): 33 ids = tok(text, return_tensors="pt", max_length=256, truncation=True) 34 with torch.no_grad(): 35 out = model.generate(**ids, max_length=256, num_beams=5, early_stopping=True) 36 return tok.batch_decode(out, skip_special_tokens=True)[0] 37 38 res = [] 39 for line in data.split("\n"): 40 if line.strip(): 41 try: 42 res.append(correct(line)) 43 except Exception: 44 res.append(line) # on failure keep the original line (safe) 45 else: 46 res.append(line) 47 48 try: 49 with open(out_path, "w", encoding="utf-8") as f: 50 f.write("\n".join(res)) 51 except Exception as e: 52 sys.stderr.write("write error: %r\n" % (e,)) 53 return 1 54 return 0 55 56if __name__ == "__main__": 57 sys.exit(main())