windowcapture
исходный код / Tools/train_gpu.py

train_gpu.py

227 строк · 7,593 байт · модуль Tools
  1#!/usr/bin/env python3
  2"""
  3Train a character-level spelling corrector using PyTorch.
  4Architecture: Encoder-Decoder GRU with attention.
  5Exports weights to binary file for C# inference.
  6"""
  7import torch
  8import torch.nn as nn
  9import torch.optim as optim
 10import struct
 11import random
 12import os
 13import time
 14
 15# === Config ===
 16ALPHA = 35  # а-я(32) + ё + space + EOS
 17HIDDEN = 96
 18MAX_LEN = 25
 19EMBED_DIM = 32
 20EPOCHS = 20
 21BATCH_SIZE = 128
 22LR = 0.001
 23
 24DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'Data')
 25DICT_PATH = os.path.join(DATA_DIR, 'dict_ru.txt')
 26SAVE_PATH = os.path.join(DATA_DIR, 'seq2spell_gpu.bin')
 27
 28def char_to_idx(c):
 29    if 'а' <= c <= 'я': return ord(c) - ord('а')
 30    if c == 'ё': return 32
 31    if c == ' ': return 33
 32    return ALPHA - 1  # EOS/unknown
 33
 34def idx_to_char(i):
 35    if 0 <= i <= 31: return chr(ord('а') + i)
 36    if i == 32: return 'ё'
 37    if i == 33: return ' '
 38    return ''
 39
 40def corrupt(word, rng):
 41    w = list(word)
 42    if len(w) < 3: return word
 43    ops = rng.randint(0, 4)
 44    if ops == 0 and len(w) > 1:  # swap
 45        p = rng.randint(0, len(w)-2)
 46        w[p], w[p+1] = w[p+1], w[p]
 47    elif ops == 1:  # replace
 48        p = rng.randint(0, len(w)-1)
 49        w[p] = chr(ord('а') + rng.randint(0, 31))
 50    elif ops == 2 and len(w) > 2:  # delete
 51        del w[rng.randint(0, len(w)-1)]
 52    elif ops == 3:  # insert
 53        w.insert(rng.randint(0, len(w)), chr(ord('а') + rng.randint(0, 31)))
 54    elif ops == 4:  # double corrupt
 55        w = list(corrupt(''.join(w), rng))
 56    return ''.join(w)
 57
 58def encode_word(word, max_len=MAX_LEN):
 59    ids = [char_to_idx(c) for c in word[:max_len]]
 60    ids.append(ALPHA - 1)  # EOS
 61    return ids
 62
 63class Seq2SpellModel(nn.Module):
 64    def __init__(self):
 65        super().__init__()
 66        self.embed = nn.Embedding(ALPHA, EMBED_DIM)
 67        self.encoder = nn.GRU(EMBED_DIM, HIDDEN, batch_first=True)
 68        self.decoder = nn.GRU(EMBED_DIM, HIDDEN, batch_first=True)
 69        self.output = nn.Linear(HIDDEN, ALPHA)
 70
 71    def forward(self, src, tgt):
 72        # Encode
 73        src_emb = self.embed(src)
 74        _, h = self.encoder(src_emb)
 75        # Decode with teacher forcing
 76        tgt_emb = self.embed(tgt)
 77        dec_out, _ = self.decoder(tgt_emb, h)
 78        logits = self.output(dec_out)
 79        return logits
 80
 81import sys
 82def flush_print(*args, **kwargs):
 83    print(*args, **kwargs)
 84    sys.stdout.flush()
 85
 86def main():
 87    flush_print("Loading dictionary...")
 88    words = []
 89    with open(DICT_PATH, 'r', encoding='utf-8') as f:
 90        for line in f:
 91            w = line.strip().lower()
 92            if 2 <= len(w) <= 20:
 93                words.append(w)
 94            if len(words) >= 100000:
 95                break
 96    flush_print(f"Loaded {len(words)} words")
 97
 98    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 99    flush_print(f"Device: {device}")
