windowcapture
исходный код / Tools/train_context_nn.py

train_context_nn.py

185 строк · 7,536 байт · модуль Tools
  1"""
  2Neural Context Model: word embedding-based bigram predictor.
  3Trained on RTX 4080.
  4
  5Architecture:
  6  - Word embedding (each word → vector)
  7  - Given prev_word embedding, predict next_word via dot product
  8  - Trained with negative sampling (like Word2Vec)
  9
 10This learns semantic similarity between word pairs:
 11  "на" is close to "карте","столе","работе" in embedding space
 12  "двойными" is close to "буквами","пробелами"
 13
 14Export: word embeddings for top 50k words → C# does dot product at runtime
 15"""
 16import torch, torch.nn as nn, struct, random, os, sys, time, numpy as np
 17
 18EMBED_DIM = 64
 19VOCAB_SIZE = 15000
 20EPOCHS = 30
 21BATCH_SIZE = 2048
 22LR = 0.01
 23NEG_SAMPLES = 5
 24
 25DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'Data')
 26SAVE_PATH = os.path.join(DATA_DIR, 'context_nn.bin')
 27
 28def p(*a,**k): print(*a,**k); sys.stdout.flush()
 29
 30class ContextModel(nn.Module):
 31    def __init__(self, vocab_size, embed_dim):
 32        super().__init__()
 33        self.embed_in = nn.Embedding(vocab_size, embed_dim)   # context word
 34        self.embed_out = nn.Embedding(vocab_size, embed_dim)  # target word
 35        nn.init.xavier_uniform_(self.embed_in.weight)
 36        nn.init.xavier_uniform_(self.embed_out.weight)
 37
 38    def forward(self, context, target, negatives):
 39        # context: [batch]
 40        # target: [batch]
 41        # negatives: [batch, neg_samples]
 42        ctx = self.embed_in(context)          # [batch, embed]
 43        tgt = self.embed_out(target)          # [batch, embed]
 44        neg = self.embed_out(negatives)       # [batch, neg, embed]
 45
 46        # Positive score: dot(context, target)
 47        pos_score = (ctx * tgt).sum(dim=1)    # [batch]
 48        pos_loss = -torch.log(torch.sigmoid(pos_score) + 1e-8)
 49
 50        # Negative score
 51        neg_score = torch.bmm(neg, ctx.unsqueeze(2)).squeeze(2)  # [batch, neg]
 52        neg_loss = -torch.log(torch.sigmoid(-neg_score) + 1e-8).sum(dim=1)
 53
 54        return (pos_loss + neg_loss).mean()
 55
 56def main():
 57    p("Loading dictionary...")
 58    words = []
 59    with open(os.path.join(DATA_DIR, 'dict_ru.txt'), 'r', encoding='utf-8') as f:
 60        for line in f:
 61            w = line.strip().lower()
 62            if 1 <= len(w) <= 20:
 63                words.append(w)
 64            if len(words) >= VOCAB_SIZE:
 65                break
 66    p(f"Loaded {len(words)} words")
 67
 68    word2idx = {w: i for i, w in enumerate(words)}
 69
 70    # Build training pairs from common patterns + random cooccurrence
 71    p("Building training pairs...")
 72    pairs = []
 73    rng = random.Random(42)
 74
 75    # Manual high-quality pairs
 76    manual = [
 77        ("на","карте"),("на","столе"),("на","работе"),("на","улице"),("на","экране"),
 78        ("в","итоге"),("в","общем"),("в","жизни"),("в","школе"),("в","интернете"),
 79        ("я","думаю"),("я","хочу"),("я","знаю"),("я","боюсь"),("я","решил"),("я","могу"),
 80        ("не","могу"),("не","знаю"),("не","хочу"),("не","работает"),("не","понимаю"),
 81        ("но","я"),("но","это"),("но","всё"),("но","он"),
 82        ("что","это"),("что","он"),("что","делать"),("что","бы"),
 83        ("двойными","буквами"),("банковскую","карту"),("на","клавиатуре"),
 84        ("потому","что"),("так","как"),("для","того"),
 85        ("спать","лягу"),("книгу","почитаю"),("кнопку","бэкспейс"),
 86        ("очень","быстро"),("очень","сильно"),("очень","хочется"),
 87        ("нервы","сдали"),("жизнь","боль"),
 88        ("может","быть"),("всё","таки"),("всё","равно"),
 89        ("как","же"),("как","будто"),("как","раз"),
 90        ("средств","на"),("на","карту"),("на","счёт"),
 91        ("исправлять","текст"),("нажимать","кнопку"),
 92        ("написал","сообщение"),("купил","телефон"),
 93    ]
 94    for w1, w2 in manual:
 95        if w1 in word2idx and w2 in word2idx:
 96            for _ in range(500):  # repeat for weight
 97                pairs.append((word2idx[w1], word2idx[w2]))
 98
 99    # Random adjacent pairs from frequency list (simulate natural text)
