windowcapture
исходный код / Helpers/GruSpellNet.cs

GruSpellNet.cs

216 строк · 8,049 байт · модуль Helpers
  1using System;
  2using System.IO;
  3using System.Collections.Generic;
  4
  5namespace WindowCapture.Helpers
  6{
  7    /// <summary>
  8    /// GruSpellNet: GRU-based Seq2Seq spell corrector.
  9    /// Trained with PyTorch, inference in pure C#.
 10    ///
 11    /// Architecture: Encoder GRU → Decoder GRU → Linear output
 12    /// 90.8% character accuracy on Russian words.
 13    /// </summary>
 14    public class GruSpellNet
 15    {
 16        int ALPHA, EMBED_DIM, HIDDEN, MAX_LEN;
 17
 18        // Embedding
 19        float[,] embed; // [ALPHA, EMBED_DIM]
 20
 21        // Encoder GRU weights (PyTorch GRU format: weight_ih, weight_hh, bias_ih, bias_hh)
 22        // Each has 3*HIDDEN rows (reset, update, new gates)
 23        float[,] enc_wih, enc_whh; // [3*HIDDEN, input/hidden]
 24        float[] enc_bih, enc_bhh;  // [3*HIDDEN]
 25
 26        // Decoder GRU
 27        float[,] dec_wih, dec_whh;
 28        float[] dec_bih, dec_bhh;
 29
 30        // Output linear
 31        float[,] outW; // [ALPHA, HIDDEN]
 32        float[] outB;  // [ALPHA]
 33
 34        volatile bool ready;
 35        public bool IsReady { get { return ready; } }
 36
 37        public bool Load(string path)
 38        {
 39            try
 40            {
 41                using (var br = new BinaryReader(new FileStream(path, FileMode.Open)))
 42                {
 43                    ALPHA = br.ReadInt32();
 44                    EMBED_DIM = br.ReadInt32();
 45                    HIDDEN = br.ReadInt32();
 46                    MAX_LEN = br.ReadInt32();
 47
 48                    embed = ReadMat(br, ALPHA, EMBED_DIM);
 49
 50                    // Encoder GRU: 4 params (weight_ih_l0, weight_hh_l0, bias_ih_l0, bias_hh_l0)
 51                    enc_wih = ReadParamMat(br, 3 * HIDDEN, EMBED_DIM);
 52                    enc_whh = ReadParamMat(br, 3 * HIDDEN, HIDDEN);
 53                    enc_bih = ReadParamVec(br, 3 * HIDDEN);
 54                    enc_bhh = ReadParamVec(br, 3 * HIDDEN);
 55
 56                    // Decoder GRU
 57                    dec_wih = ReadParamMat(br, 3 * HIDDEN, EMBED_DIM);
 58                    dec_whh = ReadParamMat(br, 3 * HIDDEN, HIDDEN);
 59                    dec_bih = ReadParamVec(br, 3 * HIDDEN);
 60                    dec_bhh = ReadParamVec(br, 3 * HIDDEN);
 61
 62                    // Output
 63                    outW = ReadMat(br, ALPHA, HIDDEN);
 64                    outB = new float[ALPHA];
 65                    for (int i = 0; i < ALPHA; i++) outB[i] = br.ReadSingle();
 66                }
 67
 68                ready = true;
 69                Logger.Log("textproc", "GruSpellNet loaded: ALPHA=" + ALPHA + " EMBED=" + EMBED_DIM + " HIDDEN=" + HIDDEN);
 70                return true;
 71            }
 72            catch (Exception ex)
 73            {
 74                Logger.Log("textproc", "GruSpellNet load err: " + ex.Message);
 75                return false;
 76            }
 77        }
 78
 79        float[,] ReadMat(BinaryReader br, int r, int c)
 80        {
 81            var m = new float[r, c];
 82            for (int i = 0; i < r; i++) for (int j = 0; j < c; j++) m[i, j] = br.ReadSingle();
 83            return m;
 84        }
 85
 86        float[,] ReadParamMat(BinaryReader br, int expectedR, int expectedC)
 87        {
 88            int n = br.ReadInt32(); // total elements
 89            int r = expectedR, c = expectedC;
 90            var m = new float[r, c];
 91            for (int i = 0; i < r; i++) for (int j = 0; j < c; j++) m[i, j] = br.ReadSingle();
 92            return m;
 93        }
 94
 95        float[] ReadParamVec(BinaryReader br, int expectedN)
 96        {
 97            int n = br.ReadInt32();
 98            var v = new float[expectedN];
 99            for (int i = 0; i < expectedN; i++) v[i] = br.ReadSingle();
100            return v;
101        }
102
103        /// <summary>Correct a word using the trained GRU model.</summary>
104        public string Correct(string input)
105        {
106            if (!ready || input.Length < 2) return input;
107            string lower = input.ToLower();
108
109            // Encode
110            float[] h = new float[HIDDEN];
111            for (int t = 0; t < Math.Min(lower.Length, MAX_LEN); t++)
112            {
113                int ci = CharIdx(lower[t]);
114                if (ci < 0 || ci >= ALPHA) ci = ALPHA - 1;
115                float[] x = new float[EMBED_DIM];
116                for (int d = 0; d < EMBED_DIM; d++) x[d] = embed[ci, d];
117                h = GruStep(x, h, enc_wih, enc_whh, enc_bih, enc_bhh);
118            }
119            // EOS token
120            float[] eosEmb = new float[EMBED_DIM];
121            for (int d = 0; d < EMBED_DIM; d++) eosEmb[d] = embed[ALPHA - 1, d];
122            h = GruStep(eosEmb, h, enc_wih, enc_whh, enc_bih, enc_bhh);
123
124            // Decode (greedy)
125            var result = new System.Text.StringBuilder();
126            int prevChar = ALPHA - 1; // start token
127            float[] dh = (float[])h.Clone();
128
129            for (int t = 0; t < MAX_LEN; t++)
130            {
131                float[] xd = new float[EMBED_DIM];
132                for (int d = 0; d < EMBED_DIM; d++) xd[d] = embed[prevChar, d];
133                dh = GruStep(xd, dh, dec_wih, dec_whh, dec_bih, dec_bhh);
134
135                // Linear + argmax
136                int bestC = 0;
137                float bestV = float.MinValue;
138                for (int c = 0; c < ALPHA; c++)
139                {
140                    float v = outB[c];
141                    for (int j = 0; j < HIDDEN; j++) v += outW[c, j] * dh[j];
142                    if (v > bestV) { bestV = v; bestC = c; }
143                }
144
145                if (bestC == ALPHA - 1) break; // EOS
146                char ch = IdxToChar(bestC);
147                if (ch != '\0') result.Append(ch);
148                prevChar = bestC;
149            }
150
151            string output = result.ToString();
152            return output.Length >= 2 ? output : input;
153        }
154
155        // GRU step: PyTorch GRU formula
156        // r = σ(W_ir @ x + b_ir + W_hr @ h + b_hr)
157        // z = σ(W_iz @ x + b_iz + W_hz @ h + b_hz)
158        // n = tanh(W_in @ x + b_in + r * (W_hn @ h + b_hn))
159        // h' = (1 - z) * n + z * h
160        float[] GruStep(float[] x, float[] h, float[,] wih, float[,] whh, float[] bih, float[] bhh)
161        {
162            int H = HIDDEN;
163            int inDim = x.Length;
164
165            // Compute gates
166            float[] r = new float[H], z = new float[H], n = new float[H];
167            float[] newH = new float[H];
168
169            for (int j = 0; j < H; j++)
170            {
171                // Reset gate
172                float rVal = bih[j] + bhh[j];
173                for (int k = 0; k < inDim; k++) rVal += wih[j, k] * x[k];
174                for (int k = 0; k < H; k++) rVal += whh[j, k] * h[k];
175                r[j] = Sigmoid(rVal);
176
177                // Update gate
178                float zVal = bih[H + j] + bhh[H + j];
179                for (int k = 0; k < inDim; k++) zVal += wih[H + j, k] * x[k];
180                for (int k = 0; k < H; k++) zVal += whh[H + j, k] * h[k];
181                z[j] = Sigmoid(zVal);
182
183                // New gate
184                float nVal = bih[2 * H + j];
185                for (int k = 0; k < inDim; k++) nVal += wih[2 * H + j, k] * x[k];
186                float hn = bhh[2 * H + j];
187                for (int k = 0; k < H; k++) hn += whh[2 * H + j, k] * h[k];
188                nVal += r[j] * hn;
189                n[j] = Tanh(nVal);
190
191                // Output
192                newH[j] = (1f - z[j]) * n[j] + z[j] * h[j];
193            }
194            return newH;
195        }
196
197        static float Sigmoid(float x) { if (x > 10) return 1f; if (x < -10) return 0f; return 1f / (1f + (float)Math.Exp(-x)); }
198        static float Tanh(float x) { if (x > 10) return 1f; if (x < -10) return -1f; float e = (float)Math.Exp(2 * x); return (e - 1) / (e + 1); }
199
200        static int CharIdx(char c)
201        {
202            if (c >= 'а' && c <= 'я') return c - 'а';
203            if (c == 'ё') return 32;
204            if (c == ' ') return 33;
205            return ALPHA_STATIC - 1;
206        }
207        static char IdxToChar(int i)
208        {
209            if (i >= 0 && i <= 31) return (char)('а' + i);
210            if (i == 32) return 'ё';
211            if (i == 33) return ' ';
212            return '\0';
213        }
214        const int ALPHA_STATIC = 35;
215    }
216}