1using System; 2using System.IO; 3using System.Collections.Generic; 4 5namespace WindowCapture.Helpers 6{ 7 /// <summary> 8 /// GruSpellNet: GRU-based Seq2Seq spell corrector. 9 /// Trained with PyTorch, inference in pure C#. 10 /// 11 /// Architecture: Encoder GRU → Decoder GRU → Linear output 12 /// 90.8% character accuracy on Russian words. 13 /// </summary> 14 public class GruSpellNet 15 { 16 int ALPHA, EMBED_DIM, HIDDEN, MAX_LEN; 17 18 // Embedding 19 float[,] embed; // [ALPHA, EMBED_DIM] 20 21 // Encoder GRU weights (PyTorch GRU format: weight_ih, weight_hh, bias_ih, bias_hh) 22 // Each has 3*HIDDEN rows (reset, update, new gates) 23 float[,] enc_wih, enc_whh; // [3*HIDDEN, input/hidden] 24 float[] enc_bih, enc_bhh; // [3*HIDDEN] 25 26 // Decoder GRU 27 float[,] dec_wih, dec_whh; 28 float[] dec_bih, dec_bhh; 29 30 // Output linear 31 float[,] outW; // [ALPHA, HIDDEN] 32 float[] outB; // [ALPHA] 33 34 volatile bool ready; 35 public bool IsReady { get { return ready; } } 36 37 public bool Load(string path) 38 { 39 try 40 { 41 using (var br = new BinaryReader(new FileStream(path, FileMode.Open))) 42 { 43 ALPHA = br.ReadInt32(); 44 EMBED_DIM = br.ReadInt32(); 45 HIDDEN = br.ReadInt32(); 46 MAX_LEN = br.ReadInt32(); 47 48 embed = ReadMat(br, ALPHA, EMBED_DIM); 49 50 // Encoder GRU: 4 params (weight_ih_l0, weight_hh_l0, bias_ih_l0, bias_hh_l0) 51 enc_wih = ReadParamMat(br, 3 * HIDDEN, EMBED_DIM); 52 enc_whh = ReadParamMat(br, 3 * HIDDEN, HIDDEN); 53 enc_bih = ReadParamVec(br, 3 * HIDDEN); 54 enc_bhh = ReadParamVec(br, 3 * HIDDEN); 55 56 // Decoder GRU 57 dec_wih = ReadParamMat(br, 3 * HIDDEN, EMBED_DIM); 58 dec_whh = ReadParamMat(br, 3 * HIDDEN, HIDDEN); 59 dec_bih = ReadParamVec(br, 3 * HIDDEN); 60 dec_bhh = ReadParamVec(br, 3 * HIDDEN); 61 62 // Output 63 outW = ReadMat(br, ALPHA, HIDDEN); 64 outB = new float[ALPHA]; 65 for (int i = 0; i < ALPHA; i++) outB[i] = br.ReadSingle(); 66 } 67 68 ready = true; 69 Logger.Log("textproc", "GruSpellNet loaded: ALPHA=" + ALPHA + " EMBED=" + EMBED_DIM + " HIDDEN=" + HIDDEN); 70 return true; 71 } 72 catch (Exception ex) 73 { 74 Logger.Log("textproc", "GruSpellNet load err: " + ex.Message); 75 return false; 76 } 77 } 78 79 float[,] ReadMat(BinaryReader br, int r, int c) 80 { 81 var m = new float[r, c]; 82 for (int i = 0; i < r; i++) for (int j = 0; j < c; j++) m[i, j] = br.ReadSingle(); 83 return m; 84 } 85 86 float[,] ReadParamMat(BinaryReader br, int expectedR, int expectedC) 87 { 88 int n = br.ReadInt32(); // total elements 89 int r = expectedR, c = expectedC; 90 var m = new float[r, c]; 91 for (int i = 0; i < r; i++) for (int j = 0; j < c; j++) m[i, j] = br.ReadSingle(); 92 return m; 93 } 94 95 float[] ReadParamVec(BinaryReader br, int expectedN) 96 { 97 int n = br.ReadInt32(); 98 var v = new float[expectedN]; 99 for (int i = 0; i < expectedN; i++) v[i] = br.ReadSingle(); 100 return v; 101 } 102 103 /// <summary>Correct a word using the trained GRU model.</summary> 104 public string Correct(string input) 105 { 106 if (!ready || input.Length < 2) return input; 107 string lower = input.ToLower(); 108 109 // Encode 110 float[] h = new float[HIDDEN]; 111 for (int t = 0; t < Math.Min(lower.Length, MAX_LEN); t++) 112 { 113 int ci = CharIdx(lower[t]); 114 if (ci < 0 || ci >= ALPHA) ci = ALPHA - 1; 115 float[] x = new float[EMBED_DIM]; 116 for (int d = 0; d < EMBED_DIM; d++) x[d] = embed[ci, d]; 117 h = GruStep(x, h, enc_wih, enc_whh, enc_bih, enc_bhh); 118 } 119 // EOS token 120 float[] eosEmb = new float[EMBED_DIM]; 121 for (int d = 0; d < EMBED_DIM; d++) eosEmb[d] = embed[ALPHA - 1, d]; 122 h = GruStep(eosEmb, h, enc_wih, enc_whh, enc_bih, enc_bhh); 123 124 // Decode (greedy) 125 var result = new System.Text.StringBuilder(); 126 int prevChar = ALPHA - 1; // start token 127 float[] dh = (float[])h.Clone(); 128 129 for (int t = 0; t < MAX_LEN; t++) 130 { 131 float[] xd = new float[EMBED_DIM]; 132 for (int d = 0; d < EMBED_DIM; d++) xd[d] = embed[prevChar, d]; 133 dh = GruStep(xd, dh, dec_wih, dec_whh, dec_bih, dec_bhh); 134 135 // Linear + argmax 136 int bestC = 0; 137 float bestV = float.MinValue; 138 for (int c = 0; c < ALPHA; c++) 139 { 140 float v = outB[c]; 141 for (int j = 0; j < HIDDEN; j++) v += outW[c, j] * dh[j]; 142 if (v > bestV) { bestV = v; bestC = c; } 143 } 144 145 if (bestC == ALPHA - 1) break; // EOS 146 char ch = IdxToChar(bestC); 147 if (ch != '\0') result.Append(ch); 148 prevChar = bestC; 149 } 150 151 string output = result.ToString(); 152 return output.Length >= 2 ? output : input; 153 } 154 155 // GRU step: PyTorch GRU formula 156 // r = σ(W_ir @ x + b_ir + W_hr @ h + b_hr) 157 // z = σ(W_iz @ x + b_iz + W_hz @ h + b_hz) 158 // n = tanh(W_in @ x + b_in + r * (W_hn @ h + b_hn)) 159 // h' = (1 - z) * n + z * h 160 float[] GruStep(float[] x, float[] h, float[,] wih, float[,] whh, float[] bih, float[] bhh) 161 { 162 int H = HIDDEN; 163 int inDim = x.Length; 164 165 // Compute gates 166 float[] r = new float[H], z = new float[H], n = new float[H]; 167 float[] newH = new float[H]; 168 169 for (int j = 0; j < H; j++) 170 { 171 // Reset gate 172 float rVal = bih[j] + bhh[j]; 173 for (int k = 0; k < inDim; k++) rVal += wih[j, k] * x[k]; 174 for (int k = 0; k < H; k++) rVal += whh[j, k] * h[k]; 175 r[j] = Sigmoid(rVal); 176 177 // Update gate 178 float zVal = bih[H + j] + bhh[H + j]; 179 for (int k = 0; k < inDim; k++) zVal += wih[H + j, k] * x[k]; 180 for (int k = 0; k < H; k++) zVal += whh[H + j, k] * h[k]; 181 z[j] = Sigmoid(zVal); 182 183 // New gate 184 float nVal = bih[2 * H + j]; 185 for (int k = 0; k < inDim; k++) nVal += wih[2 * H + j, k] * x[k]; 186 float hn = bhh[2 * H + j]; 187 for (int k = 0; k < H; k++) hn += whh[2 * H + j, k] * h[k]; 188 nVal += r[j] * hn; 189 n[j] = Tanh(nVal); 190 191 // Output 192 newH[j] = (1f - z[j]) * n[j] + z[j] * h[j]; 193 } 194 return newH; 195 } 196 197 static float Sigmoid(float x) { if (x > 10) return 1f; if (x < -10) return 0f; return 1f / (1f + (float)Math.Exp(-x)); } 198 static float Tanh(float x) { if (x > 10) return 1f; if (x < -10) return -1f; float e = (float)Math.Exp(2 * x); return (e - 1) / (e + 1); } 199 200 static int CharIdx(char c) 201 { 202 if (c >= 'а' && c <= 'я') return c - 'а'; 203 if (c == 'ё') return 32; 204 if (c == ' ') return 33; 205 return ALPHA_STATIC - 1; 206 } 207 static char IdxToChar(int i) 208 { 209 if (i >= 0 && i <= 31) return (char)('а' + i); 210 if (i == 32) return 'ё'; 211 if (i == 33) return ' '; 212 return '\0'; 213 } 214 const int ALPHA_STATIC = 35; 215 } 216}