1# wc_spell_server.py <port> 2# Warm SAGE corrector server: loads ai-forever/sage-fredt5-distilled-95m ONCE, then serves 3# corrections over localhost HTTP so the model isn't reloaded on every request (cold call ~8s, 4# warm calls ~sub-second). The .NET app (Helpers/SageClient.cs) starts it lazily, polls /ping 5# until ready, then POSTs UTF-8 text and reads back the corrected UTF-8 text. 6import sys, os 7 8def main(): 9 port = int(sys.argv[1]) if len(sys.argv) > 1 else 8765 10 11 import torch 12 from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 13 from http.server import BaseHTTPRequestHandler, HTTPServer 14 15 # Default = distilled 95M (fast). Set WC_SAGE_MODEL=ai-forever/sage-fredt5-large for higher 16 # accuracy (F1 84% vs 79%) at the cost of size/speed. 17 MODEL = os.environ.get("WC_SAGE_MODEL", "ai-forever/sage-fredt5-distilled-95m") 18 # GPU if a CUDA torch build + a GPU are present; otherwise fall back to CPU automatically. 19 try: 20 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 21 except Exception: 22 DEVICE = "cpu" 23 tok = AutoTokenizer.from_pretrained(MODEL) 24 model = AutoModelForSeq2SeqLM.from_pretrained(MODEL) 25 try: 26 model = model.to(DEVICE) 27 except Exception: 28 DEVICE = "cpu"; model = model.to("cpu") 29 model.eval() 30 sys.stderr.write("wc_spell_server device=%s\n" % DEVICE) 31 32 def correct_text(data): 33 res = [] 34 for line in data.split("\n"): 35 if line.strip(): 36 ids = tok(line, return_tensors="pt", max_length=256, truncation=True) 37 ids = {k: v.to(DEVICE) for k, v in ids.items()} 38 with torch.no_grad(): 39 out = model.generate(**ids, max_length=256, num_beams=5, early_stopping=True) 40 res.append(tok.batch_decode(out, skip_special_tokens=True)[0]) 41 else: 42 res.append(line) 43 return "\n".join(res) 44 45 # --- Stage-2 CONTEXT rescorer: a tiny masked-LM (default cointegrated/rubert-tiny2, ~29M) --- 46 # Loaded LAZILY on the first /rescore call (so plain SAGE users never pay for it). Given the 47 # left/right context and a small set of DICTIONARY candidates, it returns each candidate's 48 # length-normalized pseudo-log-likelihood P(word|context). It only SCORES candidates supplied 49 # by the caller (the noisy-channel SpellScore) -> it can re-rank but never hallucinate a word. 50 _mlm = {} 51 def _ensure_mlm(): 52 if "model" in _mlm: 53 return _mlm["model"] is not None 54 try: 55 from transformers import AutoModelForMaskedLM, AutoTokenizer as _AT 56 name = os.environ.get("WC_RESCORE_MODEL", "cointegrated/rubert-tiny2") 57 t = _AT.from_pretrained(name) 58 m = AutoModelForMaskedLM.from_pretrained(name) 59 try: m = m.to(DEVICE) 60 except Exception: m = m.to("cpu") 61 m.eval() 62 _mlm["tok"] = t; _mlm["model"] = m 63 sys.stderr.write("wc_spell_server mlm=%s loaded on %s\n" % (name, DEVICE)) 64 return True 65 except Exception as e: 66 sys.stderr.write("wc_spell_server mlm load failed: %r\n" % e) 67 _mlm["model"] = None 68 return False 69 70 def rescore(left, right, cands): 71 if not cands: 72 return [] 73 if not _ensure_mlm(): 74 return [0.0] * len(cands) 75 tok = _mlm["tok"]; m = _mlm["model"] 76 mask_id = tok.mask_token_id 77 cls = tok.cls_token_id; sep = tok.sep_token_id 78 pad = tok.pad_token_id if tok.pad_token_id is not None else 0 79 left_ids = tok(left, add_special_tokens=False)["input_ids"][-24:] if left else [] 80 right_ids = tok(right, add_special_tokens=False)["input_ids"][:12] if right else [] 81 rows = [] # (cand_index, full_row_ids, masked_position, true_token_id) 82 for ci, w in enumerate(cands): 83 wt = tok(w, add_special_tokens=False)["input_ids"] 84 if not wt: 85 continue 86 base = ([cls] if cls is not None else []) + left_ids + wt + right_ids + ([sep] if sep is not None else []) 87 off = (1 if cls is not None else 0) + len(left_ids) 88 for j, tid in enumerate(wt): 89 row = list(base); row[off + j] = mask_id 90 rows.append((ci, row, off + j, tid)) 91 if not rows: 92 return [0.0] * len(cands) 93 maxlen = max(len(x[1]) for x in rows) 94 ids = [x[1] + [pad] * (maxlen - len(x[1])) for x in rows] 95 att = [[1] * len(x[1]) + [0] * (maxlen - len(x[1])) for x in rows] 96 ii = torch.tensor(ids, device=DEVICE) 97 am = torch.tensor(att, device=DEVICE) 98 with torch.no_grad(): 99 logits = m(input_ids=ii, attention_mask=am).logits 100 logp = torch.log_softmax(logits, dim=-1) 101 scores = [0.0] * len(cands); counts = [0] * len(cands) 102 for k, (ci, row, pos, tid) in enumerate(rows): 103 scores[ci] += float(logp[k, pos, tid]); counts[ci] += 1 104 for ci in range(len(cands)): 105 scores[ci] = scores[ci] / counts[ci] if counts[ci] > 0 else -99.0 106 return scores 107 108 class Handler(BaseHTTPRequestHandler): 109 def log_message(self, *a): 110 pass # silence 111 112 def do_GET(self): 113 if self.path == "/ping": 114 self.send_response(200) 115 self.send_header("Content-Type", "text/plain") 116 self.end_headers() 117 self.wfile.write(("ok " + DEVICE).encode("utf-8")) 118 else: 119 self.send_response(404) 120 self.end_headers() 121 122 def do_POST(self): 123 try: 124 n = int(self.headers.get("Content-Length", "0")) 125 data = self.rfile.read(n).decode("utf-8") if n > 0 else "" 126 if self.path == "/rescore": 127 # body: line0=left ctx, line1=right ctx, line2+=candidate words. 128 # reply: CSV of length-normalized log P(word|context), highest = best fit. 129 try: 130 lines = data.split("\n") 131 left = lines[0] if len(lines) > 0 else "" 132 right = lines[1] if len(lines) > 1 else "" 133 cands = [l for l in lines[2:] if l] 134 out = ",".join("%.5f" % s for s in rescore(left, right, cands)) 135 except Exception: 136 out = "" # caller falls back to the noisy-channel ranking 137 body = out.encode("utf-8") 138 else: 139 try: 140 corrected = correct_text(data) 141 except Exception: 142 corrected = data # never fail destructively 143 body = corrected.encode("utf-8") 144 self.send_response(200) 145 self.send_header("Content-Type", "text/plain; charset=utf-8") 146 self.send_header("Content-Length", str(len(body))) 147 self.end_headers() 148 self.wfile.write(body) 149 except Exception: 150 try: 151 self.send_response(500); self.end_headers() 152 except Exception: 153 pass 154 155 # Bind only AFTER the model is loaded, so /ping succeeds == "ready". 156 HTTPServer(("127.0.0.1", port), Handler).serve_forever() 157 158if __name__ == "__main__": 159 main()