windowcapture
исходный код / Spell/wc_spell_server.py

wc_spell_server.py

159 строк · 7,402 байт · модуль Spell
  1# wc_spell_server.py <port>
  2# Warm SAGE corrector server: loads ai-forever/sage-fredt5-distilled-95m ONCE, then serves
  3# corrections over localhost HTTP so the model isn't reloaded on every request (cold call ~8s,
  4# warm calls ~sub-second). The .NET app (Helpers/SageClient.cs) starts it lazily, polls /ping
  5# until ready, then POSTs UTF-8 text and reads back the corrected UTF-8 text.
  6import sys, os
  7
  8def main():
  9    port = int(sys.argv[1]) if len(sys.argv) > 1 else 8765
 10
 11    import torch
 12    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 13    from http.server import BaseHTTPRequestHandler, HTTPServer
 14
 15    # Default = distilled 95M (fast). Set WC_SAGE_MODEL=ai-forever/sage-fredt5-large for higher
 16    # accuracy (F1 84% vs 79%) at the cost of size/speed.
 17    MODEL = os.environ.get("WC_SAGE_MODEL", "ai-forever/sage-fredt5-distilled-95m")
 18    # GPU if a CUDA torch build + a GPU are present; otherwise fall back to CPU automatically.
 19    try:
 20        DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 21    except Exception:
 22        DEVICE = "cpu"
 23    tok = AutoTokenizer.from_pretrained(MODEL)
 24    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL)
 25    try:
 26        model = model.to(DEVICE)
 27    except Exception:
 28        DEVICE = "cpu"; model = model.to("cpu")
 29    model.eval()
 30    sys.stderr.write("wc_spell_server device=%s\n" % DEVICE)
 31
 32    def correct_text(data):
 33        res = []
 34        for line in data.split("\n"):
 35            if line.strip():
 36                ids = tok(line, return_tensors="pt", max_length=256, truncation=True)
 37                ids = {k: v.to(DEVICE) for k, v in ids.items()}
 38                with torch.no_grad():
 39                    out = model.generate(**ids, max_length=256, num_beams=5, early_stopping=True)
 40                res.append(tok.batch_decode(out, skip_special_tokens=True)[0])
 41            else:
 42                res.append(line)
 43        return "\n".join(res)
 44
 45    # --- Stage-2 CONTEXT rescorer: a tiny masked-LM (default cointegrated/rubert-tiny2, ~29M) ---
 46    # Loaded LAZILY on the first /rescore call (so plain SAGE users never pay for it). Given the
 47    # left/right context and a small set of DICTIONARY candidates, it returns each candidate's
 48    # length-normalized pseudo-log-likelihood P(word|context). It only SCORES candidates supplied
 49    # by the caller (the noisy-channel SpellScore) -> it can re-rank but never hallucinate a word.
 50    _mlm = {}
 51    def _ensure_mlm():
 52        if "model" in _mlm:
 53            return _mlm["model"] is not None
 54        try:
 55            from transformers import AutoModelForMaskedLM, AutoTokenizer as _AT
 56            name = os.environ.get("WC_RESCORE_MODEL", "cointegrated/rubert-tiny2")
 57            t = _AT.from_pretrained(name)
 58            m = AutoModelForMaskedLM.from_pretrained(name)
 59            try: m = m.to(DEVICE)
 60            except Exception: m = m.to("cpu")
 61            m.eval()
 62            _mlm["tok"] = t; _mlm["model"] = m
 63            sys.stderr.write("wc_spell_server mlm=%s loaded on %s\n" % (name, DEVICE))
 64            return True
 65        except Exception as e:
 66            sys.stderr.write("wc_spell_server mlm load failed: %r\n" % e)
 67            _mlm["model"] = None
 68            return False
 69
 70    def rescore(left, right, cands):
 71        if not cands:
 72            return []
 73        if not _ensure_mlm():
 74            return [0.0] * len(cands)
 75        tok = _mlm["tok"]; m = _mlm["model"]
 76        mask_id = tok.mask_token_id
 77        cls = tok.cls_token_id; sep = tok.sep_token_id
 78        pad = tok.pad_token_id if tok.pad_token_id is not None else 0
 79        left_ids = tok(left, add_special_tokens=False)["input_ids"][-24:] if left else []
 80        right_ids = tok(right, add_special_tokens=False)["input_ids"][:12] if right else []
 81        rows = []  # (cand_index, full_row_ids, masked_position, true_token_id)
 82        for ci, w in enumerate(cands):
 83            wt = tok(w, add_special_tokens=False)["input_ids"]
 84            if not wt:
 85                continue
 86            base = ([cls] if cls is not None else []) + left_ids + wt + right_ids + ([sep] if sep is not None else [])
 87            off = (1 if cls is not None else 0) + len(left_ids)
 88            for j, tid in enumerate(wt):
 89                row = list(base); row[off + j] = mask_id
 90                rows.append((ci, row, off + j, tid))
 91        if not rows:
 92            return [0.0] * len(cands)
 93        maxlen = max(len(x[1]) for x in rows)
 94        ids = [x[1] + [pad] * (maxlen - len(x[1])) for x in rows]
 95        att = [[1] * len(x[1]) + [0] * (maxlen - len(x[1])) for x in rows]
 96        ii = torch.tensor(ids, device=DEVICE)
 97        am = torch.tensor(att, device=DEVICE)
 98        with torch.no_grad():
 99            logits = m(input_ids=ii, attention_mask=am).logits
