1#!/usr/bin/env python3 2"""Train GRU Seq2Seq spell corrector on CUDA (RTX 4080).""" 3import torch, torch.nn as nn, torch.optim as optim 4import struct, random, os, sys, time 5 6ALPHA = 35 7HIDDEN = 192 # bigger hidden for GPU 8EMBED_DIM = 64 # bigger embeddings 9MAX_LEN = 25 10EPOCHS = 50 # more epochs on GPU 11BATCH_SIZE = 512 # bigger batches for GPU 12LR = 0.001 13 14DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'Data') 15DICT_PATH = os.path.join(DATA_DIR, 'dict_ru.txt') 16SAVE_PATH = os.path.join(DATA_DIR, 'seq2spell_gpu.bin') 17 18def p(*a,**k): print(*a,**k); sys.stdout.flush() 19 20def char_to_idx(c): 21 if 'а' <= c <= 'я': return ord(c) - ord('а') 22 if c == 'ё': return 32 23 if c == ' ': return 33 24 return ALPHA - 1 25 26def idx_to_char(i): 27 if 0 <= i <= 31: return chr(ord('а') + i) 28 if i == 32: return 'ё' 29 if i == 33: return ' ' 30 return '' 31 32def corrupt(word, rng): 33 w = list(word) 34 if len(w) < 3: return word 35 n = rng.randint(1, 3) # 1-3 corruptions for harder training 36 for _ in range(n): 37 op = rng.randint(0, 3) 38 if op == 0 and len(w) > 1: 39 i = rng.randint(0, len(w)-2); w[i], w[i+1] = w[i+1], w[i] 40 elif op == 1: 41 w[rng.randint(0, len(w)-1)] = chr(ord('а') + rng.randint(0, 31)) 42 elif op == 2 and len(w) > 2: 43 del w[rng.randint(0, len(w)-1)] 44 elif op == 3: 45 w.insert(rng.randint(0, len(w)), chr(ord('а') + rng.randint(0, 31))) 46 return ''.join(w) 47 48def encode(word): 49 ids = [char_to_idx(c) for c in word[:MAX_LEN]] 50 ids.append(ALPHA - 1) 51 return ids 52 53class Model(nn.Module): 54 def __init__(self): 55 super().__init__() 56 self.embed = nn.Embedding(ALPHA, EMBED_DIM) 57 self.encoder = nn.GRU(EMBED_DIM, HIDDEN, num_layers=2, batch_first=True, dropout=0.1) 58 self.decoder = nn.GRU(EMBED_DIM, HIDDEN, num_layers=2, batch_first=True, dropout=0.1) 59 self.output = nn.Linear(HIDDEN, ALPHA) 60 def forward(self, src, tgt): 61 _, h = self.encoder(self.embed(src)) 62 logits = self.output(self.decoder(self.embed(tgt), h)[0]) 63 return logits 64 65def main(): 66 p("Loading dictionary...") 67 words = [] 68 with open(DICT_PATH, 'r', encoding='utf-8') as f: 69 for line in f: 70 w = line.strip().lower() 71 if 2 <= len(w) <= 20: words.append(w) 72 if len(words) >= 200000: break 73 p(f"Loaded {len(words)} words") 74 75 device = torch.device('cuda') 76 p(f"Device: {device} ({torch.cuda.get_device_name(0)})") 77 78 model = Model().to(device) 79 params = sum(p.numel() for p in model.parameters()) 80 p(f"Model params: {params:,}") 81 optimizer = optim.Adam(model.parameters(), lr=LR) 82 scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.5) 83 criterion = nn.CrossEntropyLoss(ignore_index=-1) 84 rng = random.Random(42) 85 86 p(f"Training {EPOCHS} epochs, batch={BATCH_SIZE}, hidden={HIDDEN}, embed={EMBED_DIM}, layers=2") 87 t0 = time.time() 88 89 for epoch in range(EPOCHS): 90 model.train() 91 total_loss = 0; correct = 0; total = 0 92 n_batches = 200 93 94 for _ in range(n_batches): 95 src_b, ti_b, to_b = [], [], [] 96 for _ in range(BATCH_SIZE): 97 tgt = words[rng.randint(0, len(words)-1)] 98 inp = corrupt(tgt, rng) if rng.random() < 0.7 else tgt # 70% corrupted 99 s = encode(inp); t = encode(tgt) 100 while len(s) < MAX_LEN+1: s.append(ALPHA-1) 101 while len(t) < MAX_LEN+1: t.append(ALPHA-1) 102 src_b.append(s[:MAX_LEN+1]) 103 ti_b.append([ALPHA-1] + t[:MAX_LEN]) 104 to_b.append(t[:MAX_LEN+1]) 105 106 src = torch.tensor(src_b, device=device) 107 ti = torch.tensor(ti_b, device=device) 108 to = torch.tensor(to_b, device=device) 109 110 logits = model(src, ti) 111 loss = criterion(logits.view(-1, ALPHA), to.view(-1)) 112 optimizer.zero_grad(); loss.backward() 113 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 114 optimizer.step() 115 116 total_loss += loss.item() 117 preds = logits.argmax(dim=-1) 118 mask = to != ALPHA - 1 119 correct += (preds == to)[mask].sum().item() 120 total += mask.sum().item() 121 122 scheduler.step() 123 acc = correct / total * 100 if total > 0 else 0 124 p(f" Epoch {epoch}: loss={total_loss/n_batches:.3f} acc={acc:.1f}% lr={scheduler.get_last_lr()[0]:.5f} [{time.time()-t0:.0f}s]") 125 126 # Export for C# (single-layer GRU format for simplicity) 127 p("Exporting weights...") 128 model.eval() 129 # For C# we export only the last layer of 2-layer GRU 130 with open(SAVE_PATH, 'wb') as f: 131 f.write(struct.pack('iiii', ALPHA, EMBED_DIM, HIDDEN, MAX_LEN)) 132 # Embedding 133 emb = model.embed.weight.detach().cpu().numpy() 134 for i in range(ALPHA): 135 for j in range(EMBED_DIM): 136 f.write(struct.pack('f', float(emb[i][j]))) 137 # Encoder GRU layer 1 weights 138 for name, param in model.encoder.named_parameters(): 139 if '_l0' in name: # only layer 0 140 data = param.detach().cpu().numpy().flatten() 141 f.write(struct.pack('i', len(data))) 142 for v in data: f.write(struct.pack('f', float(v))) 143 # Decoder GRU layer 1 weights 144 for name, param in model.decoder.named_parameters(): 145 if '_l0' in name: 146 data = param.detach().cpu().numpy().flatten() 147 f.write(struct.pack('i', len(data))) 148 for v in data: f.write(struct.pack('f', float(v))) 149 # Output 150 w = model.output.weight.detach().cpu().numpy() 151 b = model.output.bias.detach().cpu().numpy() 152 for i in range(ALPHA): 153 for j in range(HIDDEN): 154 f.write(struct.pack('f', float(w[i][j]))) 155 for i in range(ALPHA): 156 f.write(struct.pack('f', float(b[i]))) 157 158 fsize = os.path.getsize(SAVE_PATH) 159 p(f"Saved: {SAVE_PATH} ({fsize//1024}KB)") 160 161 # Test 162 p("\n=== Test ===") 163 tests = ["привет","прввет","компуктер","здраствуте","пшоел","кароче","тихналогия","написатьб","мсыли","бсытро"] 164 for t in tests: 165 src = torch.tensor([encode(t)[:MAX_LEN+1] + [ALPHA-1]*(MAX_LEN+1-len(encode(t)))], device=device) 166 _, h = model.encoder(model.embed(src)) 167 result = []; ic = ALPHA-1 168 for _ in range(MAX_LEN): 169 it = torch.tensor([[ic]], device=device) 170 out, h = model.decoder(model.embed(it), h) 171 pred = model.output(out.squeeze(1)).argmax(dim=-1).item() 172 if pred == ALPHA-1: break 173 ch = idx_to_char(pred) 174 if ch: result.append(ch) 175 ic = pred 176 p(f" {t} -> {''.join(result)}") 177 178 p("\nDone!") 179 180if __name__ == '__main__': 181 main()