windowcapture
исходный код / Helpers/CharEmbedNet.cs

CharEmbedNet.cs

344 строк · 14,275 байт · модуль Helpers
  1using System;
  2using System.Collections.Generic;
  3using System.IO;
  4
  5namespace WindowCapture.Helpers
  6{
  7    /// <summary>
  8    /// CharEmbedNet: Character Embedding Neural Network.
  9    ///
 10    /// Architecture:
 11    ///   1. Each character (а-я,ё) gets an embedding vector of size EMB_DIM
 12    ///   2. Word = average of character embeddings (order-aware: position weight)
 13    ///   3. All 1.5M dictionary words pre-computed as vectors → saved to file
 14    ///   4. At runtime: input word → compute vector → find nearest dictionary vector
 15    ///
 16    /// Training:
 17    ///   - Positive pairs: (word, same word) → vectors should be identical
 18    ///   - Negative pairs: (word, corrupted word) → vectors should be different
 19    ///   - Triplet loss: d(anchor, positive) &lt; d(anchor, negative) + margin
 20    ///
 21    /// File format:
 22    ///   charembednet.bin:
 23    ///     [0..ALPHA*EMB_DIM*4] = char embeddings (float[ALPHA][EMB_DIM])
 24    ///     [next..] = precomputed word vectors (float[numWords][VEC_DIM]) + word index
 25    ///
 26    /// This replaces SpellNet for word validity AND candidate ranking.
 27    /// </summary>
 28    public class CharEmbedNet
 29    {
 30        const int ALPHA = 34;    // а-я(32) + ё + padding
 31        const int EMB_DIM = 48;  // embedding per character
 32        const int VEC_DIM = 48;  // final word vector dimension
 33        const int MAX_LEN = 20;
 34
 35        // Trained char embeddings
 36        float[,] charEmb; // [ALPHA, EMB_DIM]
 37
 38        // Precomputed dictionary vectors
 39        float[,] dictVecs;   // [numWords, VEC_DIM]
 40        string[] dictWords;  // word for each vector
 41        int dictCount;
 42
 43        // Position weights (trained): how much each position contributes
 44        float[] posWeights; // [MAX_LEN]
 45
 46        volatile bool ready;
 47        public bool IsReady { get { return ready; } }
 48
 49        /// <summary>Compute word vector from character embeddings.</summary>
 50        public float[] WordToVec(string word)
 51        {
 52            float[] vec = new float[VEC_DIM];
 53            string w = word.ToLower();
 54            int len = Math.Min(w.Length, MAX_LEN);
 55            float totalWeight = 0;
 56
 57            for (int i = 0; i < len; i++)
 58            {
 59                int ci = CharIdx(w[i]);
 60                if (ci < 0) continue;
 61                float pw = (posWeights != null && i < posWeights.Length) ? posWeights[i] : 1f;
 62                for (int d = 0; d < VEC_DIM; d++)
 63                    vec[d] += charEmb[ci, d] * pw;
 64                totalWeight += pw;
 65            }
 66
 67            // Normalize
 68            if (totalWeight > 0)
 69                for (int d = 0; d < VEC_DIM; d++) vec[d] /= totalWeight;
 70
 71            // L2 normalize
 72            float norm = 0;
 73            for (int d = 0; d < VEC_DIM; d++) norm += vec[d] * vec[d];
 74            if (norm > 0) { norm = (float)Math.Sqrt(norm); for (int d = 0; d < VEC_DIM; d++) vec[d] /= norm; }
 75
 76            return vec;
 77        }
 78
 79        /// <summary>Find top N nearest dictionary words to input.</summary>
 80        public List<string> FindNearest(string input, int topN)
 81        {
 82            if (!ready) return new List<string>();
 83            float[] inputVec = WordToVec(input);
 84
 85            // Cosine similarity with all dict vectors
 86            var scores = new List<KeyValuePair<int, float>>();
 87            for (int i = 0; i < dictCount; i++)
 88            {
 89                float dot = 0;
 90                for (int d = 0; d < VEC_DIM; d++)
 91                    dot += inputVec[d] * dictVecs[i, d];
 92                scores.Add(new KeyValuePair<int, float>(i, dot));
 93            }
 94
 95            // Partial sort: find top N
 96            scores.Sort(delegate(KeyValuePair<int, float> a, KeyValuePair<int, float> b)
 97            { return b.Value.CompareTo(a.Value); });
 98
 99            var result = new List<string>();
100            int count = Math.Min(topN, scores.Count);
101            for (int i = 0; i < count; i++)
102                result.Add(dictWords[scores[i].Key]);
103            return result;
104        }
105
106        /// <summary>Cosine similarity between input and a candidate word.</summary>
107        public float Similarity(string input, string candidate)
108        {
109            float[] v1 = WordToVec(input);
110            float[] v2 = WordToVec(candidate);
111            float dot = 0;
112            for (int d = 0; d < VEC_DIM; d++) dot += v1[d] * v2[d];
113            return dot;
114        }
115
116        /// <summary>Load pretrained model from file.</summary>
117        public bool Load(string path)
118        {
119            try
120            {
121                using (var fs = new FileStream(path, FileMode.Open))
122                using (var br = new BinaryReader(fs))
123                {
124                    // Char embeddings
125                    charEmb = new float[ALPHA, EMB_DIM];
126                    for (int i = 0; i < ALPHA; i++)
127                        for (int d = 0; d < EMB_DIM; d++)
128                            charEmb[i, d] = br.ReadSingle();
129
130                    // Position weights
131                    posWeights = new float[MAX_LEN];
132                    for (int i = 0; i < MAX_LEN; i++)
133                        posWeights[i] = br.ReadSingle();
134
135                    // Dict vectors
136                    dictCount = br.ReadInt32();
137                    dictWords = new string[dictCount];
138                    dictVecs = new float[dictCount, VEC_DIM];
139                    for (int i = 0; i < dictCount; i++)
140                    {
141                        dictWords[i] = br.ReadString();
142                        for (int d = 0; d < VEC_DIM; d++)
143                            dictVecs[i, d] = br.ReadSingle();
144                    }
145                }
146                ready = true;
147                Logger.Log("textproc", "CharEmbedNet loaded: " + dictCount + " word vectors, " +
148                    (ALPHA * EMB_DIM + MAX_LEN + dictCount * (VEC_DIM + 20)) * 4 / 1024 / 1024 + "MB");
149                return true;
150            }
151            catch (Exception ex)
152            {
153                Logger.Log("textproc", "CharEmbedNet load err: " + ex.Message);
154                return false;
155            }
156        }
157
158        // ====================================================================
159        // TRAINING (offline)
160        // ====================================================================
161        public static void Train(string[] dictionary, string savePath, int epochs)
162        {
163            var rng = new Random(42);
164            int dictLen = Math.Min(dictionary.Length, 100000); // top 100k for training
165
166            // Initialize embeddings randomly
167            float[,] emb = new float[ALPHA, EMB_DIM];
168            float[] pw = new float[MAX_LEN];
169            for (int i = 0; i < ALPHA; i++)
170                for (int d = 0; d < EMB_DIM; d++)
171                    emb[i, d] = (float)(rng.NextDouble() * 2 - 1) * 0.3f;
172            // Position weights: first 2 chars and last 2 chars are critical
173            for (int i = 0; i < MAX_LEN; i++)
174            {
175                pw[i] = 1.0f;
176                if (i == 0) pw[i] = 3.0f;       // first char = most important
177                else if (i == 1) pw[i] = 2.0f;   // second char
178                else if (i >= MAX_LEN - 2) pw[i] = 1.5f; // last chars (endings)
179            }
180
181            float lr = 0.003f;
182            int samplesPerEpoch = 200000;
183
184            Logger.Log("textproc", "CharEmbedNet training: " + dictLen + " words, " + epochs + " epochs");
185
186            for (int epoch = 0; epoch < epochs; epoch++)
187            {
188                float totalLoss = 0;
189                int correct = 0;
190
191                for (int s = 0; s < samplesPerEpoch; s++)
192                {
193                    // Triplet: anchor (real word), positive (slightly corrupted → should be close),
194                    //          negative (different word → should be far)
195                    string anchor = dictionary[rng.Next(dictLen)].ToLower();
196                    if (anchor.Length < 3 || anchor.Length > MAX_LEN) continue;
197
198                    string positive = CorruptLight(anchor, rng); // light corruption → close
199                    string negative = dictionary[rng.Next(dictLen)].ToLower(); // different word → far
200                    while (negative == anchor) negative = dictionary[rng.Next(dictLen)].ToLower();
201
202                    float[] vAnchor = ComputeVec(anchor, emb, pw);
203                    float[] vPos = ComputeVec(positive, emb, pw);
204                    float[] vNeg = ComputeVec(negative, emb, pw);
205
206                    float dPos = CosDist(vAnchor, vPos);
207                    float dNeg = CosDist(vAnchor, vNeg);
208
209                    // Triplet loss: want dPos < dNeg - margin
210                    float margin = 0.3f;
211                    float loss = Math.Max(0, dPos - dNeg + margin);
212                    totalLoss += loss;
213
214                    if (dPos < dNeg) correct++;
215
216                    if (loss > 0)
217                    {
218                        // Gradient: push anchor closer to positive, farther from negative
219                        // Simplified: adjust char embeddings for anchor/positive to be more similar
220                        float gradScale = lr;
221                        AdjustEmb(anchor, positive, emb, pw, gradScale);  // pull together
222                        AdjustEmb(anchor, negative, emb, pw, -gradScale * 0.5f); // push apart
223                    }
224                }
225
226                if (epoch % 5 == 0 || epoch == epochs - 1)
227                    Logger.Log("textproc", "  Epoch " + epoch + ": loss=" + (totalLoss / samplesPerEpoch).ToString("F4") +
228                        " acc=" + ((float)correct / samplesPerEpoch * 100).ToString("F1") + "%");
229
230                if (epoch > 0 && epoch % 10 == 0) lr *= 0.7f;
231            }
232
233            // Save: embeddings + precomputed vectors for top words
234            int saveCount = Math.Min(dictionary.Length, 80000); // save top 80k vectors
235            Logger.Log("textproc", "CharEmbedNet: precomputing " + saveCount + " word vectors...");
236
237            using (var fs = new FileStream(savePath, FileMode.Create))
238            using (var bw = new BinaryWriter(fs))
239            {
240                // Char embeddings
241                for (int i = 0; i < ALPHA; i++)
242                    for (int d = 0; d < EMB_DIM; d++)
243                        bw.Write(emb[i, d]);
244                // Position weights
245                for (int i = 0; i < MAX_LEN; i++) bw.Write(pw[i]);
246                // Dict vectors
247                bw.Write(saveCount);
248                for (int i = 0; i < saveCount; i++)
249                {
250                    string w = dictionary[i].ToLower();
251                    bw.Write(w);
252                    float[] vec = ComputeVec(w, emb, pw);
253                    for (int d = 0; d < VEC_DIM; d++) bw.Write(vec[d]);
254                }
255            }
256
257            long fileSize = new FileInfo(savePath).Length;
258            Logger.Log("textproc", "CharEmbedNet saved: " + fileSize / 1024 + "KB");
259        }
260
261        static float[] ComputeVec(string word, float[,] emb, float[] pw)
262        {
263            float[] vec = new float[VEC_DIM];
264            int len = Math.Min(word.Length, MAX_LEN);
265            float totalW = 0;
266            for (int i = 0; i < len; i++)
267            {
268                int ci = CharIdx(word[i]);
269                if (ci < 0) continue;
270                float w = (i < pw.Length) ? pw[i] : 1f;
271                for (int d = 0; d < VEC_DIM; d++) vec[d] += emb[ci, d] * w;
272                totalW += w;
273            }
274            if (totalW > 0) for (int d = 0; d < VEC_DIM; d++) vec[d] /= totalW;
275            float norm = 0;
276            for (int d = 0; d < VEC_DIM; d++) norm += vec[d] * vec[d];
277            if (norm > 0) { norm = (float)Math.Sqrt(norm); for (int d = 0; d < VEC_DIM; d++) vec[d] /= norm; }
278            return vec;
279        }
280
281        static float CosDist(float[] a, float[] b)
282        {
283            float dot = 0;
284            for (int d = 0; d < VEC_DIM; d++) dot += a[d] * b[d];
285            return 1f - dot; // 0 = identical, 2 = opposite
286        }
287
288        static void AdjustEmb(string w1, string w2, float[,] emb, float[] pw, float scale)
289        {
290            int len1 = Math.Min(w1.Length, MAX_LEN);
291            int len2 = Math.Min(w2.Length, MAX_LEN);
292            for (int i = 0; i < Math.Min(len1, len2); i++)
293            {
294                int c1 = CharIdx(w1[i]);
295                int c2 = CharIdx(w2[i]);
296                if (c1 < 0 || c2 < 0) continue;
297                for (int d = 0; d < EMB_DIM; d++)
298                {
299                    float diff = emb[c1, d] - emb[c2, d];
300                    // Gradient clipping to prevent NaN
301                    float grad = diff * scale;
302                    if (grad > 0.1f) grad = 0.1f;
303                    if (grad < -0.1f) grad = -0.1f;
304                    emb[c1, d] -= grad;
305                    emb[c2, d] += grad;
306                    // Keep embeddings bounded
307                    if (emb[c1, d] > 2f) emb[c1, d] = 2f;
308                    if (emb[c1, d] < -2f) emb[c1, d] = -2f;
309                    if (emb[c2, d] > 2f) emb[c2, d] = 2f;
310                    if (emb[c2, d] < -2f) emb[c2, d] = -2f;
311                }
312            }
313        }
314
315        static string CorruptLight(string word, Random rng)
316        {
317            if (word.Length < 3) return word;
318            switch (rng.Next(4))
319            {
320                case 0: // swap adjacent
321                    int p = rng.Next(word.Length - 1);
322                    char[] a = word.ToCharArray();
323                    char t = a[p]; a[p] = a[p + 1]; a[p + 1] = t;
324                    return new string(a);
325                case 1: // replace one char
326                    a = word.ToCharArray();
327                    a[rng.Next(a.Length)] = (char)('а' + rng.Next(32));
328                    return new string(a);
329                case 2: // delete one char
330                    return word.Remove(rng.Next(word.Length), 1);
331                case 3: // insert one char
332                    return word.Insert(rng.Next(word.Length + 1), ((char)('а' + rng.Next(32))).ToString());
333                default: return word;
334            }
335        }
336
337        static int CharIdx(char c)
338        {
339            if (c >= 'а' && c <= 'я') return c - 'а' + 1;
340            if (c == 'ё') return 33;
341            return 0;
342        }
343    }
344}