windowcapture
исходный код / Tools/train_cuda.py

train_cuda.py

181 строк · 6,950 байт · модуль Tools
  1#!/usr/bin/env python3
  2"""Train GRU Seq2Seq spell corrector on CUDA (RTX 4080)."""
  3import torch, torch.nn as nn, torch.optim as optim
  4import struct, random, os, sys, time
  5
  6ALPHA = 35
  7HIDDEN = 192       # bigger hidden for GPU
  8EMBED_DIM = 64     # bigger embeddings
  9MAX_LEN = 25
 10EPOCHS = 50        # more epochs on GPU
 11BATCH_SIZE = 512   # bigger batches for GPU
 12LR = 0.001
 13
 14DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'Data')
 15DICT_PATH = os.path.join(DATA_DIR, 'dict_ru.txt')
 16SAVE_PATH = os.path.join(DATA_DIR, 'seq2spell_gpu.bin')
 17
 18def p(*a,**k): print(*a,**k); sys.stdout.flush()
 19
 20def char_to_idx(c):
 21    if 'а' <= c <= 'я': return ord(c) - ord('а')
 22    if c == 'ё': return 32
 23    if c == ' ': return 33
 24    return ALPHA - 1
 25
 26def idx_to_char(i):
 27    if 0 <= i <= 31: return chr(ord('а') + i)
 28    if i == 32: return 'ё'
 29    if i == 33: return ' '
 30    return ''
 31
 32def corrupt(word, rng):
 33    w = list(word)
 34    if len(w) < 3: return word
 35    n = rng.randint(1, 3)  # 1-3 corruptions for harder training
 36    for _ in range(n):
 37        op = rng.randint(0, 3)
 38        if op == 0 and len(w) > 1:
 39            i = rng.randint(0, len(w)-2); w[i], w[i+1] = w[i+1], w[i]
 40        elif op == 1:
 41            w[rng.randint(0, len(w)-1)] = chr(ord('а') + rng.randint(0, 31))
 42        elif op == 2 and len(w) > 2:
 43            del w[rng.randint(0, len(w)-1)]
 44        elif op == 3:
 45            w.insert(rng.randint(0, len(w)), chr(ord('а') + rng.randint(0, 31)))
 46    return ''.join(w)
 47
 48def encode(word):
 49    ids = [char_to_idx(c) for c in word[:MAX_LEN]]
 50    ids.append(ALPHA - 1)
 51    return ids
 52
 53class Model(nn.Module):
 54    def __init__(self):
 55        super().__init__()
 56        self.embed = nn.Embedding(ALPHA, EMBED_DIM)
 57        self.encoder = nn.GRU(EMBED_DIM, HIDDEN, num_layers=2, batch_first=True, dropout=0.1)
 58        self.decoder = nn.GRU(EMBED_DIM, HIDDEN, num_layers=2, batch_first=True, dropout=0.1)
 59        self.output = nn.Linear(HIDDEN, ALPHA)
 60    def forward(self, src, tgt):
 61        _, h = self.encoder(self.embed(src))
 62        logits = self.output(self.decoder(self.embed(tgt), h)[0])
 63        return logits
 64
 65def main():
 66    p("Loading dictionary...")
 67    words = []
 68    with open(DICT_PATH, 'r', encoding='utf-8') as f:
 69        for line in f:
 70            w = line.strip().lower()
 71            if 2 <= len(w) <= 20: words.append(w)
 72            if len(words) >= 200000: break
 73    p(f"Loaded {len(words)} words")
 74
 75    device = torch.device('cuda')
 76    p(f"Device: {device} ({torch.cuda.get_device_name(0)})")
 77
 78    model = Model().to(device)
 79    params = sum(p.numel() for p in model.parameters())
 80    p(f"Model params: {params:,}")
 81    optimizer = optim.Adam(model.parameters(), lr=LR)
 82    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.5)
 83    criterion = nn.CrossEntropyLoss(ignore_index=-1)
 84    rng = random.Random(42)
 85
 86    p(f"Training {EPOCHS} epochs, batch={BATCH_SIZE}, hidden={HIDDEN}, embed={EMBED_DIM}, layers=2")
 87    t0 = time.time()
 88
 89    for epoch in range(EPOCHS):
 90        model.train()
 91        total_loss = 0; correct = 0; total = 0
 92        n_batches = 200
 93
 94        for _ in range(n_batches):
 95            src_b, ti_b, to_b = [], [], []
 96            for _ in range(BATCH_SIZE):
 97                tgt = words[rng.randint(0, len(words)-1)]
 98                inp = corrupt(tgt, rng) if rng.random() < 0.7 else tgt  # 70% corrupted
 99                s = encode(inp); t = encode(tgt)