100            logp = torch.log_softmax(logits, dim=-1)
101        scores = [0.0] * len(cands); counts = [0] * len(cands)
102        for k, (ci, row, pos, tid) in enumerate(rows):
103            scores[ci] += float(logp[k, pos, tid]); counts[ci] += 1
104        for ci in range(len(cands)):
105            scores[ci] = scores[ci] / counts[ci] if counts[ci] > 0 else -99.0
106        return scores
107
108    class Handler(BaseHTTPRequestHandler):
109        def log_message(self, *a):
110            pass  # silence
111
112        def do_GET(self):
113            if self.path == "/ping":
114                self.send_response(200)
115                self.send_header("Content-Type", "text/plain")
116                self.end_headers()
117                self.wfile.write(("ok " + DEVICE).encode("utf-8"))
118            else:
119                self.send_response(404)
120                self.end_headers()
121
122        def do_POST(self):
123            try:
124                n = int(self.headers.get("Content-Length", "0"))
125                data = self.rfile.read(n).decode("utf-8") if n > 0 else ""
126                if self.path == "/rescore":
127                    # body: line0=left ctx, line1=right ctx, line2+=candidate words.
128                    # reply: CSV of length-normalized log P(word|context), highest = best fit.
129                    try:
130                        lines = data.split("\n")
131                        left = lines[0] if len(lines) > 0 else ""
132                        right = lines[1] if len(lines) > 1 else ""
133                        cands = [l for l in lines[2:] if l]
134                        out = ",".join("%.5f" % s for s in rescore(left, right, cands))
135                    except Exception:
136                        out = ""  # caller falls back to the noisy-channel ranking
137                    body = out.encode("utf-8")
138                else:
139                    try:
140                        corrected = correct_text(data)
141                    except Exception:
142                        corrected = data  # never fail destructively
143                    body = corrected.encode("utf-8")
144                self.send_response(200)
145                self.send_header("Content-Type", "text/plain; charset=utf-8")
146                self.send_header("Content-Length", str(len(body)))
147                self.end_headers()
148                self.wfile.write(body)
149            except Exception:
150                try:
151                    self.send_response(500); self.end_headers()
152                except Exception:
153                    pass
154
155    # Bind only AFTER the model is loaded, so /ping succeeds == "ready".
156    HTTPServer(("127.0.0.1", port), Handler).serve_forever()
157
158if __name__ == "__main__":
159    main()