windowcapture
исходный код / Helpers/SpellNet.cs

SpellNet.cs

231 строк · 10,012 байт · модуль Helpers
  1using System;
  2using System.IO;
  3using System.Text;
  4
  5namespace WindowCapture.Helpers
  6{
  7    /// <summary>
  8    /// SpellNet v2: Character trigram neural network for word validity.
  9    ///
 10    /// Architecture:
 11    ///   Input:  word → character trigrams → hashed bag-of-trigrams vector [TRIG_DIM]
 12    ///   Layer1: FC(TRIG_DIM, HIDDEN) + ReLU
 13    ///   Layer2: FC(HIDDEN, 1) + Sigmoid
 14    ///   Output: P(valid word) ∈ [0.0, 1.0]
 15    ///
 16    /// Trigram approach is MUCH better than one-hot because:
 17    ///   - Captures local character patterns ("при","ров","ать" = common RU patterns)
 18    ///   - Fixed-size input regardless of word length
 19    ///   - Dense representation (many trigrams activated per word)
 20    ///
 21    /// Weight file: (TRIG_DIM*HIDDEN + HIDDEN + HIDDEN + 1) * 4 bytes ≈ 260KB
 22    /// </summary>
 23    public class SpellNet
 24    {
 25        const int TRIG_DIM = 512;  // hashed trigram vector size
 26        const int HIDDEN = 128;    // hidden neurons
 27
 28        float[] w1; // [TRIG_DIM × HIDDEN]
 29        float[] b1; // [HIDDEN]
 30        float[] w2; // [HIDDEN]
 31        float b2;
 32
 33        float[] trigVec;  // reusable [TRIG_DIM]
 34        float[] hidVec;   // reusable [HIDDEN]
 35
 36        volatile bool ready;
 37        public bool IsReady { get { return ready; } }
 38
 39        public bool Load(string path)
 40        {
 41            try
 42            {
 43                byte[] raw = File.ReadAllBytes(path);
 44                int expected = (TRIG_DIM * HIDDEN + HIDDEN + HIDDEN + 1) * 4;
 45                if (raw.Length < expected) { Logger.Log("textproc", "SpellNet: too small " + raw.Length); return false; }
 46
 47                w1 = new float[TRIG_DIM * HIDDEN];
 48                b1 = new float[HIDDEN];
 49                w2 = new float[HIDDEN];
 50
 51                int off = 0;
 52                Buffer.BlockCopy(raw, off, w1, 0, TRIG_DIM * HIDDEN * 4); off += TRIG_DIM * HIDDEN * 4;
 53                Buffer.BlockCopy(raw, off, b1, 0, HIDDEN * 4); off += HIDDEN * 4;
 54                Buffer.BlockCopy(raw, off, w2, 0, HIDDEN * 4); off += HIDDEN * 4;
 55                b2 = BitConverter.ToSingle(raw, off);
 56
 57                trigVec = new float[TRIG_DIM];
 58                hidVec = new float[HIDDEN];
 59                ready = true;
 60                Logger.Log("textproc", "SpellNet v2 loaded: " + raw.Length + " bytes");
 61                return true;
 62            }
 63            catch (Exception ex) { Logger.Log("textproc", "SpellNet load err: " + ex.Message); return false; }
 64        }
 65
 66        /// <summary>Predict word validity: 0.0=garbage, 1.0=valid word.</summary>
 67        public float Predict(string word)
 68        {
 69            if (!ready) return 0.5f;
 70            string w = word.ToLower();
 71
 72            // Encode: bag of character trigrams (hashed)
 73            Array.Clear(trigVec, 0, TRIG_DIM);
 74            string padded = "#" + w + "#";
 75            for (int i = 0; i <= padded.Length - 3; i++)
 76            {
 77                int h = TrigramHash(padded[i], padded[i + 1], padded[i + 2]);
 78                trigVec[h % TRIG_DIM] += 1f;
 79            }
 80            // Normalize
 81            float norm = 0;
 82            for (int i = 0; i < TRIG_DIM; i++) norm += trigVec[i] * trigVec[i];
 83            if (norm > 0) { norm = (float)Math.Sqrt(norm); for (int i = 0; i < TRIG_DIM; i++) trigVec[i] /= norm; }
 84
 85            // Layer 1: ReLU(trigVec × W1 + b1)
 86            for (int h = 0; h < HIDDEN; h++)
 87            {
 88                float sum = b1[h];
 89                for (int i = 0; i < TRIG_DIM; i++)
 90                    if (trigVec[i] != 0f) sum += trigVec[i] * w1[i * HIDDEN + h];
 91                hidVec[h] = sum > 0f ? sum : 0f;
 92            }
 93
 94            // Layer 2: Sigmoid(hidden × W2 + b2)
 95            float output = b2;
 96            for (int h = 0; h < HIDDEN; h++) output += hidVec[h] * w2[h];
 97            return Sigmoid(output);
 98        }
 99
100        static int TrigramHash(char a, char b, char c)
101        {
102            unchecked { return Math.Abs((a * 31 + b) * 31 + c); }
103        }
104
105        static float Sigmoid(float x)
106        {
107            if (x > 10f) return 1f; if (x < -10f) return 0f;
108            return 1f / (1f + (float)Math.Exp(-x));
109        }
110
111        // ====================================================================
112        // TRAINING
113        // ====================================================================
114        public static void Train(string[] dictionary, string savePath, int epochs)
115        {
116            var rng = new Random(42);
117            int dictLen = Math.Min(dictionary.Length, 500000);
118
119            float[] tw1 = new float[TRIG_DIM * HIDDEN];
120            float[] tb1 = new float[HIDDEN];
121            float[] tw2 = new float[HIDDEN];
122            float tb2 = 0f;
123
124            float s1 = (float)Math.Sqrt(2.0 / TRIG_DIM);
125            for (int i = 0; i < tw1.Length; i++) tw1[i] = (float)(rng.NextDouble() * 2 - 1) * s1;
126            float s2 = (float)Math.Sqrt(2.0 / HIDDEN);
127            for (int i = 0; i < tw2.Length; i++) tw2[i] = (float)(rng.NextDouble() * 2 - 1) * s2;
128
129            float[] inp = new float[TRIG_DIM];
130            float[] hid = new float[HIDDEN];
131            float[] gw1 = new float[TRIG_DIM * HIDDEN];
132            float[] gb1 = new float[HIDDEN];
133            float[] gw2 = new float[HIDDEN];
134
135            float lr = 0.01f;
136            int batch = 64;
137            int samplesPerEpoch = 200000;
138
139            Logger.Log("textproc", "SpellNet v2 training: " + dictLen + " words, " + epochs + " epochs, trigDim=" + TRIG_DIM + " hidden=" + HIDDEN);
140
141            for (int epoch = 0; epoch < epochs; epoch++)
142            {
143                float totalLoss = 0; int correct = 0, total = 0;
144
145                for (int s = 0; s < samplesPerEpoch; s++)
146                {
147                    bool isPos = (s % 2) == 0;
148                    string word = isPos ? dictionary[rng.Next(dictLen)] : CorruptWord(dictionary[rng.Next(dictLen)], rng);
149                    float target = isPos ? 1f : 0f;
150                    if (word.Length < 2 || word.Length > 20) continue;
151
152                    // Encode trigrams
153                    Array.Clear(inp, 0, TRIG_DIM);
154                    string padded = "#" + word.ToLower() + "#";
155                    for (int i = 0; i <= padded.Length - 3; i++)
156                    {
157                        int h = TrigramHash(padded[i], padded[i + 1], padded[i + 2]);
158                        inp[h % TRIG_DIM] += 1f;
159                    }
160                    float norm = 0; for (int i = 0; i < TRIG_DIM; i++) norm += inp[i] * inp[i];
161                    if (norm > 0) { norm = (float)Math.Sqrt(norm); for (int i = 0; i < TRIG_DIM; i++) inp[i] /= norm; }
162
163                    // Forward
164                    for (int h = 0; h < HIDDEN; h++)
165                    {
166                        float sum = tb1[h];
167                        for (int i = 0; i < TRIG_DIM; i++) if (inp[i] != 0f) sum += inp[i] * tw1[i * HIDDEN + h];
168                        hid[h] = sum > 0f ? sum : 0f;
169                    }
170                    float outVal = tb2;
171                    for (int h = 0; h < HIDDEN; h++) outVal += hid[h] * tw2[h];
172                    float pred = Sigmoid(outVal);
173
174                    float eps = 1e-7f;
175                    totalLoss -= target * (float)Math.Log(pred + eps) + (1 - target) * (float)Math.Log(1 - pred + eps);
176                    if ((pred >= 0.5f) == (target >= 0.5f)) correct++;
177                    total++;
178
179                    // Backward
180                    float dOut = pred - target;
181                    for (int h = 0; h < HIDDEN; h++) gw2[h] += hid[h] * dOut;
182                    tb2 -= lr * dOut / batch;
183                    for (int h = 0; h < HIDDEN; h++)
184                    {
185                        float dh = dOut * tw2[h]; if (hid[h] <= 0) dh = 0;
186                        gb1[h] += dh;
187                        for (int i = 0; i < TRIG_DIM; i++) if (inp[i] != 0f) gw1[i * HIDDEN + h] += inp[i] * dh;
188                    }
189
190                    if ((s + 1) % batch == 0)
191                    {
192                        float sc = lr / batch;
193                        for (int i = 0; i < tw1.Length; i++) { tw1[i] -= gw1[i] * sc; gw1[i] = 0; }
194                        for (int h = 0; h < HIDDEN; h++) { tb1[h] -= gb1[h] * sc; gb1[h] = 0; tw2[h] -= gw2[h] * sc; gw2[h] = 0; }
195                    }
196                }
197                if (epoch % 2 == 0 || epoch == epochs - 1)
198                    Logger.Log("textproc", "  Epoch " + epoch + ": loss=" + (totalLoss / total).ToString("F4") +
199                        " acc=" + ((float)correct / total * 100).ToString("F1") + "%");
200
201                // Decay learning rate
202                if (epoch > 0 && epoch % 10 == 0) lr *= 0.5f;
203            }
204
205            using (var fs = new FileStream(savePath, FileMode.Create))
206            using (var bw = new BinaryWriter(fs))
207            {
208                for (int i = 0; i < tw1.Length; i++) bw.Write(tw1[i]);
209                for (int i = 0; i < tb1.Length; i++) bw.Write(tb1[i]);
210                for (int i = 0; i < tw2.Length; i++) bw.Write(tw2[i]);
211                bw.Write(tb2);
212            }
213            Logger.Log("textproc", "SpellNet v2 saved: " + ((tw1.Length + tb1.Length + tw2.Length + 1) * 4 / 1024) + "KB");
214        }
215
216        static string CorruptWord(string word, Random rng)
217        {
218            string w = word.ToLower();
219            if (w.Length < 3) return w + "ъ";
220            switch (rng.Next(5))
221            {
222                case 0: int p = rng.Next(w.Length - 1); char[] a = w.ToCharArray(); char t = a[p]; a[p] = a[p + 1]; a[p + 1] = t; return new string(a);
223                case 1: p = rng.Next(w.Length); return w.Insert(p, w[p].ToString());
224                case 2: return w.Remove(rng.Next(w.Length), 1);
225                case 3: a = w.ToCharArray(); a[rng.Next(a.Length)] = (char)('а' + rng.Next(32)); return new string(a);
226                case 4: return w.Insert(rng.Next(w.Length + 1), ((char)('а' + rng.Next(32))).ToString());
227                default: return w;
228            }
229        }
230    }
231}