100    prepositions = {"в","на","по","к","с","у","за","от","из","до","для","без","при"}
101    for _ in range(500000):
102        i1 = rng.randint(0, min(10000, len(words)-1))
103        i2 = rng.randint(0, min(20000, len(words)-1))
104        pairs.append((i1, i2))
105
106    p(f"Training pairs: {len(pairs)}")
107
108    dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
109    p(f"Device: {dev}")
110
111    model = ContextModel(len(words), EMBED_DIM).to(dev)
112    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
113
114    p(f"Training {EPOCHS} epochs, batch={BATCH_SIZE}, embed={EMBED_DIM}, neg={NEG_SAMPLES}")
115    t0 = time.time()
116
117    for epoch in range(EPOCHS):
118        rng.shuffle(pairs)
119        total_loss = 0
120        n_batches = 0
121
122        for i in range(0, len(pairs) - BATCH_SIZE, BATCH_SIZE):
123            batch = pairs[i:i+BATCH_SIZE]
124            ctx = torch.tensor([p2[0] for p2 in batch], device=dev)
125            tgt = torch.tensor([p2[1] for p2 in batch], device=dev)
126            neg = torch.randint(0, len(words), (len(batch), NEG_SAMPLES), device=dev)
127
128            loss = model(ctx, tgt, neg)
129            optimizer.zero_grad()
130            loss.backward()
131            optimizer.step()
132            total_loss += loss.item()
133            n_batches += 1
134
135        if epoch % 5 == 0 or epoch == EPOCHS - 1:
136            p(f"  Epoch {epoch}: loss={total_loss/n_batches:.4f} [{time.time()-t0:.0f}s]")
137
138    # Export embeddings
139    p("Exporting...")
140    model.eval()
141    emb_in = model.embed_in.weight.detach().cpu().numpy()  # context embeddings
142    emb_out = model.embed_out.weight.detach().cpu().numpy() # target embeddings
143
144    with open(SAVE_PATH, 'wb') as f:
145        f.write(struct.pack('ii', len(words), EMBED_DIM))
146        # Word strings
147        for w in words:
148            wb = w.encode('utf-8')
149            f.write(struct.pack('i', len(wb)))
150            f.write(wb)
151        # Context embeddings (for prev_word)
152        for i in range(len(words)):
153            for j in range(EMBED_DIM):
154                f.write(struct.pack('f', float(emb_in[i][j])))
155        # Target embeddings (for candidate word)
156        for i in range(len(words)):
157            for j in range(EMBED_DIM):
158                f.write(struct.pack('f', float(emb_out[i][j])))
159
160    p(f"Saved: {os.path.getsize(SAVE_PATH)//1024}KB")
161
162    # Test
163    p("\n=== Context Similarity Test ===")
164    def score(w1, w2):
165        if w1 not in word2idx or w2 not in word2idx: return 0
166        v1 = emb_in[word2idx[w1]]
167        v2 = emb_out[word2idx[w2]]
168        return float(np.dot(v1, v2))
169
170    test_pairs = [
171        ("на","карте"),("на","катере"),("на","карту"),
172        ("двойными","буквами"),("двойными","буями"),
173        ("но","боюсь"),("но","бюст"),
174        ("спать","лягу"),("спать","лгу"),
175        ("нажимать","кнопку"),("нажимать","провод"),
176        ("я","решил"),("я","постирать"),
177    ]
178    for w1, w2 in test_pairs:
179        s = score(w1, w2)
180        p(f"  score({w1}, {w2}) = {s:.3f}")
181
182    p("\nDone!")
183
184if __name__ == '__main__':
185    main()