1"""Re-export: retrain briefly (1 epoch) then export with numpy available.""" 2import torch, torch.nn as nn, struct, random, os, sys, time 3 4ALPHA=35; HIDDEN=192; EMBED_DIM=64; MAX_LEN=25; BATCH_SIZE=512 5DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'Data') 6SAVE_PATH = os.path.join(DATA_DIR, 'seq2spell_gpu.bin') 7 8def p(*a,**k): print(*a,**k); sys.stdout.flush() 9def char_to_idx(c): 10 if 'а' <= c <= 'я': return ord(c) - ord('а') 11 if c == 'ё': return 32 12 if c == ' ': return 33 13 return ALPHA - 1 14def idx_to_char(i): 15 if 0 <= i <= 31: return chr(ord('а') + i) 16 if i == 32: return 'ё' 17 if i == 33: return ' ' 18 return '' 19def corrupt(word, rng): 20 w = list(word) 21 if len(w) < 3: return word 22 for _ in range(rng.randint(1,3)): 23 op = rng.randint(0,3) 24 if op==0 and len(w)>1: i=rng.randint(0,len(w)-2); w[i],w[i+1]=w[i+1],w[i] 25 elif op==1: w[rng.randint(0,len(w)-1)]=chr(ord('а')+rng.randint(0,31)) 26 elif op==2 and len(w)>2: del w[rng.randint(0,len(w)-1)] 27 elif op==3: w.insert(rng.randint(0,len(w)),chr(ord('а')+rng.randint(0,31))) 28 return ''.join(w) 29def encode(word): 30 ids = [char_to_idx(c) for c in word[:MAX_LEN]]; ids.append(ALPHA-1); return ids 31 32class Model(nn.Module): 33 def __init__(self): 34 super().__init__() 35 self.embed = nn.Embedding(ALPHA, EMBED_DIM) 36 self.encoder = nn.GRU(EMBED_DIM, HIDDEN, num_layers=2, batch_first=True, dropout=0.1) 37 self.decoder = nn.GRU(EMBED_DIM, HIDDEN, num_layers=2, batch_first=True, dropout=0.1) 38 self.output = nn.Linear(HIDDEN, ALPHA) 39 def forward(self, src, tgt): 40 _, h = self.encoder(self.embed(src)) 41 return self.output(self.decoder(self.embed(tgt), h)[0]) 42 43p("Loading dict...") 44words = [] 45with open(os.path.join(DATA_DIR,'dict_ru.txt'),'r',encoding='utf-8') as f: 46 for l in f: 47 w=l.strip().lower() 48 if 2<=len(w)<=20: words.append(w) 49 if len(words)>=200000: break 50p(f"{len(words)} words") 51 52device = torch.device('cuda') 53model = Model().to(device) 54optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 55criterion = nn.CrossEntropyLoss() 56rng = random.Random(42) 57 58# Quick train 10 epochs 59p("Quick training 10 epochs on GPU...") 60for epoch in range(10): 61 model.train(); tl=0; cr=0; tot=0 62 for _ in range(200): 63 sb,ti,to_=[],[],[] 64 for _ in range(BATCH_SIZE): 65 tgt=words[rng.randint(0,len(words)-1)] 66 inp=corrupt(tgt,rng) if rng.random()<0.7 else tgt 67 s=encode(inp);t=encode(tgt) 68 while len(s)<MAX_LEN+1:s.append(ALPHA-1) 69 while len(t)<MAX_LEN+1:t.append(ALPHA-1) 70 sb.append(s[:MAX_LEN+1]);ti.append([ALPHA-1]+t[:MAX_LEN]);to_.append(t[:MAX_LEN+1]) 71 src=torch.tensor(sb,device=device);tin=torch.tensor(ti,device=device);tout=torch.tensor(to_,device=device) 72 logits=model(src,tin);loss=criterion(logits.view(-1,ALPHA),tout.view(-1)) 73 optimizer.zero_grad();loss.backward();torch.nn.utils.clip_grad_norm_(model.parameters(),1.0);optimizer.step() 74 tl+=loss.item();preds=logits.argmax(-1);mask=tout!=ALPHA-1;cr+=(preds==tout)[mask].sum().item();tot+=mask.sum().item() 75 p(f" Epoch {epoch}: acc={cr/tot*100:.1f}%") 76 77p("Exporting...") 78model.eval() 79import numpy as np 80with open(SAVE_PATH,'wb') as f: 81 f.write(struct.pack('iiii',ALPHA,EMBED_DIM,HIDDEN,MAX_LEN)) 82 emb=model.embed.weight.detach().cpu().numpy() 83 for i in range(ALPHA): 84 for j in range(EMBED_DIM): f.write(struct.pack('f',float(emb[i][j]))) 85 for name,param in model.encoder.named_parameters(): 86 if '_l0' in name: 87 data=param.detach().cpu().numpy().flatten() 88 f.write(struct.pack('i',len(data))) 89 for v in data: f.write(struct.pack('f',float(v))) 90 for name,param in model.decoder.named_parameters(): 91 if '_l0' in name: 92 data=param.detach().cpu().numpy().flatten() 93 f.write(struct.pack('i',len(data))) 94 for v in data: f.write(struct.pack('f',float(v))) 95 w=model.output.weight.detach().cpu().numpy();b=model.output.bias.detach().cpu().numpy() 96 for i in range(ALPHA): 97 for j in range(HIDDEN): f.write(struct.pack('f',float(w[i][j]))) 98 for i in range(ALPHA): f.write(struct.pack('f',float(b[i]))) 99p(f"Saved: {os.path.getsize(SAVE_PATH)//1024}KB") 100 101# Test 102p("\n=== Test ===") 103tests=["привет","прввет","компуктер","здраствуте","пшоел","кароче","тихналогия","мсыли","бсытро","написатьб"] 104for t in tests: 105 src=torch.tensor([encode(t)[:MAX_LEN+1]+[ALPHA-1]*(MAX_LEN+1-len(encode(t)))],device=device) 106 _,h=model.encoder(model.embed(src)) 107 result=[];ic=ALPHA-1 108 for _ in range(MAX_LEN): 109 it=torch.tensor([[ic]],device=device) 110 out,h=model.decoder(model.embed(it),h) 111 pred=model.output(out.squeeze(1)).argmax(-1).item() 112 if pred==ALPHA-1:break 113 ch=idx_to_char(pred) 114 if ch:result.append(ch) 115 ic=pred 116 p(f" {t} -> {''.join(result)}") 117p("Done!")