1""" 2Neural Context Model: word embedding-based bigram predictor. 3Trained on RTX 4080. 4 5Architecture: 6 - Word embedding (each word → vector) 7 - Given prev_word embedding, predict next_word via dot product 8 - Trained with negative sampling (like Word2Vec) 9 10This learns semantic similarity between word pairs: 11 "на" is close to "карте","столе","работе" in embedding space 12 "двойными" is close to "буквами","пробелами" 13 14Export: word embeddings for top 50k words → C# does dot product at runtime 15""" 16import torch, torch.nn as nn, struct, random, os, sys, time, numpy as np 17 18EMBED_DIM = 64 19VOCAB_SIZE = 15000 20EPOCHS = 30 21BATCH_SIZE = 2048 22LR = 0.01 23NEG_SAMPLES = 5 24 25DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'Data') 26SAVE_PATH = os.path.join(DATA_DIR, 'context_nn.bin') 27 28def p(*a,**k): print(*a,**k); sys.stdout.flush() 29 30class ContextModel(nn.Module): 31 def __init__(self, vocab_size, embed_dim): 32 super().__init__() 33 self.embed_in = nn.Embedding(vocab_size, embed_dim) # context word 34 self.embed_out = nn.Embedding(vocab_size, embed_dim) # target word 35 nn.init.xavier_uniform_(self.embed_in.weight) 36 nn.init.xavier_uniform_(self.embed_out.weight) 37 38 def forward(self, context, target, negatives): 39 # context: [batch] 40 # target: [batch] 41 # negatives: [batch, neg_samples] 42 ctx = self.embed_in(context) # [batch, embed] 43 tgt = self.embed_out(target) # [batch, embed] 44 neg = self.embed_out(negatives) # [batch, neg, embed] 45 46 # Positive score: dot(context, target) 47 pos_score = (ctx * tgt).sum(dim=1) # [batch] 48 pos_loss = -torch.log(torch.sigmoid(pos_score) + 1e-8) 49 50 # Negative score 51 neg_score = torch.bmm(neg, ctx.unsqueeze(2)).squeeze(2) # [batch, neg] 52 neg_loss = -torch.log(torch.sigmoid(-neg_score) + 1e-8).sum(dim=1) 53 54 return (pos_loss + neg_loss).mean() 55 56def main(): 57 p("Loading dictionary...") 58 words = [] 59 with open(os.path.join(DATA_DIR, 'dict_ru.txt'), 'r', encoding='utf-8') as f: 60 for line in f: 61 w = line.strip().lower() 62 if 1 <= len(w) <= 20: 63 words.append(w) 64 if len(words) >= VOCAB_SIZE: 65 break 66 p(f"Loaded {len(words)} words") 67 68 word2idx = {w: i for i, w in enumerate(words)} 69 70 # Build training pairs from common patterns + random cooccurrence 71 p("Building training pairs...") 72 pairs = [] 73 rng = random.Random(42) 74 75 # Manual high-quality pairs 76 manual = [ 77 ("на","карте"),("на","столе"),("на","работе"),("на","улице"),("на","экране"), 78 ("в","итоге"),("в","общем"),("в","жизни"),("в","школе"),("в","интернете"), 79 ("я","думаю"),("я","хочу"),("я","знаю"),("я","боюсь"),("я","решил"),("я","могу"), 80 ("не","могу"),("не","знаю"),("не","хочу"),("не","работает"),("не","понимаю"), 81 ("но","я"),("но","это"),("но","всё"),("но","он"), 82 ("что","это"),("что","он"),("что","делать"),("что","бы"), 83 ("двойными","буквами"),("банковскую","карту"),("на","клавиатуре"), 84 ("потому","что"),("так","как"),("для","того"), 85 ("спать","лягу"),("книгу","почитаю"),("кнопку","бэкспейс"), 86 ("очень","быстро"),("очень","сильно"),("очень","хочется"), 87 ("нервы","сдали"),("жизнь","боль"), 88 ("может","быть"),("всё","таки"),("всё","равно"), 89 ("как","же"),("как","будто"),("как","раз"), 90 ("средств","на"),("на","карту"),("на","счёт"), 91 ("исправлять","текст"),("нажимать","кнопку"), 92 ("написал","сообщение"),("купил","телефон"), 93 ] 94 for w1, w2 in manual: 95 if w1 in word2idx and w2 in word2idx: 96 for _ in range(500): # repeat for weight 97 pairs.append((word2idx[w1], word2idx[w2])) 98 99 # Random adjacent pairs from frequency list (simulate natural text) 100 prepositions = {"в","на","по","к","с","у","за","от","из","до","для","без","при"} 101 for _ in range(500000): 102 i1 = rng.randint(0, min(10000, len(words)-1)) 103 i2 = rng.randint(0, min(20000, len(words)-1)) 104 pairs.append((i1, i2)) 105 106 p(f"Training pairs: {len(pairs)}") 107 108 dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 109 p(f"Device: {dev}") 110 111 model = ContextModel(len(words), EMBED_DIM).to(dev) 112 optimizer = torch.optim.Adam(model.parameters(), lr=LR) 113 114 p(f"Training {EPOCHS} epochs, batch={BATCH_SIZE}, embed={EMBED_DIM}, neg={NEG_SAMPLES}") 115 t0 = time.time() 116 117 for epoch in range(EPOCHS): 118 rng.shuffle(pairs) 119 total_loss = 0 120 n_batches = 0 121 122 for i in range(0, len(pairs) - BATCH_SIZE, BATCH_SIZE): 123 batch = pairs[i:i+BATCH_SIZE] 124 ctx = torch.tensor([p2[0] for p2 in batch], device=dev) 125 tgt = torch.tensor([p2[1] for p2 in batch], device=dev) 126 neg = torch.randint(0, len(words), (len(batch), NEG_SAMPLES), device=dev) 127 128 loss = model(ctx, tgt, neg) 129 optimizer.zero_grad() 130 loss.backward() 131 optimizer.step() 132 total_loss += loss.item() 133 n_batches += 1 134 135 if epoch % 5 == 0 or epoch == EPOCHS - 1: 136 p(f" Epoch {epoch}: loss={total_loss/n_batches:.4f} [{time.time()-t0:.0f}s]") 137 138 # Export embeddings 139 p("Exporting...") 140 model.eval() 141 emb_in = model.embed_in.weight.detach().cpu().numpy() # context embeddings 142 emb_out = model.embed_out.weight.detach().cpu().numpy() # target embeddings 143 144 with open(SAVE_PATH, 'wb') as f: 145 f.write(struct.pack('ii', len(words), EMBED_DIM)) 146 # Word strings 147 for w in words: 148 wb = w.encode('utf-8') 149 f.write(struct.pack('i', len(wb))) 150 f.write(wb) 151 # Context embeddings (for prev_word) 152 for i in range(len(words)): 153 for j in range(EMBED_DIM): 154 f.write(struct.pack('f', float(emb_in[i][j]))) 155 # Target embeddings (for candidate word) 156 for i in range(len(words)): 157 for j in range(EMBED_DIM): 158 f.write(struct.pack('f', float(emb_out[i][j]))) 159 160 p(f"Saved: {os.path.getsize(SAVE_PATH)//1024}KB") 161 162 # Test 163 p("\n=== Context Similarity Test ===") 164 def score(w1, w2): 165 if w1 not in word2idx or w2 not in word2idx: return 0 166 v1 = emb_in[word2idx[w1]] 167 v2 = emb_out[word2idx[w2]] 168 return float(np.dot(v1, v2)) 169 170 test_pairs = [ 171 ("на","карте"),("на","катере"),("на","карту"), 172 ("двойными","буквами"),("двойными","буями"), 173 ("но","боюсь"),("но","бюст"), 174 ("спать","лягу"),("спать","лгу"), 175 ("нажимать","кнопку"),("нажимать","провод"), 176 ("я","решил"),("я","постирать"), 177 ] 178 for w1, w2 in test_pairs: 179 s = score(w1, w2) 180 p(f" score({w1}, {w2}) = {s:.3f}") 181 182 p("\nDone!") 183 184if __name__ == '__main__': 185 main()