100
101    model = Seq2SpellModel().to(device)
102    optimizer = optim.Adam(model.parameters(), lr=LR)
103    criterion = nn.CrossEntropyLoss(ignore_index=-1)
104    rng = random.Random(42)
105
106    flush_print(f"Training {EPOCHS} epochs, batch={BATCH_SIZE}, hidden={HIDDEN}, embed={EMBED_DIM}")
107    start = time.time()
108
109    for epoch in range(EPOCHS):
110        model.train()
111        total_loss = 0
112        correct = 0
113        total = 0
114        n_batches = 500
115
116        for _ in range(n_batches):
117            # Build batch
118            src_batch = []
119            tgt_in_batch = []
120            tgt_out_batch = []
121
122            for _ in range(BATCH_SIZE):
123                target = words[rng.randint(0, len(words)-1)]
124                inp = corrupt(target, rng) if rng.random() < 0.5 else target
125
126                src_ids = encode_word(inp)
127                tgt_ids = encode_word(target)
128
129                # Pad
130                while len(src_ids) < MAX_LEN + 1: src_ids.append(ALPHA - 1)
131                while len(tgt_ids) < MAX_LEN + 1: tgt_ids.append(ALPHA - 1)
132
133                src_batch.append(src_ids[:MAX_LEN+1])
134                tgt_in_batch.append([ALPHA-1] + tgt_ids[:MAX_LEN])  # shift right
135                tgt_out_batch.append(tgt_ids[:MAX_LEN+1])
136
137            src = torch.tensor(src_batch, device=device)
138            tgt_in = torch.tensor(tgt_in_batch, device=device)
139            tgt_out = torch.tensor(tgt_out_batch, device=device)
140
141            logits = model(src, tgt_in)
142            loss = criterion(logits.view(-1, ALPHA), tgt_out.view(-1))
143
144            optimizer.zero_grad()
145            loss.backward()
146            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
147            optimizer.step()
148
149            total_loss += loss.item()
150            preds = logits.argmax(dim=-1)
151            mask = tgt_out != ALPHA - 1
152            correct += (preds == tgt_out)[mask].sum().item()
153            total += mask.sum().item()
154
155        acc = correct / total * 100 if total > 0 else 0
156        elapsed = time.time() - start
157        flush_print(f"  Epoch {epoch}: loss={total_loss/n_batches:.3f} acc={acc:.1f}% [{elapsed:.0f}s]")
158
159    # === Export weights for C# inference ===
160    flush_print("Exporting weights...")
161    model.eval()
162
163    with open(SAVE_PATH, 'wb') as f:
164        # Write dimensions
165        f.write(struct.pack('iiii', ALPHA, EMBED_DIM, HIDDEN, MAX_LEN))
166
167        # Embedding weights [ALPHA, EMBED_DIM]
168        emb = model.embed.weight.detach().cpu().numpy()
169        for i in range(ALPHA):
170            for j in range(EMBED_DIM):
171                f.write(struct.pack('f', float(emb[i][j])))
172
173        # Encoder GRU weights
174        for name, param in model.encoder.named_parameters():
175            data = param.detach().cpu().numpy().flatten()
176            f.write(struct.pack('i', len(data)))
177            for v in data:
178                f.write(struct.pack('f', float(v)))
179
180        # Decoder GRU weights
181        for name, param in model.decoder.named_parameters():
182            data = param.detach().cpu().numpy().flatten()
183            f.write(struct.pack('i', len(data)))
184            for v in data:
185                f.write(struct.pack('f', float(v)))
186
187        # Output linear [HIDDEN, ALPHA] + bias [ALPHA]
188        w = model.output.weight.detach().cpu().numpy()
189        b = model.output.bias.detach().cpu().numpy()
190        for i in range(ALPHA):
191            for j in range(HIDDEN):
192                f.write(struct.pack('f', float(w[i][j])))
193        for i in range(ALPHA):
194            f.write(struct.pack('f', float(b[i])))
195
196    fsize = os.path.getsize(SAVE_PATH)
197    flush_print(f"Saved: {SAVE_PATH} ({fsize//1024}KB)")
198
199    # Test
200    print("\n=== Test ===")
201    model.eval()
202    tests = ["привет", "прввет", "компуктер", "здраствуте", "пшоел", "кароче", "тихналогия", "написатьб", "решылб"]
203    for t in tests:
204        src = torch.tensor([encode_word(t)[:MAX_LEN+1] + [ALPHA-1]*(MAX_LEN+1-len(encode_word(t)))], device=device)
205        # Greedy decode
206        h = None
207        src_emb = model.embed(src)
208        _, h = model.encoder(src_emb)
209
210        result = []
211        inp_char = ALPHA - 1  # start with EOS
212        for _ in range(MAX_LEN):
213            inp_t = torch.tensor([[inp_char]], device=device)
214            inp_emb = model.embed(inp_t)
215            out, h = model.decoder(inp_emb, h)
216            logits = model.output(out.squeeze(1))
217            pred = logits.argmax(dim=-1).item()
218            if pred == ALPHA - 1: break  # EOS
219            ch = idx_to_char(pred)
220            if ch: result.append(ch)
221            inp_char = pred
222        flush_print(f"  {t}{''.join(result)}")
223
224    flush_print("\nDone!")
225
226if __name__ == '__main__':
227    main()