1#!/usr/bin/env python3 2""" 3Train a character-level spelling corrector using PyTorch. 4Architecture: Encoder-Decoder GRU with attention. 5Exports weights to binary file for C# inference. 6""" 7import torch 8import torch.nn as nn 9import torch.optim as optim 10import struct 11import random 12import os 13import time 14 15# === Config === 16ALPHA = 35 # а-я(32) + ё + space + EOS 17HIDDEN = 96 18MAX_LEN = 25 19EMBED_DIM = 32 20EPOCHS = 20 21BATCH_SIZE = 128 22LR = 0.001 23 24DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'Data') 25DICT_PATH = os.path.join(DATA_DIR, 'dict_ru.txt') 26SAVE_PATH = os.path.join(DATA_DIR, 'seq2spell_gpu.bin') 27 28def char_to_idx(c): 29 if 'а' <= c <= 'я': return ord(c) - ord('а') 30 if c == 'ё': return 32 31 if c == ' ': return 33 32 return ALPHA - 1 # EOS/unknown 33 34def idx_to_char(i): 35 if 0 <= i <= 31: return chr(ord('а') + i) 36 if i == 32: return 'ё' 37 if i == 33: return ' ' 38 return '' 39 40def corrupt(word, rng): 41 w = list(word) 42 if len(w) < 3: return word 43 ops = rng.randint(0, 4) 44 if ops == 0 and len(w) > 1: # swap 45 p = rng.randint(0, len(w)-2) 46 w[p], w[p+1] = w[p+1], w[p] 47 elif ops == 1: # replace 48 p = rng.randint(0, len(w)-1) 49 w[p] = chr(ord('а') + rng.randint(0, 31)) 50 elif ops == 2 and len(w) > 2: # delete 51 del w[rng.randint(0, len(w)-1)] 52 elif ops == 3: # insert 53 w.insert(rng.randint(0, len(w)), chr(ord('а') + rng.randint(0, 31))) 54 elif ops == 4: # double corrupt 55 w = list(corrupt(''.join(w), rng)) 56 return ''.join(w) 57 58def encode_word(word, max_len=MAX_LEN): 59 ids = [char_to_idx(c) for c in word[:max_len]] 60 ids.append(ALPHA - 1) # EOS 61 return ids 62 63class Seq2SpellModel(nn.Module): 64 def __init__(self): 65 super().__init__() 66 self.embed = nn.Embedding(ALPHA, EMBED_DIM) 67 self.encoder = nn.GRU(EMBED_DIM, HIDDEN, batch_first=True) 68 self.decoder = nn.GRU(EMBED_DIM, HIDDEN, batch_first=True) 69 self.output = nn.Linear(HIDDEN, ALPHA) 70 71 def forward(self, src, tgt): 72 # Encode 73 src_emb = self.embed(src) 74 _, h = self.encoder(src_emb) 75 # Decode with teacher forcing 76 tgt_emb = self.embed(tgt) 77 dec_out, _ = self.decoder(tgt_emb, h) 78 logits = self.output(dec_out) 79 return logits 80 81import sys 82def flush_print(*args, **kwargs): 83 print(*args, **kwargs) 84 sys.stdout.flush() 85 86def main(): 87 flush_print("Loading dictionary...") 88 words = [] 89 with open(DICT_PATH, 'r', encoding='utf-8') as f: 90 for line in f: 91 w = line.strip().lower() 92 if 2 <= len(w) <= 20: 93 words.append(w) 94 if len(words) >= 100000: 95 break 96 flush_print(f"Loaded {len(words)} words") 97 98 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 99 flush_print(f"Device: {device}") 100 101 model = Seq2SpellModel().to(device) 102 optimizer = optim.Adam(model.parameters(), lr=LR) 103 criterion = nn.CrossEntropyLoss(ignore_index=-1) 104 rng = random.Random(42) 105 106 flush_print(f"Training {EPOCHS} epochs, batch={BATCH_SIZE}, hidden={HIDDEN}, embed={EMBED_DIM}") 107 start = time.time() 108 109 for epoch in range(EPOCHS): 110 model.train() 111 total_loss = 0 112 correct = 0 113 total = 0 114 n_batches = 500 115 116 for _ in range(n_batches): 117 # Build batch 118 src_batch = [] 119 tgt_in_batch = [] 120 tgt_out_batch = [] 121 122 for _ in range(BATCH_SIZE): 123 target = words[rng.randint(0, len(words)-1)] 124 inp = corrupt(target, rng) if rng.random() < 0.5 else target 125 126 src_ids = encode_word(inp) 127 tgt_ids = encode_word(target) 128 129 # Pad 130 while len(src_ids) < MAX_LEN + 1: src_ids.append(ALPHA - 1) 131 while len(tgt_ids) < MAX_LEN + 1: tgt_ids.append(ALPHA - 1) 132 133 src_batch.append(src_ids[:MAX_LEN+1]) 134 tgt_in_batch.append([ALPHA-1] + tgt_ids[:MAX_LEN]) # shift right 135 tgt_out_batch.append(tgt_ids[:MAX_LEN+1]) 136 137 src = torch.tensor(src_batch, device=device) 138 tgt_in = torch.tensor(tgt_in_batch, device=device) 139 tgt_out = torch.tensor(tgt_out_batch, device=device) 140 141 logits = model(src, tgt_in) 142 loss = criterion(logits.view(-1, ALPHA), tgt_out.view(-1)) 143 144 optimizer.zero_grad() 145 loss.backward() 146 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 147 optimizer.step() 148 149 total_loss += loss.item() 150 preds = logits.argmax(dim=-1) 151 mask = tgt_out != ALPHA - 1 152 correct += (preds == tgt_out)[mask].sum().item() 153 total += mask.sum().item() 154 155 acc = correct / total * 100 if total > 0 else 0 156 elapsed = time.time() - start 157 flush_print(f" Epoch {epoch}: loss={total_loss/n_batches:.3f} acc={acc:.1f}% [{elapsed:.0f}s]") 158 159 # === Export weights for C# inference === 160 flush_print("Exporting weights...") 161 model.eval() 162 163 with open(SAVE_PATH, 'wb') as f: 164 # Write dimensions 165 f.write(struct.pack('iiii', ALPHA, EMBED_DIM, HIDDEN, MAX_LEN)) 166 167 # Embedding weights [ALPHA, EMBED_DIM] 168 emb = model.embed.weight.detach().cpu().numpy() 169 for i in range(ALPHA): 170 for j in range(EMBED_DIM): 171 f.write(struct.pack('f', float(emb[i][j]))) 172 173 # Encoder GRU weights 174 for name, param in model.encoder.named_parameters(): 175 data = param.detach().cpu().numpy().flatten() 176 f.write(struct.pack('i', len(data))) 177 for v in data: 178 f.write(struct.pack('f', float(v))) 179 180 # Decoder GRU weights 181 for name, param in model.decoder.named_parameters(): 182 data = param.detach().cpu().numpy().flatten() 183 f.write(struct.pack('i', len(data))) 184 for v in data: 185 f.write(struct.pack('f', float(v))) 186 187 # Output linear [HIDDEN, ALPHA] + bias [ALPHA] 188 w = model.output.weight.detach().cpu().numpy() 189 b = model.output.bias.detach().cpu().numpy() 190 for i in range(ALPHA): 191 for j in range(HIDDEN): 192 f.write(struct.pack('f', float(w[i][j]))) 193 for i in range(ALPHA): 194 f.write(struct.pack('f', float(b[i]))) 195 196 fsize = os.path.getsize(SAVE_PATH) 197 flush_print(f"Saved: {SAVE_PATH} ({fsize//1024}KB)") 198 199 # Test 200 print("\n=== Test ===") 201 model.eval() 202 tests = ["привет", "прввет", "компуктер", "здраствуте", "пшоел", "кароче", "тихналогия", "написатьб", "решылб"] 203 for t in tests: 204 src = torch.tensor([encode_word(t)[:MAX_LEN+1] + [ALPHA-1]*(MAX_LEN+1-len(encode_word(t)))], device=device) 205 # Greedy decode 206 h = None 207 src_emb = model.embed(src) 208 _, h = model.encoder(src_emb) 209 210 result = [] 211 inp_char = ALPHA - 1 # start with EOS 212 for _ in range(MAX_LEN): 213 inp_t = torch.tensor([[inp_char]], device=device) 214 inp_emb = model.embed(inp_t) 215 out, h = model.decoder(inp_emb, h) 216 logits = model.output(out.squeeze(1)) 217 pred = logits.argmax(dim=-1).item() 218 if pred == ALPHA - 1: break # EOS 219 ch = idx_to_char(pred) 220 if ch: result.append(ch) 221 inp_char = pred 222 flush_print(f" {t} → {''.join(result)}") 223 224 flush_print("\nDone!") 225 226if __name__ == '__main__': 227 main()