windowcapture
исходный код / Helpers/Seq2Spell.cs

Seq2Spell.cs

334 строк · 14,115 байт · модуль Helpers
  1using System;
  2using System.IO;
  3using System.Collections.Generic;
  4
  5namespace WindowCapture.Helpers
  6{
  7    /// <summary>
  8    /// Seq2Spell: Sequence-to-Sequence RNN for spelling correction.
  9    ///
 10    /// Innovation: Instead of just scoring candidates, this model GENERATES
 11    /// the corrected word character-by-character.
 12    ///
 13    /// Architecture (Elman RNN):
 14    ///   Encoder: reads input char-by-char → builds hidden state
 15    ///   Decoder: generates output char-by-char from hidden state
 16    ///
 17    ///   h_t = tanh(W_ih * x_t + W_hh * h_{t-1} + b_h)
 18    ///   y_t = softmax(W_ho * h_t + b_o)
 19    ///
 20    /// Dimensions:
 21    ///   Input/Output: ALPHA (34 chars + EOS)
 22    ///   Hidden: HIDDEN_DIM
 23    ///   Max sequence: MAX_LEN
 24    ///
 25    /// Training: teacher forcing on (input_corrupted, output_correct) pairs
 26    /// File: ~200KB weights
 27    /// </summary>
 28    public class Seq2Spell
 29    {
 30        const int ALPHA = 35;      // а-я(32) + ё + space + EOS
 31        const int HIDDEN = 64;
 32        const int MAX_LEN = 25;
 33
 34        // Encoder weights
 35        float[,] Wih; // [ALPHA, HIDDEN] — input to hidden
 36        float[,] Whh; // [HIDDEN, HIDDEN] — hidden to hidden
 37        float[] bh;   // [HIDDEN]
 38
 39        // Decoder weights
 40        float[,] Wdh; // [ALPHA, HIDDEN] — decoder input to hidden
 41        float[,] Wdhh; // [HIDDEN, HIDDEN] — decoder hidden to hidden
 42        float[] bdh;   // [HIDDEN]
 43        float[,] Who;  // [HIDDEN, ALPHA] — hidden to output
 44        float[] bo;    // [ALPHA]
 45
 46        volatile bool ready;
 47        public bool IsReady { get { return ready; } }
 48
 49        /// <summary>Correct a word using the RNN.</summary>
 50        public string Correct(string input)
 51        {
 52            if (!ready || input.Length < 2) return input;
 53            string lower = input.ToLower();
 54
 55            // Encode
 56            float[] h = new float[HIDDEN];
 57            for (int t = 0; t < Math.Min(lower.Length, MAX_LEN); t++)
 58            {
 59                int ci = CharIdx(lower[t]);
 60                float[] newH = new float[HIDDEN];
 61                for (int j = 0; j < HIDDEN; j++)
 62                {
 63                    float sum = bh[j];
 64                    if (ci >= 0 && ci < ALPHA) sum += Wih[ci, j];
 65                    for (int k = 0; k < HIDDEN; k++) sum += Whh[k, j] * h[k];
 66                    newH[j] = Tanh(sum);
 67                }
 68                h = newH;
 69            }
 70
 71            // Decode
 72            var result = new System.Text.StringBuilder();
 73            int prevChar = ALPHA - 1; // start token = EOS
 74            float[] dh = (float[])h.Clone(); // init decoder hidden from encoder
 75
 76            for (int t = 0; t < MAX_LEN; t++)
 77            {
 78                // Decoder step
 79                float[] newDh = new float[HIDDEN];
 80                for (int j = 0; j < HIDDEN; j++)
 81                {
 82                    float sum = bdh[j];
 83                    if (prevChar >= 0 && prevChar < ALPHA) sum += Wdh[prevChar, j];
 84                    for (int k = 0; k < HIDDEN; k++) sum += Wdhh[k, j] * dh[k];
 85                    newDh[j] = Tanh(sum);
 86                }
 87                dh = newDh;
 88
 89                // Output: softmax over ALPHA chars
 90                float[] logits = new float[ALPHA];
 91                for (int c = 0; c < ALPHA; c++)
 92                {
 93                    logits[c] = bo[c];
 94                    for (int j = 0; j < HIDDEN; j++) logits[c] += Who[j, c] * dh[j];
 95                }
 96
 97                // Argmax
 98                int bestChar = 0;
 99                float bestVal = logits[0];
100                for (int c = 1; c < ALPHA; c++)
101                    if (logits[c] > bestVal) { bestVal = logits[c]; bestChar = c; }
102
103                if (bestChar == ALPHA - 1) break; // EOS
104                char ch = IdxToChar(bestChar);
105                if (ch != '\0') result.Append(ch);
106                prevChar = bestChar;
107            }
108
109            string output = result.ToString();
110            return output.Length >= 2 ? output : input; // fallback if output too short
111        }
112
113        public bool Load(string path)
114        {
115            try
116            {
117                using (var br = new BinaryReader(new FileStream(path, FileMode.Open)))
118                {
119                    Wih = ReadMatrix(br, ALPHA, HIDDEN);
120                    Whh = ReadMatrix(br, HIDDEN, HIDDEN);
121                    bh = ReadVector(br, HIDDEN);
122                    Wdh = ReadMatrix(br, ALPHA, HIDDEN);
123                    Wdhh = ReadMatrix(br, HIDDEN, HIDDEN);
124                    bdh = ReadVector(br, HIDDEN);
125                    Who = ReadMatrix(br, HIDDEN, ALPHA);
126                    bo = ReadVector(br, ALPHA);
127                }
128                ready = true;
129                Logger.Log("textproc", "Seq2Spell loaded: HIDDEN=" + HIDDEN + " ALPHA=" + ALPHA);
130                return true;
131            }
132            catch (Exception ex) { Logger.Log("textproc", "Seq2Spell load err: " + ex.Message); return false; }
133        }
134
135        static float[,] ReadMatrix(BinaryReader br, int rows, int cols)
136        {
137            var m = new float[rows, cols];
138            for (int i = 0; i < rows; i++) for (int j = 0; j < cols; j++) m[i, j] = br.ReadSingle();
139            return m;
140        }
141        static float[] ReadVector(BinaryReader br, int n)
142        {
143            var v = new float[n]; for (int i = 0; i < n; i++) v[i] = br.ReadSingle(); return v;
144        }
145
146        static float Tanh(float x)
147        {
148            if (x > 10) return 1f; if (x < -10) return -1f;
149            float e2x = (float)Math.Exp(2 * x);
150            return (e2x - 1) / (e2x + 1);
151        }
152
153        static int CharIdx(char c)
154        {
155            if (c >= 'а' && c <= 'я') return c - 'а';
156            if (c == 'ё') return 32;
157            if (c == ' ') return 33;
158            return -1; // unknown
159        }
160        static char IdxToChar(int idx)
161        {
162            if (idx >= 0 && idx <= 31) return (char)('а' + idx);
163            if (idx == 32) return 'ё';
164            if (idx == 33) return ' ';
165            return '\0';
166        }
167
168        // ====================================================================
169        // TRAINING
170        // ====================================================================
171        public static void Train(string[] dictionary, string savePath, int epochs)
172        {
173            var rng = new Random(42);
174            int dictLen = Math.Min(dictionary.Length, 80000);
175
176            // Init weights (Xavier)
177            float s1 = (float)Math.Sqrt(2.0 / ALPHA);
178            float s2 = (float)Math.Sqrt(2.0 / HIDDEN);
179            float[,] wih = RandMatrix(ALPHA, HIDDEN, s1, rng);
180            float[,] whh = RandMatrix(HIDDEN, HIDDEN, s2, rng);
181            float[] bh_t = new float[HIDDEN];
182            float[,] wdh = RandMatrix(ALPHA, HIDDEN, s1, rng);
183            float[,] wdhh = RandMatrix(HIDDEN, HIDDEN, s2, rng);
184            float[] bdh_t = new float[HIDDEN];
185            float[,] who = RandMatrix(HIDDEN, ALPHA, s2, rng);
186            float[] bo_t = new float[ALPHA];
187
188            float lr = 0.005f;
189            int samplesPerEpoch = 50000;
190
191            Logger.Log("textproc", "Seq2Spell training: " + dictLen + " words, " + epochs + " epochs");
192
193            for (int epoch = 0; epoch < epochs; epoch++)
194            {
195                float totalLoss = 0;
196                int totalChars = 0, correctChars = 0;
197
198                for (int s = 0; s < samplesPerEpoch; s++)
199                {
200                    string target = dictionary[rng.Next(dictLen)].ToLower();
201                    if (target.Length < 2 || target.Length > MAX_LEN - 1) continue;
202                    string input = (s % 2 == 0) ? CorruptWord(target, rng) : target;
203
204                    // Encode
205                    float[][] encH = new float[input.Length + 1][];
206                    encH[0] = new float[HIDDEN];
207                    for (int t = 0; t < input.Length; t++)
208                    {
209                        int ci = CharIdx(input[t]);
210                        encH[t + 1] = new float[HIDDEN];
211                        for (int j = 0; j < HIDDEN; j++)
212                        {
213                            float sum = bh_t[j];
214                            if (ci >= 0 && ci < ALPHA) sum += wih[ci, j];
215                            for (int k = 0; k < HIDDEN; k++) sum += whh[k, j] * encH[t][k];
216                            encH[t + 1][j] = Tanh(sum);
217                        }
218                    }
219
220                    // Decode with teacher forcing
221                    string targetWithEos = target + ((char)(ALPHA - 1 + 'а')); // append EOS marker
222                    float[][] decH = new float[target.Length + 2][];
223                    decH[0] = (float[])encH[input.Length].Clone();
224
225                    for (int t = 0; t < target.Length + 1; t++)
226                    {
227                        int prevC = (t == 0) ? ALPHA - 1 : CharIdx(target[t - 1]);
228                        if (prevC < 0) prevC = 0;
229
230                        decH[t + 1] = new float[HIDDEN];
231                        for (int j = 0; j < HIDDEN; j++)
232                        {
233                            float sum = bdh_t[j];
234                            if (prevC < ALPHA) sum += wdh[prevC, j];
235                            for (int k = 0; k < HIDDEN; k++) sum += wdhh[k, j] * decH[t][k];
236                            decH[t + 1][j] = Tanh(sum);
237                        }
238
239                        // Output logits + softmax
240                        float[] logits = new float[ALPHA];
241                        float maxL = float.MinValue;
242                        for (int c = 0; c < ALPHA; c++)
243                        {
244                            logits[c] = bo_t[c];
245                            for (int j = 0; j < HIDDEN; j++) logits[c] += who[j, c] * decH[t + 1][j];
246                            if (logits[c] > maxL) maxL = logits[c];
247                        }
248                        float sumExp = 0;
249                        float[] probs = new float[ALPHA];
250                        for (int c = 0; c < ALPHA; c++) { probs[c] = (float)Math.Exp(logits[c] - maxL); sumExp += probs[c]; }
251                        for (int c = 0; c < ALPHA; c++) probs[c] /= sumExp;
252
253                        int targetC = (t < target.Length) ? CharIdx(target[t]) : ALPHA - 1;
254                        if (targetC < 0) targetC = 0;
255
256                        // Cross-entropy loss
257                        totalLoss -= (float)Math.Log(probs[targetC] + 1e-8f);
258                        totalChars++;
259
260                        // Argmax accuracy
261                        int bestC = 0;
262                        for (int c = 1; c < ALPHA; c++) if (probs[c] > probs[bestC]) bestC = c;
263                        if (bestC == targetC) correctChars++;
264
265                        // Gradient: dL/dlogits = probs - one_hot(target)
266                        float[] dLogits = new float[ALPHA];
267                        for (int c = 0; c < ALPHA; c++) dLogits[c] = probs[c];
268                        dLogits[targetC] -= 1f;
269
270                        // Backprop through output layer (simplified: only update, no full BPTT)
271                        float gradClip = 0.5f;
272                        for (int j = 0; j < HIDDEN; j++)
273                        {
274                            for (int c = 0; c < ALPHA; c++)
275                            {
276                                float g = dLogits[c] * decH[t + 1][j] * lr;
277                                g = Math.Max(-gradClip, Math.Min(gradClip, g));
278                                who[j, c] -= g;
279                            }
280                        }
281                        for (int c = 0; c < ALPHA; c++)
282                        {
283                            float g = dLogits[c] * lr;
284                            g = Math.Max(-gradClip, Math.Min(gradClip, g));
285                            bo_t[c] -= g;
286                        }
287                    }
288                }
289
290                if (epoch % 5 == 0 || epoch == epochs - 1)
291                    Logger.Log("textproc", "  Epoch " + epoch + ": loss=" + (totalLoss / totalChars).ToString("F3") +
292                        " charAcc=" + ((float)correctChars / totalChars * 100).ToString("F1") + "%");
293
294                if (epoch > 0 && epoch % 10 == 0) lr *= 0.7f;
295            }
296
297            // Save
298            using (var bw = new BinaryWriter(new FileStream(savePath, FileMode.Create)))
299            {
300                WriteMatrix(bw, wih); WriteMatrix(bw, whh); WriteVector(bw, bh_t);
301                WriteMatrix(bw, wdh); WriteMatrix(bw, wdhh); WriteVector(bw, bdh_t);
302                WriteMatrix(bw, who); WriteVector(bw, bo_t);
303            }
304            Logger.Log("textproc", "Seq2Spell saved: " + new FileInfo(savePath).Length / 1024 + "KB");
305        }
306
307        static float[,] RandMatrix(int r, int c, float scale, Random rng)
308        {
309            var m = new float[r, c];
310            for (int i = 0; i < r; i++) for (int j = 0; j < c; j++) m[i, j] = (float)(rng.NextDouble() * 2 - 1) * scale;
311            return m;
312        }
313        static void WriteMatrix(BinaryWriter bw, float[,] m)
314        { for (int i = 0; i < m.GetLength(0); i++) for (int j = 0; j < m.GetLength(1); j++) bw.Write(m[i, j]); }
315        static void WriteVector(BinaryWriter bw, float[] v)
316        { for (int i = 0; i < v.Length; i++) bw.Write(v[i]); }
317
318        static string CorruptWord(string w, Random rng)
319        {
320            if (w.Length < 3) return w;
321            switch (rng.Next(5))
322            {
323                case 0: int p = rng.Next(w.Length - 1); char[] a = w.ToCharArray(); char t = a[p]; a[p] = a[p + 1]; a[p + 1] = t; return new string(a);
324                case 1: a = w.ToCharArray(); a[rng.Next(a.Length)] = (char)('а' + rng.Next(32)); return new string(a);
325                case 2: return w.Remove(rng.Next(w.Length), 1);
326                case 3: return w.Insert(rng.Next(w.Length + 1), ((char)('а' + rng.Next(32))).ToString());
327                case 4: // double corruption
328                    string w2 = CorruptWord(w, rng);
329                    return CorruptWord(w2, rng);
330                default: return w;
331            }
332        }
333    }
334}