100                while len(s) < MAX_LEN+1: s.append(ALPHA-1)
101                while len(t) < MAX_LEN+1: t.append(ALPHA-1)
102                src_b.append(s[:MAX_LEN+1])
103                ti_b.append([ALPHA-1] + t[:MAX_LEN])
104                to_b.append(t[:MAX_LEN+1])
105
106            src = torch.tensor(src_b, device=device)
107            ti = torch.tensor(ti_b, device=device)
108            to = torch.tensor(to_b, device=device)
109
110            logits = model(src, ti)
111            loss = criterion(logits.view(-1, ALPHA), to.view(-1))
112            optimizer.zero_grad(); loss.backward()
113            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
114            optimizer.step()
115
116            total_loss += loss.item()
117            preds = logits.argmax(dim=-1)
118            mask = to != ALPHA - 1
119            correct += (preds == to)[mask].sum().item()
120            total += mask.sum().item()
121
122        scheduler.step()
123        acc = correct / total * 100 if total > 0 else 0
124        p(f"  Epoch {epoch}: loss={total_loss/n_batches:.3f} acc={acc:.1f}% lr={scheduler.get_last_lr()[0]:.5f} [{time.time()-t0:.0f}s]")
125
126    # Export for C# (single-layer GRU format for simplicity)
127    p("Exporting weights...")
128    model.eval()
129    # For C# we export only the last layer of 2-layer GRU
130    with open(SAVE_PATH, 'wb') as f:
131        f.write(struct.pack('iiii', ALPHA, EMBED_DIM, HIDDEN, MAX_LEN))
132        # Embedding
133        emb = model.embed.weight.detach().cpu().numpy()
134        for i in range(ALPHA):
135            for j in range(EMBED_DIM):
136                f.write(struct.pack('f', float(emb[i][j])))
137        # Encoder GRU layer 1 weights
138        for name, param in model.encoder.named_parameters():
139            if '_l0' in name:  # only layer 0
140                data = param.detach().cpu().numpy().flatten()
141                f.write(struct.pack('i', len(data)))
142                for v in data: f.write(struct.pack('f', float(v)))
143        # Decoder GRU layer 1 weights
144        for name, param in model.decoder.named_parameters():
145            if '_l0' in name:
146                data = param.detach().cpu().numpy().flatten()
147                f.write(struct.pack('i', len(data)))
148                for v in data: f.write(struct.pack('f', float(v)))
149        # Output
150        w = model.output.weight.detach().cpu().numpy()
151        b = model.output.bias.detach().cpu().numpy()
152        for i in range(ALPHA):
153            for j in range(HIDDEN):
154                f.write(struct.pack('f', float(w[i][j])))
155        for i in range(ALPHA):
156            f.write(struct.pack('f', float(b[i])))
157
158    fsize = os.path.getsize(SAVE_PATH)
159    p(f"Saved: {SAVE_PATH} ({fsize//1024}KB)")
160
161    # Test
162    p("\n=== Test ===")
163    tests = ["привет","прввет","компуктер","здраствуте","пшоел","кароче","тихналогия","написатьб","мсыли","бсытро"]
164    for t in tests:
165        src = torch.tensor([encode(t)[:MAX_LEN+1] + [ALPHA-1]*(MAX_LEN+1-len(encode(t)))], device=device)
166        _, h = model.encoder(model.embed(src))
167        result = []; ic = ALPHA-1
168        for _ in range(MAX_LEN):
169            it = torch.tensor([[ic]], device=device)
170            out, h = model.decoder(model.embed(it), h)
171            pred = model.output(out.squeeze(1)).argmax(dim=-1).item()
172            if pred == ALPHA-1: break
173            ch = idx_to_char(pred)
174            if ch: result.append(ch)
175            ic = pred
176        p(f"  {t} -> {''.join(result)}")
177
178    p("\nDone!")
179
180if __name__ == '__main__':
181    main()