1using System; 2using System.IO; 3using System.Collections.Generic; 4 5namespace WindowCapture.Helpers 6{ 7 /// <summary> 8 /// Seq2Spell: Sequence-to-Sequence RNN for spelling correction. 9 /// 10 /// Innovation: Instead of just scoring candidates, this model GENERATES 11 /// the corrected word character-by-character. 12 /// 13 /// Architecture (Elman RNN): 14 /// Encoder: reads input char-by-char → builds hidden state 15 /// Decoder: generates output char-by-char from hidden state 16 /// 17 /// h_t = tanh(W_ih * x_t + W_hh * h_{t-1} + b_h) 18 /// y_t = softmax(W_ho * h_t + b_o) 19 /// 20 /// Dimensions: 21 /// Input/Output: ALPHA (34 chars + EOS) 22 /// Hidden: HIDDEN_DIM 23 /// Max sequence: MAX_LEN 24 /// 25 /// Training: teacher forcing on (input_corrupted, output_correct) pairs 26 /// File: ~200KB weights 27 /// </summary> 28 public class Seq2Spell 29 { 30 const int ALPHA = 35; // а-я(32) + ё + space + EOS 31 const int HIDDEN = 64; 32 const int MAX_LEN = 25; 33 34 // Encoder weights 35 float[,] Wih; // [ALPHA, HIDDEN] — input to hidden 36 float[,] Whh; // [HIDDEN, HIDDEN] — hidden to hidden 37 float[] bh; // [HIDDEN] 38 39 // Decoder weights 40 float[,] Wdh; // [ALPHA, HIDDEN] — decoder input to hidden 41 float[,] Wdhh; // [HIDDEN, HIDDEN] — decoder hidden to hidden 42 float[] bdh; // [HIDDEN] 43 float[,] Who; // [HIDDEN, ALPHA] — hidden to output 44 float[] bo; // [ALPHA] 45 46 volatile bool ready; 47 public bool IsReady { get { return ready; } } 48 49 /// <summary>Correct a word using the RNN.</summary> 50 public string Correct(string input) 51 { 52 if (!ready || input.Length < 2) return input; 53 string lower = input.ToLower(); 54 55 // Encode 56 float[] h = new float[HIDDEN]; 57 for (int t = 0; t < Math.Min(lower.Length, MAX_LEN); t++) 58 { 59 int ci = CharIdx(lower[t]); 60 float[] newH = new float[HIDDEN]; 61 for (int j = 0; j < HIDDEN; j++) 62 { 63 float sum = bh[j]; 64 if (ci >= 0 && ci < ALPHA) sum += Wih[ci, j]; 65 for (int k = 0; k < HIDDEN; k++) sum += Whh[k, j] * h[k]; 66 newH[j] = Tanh(sum); 67 } 68 h = newH; 69 } 70 71 // Decode 72 var result = new System.Text.StringBuilder(); 73 int prevChar = ALPHA - 1; // start token = EOS 74 float[] dh = (float[])h.Clone(); // init decoder hidden from encoder 75 76 for (int t = 0; t < MAX_LEN; t++) 77 { 78 // Decoder step 79 float[] newDh = new float[HIDDEN]; 80 for (int j = 0; j < HIDDEN; j++) 81 { 82 float sum = bdh[j]; 83 if (prevChar >= 0 && prevChar < ALPHA) sum += Wdh[prevChar, j]; 84 for (int k = 0; k < HIDDEN; k++) sum += Wdhh[k, j] * dh[k]; 85 newDh[j] = Tanh(sum); 86 } 87 dh = newDh; 88 89 // Output: softmax over ALPHA chars 90 float[] logits = new float[ALPHA]; 91 for (int c = 0; c < ALPHA; c++) 92 { 93 logits[c] = bo[c]; 94 for (int j = 0; j < HIDDEN; j++) logits[c] += Who[j, c] * dh[j]; 95 } 96 97 // Argmax 98 int bestChar = 0; 99 float bestVal = logits[0]; 100 for (int c = 1; c < ALPHA; c++) 101 if (logits[c] > bestVal) { bestVal = logits[c]; bestChar = c; } 102 103 if (bestChar == ALPHA - 1) break; // EOS 104 char ch = IdxToChar(bestChar); 105 if (ch != '\0') result.Append(ch); 106 prevChar = bestChar; 107 } 108 109 string output = result.ToString(); 110 return output.Length >= 2 ? output : input; // fallback if output too short 111 } 112 113 public bool Load(string path) 114 { 115 try 116 { 117 using (var br = new BinaryReader(new FileStream(path, FileMode.Open))) 118 { 119 Wih = ReadMatrix(br, ALPHA, HIDDEN); 120 Whh = ReadMatrix(br, HIDDEN, HIDDEN); 121 bh = ReadVector(br, HIDDEN); 122 Wdh = ReadMatrix(br, ALPHA, HIDDEN); 123 Wdhh = ReadMatrix(br, HIDDEN, HIDDEN); 124 bdh = ReadVector(br, HIDDEN); 125 Who = ReadMatrix(br, HIDDEN, ALPHA); 126 bo = ReadVector(br, ALPHA); 127 } 128 ready = true; 129 Logger.Log("textproc", "Seq2Spell loaded: HIDDEN=" + HIDDEN + " ALPHA=" + ALPHA); 130 return true; 131 } 132 catch (Exception ex) { Logger.Log("textproc", "Seq2Spell load err: " + ex.Message); return false; } 133 } 134 135 static float[,] ReadMatrix(BinaryReader br, int rows, int cols) 136 { 137 var m = new float[rows, cols]; 138 for (int i = 0; i < rows; i++) for (int j = 0; j < cols; j++) m[i, j] = br.ReadSingle(); 139 return m; 140 } 141 static float[] ReadVector(BinaryReader br, int n) 142 { 143 var v = new float[n]; for (int i = 0; i < n; i++) v[i] = br.ReadSingle(); return v; 144 } 145 146 static float Tanh(float x) 147 { 148 if (x > 10) return 1f; if (x < -10) return -1f; 149 float e2x = (float)Math.Exp(2 * x); 150 return (e2x - 1) / (e2x + 1); 151 } 152 153 static int CharIdx(char c) 154 { 155 if (c >= 'а' && c <= 'я') return c - 'а'; 156 if (c == 'ё') return 32; 157 if (c == ' ') return 33; 158 return -1; // unknown 159 } 160 static char IdxToChar(int idx) 161 { 162 if (idx >= 0 && idx <= 31) return (char)('а' + idx); 163 if (idx == 32) return 'ё'; 164 if (idx == 33) return ' '; 165 return '\0'; 166 } 167 168 // ==================================================================== 169 // TRAINING 170 // ==================================================================== 171 public static void Train(string[] dictionary, string savePath, int epochs) 172 { 173 var rng = new Random(42); 174 int dictLen = Math.Min(dictionary.Length, 80000); 175 176 // Init weights (Xavier) 177 float s1 = (float)Math.Sqrt(2.0 / ALPHA); 178 float s2 = (float)Math.Sqrt(2.0 / HIDDEN); 179 float[,] wih = RandMatrix(ALPHA, HIDDEN, s1, rng); 180 float[,] whh = RandMatrix(HIDDEN, HIDDEN, s2, rng); 181 float[] bh_t = new float[HIDDEN]; 182 float[,] wdh = RandMatrix(ALPHA, HIDDEN, s1, rng); 183 float[,] wdhh = RandMatrix(HIDDEN, HIDDEN, s2, rng); 184 float[] bdh_t = new float[HIDDEN]; 185 float[,] who = RandMatrix(HIDDEN, ALPHA, s2, rng); 186 float[] bo_t = new float[ALPHA]; 187 188 float lr = 0.005f; 189 int samplesPerEpoch = 50000; 190 191 Logger.Log("textproc", "Seq2Spell training: " + dictLen + " words, " + epochs + " epochs"); 192 193 for (int epoch = 0; epoch < epochs; epoch++) 194 { 195 float totalLoss = 0; 196 int totalChars = 0, correctChars = 0; 197 198 for (int s = 0; s < samplesPerEpoch; s++) 199 { 200 string target = dictionary[rng.Next(dictLen)].ToLower(); 201 if (target.Length < 2 || target.Length > MAX_LEN - 1) continue; 202 string input = (s % 2 == 0) ? CorruptWord(target, rng) : target; 203 204 // Encode 205 float[][] encH = new float[input.Length + 1][]; 206 encH[0] = new float[HIDDEN]; 207 for (int t = 0; t < input.Length; t++) 208 { 209 int ci = CharIdx(input[t]); 210 encH[t + 1] = new float[HIDDEN]; 211 for (int j = 0; j < HIDDEN; j++) 212 { 213 float sum = bh_t[j]; 214 if (ci >= 0 && ci < ALPHA) sum += wih[ci, j]; 215 for (int k = 0; k < HIDDEN; k++) sum += whh[k, j] * encH[t][k]; 216 encH[t + 1][j] = Tanh(sum); 217 } 218 } 219 220 // Decode with teacher forcing 221 string targetWithEos = target + ((char)(ALPHA - 1 + 'а')); // append EOS marker 222 float[][] decH = new float[target.Length + 2][]; 223 decH[0] = (float[])encH[input.Length].Clone(); 224 225 for (int t = 0; t < target.Length + 1; t++) 226 { 227 int prevC = (t == 0) ? ALPHA - 1 : CharIdx(target[t - 1]); 228 if (prevC < 0) prevC = 0; 229 230 decH[t + 1] = new float[HIDDEN]; 231 for (int j = 0; j < HIDDEN; j++) 232 { 233 float sum = bdh_t[j]; 234 if (prevC < ALPHA) sum += wdh[prevC, j]; 235 for (int k = 0; k < HIDDEN; k++) sum += wdhh[k, j] * decH[t][k]; 236 decH[t + 1][j] = Tanh(sum); 237 } 238 239 // Output logits + softmax 240 float[] logits = new float[ALPHA]; 241 float maxL = float.MinValue; 242 for (int c = 0; c < ALPHA; c++) 243 { 244 logits[c] = bo_t[c]; 245 for (int j = 0; j < HIDDEN; j++) logits[c] += who[j, c] * decH[t + 1][j]; 246 if (logits[c] > maxL) maxL = logits[c]; 247 } 248 float sumExp = 0; 249 float[] probs = new float[ALPHA]; 250 for (int c = 0; c < ALPHA; c++) { probs[c] = (float)Math.Exp(logits[c] - maxL); sumExp += probs[c]; } 251 for (int c = 0; c < ALPHA; c++) probs[c] /= sumExp; 252 253 int targetC = (t < target.Length) ? CharIdx(target[t]) : ALPHA - 1; 254 if (targetC < 0) targetC = 0; 255 256 // Cross-entropy loss 257 totalLoss -= (float)Math.Log(probs[targetC] + 1e-8f); 258 totalChars++; 259 260 // Argmax accuracy 261 int bestC = 0; 262 for (int c = 1; c < ALPHA; c++) if (probs[c] > probs[bestC]) bestC = c; 263 if (bestC == targetC) correctChars++; 264 265 // Gradient: dL/dlogits = probs - one_hot(target) 266 float[] dLogits = new float[ALPHA]; 267 for (int c = 0; c < ALPHA; c++) dLogits[c] = probs[c]; 268 dLogits[targetC] -= 1f; 269 270 // Backprop through output layer (simplified: only update, no full BPTT) 271 float gradClip = 0.5f; 272 for (int j = 0; j < HIDDEN; j++) 273 { 274 for (int c = 0; c < ALPHA; c++) 275 { 276 float g = dLogits[c] * decH[t + 1][j] * lr; 277 g = Math.Max(-gradClip, Math.Min(gradClip, g)); 278 who[j, c] -= g; 279 } 280 } 281 for (int c = 0; c < ALPHA; c++) 282 { 283 float g = dLogits[c] * lr; 284 g = Math.Max(-gradClip, Math.Min(gradClip, g)); 285 bo_t[c] -= g; 286 } 287 } 288 } 289 290 if (epoch % 5 == 0 || epoch == epochs - 1) 291 Logger.Log("textproc", " Epoch " + epoch + ": loss=" + (totalLoss / totalChars).ToString("F3") + 292 " charAcc=" + ((float)correctChars / totalChars * 100).ToString("F1") + "%"); 293 294 if (epoch > 0 && epoch % 10 == 0) lr *= 0.7f; 295 } 296 297 // Save 298 using (var bw = new BinaryWriter(new FileStream(savePath, FileMode.Create))) 299 { 300 WriteMatrix(bw, wih); WriteMatrix(bw, whh); WriteVector(bw, bh_t); 301 WriteMatrix(bw, wdh); WriteMatrix(bw, wdhh); WriteVector(bw, bdh_t); 302 WriteMatrix(bw, who); WriteVector(bw, bo_t); 303 } 304 Logger.Log("textproc", "Seq2Spell saved: " + new FileInfo(savePath).Length / 1024 + "KB"); 305 } 306 307 static float[,] RandMatrix(int r, int c, float scale, Random rng) 308 { 309 var m = new float[r, c]; 310 for (int i = 0; i < r; i++) for (int j = 0; j < c; j++) m[i, j] = (float)(rng.NextDouble() * 2 - 1) * scale; 311 return m; 312 } 313 static void WriteMatrix(BinaryWriter bw, float[,] m) 314 { for (int i = 0; i < m.GetLength(0); i++) for (int j = 0; j < m.GetLength(1); j++) bw.Write(m[i, j]); } 315 static void WriteVector(BinaryWriter bw, float[] v) 316 { for (int i = 0; i < v.Length; i++) bw.Write(v[i]); } 317 318 static string CorruptWord(string w, Random rng) 319 { 320 if (w.Length < 3) return w; 321 switch (rng.Next(5)) 322 { 323 case 0: int p = rng.Next(w.Length - 1); char[] a = w.ToCharArray(); char t = a[p]; a[p] = a[p + 1]; a[p + 1] = t; return new string(a); 324 case 1: a = w.ToCharArray(); a[rng.Next(a.Length)] = (char)('а' + rng.Next(32)); return new string(a); 325 case 2: return w.Remove(rng.Next(w.Length), 1); 326 case 3: return w.Insert(rng.Next(w.Length + 1), ((char)('а' + rng.Next(32))).ToString()); 327 case 4: // double corruption 328 string w2 = CorruptWord(w, rng); 329 return CorruptWord(w2, rng); 330 default: return w; 331 } 332 } 333 } 334}